Skip to content

Commit 4038f8e

Browse files
authored
Correct the multinode training doc (#5747)
* fix Jon's comment * add pjrt_distributed flag back. * updated the doc * fix typo * fix typo
1 parent 83778f0 commit 4038f8e

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

docs/pjrt.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --f
206206
You can also use `torchrun` to initiate the single-node multi-GPU training. For example,
207207

208208
```
209-
PJRT_DEVICE=GPU torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1
209+
PJRT_DEVICE=GPU torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
210210
```
211211

212212
In the above example, `--nnodes` means how many machines (physical machines or VMs) to be used (it is 1 since we do single-node training). `--nproc-per-node` means how many GPU devices to be used.
@@ -245,10 +245,10 @@ On the second GPU machine, run
245245
--nnodes=2 \
246246
--node_rank=1 \
247247
--nproc_per_node=4 \
248-
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet_torchrun.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
248+
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
249249
```
250250

251-
the difference between the 2 commands above are `--node_rank` and potentially `--nproc_per_node` if you want to use different number of GPU devices on each machine. All the rest are identical.
251+
the difference between the 2 commands above are `--node_rank` and potentially `--nproc_per_node` if you want to use different number of GPU devices on each machine. All the rest are identical. For more information about `torchrun`, please refer to this [page](https://pytorch.org/docs/stable/elastic/run.html).
252252

253253
## Differences from XRT
254254

test/test_train_mp_imagenet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
'--ddp': {
3232
'action': 'store_true',
3333
},
34+
'--pjrt_distributed': {
35+
'action': 'store_true',
36+
},
3437
'--profile': {
3538
'action': 'store_true',
3639
},
@@ -175,7 +178,7 @@ def _train_update(device, step, loss, tracker, epoch, writer):
175178

176179

177180
def train_imagenet():
178-
if FLAGS.ddp:
181+
if FLAGS.ddp or FLAGS.pjrt_distributed:
179182
dist.init_process_group('xla', init_method='xla://')
180183

181184
print('==> Preparing data..')

test/test_train_mp_mnist.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
'--ddp': {
66
'action': 'store_true',
77
},
8+
'--pjrt_distributed': {
9+
'action': 'store_true',
10+
},
811
}
912

1013
FLAGS = args_parse.parse_common_options(
@@ -73,7 +76,7 @@ def _train_update(device, step, loss, tracker, epoch, writer):
7376

7477

7578
def train_mnist(flags, **kwargs):
76-
if flags.ddp:
79+
if flags.ddp or flags.pjrt_distributed:
7780
dist.init_process_group('xla', init_method='xla://')
7881

7982
torch.manual_seed(1)

0 commit comments

Comments
 (0)