Skip to content

Conversation

@zpcore
Copy link
Member

@zpcore zpcore commented Aug 29, 2024

Fix inplace copy that extra mark_step will be conducted.

@torch.compile(backend='openxla') def cc(arg0_1): x = torch.randn([1]) copy = torch.ops.aten.copy.default(arg0_1, x) return copy 
@zpcore zpcore added the dynamo label Aug 29, 2024
@JackCaoG
Copy link
Collaborator

The issue is that InputCollector might also trigger inplace ops, we just need to clear it one more time

--- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -723,6 +723,18 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): return extract_compiled_graph_helper(xla_model, xla_args) +def _clear_pending_irs_on_args(args_tensor_only, cloned_args): + # if args_tensor_only has pending IR which means there is a in place operations + # happened. We don't want to execute that operation yet, so we will replace the + # pending IR with the cloned arg. + args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( + args_tensor_only) + + for i, need_update in enumerate(args_need_update_bool): + if need_update and isinstance(args_tensor_only[i], torch.Tensor): + args_tensor_only[i].copy_(cloned_args[i]) + + def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args, all_xla_args_tensor_only): # below logic will try to partition the fx graph based on the fallback ops. @@ -739,18 +751,8 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args, print('Dynamo fallback ops are' + str(unsupported_nodes) + '. Please open a GitHub issue with the above op lowering requests.') - # This logic, needed for supporting in-place operations, is a duplicate of - # the one in the main `extract_internal` function above. We need to do this - # check for fetching fallback ops as well. - # TODO (@wonjoo): Make this duplicate code a bit cleaner. - args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( - all_xla_args_tensor_only) - - # Again, same logic in the `extract_internal` above to support in-place operations. - # TODO (@wonjoo): Make this duplicate code a bit cleaner. - for i, need_update in enumerate(args_need_update_bool): - if need_update and isinstance(all_xla_args_tensor_only[i], torch.Tensor): - all_xla_args_tensor_only[i].copy_(cloned_args[i]) + # UnsupportedNodesCollector might trigger in place ops, need to clear them here. + _clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args) torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) @@ -775,6 +777,9 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args, partitioned_graph = partitioner.fuse_partitions(partitions) InputCollector(partitioned_graph).run(*xla_args) + # InputCollector might trigger in place ops, need to clear them here. + _clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args) + # compile each submodule and replace it with a call for node in partitioned_graph.graph.nodes: if node.op == "call_module" and "fused_" in node.name: 

the test I used

 @torch.compile(backend='openxla') def cc(arg0_1): new_arg = arg0_1 * 2 copy = torch.ops.aten.copy.default(arg0_1, new_arg) return copy device = torch_xla.device() input = torch.randn([1], device=device) print(input) res = cc(input) print(res) 

You can add it to test_dynamo.py and make sure that *2 only got executed once

@zpcore
Copy link
Member Author

zpcore commented Aug 30, 2024

Thanks for the help, this works now!

@zpcore zpcore marked this pull request as ready for review August 31, 2024 20:38
@zpcore zpcore requested a review from JackCaoG September 3, 2024 16:24
@zpcore zpcore merged commit 989ac69 into master Sep 3, 2024
@zpcore zpcore deleted the piz/inplace-cp branch September 3, 2024 18:21
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Oct 11, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2 participants