There was an error while loading. Please reload this page.
1 parent 1e473ed commit e9796e1Copy full SHA for e9796e1
torchao/quantization/pt2e/qat_utils.py
@@ -888,7 +888,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
888
and node.args[0].op == "get_attr"
889
and node.args[1] == 1
890
and torch.nn.modules.batchnorm.BatchNorm2d
891
- in [val[1] for val in node.meta["source_fn_stack"]]
+ in [val[1] for _, val in node.meta["nn_module_stack"].items()]
892
):
893
m.graph.erase_node(node)
894
0 commit comments