Skip to content

Commit 294610a

Browse files
authored
Fix the missing parameter error when running mp_imagenet with torchrun (#5729)
* Fix the missing parameter error when running mp_imagenet with torchrun * made it local rank
1 parent ecc0e23 commit 294610a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

test/test_train_mp_imagenet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
import torch_xla.distributed.parallel_loader as pl
8686
import torch_xla.debug.profiler as xp
8787
import torch_xla.utils.utils as xu
88+
import torch_xla.core.xla_env_vars as xenv
8889
import torch_xla.core.xla_model as xm
8990
import torch_xla.distributed.xla_multiprocessing as xmp
9091
import torch_xla.test.test_utils as test_utils
@@ -375,6 +376,6 @@ def _mp_fn(index, flags):
375376

376377
if __name__ == '__main__':
377378
if dist.is_torchelastic_launched():
378-
_mp_fn(FLAGS)
379+
_mp_fn(xu.getenv_as(xenv.LOCAL_RANK, int), FLAGS)
379380
else:
380381
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)

0 commit comments

Comments
 (0)