Skip to content
Prev Previous commit
Next Next commit
chore: updates
  • Loading branch information
peri044 committed Feb 21, 2025
commit 684f424ad8abe9b5f9c3b2145d9ab00720c1f7ca
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
aten.addcmul,
aten.addcmul_,
aten.addr,
aten.addmm,
aten.aminmax,
aten.arange.default,
aten.arange.start,
Expand Down
25 changes: 0 additions & 25 deletions py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,6 @@
logger = logging.getLogger(__name__)


def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
target = torch.ops.aten.addmm.default
addmm_nodes = [node for node in gm.graph.nodes if node.target == target]
for addmm_node in addmm_nodes:
bias, mat1, mat2 = addmm_node.all_input_nodes

with gm.graph.inserting_before(addmm_node):
mm_node = gm.graph.call_function(
torch.ops.aten.mm.default,
args=(mat1, mat2),
)
add_node = gm.graph.call_function(
torch.ops.aten.add.Tensor,
args=(bias, mm_node),
)

addmm_node.replace_all_uses_with(add_node, propagate_meta=True)
gm.graph.erase_node(addmm_node)

return gm


def accumulate_fp32_matmul(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
Expand All @@ -41,9 +19,6 @@ def accumulate_fp32_matmul(
torch.ops.aten.bmm.default,
]

# Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes
split_addmm_nodes(gm)

matmul_nodes = [
node for node in gm.graph.nodes if node.target in matmul_targets
]
Expand Down
Loading