Add fx passes to support unbounded dynamism in torch op arg #6653
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Add FX passes to support dynamism on torch ops that take symbolic dim size as parameters (e.g. `aten.view(x, [sym_size, -1]). The FX passes groups the generation of symbolic dim size and the torch op into a new XLA op. Grouping them together is necessary to lowering these ops with dynamism, because the operations on symbolic size cannot be traced in LTC (context in #6393). Once the source of the symbolic size and the consuming torch ops are captured in a single XLA op, it becomes feasible to lowered to HLO/StableHLO with dynamism semantics.
The FX passes will run automatically if the exported program has symbolic shape input.
The following ops are fused into XLA ops:
sym_size.int+aten.expand=>xla.dynamic_expandsym_size.int+ (mul) +aten.view=>xla.dynamic_viewSome torch ops are generating
aten.viewwith symbolic dim size during decomposition in upstream PyTorch, or existing lowering logic in torch_xla. FX passes are introduce to handle these ops as well.aten.native_layer_normaten.group_normaten.selectaten.unsqueezeOther changes:
rsubmeanandvartorch_xla/test/stablehlo/utils.pytotorch_xla/utils/stablehlo_test_utils.py, to fix path not found issue when running tests with pytest.Test: