Skip to content

Conversation

@vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Jun 12, 2023

This PR is a simplified version of #4998. This PR is intended to set has_symbolic_sizes_strides properly on XLATensorImpl. The ultimate goal is to make sure it throws an error when we call tensor.sizes() on a tensor with dynamic dimension. This PR should work with the upstream pr to achieve that goal.

Note that pytorch/xla has already checked has_symbolic_sizes_strides_ at XLATensorImpl::sizes_custom() (code). All remaining work is to set has_symbolic_sizes_strides_ properly.

The problem right now is that torch.fill_ started to call XLANativeFunctions::empty_strided_symint. But before this PR and the pytorch pr pytorch/pytorch#101634, doing torch.fill_ did not call XLANativeFunctions::empty_strided_symint

@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Jun 12, 2023

@JackCaoG , per our discussion, I tried your suggestion as in this PR but it didn't work: pytorch/xla/test/ds/test_dynamic_shapes.py TestDynamicShapes.test_fill_ still crashes. All this PR does is to set has_symbolic_sizes_strides on an XlaTensorImpl, which cause PyTorch from calling return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset()) to calling return at::detail::computeStorageNbytes(value.sym_sizes(), value.sym_strides(), value.dtype().itemsize(), value.sym_storage_offset()); (pytorch code).

But I did find something strange. Before this change, neither XLANativeFunctions::empty_strided_symint nor XLANativeFunctions::new_empty_strided_symint is called. But with the change in this PR, XLANativeFunctions::empty_strided_symint is called all of sudden.

