Skip to content

Conversation

@ysiraichi
Copy link
Collaborator

Fix: #7083

This PR adds data-type promotion to stack operation. Previously, there was none. So, the kernel implicitly expected the arguments to be of the same data-type. This might not be the case when using AMP.

cc @miladm @JackCaoG

@ysiraichi ysiraichi requested a review from JackCaoG May 22, 2024 00:42
at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
at::ScalarType result_type = at::native::result_type(tensors);
std::vector<at::Tensor> c_tensors(tensors.size());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is stack expecting input tensor to be CPU? std::vector<at::Tensor> c_tensors will return a list of tenosrs on CPU right?

Copy link
Collaborator Author

@ysiraichi ysiraichi May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Unless I'm missing something, they are casted tensors, on XLA.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I am abit confused. Reading your code, you init the c_tensors vector which I assume they will be cpu tensors since you didn;t provide the device type. In the later code you only update the dtype of these c_tensors, I don't know when are they moved to the XLA device.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a summary of what this code is doing: considering the arguments tensors (a list of XLA tensors) and dim, the function:

  1. Computes the common data-type of all tensors: result_type
  2. Converts each tensor to the common data-type, storing the result in c_tensors (as in "cast tensors")
  3. Calls tensor_methods::stack with the casted tensors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. transform is called with tensors.begin()..

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-stack-dtype-promotion branch from 57352fb to 5fbcdd9 Compare May 22, 2024 14:18
@ysiraichi ysiraichi requested a review from JackCaoG May 22, 2024 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2 participants