-
Couldn't load subscription status.
- Fork 560
[Fori_loop|While_loop] Enable fori_loop with add/sub test case #6603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
7a8f9f5 to c0d3359 Compare torch_xla/csrc/lowering_context.cpp Outdated
| if (!root_tuple_.empty() & (root_tuple_.size() > 1)) { | ||
| xla::XlaOp root = xla::Tuple(builder(), root_tuple_); | ||
| xla = builder()->Build(root); | ||
| } else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain: we need to skip tuple for cond/body computation creation to match the xla::While format check for cond, error log
| Hi, @JackCaoG, since this PR would add new function to |
| @amithrm FYI |
| kokoro failure should be fixed on master branch, let's skip it now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some suggestions
c35ac53 to c7f09d5 Compare | &PyLoweringContext::GetParameterIdTensorMapping) | ||
| .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId); | ||
| .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId) | ||
| .def("set_name_string", &PyLoweringContext::SetNameString) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good, thanks!
torch_xla/csrc/lowering_context.cpp Outdated
| if (!root_tuple_.empty() & (root_tuple_.size() > 1)) { | ||
| xla::XlaOp root = xla::Tuple(builder(), root_tuple_); | ||
| xla = builder()->Build(root); | ||
| } else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should condition on get_name_string(). Add this check at the top and build for while loop `if get_name_string() == "condctx" or get_name_string() == "bodyctx"; otherwise, you can keep the original build logic.
Have your logic for while loop build in a separate private method, and call it if ``if get_name_string() == "condctx" or get_name_string() == "bodyctx"` is true.
So you can keep BuildXla() simple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, make sense and updated in the newest commit, due to the logic for while loop is one simple line code, we run it directly without warping it in a separate private method
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Outdated Show resolved Hide resolved
96ce688 to 173ff44 Compare
For fori_loop implementation with while_loop, this PR is for lowering body/cond to replace formal placeholder
This is the step two PR, and father PR(#6532), child PR(#6529), source PR(#6563)
some issue fixed:
(before)body fn is(after)tried torch.sub(a, b), passed too locally-, not a torch func, will test later(before)input are limited to list/tuple(after)this match torch._higher_order_ops.while_loop required(before)input was trans from list to not list after torch.compile, TODO, add the same logic like torch.compile to use inputs, not like currently create a duplicated tensor in the fori_loop.py file