tensor_(c10::make_intrusive<XLATensor>(std::move(tensor))) {
is_non_overlapping_and_dense_ = false;
const_cast<XLATensorImpl*>(this)->SetupSizeProperties();
set_sizes_and_strides(c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size()), c10::fromIntArrayRefSlow(strides_default()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this call for?

Copy link
Collaborator Author

@vanbasten23 vanbasten23 Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does has_symbolic_sizes_strides_ = true;:
https://github.com/pytorch/pytorch/blob/2c313e7b99b6070a3a7d640e4bc8bf2fe3acbcf1/c10/core/TensorImpl.cpp#L1152

And both pytorch and pt/xla check if has_symbolic_sizes_strides_ = false;(pytorch code) when doing tensor.size()

@vanbasten23
Copy link
Collaborator Author

hi @ezyang , can you please take a look at this PR and see if my change makes sense, along with the pytorch change?

Also, in our meeting, you mentioned that we can check if the stride is contiguous. If so, we can ask XLANativeFunctions::empty_strided_symint to forward the call to XLANativeFunctions::empty_symint. Could you point me to the implementation about how pytorch implement this check? I found a is_contiguous_strides (code) but it only support static size while in our case the size is dynamic and stride is static.

Thanks!

@vanbasten23 vanbasten23 requested a review from ezyang June 13, 2023 21:22
@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Jun 14, 2023

FWIW, XLANativeFunctions::empty_strided_symint is called from pytorch/aten/src/ATen/native/TensorFactories.cpp at here.

Not sure why self.is_non_overlapping_and_dense() would evaluate to true in the above place, because torch_xla set XLATensorImpl.is_non_overlapping_and_dense_ to false (here). Let me check more.

Also, I summarized the problem in the pr description.

@vanbasten23
Copy link
Collaborator Author

I figured out why torch.fill_ starts to call XLANativeFunctions::empty_strided_symint after this PR, while before torch.fill_ didn't call into this function. Here is why:

With this PR, empty_like here in TensorFactories.cpp evaluate self.is_non_overlapping_and_dense() to be true hence starts calling at::empty_strided_symint. Before the PR, self.is_non_overlapping_and_dense() evaluates to false and calls at::empty_symint here. The reason why self.is_non_overlapping_and_dense() returns a different value is because at https://github.com/pytorch/pytorch/blob/1985c490fe039278f14a874485118284420421db/c10/core/TensorImpl.h#L860-L867, this PR set has_symbolic_sizes_strides_ correctly to true for dynamic tensors while before the PR has_symbolic_sizes_strides_ is default to false. That's why before and after the PR, the code choose different branch in this function.

So should self.is_non_overlapping_and_dense() return false for torch_xla? Per our discussion with @JackCaoG and @miladm , it seems it should return true, even though currently it's hardcoded to be false (code). See the conversation in this original PR. The above PR should've been reverted but we never got to do that.

@vanbasten23
Copy link
Collaborator Author

hi @ezyang , in our meeting, you mentioned if the stride is contiguous, we can skip calling XLANativeFunctions::empty_strided_symint and call XLANativeFunction::empty_symint instead here. Can you point to us how PyTorch checks if a stride is contiguous? I found an is_contiguous_strides implementation(code) but it only support static size/stride while in our case the size is dynamic and stride is static.

@ezyang
Copy link
Collaborator

ezyang commented Jun 20, 2023

Oh oops, I spoke too loosely. What you are hoping for is the call site calls empty directly, instead empty_strided with contiguous inputs. You usually have to fix the call site if it is messing this up, but the only place I really remember this being troublesome is pointwise TensorIterator ops

@vanbasten23
Copy link
Collaborator Author

Oh oops, I spoke too loosely. What you are hoping for is the call site calls empty directly, instead empty_strided with contiguous inputs. You usually have to fix the call site if it is messing this up, but the only place I really remember this being troublesome is pointwise TensorIterator ops

Thanks @ezyang for the response. I did check the call site. The reason why the change in this PR starts calling empty_strided instead of the previous empty_symint is that tensor.is_non_overlapping_and_dense() start evaluating to true (call site in pytorch). The reason for that is: tensorImpl.is_non_overlapping_and_dense depends on has_symbolic_sizes_strides_: code. That said, if this PR correctly set has_symbolic_sizes_strides_ to true, is it expected for pytorch to call empty_strided?

On a side note, we found this PR #2682 which incorrectly set is_non_overlapping_and_dense_ to false for XLATensorImpl. If reverting the above PR is the correct thing to do, due to this logic, it seems is_non_overlapping_and_dense should return true and pytorch is expected to call empty_stride.

With my hypothesis, do you think the call site is doing the right thing calling the empty_stride?

@ezyang
Copy link
Collaborator

ezyang commented Jun 21, 2023

Interesting. It seems like we have never handled this correctly. A PyTorch change which does something like pytorch/pytorch#95003 should work.

@JackCaoG
Copy link
Collaborator

Sorry I am a bit lost here. @ezyang the pr you pointed to was using the memory_format to set is_non_overlapping_and_dense_ which make sense, currently this is_non_overlapping_and_dense_ is always set to false by

is_non_overlapping_and_dense_ = false;
which is incorrect.

My question is

  1. where should we insert the logic to set is_non_overlapping_and_dense_ based on memory_format? or that logic should live upstream
  2. Given that XLATensor should always be continuous(based on https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_impl.cpp#L162-L166) is there any shortcut we can take?

@vanbasten23 also tried to undo the logic to overwrite is_non_overlapping_and_dense_ in #5215 and it failed some dynamic shape. @ezyang Correct me if I am wrong, I believe the right thing to do is not to overwrite is_non_overlapping_and_dense_ to false always, that's a short term fix for Ailing.

@ezyang
Copy link
Collaborator

ezyang commented Jun 23, 2023

It is is a bit painful to give guidance here, because I only really understand how everything is supposed to work in a universe where XLATensor DOES support full strides (and you rely on functionalization to undo the strides before they get to XLA.) Then your problems resolve to an easier problem, which is whenever there is a direct call to empty_strided, is there a simpler call we can make that doesn't require performing compute on unbacked symints. In the case of empty_like, when the input is non-overlapping and dense, we can just directly paste the sizes/strides/contiguity fields from the input tensor, because we are guaranteed to preserve everything, and we don't have to worry about a contiguity recompute poking one of the fields wrong.

But you're in some weird halfway state. I agree that it is probably a good idea to revert Ailing's change but I don't know how complicated it is to do so.

To answer this question:

where should we insert the logic to set is_non_overlapping_and_dense_ based on memory_format? or that logic should live upstream

I think it needs to live upstream in empty_like.

Given that XLATensor should always be continuous(based on https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_impl.cpp#L162-L166) is there any shortcut we can take?

If the XLA tensor is always contiguous, you can compute the default contiguous strides and then manually overwrite the contiguity fields with their expected values, similar to the PR I linked. This wouldn't need upstream change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants