Skip to content

Conversation

@ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented Feb 23, 2024

For fori_loop implementation with while_loop, this PR is for lowering body/cond to replace formal placeholder

This is the step two PR, and father PR(#6532), child PR(#6529), source PR(#6563)


some issue fixed:

  • (before)body fn is -, not a torch func, will test later (after)tried torch.sub(a, b), passed too locally
  • current code has changed many logic of lowering, let's move these logics to a new function without affecting the existing functions
  • (before)input are limited to list/tuple (after)this match torch._higher_order_ops.while_loop required
  • (before)input was trans from list to not list after torch.compile, TODO, add the same logic like torch.compile to use inputs, not like currently create a duplicated tensor in the fori_loop.py file
@ManfeiBai ManfeiBai changed the title Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py [Do Not Merge] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Feb 23, 2024
@ManfeiBai ManfeiBai changed the title [Do Not Merge] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py [Do Not Merge][Fori_loop|While_loop] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Feb 24, 2024
if (!root_tuple_.empty() & (root_tuple_.size() > 1)) {
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
xla = builder()->Build(root);
} else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) {
Copy link
Collaborator Author

@ManfeiBai ManfeiBai Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain: we need to skip tuple for cond/body computation creation to match the xla::While format check for cond, error log

@ManfeiBai ManfeiBai marked this pull request as ready for review February 26, 2024 19:34
@ManfeiBai
Copy link
Collaborator Author

ManfeiBai commented Feb 26, 2024

Hi, @JackCaoG, since this PR would add new function to PyLoweringContext, do we want to request review from aws too?

@ManfeiBai ManfeiBai changed the title [Do Not Merge][Fori_loop|While_loop] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py [Fori_loop|While_loop] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Feb 26, 2024
@JackCaoG
Copy link
Collaborator

@amithrm FYI

@ManfeiBai
Copy link
Collaborator Author

kokoro failure should be fixed on master branch, let's skip it now

@ManfeiBai ManfeiBai requested a review from yeounoh February 29, 2024 18:07
Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some suggestions

@ManfeiBai ManfeiBai requested a review from yeounoh February 29, 2024 22:27
@ManfeiBai ManfeiBai force-pushed the ManfeiBai-patch-73 branch from c35ac53 to c7f09d5 Compare March 4, 2024 17:47
&PyLoweringContext::GetParameterIdTensorMapping)
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId);
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId)
.def("set_name_string", &PyLoweringContext::SetNameString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, thanks!

if (!root_tuple_.empty() & (root_tuple_.size() > 1)) {
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
xla = builder()->Build(root);
} else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should condition on get_name_string(). Add this check at the top and build for while loop `if get_name_string() == "condctx" or get_name_string() == "bodyctx"; otherwise, you can keep the original build logic.

Have your logic for while loop build in a separate private method, and call it if ``if get_name_string() == "condctx" or get_name_string() == "bodyctx"` is true.

So you can keep BuildXla() simple.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, make sense and updated in the newest commit, due to the logic for while loop is one simple line code, we run it directly without warping it in a separate private method

@ManfeiBai ManfeiBai force-pushed the ManfeiBai-patch-73 branch from 96ce688 to 173ff44 Compare March 8, 2024 18:44
@ManfeiBai ManfeiBai merged commit 6170df5 into master Mar 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

5 participants