-
Couldn't load subscription status.
- Fork 560
Closed
Description
Referring to the "SPMD user guide" here: https://github.com/pytorch/xla/blob/535d398b9a7d2952d04abe6395307897352664d2/docs/spmd.md
What's the right approach to use the torch.distributed.checkpoint (and the distributed checkpoint manager) on Cloud TPUs?
I understand I should call init_process_group to use the torch.distributed.* functions. When I use:
xr.use_spmd() dist.init_process_group("xla", init_method="pjrt://") I run into this broken assertion: AssertionError: XLA backend is not supported with SPMD. Please use a CPU process group instead.
Do I need to init a CPU process group (with rank, world size etc) in "parallel" with how I normally use torch_xla.core.xla_model for distribution on TPUs?
Metadata
Metadata
Assignees
Labels
No labels