11import  datetime 
2+ import  logging 
23import  threading 
34
45import  torch .distributed  as  dist 
5- from  torch .testing ._internal .distributed  import  multi_threaded_pg 
66from  torch_xla .distributed  import  xla_backend 
77from  torch_xla  import  runtime  as  xr 
8+ from  torch_xla ._internal  import  pjrt 
89from  torch_xla ._internal  import  tpu 
910import  torch_xla .utils .utils  as  xu 
1011
1516def  _pjrt_rendezvous_handler (url : str ,
1617 timeout : datetime .timedelta  =  ...,
1718 ** kwargs ):
19+  # Assume `xmp.spawn` has not been called when using torchrun 
20+  if  dist .is_torchelastic_launched ():
21+  local_world_size  =  xu .getenv_as ('LOCAL_WORLD_SIZE' , int )
22+  local_rank  =  xu .getenv_as ('LOCAL_RANK' , int )
23+  pjrt .initialize_multiprocess (local_rank , local_world_size )
24+ 
1825 master_ip  =  xu .getenv_as ('MASTER_ADDR' , str )
1926 if  not  master_ip :
2027 master_ip  =  tpu .discover_master_worker_ip () if  xr .device_type (
@@ -24,15 +31,26 @@ def _pjrt_rendezvous_handler(url: str,
2431 with  _store_lock :
2532 global  _store 
2633 if  not  _store :
27-  _store  =  dist .TCPStore (
28-  master_ip ,
29-  master_port ,
30-  xr .process_count (),
31-  is_master = xr .process_index () ==  0 )
34+  if  xu .getenv_as ('TORCHELASTIC_USE_AGENT_STORE' , str ) ==  'True' :
35+  attempt  =  xu .getenv_as ('TORCHELASTIC_RESTART_COUNT' , int , defval = 0 )
36+  tcp_store  =  dist .TCPStore (
37+  master_ip , master_port , xr .process_count (), is_master = False )
38+  _store  =  dist .PrefixStore (f"/worker/attempt_{ attempt }  , tcp_store )
39+  else :
40+  _store  =  dist .TCPStore (
41+  master_ip ,
42+  master_port ,
43+  xr .process_count (),
44+  is_master = xr .process_index () ==  0 )
3245
3346 yield  (_store , xr .global_ordinal (), xr .world_size ())
3447
3548
36- multi_threaded_pg ._install_threaded_pg ()
49+ if  tpu .num_available_chips () >  0  and  tpu .version () <=  3 :
50+  from  torch .testing ._internal .distributed  import  multi_threaded_pg 
51+  logging .warning ('Patching torch.distributed state to support multithreading.' )
52+  logging .warning ('torch.distributed support on TPU v2 and v3 is experimental ' 
53+  'and does not support torchrun.' )
54+  multi_threaded_pg ._install_threaded_pg ()
3755
3856dist .register_rendezvous_handler ('pjrt' , _pjrt_rendezvous_handler )
0 commit comments