Skip to content

Conversation

@ysiraichi
Copy link
Collaborator

Fix: #5719

This PR introduces a base_ attribute for XLATensor. It keeps track of the tensor whose storage would be aliased by the outer tensor due to a view operation.

a = torch.rand(5, device=xm.xla_device()) # base_ is undefined b = a + a # base_ is undefined c = b[2:] # base_ is b d = c.as_strided((5,), (1,), 0) # uses b (the base_ tensor), instead of c, as input # base_ is b
@JackCaoG
Copy link
Collaborator

lol do you mind resolve the conflict?

@ysiraichi ysiraichi force-pushed the fix-asstrided-fallback branch from f9958ef to e66a809 Compare November 22, 2023 20:00
@ysiraichi
Copy link
Collaborator Author

@JackCaoG I think this is ready for another round of reviews. Could you take a look at it?

const at::Tensor& GetRootBase(const at::Tensor& tensor);
// Sets the base tensor of a given XLATensor. Convenient function
// to be used when returning tensors.
XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we ever expect base to be on non-xla device? If not can we add an explict check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I don't think so, since we got to a XLA dispatched kernel. Will add the check.

@ysiraichi ysiraichi merged commit 4ac255a into pytorch:master Nov 28, 2023
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
…pytorch#5914) * Add test. * Create `base_` tensor for views. * Use base tensor in `as_strided` operation. * Set base tensor of `as_strided`. * Fix lint errors. * Fix for disabled functionalization. * Address review.
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
…pytorch#5914) * Add test. * Create `base_` tensor for views. * Use base tensor in `as_strided` operation. * Set base tensor of `as_strided`. * Fix lint errors. * Fix for disabled functionalization. * Address review.
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
…pytorch#5914) * Add test. * Create `base_` tensor for views. * Use base tensor in `as_strided` operation. * Set base tensor of `as_strided`. * Fix lint errors. * Fix for disabled functionalization. * Address review.
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
…#5914) * Add test. * Create `base_` tensor for views. * Use base tensor in `as_strided` operation. * Set base tensor of `as_strided`. * Fix lint errors. * Fix for disabled functionalization. * Address review.
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
…#5914) * Add test. * Create `base_` tensor for views. * Use base tensor in `as_strided` operation. * Set base tensor of `as_strided`. * Fix lint errors. * Fix for disabled functionalization. * Address review.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2 participants