Skip to content

Conversation

@iwknow
Copy link
Collaborator

@iwknow iwknow commented Mar 12, 2025

This change concludes #8678

@iwknow
Copy link
Collaborator Author

iwknow commented Mar 13, 2025

it seems that i cannot trigger the re-run of the testing workflow. is it a permission issue? or do i miss anything? @tengyifei

@tengyifei
Copy link
Collaborator

tengyifei commented Mar 13, 2025

@iwknow for security only repo writers can run workflow. i just ran it for you

@iwknow
Copy link
Collaborator Author

iwknow commented Mar 13, 2025

I am a little bit confused about what to expect here. in my test, i have

xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, device=xm.xla_device(), requires_grad=True) xst1 = xs.mark_sharding_with_gradients(xt1, mesh, partition_spec) xst1.retain_grad() output = xst1.sum() output.retain_grad() output.backward() 

there are three tensors: xt1, xst1, output and their corresponding gradients. I expect xst1, xst1.grad, output, output.grad, and xt1.grad have "sharding" section in their hlo. However, my experiment shows that xt1, xt1.grad, xst1, output have the "sharding" section in their hlo but xst1.grad and output.grad doesn't have the "sharding". is this expected? or is that related to retain_grad (by default, these grads are not retained)? or do i miss something?

@tengyifei
Copy link
Collaborator

tengyifei commented Mar 14, 2025

I am a little bit confused about what to expect here. in my test, i have

xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, device=xm.xla_device(), requires_grad=True) xst1 = xs.mark_sharding_with_gradients(xt1, mesh, partition_spec) xst1.retain_grad() output = xst1.sum() output.retain_grad() output.backward() 

there are three tensors: xt1, xst1, output and their corresponding gradients. I expect xst1, xst1.grad, output,
output.grad, and xt1.grad have "sharding" section in their hlo. However, my experiment shows that xt1, xt1.grad, xst1,
output have the "sharding" section in their hlo but xst1.grad and output.grad doesn't have the "sharding". is this
expected? or is that related to retain_grad (by default, these grads are not retained)? or do i miss something?

There's a difference between "having sharding in their hlo" vs "having sharding in their torch_xla._XLAC._get_xla_sharding_spec(my_tensor)".

If you check the HLO of a tensor, that will contain not just the HLO corresponding to the tensor itself, but also the HLO of any input tensor and their inputs, etc., recursively until you hit device data tensors. It's better to check torch_xla._XLAC._get_xla_sharding_spec, which will return the sharding spec of the tensor and nothing else. I think if you search for this function's usage in the code base you'd find how to write tests with it.

In our snippet above, I'd expect the following:

  • xst1 has sharding annotations if you look in _get_xla_sharding_spec
  • xst1.grad has sharding annotations if you look in _get_xla_sharding_spec
  • everything else won't have sharding annotations

That's because if we don't call torch_xla.sync(), the GSPMD sharding propagation is not run. Then only the tensors which we explicitly called mark_sharding on will have a sharding spec.

I do think xst.retain_grad() is required to keep the xst.grad node around though; otherwise it would be cleared by PyTorch after output.backward().

@iwknow
Copy link
Collaborator Author

iwknow commented Mar 15, 2025

Thanks for the detailed explanation! One more thing:

In our snippet above, I'd expect the following:

xst1 has sharding annotations if you look in _get_xla_sharding_spec
xst1.grad has sharding annotations if you look in _get_xla_sharding_spec

do you mean xt1 instead of xst1 on your second bullet point? my experiment shows that only xst1 and xt1.grad has the sharding annotation. This is also align with the assertions in other tests (i.e. test_mark_sharding_autograd, test_mark_sharding_aot_compile). my understanding is that the MarkShardingFunction is a custom operation in the computation graph. Then we have

FORWARD: xt1(sharding=None) ---- MarkShardingFunction.forward -----> xst1(sharding={user_defined_sharding}) BACKWARD: xr1.grad(sharding={user_defined_sharding}) <---- MarkShardingFunction.backward ---- xst1.grad(sharding=None) 
@tengyifei
Copy link
Collaborator

Ah, you're right

@iwknow iwknow requested review from bhavya01 and tengyifei March 15, 2025 23:58
@tengyifei tengyifei merged commit 5caaaee into pytorch:master Mar 17, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants