-
Couldn't load subscription status.
- Fork 560
Closed
Labels
distributedSPMD and other distributed things.SPMD and other distributed things.enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomers
Description
🚀 Feature
PyTorch/XLA xs.mark_sharding is an in-place operation that adds sharding annotation to an XLA tensor. However, gradients to be applied to the tensor are not annotated with sharding annotations.
Motivation
In some cases, GSPMD fails to propagate sharding annotation from the tensor to its gradient. It's useful to shard both tensor and its gradient with the same sharding annotation.
Pitch
We could write a torch.autograd.Function implementation to do this.
Additional context
JAX mark_sharding shards the gradients too.
cc @bhavya01
rpsilva-aws and miladm
Metadata
Metadata
Labels
distributedSPMD and other distributed things.SPMD and other distributed things.enhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomers