Skip to content
16 changes: 16 additions & 0 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import torch_xla.distributed.xla_multiprocessing as xmp


def all_gather(tensor, dim):
return xm.all_gather(tensor, dim=dim)


def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
Expand All @@ -14,6 +18,18 @@ def _mp_fn(index):
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)

cpu_result = result.cpu()
expected = torch.arange(0, world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print('xm.all_gather() produced wrong reductions', file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

compiled_all_gather = torch.compile(
all_gather, backend='torchxla_trace_once', fullgraph=True)
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = compiled_all_gather(ordinal_tensor, dim=0)

cpu_result = result.cpu()
expected = torch.arange(0, world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
Expand Down
33 changes: 31 additions & 2 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
import torch_xla
from torch_xla.experimental import pjrt
from torch_xla.experimental import tpu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.debug.metrics_saver as ms
import torch_xla.utils.utils as xu
Expand All @@ -26,6 +27,27 @@
_DEVICE_CONTEXTS = dict()
_DEVICE_CONTEXTS_LOCK = threading.Lock()

# Note [Dynamo WORLD_SIEZ and ORDINAL]
# Belows are workaround to cache the ordinal and world_size such that
# Dynamo won't do graph breaks when xm.xrt_world_size() and xm.get_ordinal() are called.
_WORLD_SIZE = None
_ORDINAL = None


def _init_world_size_ordinal():
global _WORLD_SIZE, _ORDINAL

if not pjrt.using_pjrt():
return

# We don't support V3-8. See Note [V3-8 Threading]
if pjrt.device_type() == 'TPU' and tpu.version() < 4:
return

if _WORLD_SIZE is None:
_WORLD_SIZE = xrt_world_size()
_ORDINAL = get_ordinal()


class DeviceContext(object):

Expand Down Expand Up @@ -90,6 +112,10 @@ def xrt_world_size(defval=1):
Returns:
The number of devices which is taking part of the replication.
"""
global _WORLD_SIZE
if _WORLD_SIZE is not None:
return _WORLD_SIZE

if pjrt.using_pjrt():
return pjrt.world_size()

Expand All @@ -109,6 +135,10 @@ def get_ordinal(defval=0):
Returns:
The replication ordinal of the current thread.
"""
global _ORDINAL
if _ORDINAL is not None:
return _ORDINAL

if pjrt.using_pjrt():
return pjrt.global_ordinal()

Expand Down Expand Up @@ -533,8 +563,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
if pin_layout and xla_device_hw(
value.device) in ('TPU', 'GPU', 'XPU') and output == None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had it because CPU was not supported at some point. Do you need to remove it because it will break dynamo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea.

if pin_layout and output == None:
# There is not an easy way to pin the all_gather layout on TPU and GPU, use
# all_reduce based all_gather for this purpose.
return _all_gather_using_all_reduce(
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
namespace torch_xla {
namespace {

// Note [V3-8 Threading]
// For V3-8 + PJRT, we have 4 processes and each process has 2 threads to manage
// the 8 cores. Therefore, we need different tokens for different threads.
std::unordered_map<int64_t, std::shared_ptr<torch::lazy::Value>>
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/experimental/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def _run_thread_per_device(
def _thread_fn(device: torch.device):
torch_xla._XLAC._xla_set_default_device(device)

# See Note Note [Dynamo WORLD_SIEZ and ORDINAL].
xm._init_world_size_ordinal()

return fn()

with concurrent.futures.ThreadPoolExecutor(
Expand Down