Skip to content

Conversation

@ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented Feb 19, 2024

follow up Reenable PR for #6537 and #6554


passed local test:

# pytest test/test_core_aten_ops.py -k test_aten_replication_pad3d_0 =============================== test session starts =============================== platform linux -- Python 3.10.13, pytest-8.0.0, pluggy-1.4.0 rootdir: /root/pytorch configfile: pytest.ini plugins: hypothesis-6.97.4 collected 499 items / 498 deselected / 1 selected test/test_core_aten_ops.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1708454923.760699 843883 pjrt_api.cc:100] GetPjrtApi was found for tpu at /root/miniconda3/envs/torch310/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1708454923.760777 843883 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1708454923.760792 843883 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40. . [100%] ======================== 1 passed, 498 deselected in 6.46s ======================== # pytest test/test_core_aten_ops.py -k test_aten_replication_pad3d_1 =============================== test session starts =============================== platform linux -- Python 3.10.13, pytest-8.0.0, pluggy-1.4.0 rootdir: /root/pytorch configfile: pytest.ini plugins: hypothesis-6.97.4 collected 499 items / 498 deselected / 1 selected test/test_core_aten_ops.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1708454989.804666 845455 pjrt_api.cc:100] GetPjrtApi was found for tpu at /root/miniconda3/envs/torch310/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1708454989.804751 845455 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1708454989.804773 845455 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40. . [100%] ======================== 1 passed, 498 deselected in 6.60s ======================== 

print metric locally: https://gist.github.com/ManfeiBai/e661eab6fae8a10a1828369a2a016b8e

@ManfeiBai ManfeiBai changed the title Update aten_xla_type.cpp lower replication_pad3d and replication_pad3d_backward Feb 19, 2024
@ManfeiBai ManfeiBai requested a review from wonjoo-wj February 20, 2024 21:58
@ManfeiBai ManfeiBai marked this pull request as ready for review February 20, 2024 21:58
@ManfeiBai
Copy link
Collaborator Author

ManfeiBai commented Feb 20, 2024

CI failed due to ERROR: Failed to query remote execution capabilities: Unexpected error refreshing access token, local test passed and pasted in description

Copy link
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @ManfeiBai! I didn't know we can directly re-use the existing replication pad lowering logic, but it seems like we can. Nice!

Also why did we lower the _backward variant? Is it required?


XLATensorPtr replication_pad3d(const XLATensorPtr& input,
std::vector<int64_t> padding);
XLATensorPtr replication_pad3d_backward(const XLATensorPtr& grad_output,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should add empty newline here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Wonjoo, according to https://github.com/pytorch/xla/blob/0192ff75324d51d748d76f7717bbccabc15d1db8/torch_xla/csrc/tensor_methods.h#L729C1-L745C71, do we want to keep the same style like 1d and 2d without empty newline here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, then let's just keep it as is. Thanks!

@ManfeiBai
Copy link
Collaborator Author

_backward

yes, we don't have test in test/test_core_aten_ops.py for backward, we lower the _backward variant due to replication_pad1d and replication_pad2d also did that, do I add the same for replication_pad3d and the matched CPP tests here too

Copy link
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, thanks! Feel free to include the newline change with another PR, don't want to block this PR for a nit comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

3 participants