Skip to content

Commit 9db8fac

Browse files
authored
[PJRT] Support torchrun with pjrt:// init_method (#5438)
* Support torchrun with `pjrt://` `init_method` * move import * fix error * Fix NameError * Fix path * Remove from TPU CI
1 parent 5b88b5f commit 9db8fac

File tree

3 files changed

+68
-10
lines changed

3 files changed

+68
-10
lines changed

test/pjrt/test_torchrun.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from absl.testing import absltest
2+
from absl import logging
3+
import torch
4+
import torch.distributed as dist
5+
import torch_xla.core.xla_model as xm
6+
import torch_xla.experimental.pjrt_backend
7+
import torch_xla.runtime as xr
8+
import torch_xla.utils.utils as xu
9+
10+
11+
class TestTorchrun(absltest.TestCase):
12+
13+
def test_all_gather(self):
14+
dist.init_process_group('xla', init_method='pjrt://')
15+
16+
dist_world_size = xu.getenv_as('WORLD_SIZE', int)
17+
devices_per_thread = xr.addressable_device_count()
18+
19+
expected_world_size = dist_world_size * devices_per_thread
20+
21+
rank = torch.tensor([dist.get_rank()],
22+
dtype=torch.float32,
23+
device=xm.xla_device())
24+
output = [rank.clone() for _ in range(expected_world_size)]
25+
dist.all_gather(output, rank)
26+
result = torch.concat(output)
27+
xm.mark_step()
28+
29+
expected = torch.arange(0, expected_world_size, step=1, dtype=torch.float32)
30+
torch.testing.assert_close(result.cpu(), expected)
31+
32+
33+
if __name__ == '__main__':
34+
if not dist.is_torchelastic_launched():
35+
logging.error('Test must be launched with torchrun!')
36+
exit(1)
37+
38+
absltest.main()

torch_xla/_internal/pjrt.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def _run_thread_per_device(
5656
initializer_fn(local_rank, local_world_size)
5757

5858
devices = xm.get_xla_supported_devices()
59-
xm.set_replication(xm.xla_device(), devices)
6059
num_threads = len(devices)
6160

6261
@functools.wraps(fn)
@@ -104,13 +103,16 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
104103

105104

106105
@runtime.requires_pjrt
107-
def _initialize_multiprocess(local_rank: int, local_world_size: int):
106+
def initialize_multiprocess(local_rank: int, local_world_size: int):
108107
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_RANK, str(local_rank))
109108
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, str(local_world_size))
110109

111110
if runtime.device_type() == 'TPU':
112111
tpu.configure_topology(local_rank, local_world_size)
113112

113+
devices = xm.get_xla_supported_devices()
114+
xm.set_replication(xm.xla_device(), devices)
115+
114116

115117
@runtime.requires_pjrt
116118
def run_multiprocess(fn: Callable[..., R],
@@ -148,7 +150,7 @@ def run_multiprocess(fn: Callable[..., R],
148150
_run_thread_per_device,
149151
local_world_size=num_processes,
150152
fn=functools.partial(fn, *args, **kwargs),
151-
initializer_fn=_initialize_multiprocess)
153+
initializer_fn=initialize_multiprocess)
152154
process_results = executor.map(mp_fn, range(num_processes))
153155
replica_results = list(
154156
itertools.chain.from_iterable(
Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import datetime
2+
import logging
23
import threading
34

45
import torch.distributed as dist
5-
from torch.testing._internal.distributed import multi_threaded_pg
66
from torch_xla.distributed import xla_backend
77
from torch_xla import runtime as xr
8+
from torch_xla._internal import pjrt
89
from torch_xla._internal import tpu
910
import torch_xla.utils.utils as xu
1011

@@ -15,6 +16,12 @@
1516
def _pjrt_rendezvous_handler(url: str,
1617
timeout: datetime.timedelta = ...,
1718
**kwargs):
19+
# Assume `xmp.spawn` has not been called when using torchrun
20+
if dist.is_torchelastic_launched():
21+
local_world_size = xu.getenv_as('LOCAL_WORLD_SIZE', int)
22+
local_rank = xu.getenv_as('LOCAL_RANK', int)
23+
pjrt.initialize_multiprocess(local_rank, local_world_size)
24+
1825
master_ip = xu.getenv_as('MASTER_ADDR', str)
1926
if not master_ip:
2027
master_ip = tpu.discover_master_worker_ip() if xr.device_type(
@@ -24,15 +31,26 @@ def _pjrt_rendezvous_handler(url: str,
2431
with _store_lock:
2532
global _store
2633
if not _store:
27-
_store = dist.TCPStore(
28-
master_ip,
29-
master_port,
30-
xr.process_count(),
31-
is_master=xr.process_index() == 0)
34+
if xu.getenv_as('TORCHELASTIC_USE_AGENT_STORE', str) == 'True':
35+
attempt = xu.getenv_as('TORCHELASTIC_RESTART_COUNT', int, defval=0)
36+
tcp_store = dist.TCPStore(
37+
master_ip, master_port, xr.process_count(), is_master=False)
38+
_store = dist.PrefixStore(f"/worker/attempt_{attempt}", tcp_store)
39+
else:
40+
_store = dist.TCPStore(
41+
master_ip,
42+
master_port,
43+
xr.process_count(),
44+
is_master=xr.process_index() == 0)
3245

3346
yield (_store, xr.global_ordinal(), xr.world_size())
3447

3548

36-
multi_threaded_pg._install_threaded_pg()
49+
if tpu.num_available_chips() > 0 and tpu.version() <= 3:
50+
from torch.testing._internal.distributed import multi_threaded_pg
51+
logging.warning('Patching torch.distributed state to support multithreading.')
52+
logging.warning('torch.distributed support on TPU v2 and v3 is experimental '
53+
'and does not support torchrun.')
54+
multi_threaded_pg._install_threaded_pg()
3755

3856
dist.register_rendezvous_handler('pjrt', _pjrt_rendezvous_handler)

0 commit comments

Comments
 (0)