Skip to content
2 changes: 2 additions & 0 deletions test/test_zero1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
class XlaZeRO1Test(unittest.TestCase):

@unittest.skipIf(pjrt.device_type() == 'TPU', "Crash on TPU")
@unittest.skipIf(pjrt.device_type() == 'GPU',
"TODO(alanwaketan): Fix it for the token change.")
def test_zero1(self):
device = xm.xla_device()

Expand Down
33 changes: 16 additions & 17 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def set_replication(device, devices):
else:
torch_xla._XLAC._xla_set_replication_devices([])
devctx.device_index = 0
devctx.all_reduce_token = None
torch_xla._XLAC._set_all_reduce_token(devctx.device, None)
torch_xla._XLAC._xla_set_default_device(device)


Expand Down Expand Up @@ -412,10 +412,7 @@ def _fetch_gradients(optimizer):

def _get_all_reduce_token():
devctx = _get_device_context()
token = getattr(devctx, 'all_reduce_token', None)
if token is None:
token = torch_xla._XLAC._xla_create_token(devctx.device)
devctx.all_reduce_token = token
token = torch_xla._XLAC._get_all_reduce_token(devctx.device)
return token, devctx


Expand Down Expand Up @@ -452,11 +449,13 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True):
if isinstance(inputs, torch.Tensor):
result = torch_xla._XLAC._xla_all_reduce(reduce_type, inputs, token, scale,
groups, pin_layout)
devctx.all_reduce_token = result[1]
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
results = [result[0]]
else:
devctx.all_reduce_token = torch_xla._XLAC._xla_all_reduce_inplace(
reduce_type, inputs, token, scale, groups, pin_layout)
torch_xla._XLAC._set_all_reduce_token(
devctx.device,
torch_xla._XLAC._xla_all_reduce_inplace(reduce_type, inputs, token,
scale, groups, pin_layout))
results = inputs

return results[0] if isinstance(inputs, torch.Tensor) else results
Expand Down Expand Up @@ -548,12 +547,12 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
shard_count, groups or [],
pin_layout)
devctx.all_reduce_token = new_token
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count,
groups or [], pin_layout)
devctx.all_reduce_token = result[1]
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]


Expand Down Expand Up @@ -590,7 +589,7 @@ def all_to_all(value,
result = torch_xla._XLAC._xla_all_to_all(value, token, split_dimension,
concat_dimension, split_count,
groups or [], pin_layout)
devctx.all_reduce_token = result[1]
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]


Expand All @@ -615,7 +614,7 @@ def collective_permute(value, pairs):
"""
token, devctx = _get_all_reduce_token()
result = torch_xla._XLAC._xla_collective_permute(value, token, pairs)
devctx.all_reduce_token = result[1]
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]


Expand Down Expand Up @@ -665,7 +664,7 @@ def send(value, channel_id):
# The input will be returned as result.
input_as_result, new_token = torch_xla._XLAC._xla_send(
value, token, channel_id)
devctx.all_reduce_token = new_token
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return input_as_result


Expand All @@ -680,7 +679,7 @@ def recv(output, channel_id):
"""
token, devctx = _get_all_reduce_token()
result, new_token = torch_xla._XLAC._xla_recv(output, token, channel_id)
devctx.all_reduce_token = new_token
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return result


Expand Down Expand Up @@ -729,13 +728,13 @@ def reduce_scatter(reduce_type,
scatter_dim,
shard_count, groups or
[], pin_layout)
devctx.all_reduce_token = new_token
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale,
scatter_dim, shard_count,
groups or [], pin_layout)
devctx.all_reduce_token = result[1]
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]


Expand Down Expand Up @@ -805,7 +804,7 @@ def mark_step(wait=False):
if is_master_ordinal():
ms.save_metrics()
devctx = _run_step_closures()
devctx.all_reduce_token = None
torch_xla._XLAC._set_all_reduce_token(devctx.device, None)


def wait_device_ops(devices=[]):
Expand Down
34 changes: 34 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/token_handler.h"
#include "torch_xla/csrc/xla_graph_executor.h"

namespace torch_xla {
namespace {

// For V3-8 + PJRT, we have 4 processes and each process has 2 threads to manage
// the 8 cores. Therefore, we need different tokens for different threads.
std::unordered_map<int64_t, std::shared_ptr<torch::lazy::Value>>
g_all_reduce_tokens;

struct PerTypeContext {
std::vector<xla::XlaOp> ops;
std::vector<size_t> indices;
Expand Down Expand Up @@ -82,6 +88,19 @@ std::vector<xla::ReplicaGroup> CreateReduceGroups(
return reduce_groups;
}

std::shared_ptr<torch::lazy::Value> CreateToken(
const torch::lazy::BackendDevice& device) {
// This should be using xla::CreateToken() once we have added Token support to
// XLA AllReduce(). Meanwhile we use a constant as token, and we handle it
// accordingly in cross_replica_reduces.cpp.
// This needs to be device data (hence coming in as XLA computation parameter)
// as otherwise the XLA compiler passes will remove it, vanishing its
// sequencing effects.
torch::lazy::Value ir_value = XLAGraphExecutor::Get()->GetDeviceDataIrValue(
0.0, xla::PrimitiveType::F32, device);
return std::make_shared<torch::lazy::Value>(std::move(ir_value));
}

} // namespace

std::vector<xla::XlaOp> BuildAllReduce(
Expand Down Expand Up @@ -259,4 +278,19 @@ ReduceScatterResult BuildReduceScatter(
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

const torch::lazy::Value& GetAllReduceToken(
const torch::lazy::BackendDevice& device) {
auto it = g_all_reduce_tokens.find(device.ordinal());
if (it == g_all_reduce_tokens.end() || it->second == nullptr) {
g_all_reduce_tokens[device.ordinal()] = CreateToken(device);
return *g_all_reduce_tokens[device.ordinal()];
}
return *it->second;
}

void SetAllReduceToken(const torch::lazy::BackendDevice& device,
const std::shared_ptr<torch::lazy::Value>& token) {
g_all_reduce_tokens[device.ordinal()] = token;
}

} // namespace torch_xla
7 changes: 7 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "torch/csrc/lazy/core/ir.h"
#include "torch_xla/csrc/device.h"

namespace torch_xla {

Expand Down Expand Up @@ -77,4 +79,9 @@ ReduceScatterResult BuildReduceScatter(
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

const torch::lazy::Value& GetAllReduceToken(
const torch::lazy::BackendDevice& device);
void SetAllReduceToken(const torch::lazy::BackendDevice& device,
const std::shared_ptr<torch::lazy::Value>& token);

} // namespace torch_xla
26 changes: 11 additions & 15 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,19 +443,6 @@ std::vector<at::Tensor> GetXlaTensorsFromAten(
return xla_tensors;
}

std::shared_ptr<torch::lazy::Value> CreateToken(const std::string& device_str) {
// This should be using xla::CreateToken() once we have added Token support to
// XLA AllReduce(). Meanwhile we use a constant as token, and we handle it
// accordingly in cross_replica_reduces.cpp.
// This needs to be device data (hence coming in as XLA computation parameter)
// as otherwise the XLA compiler passes will remove it, vanishing its
// sequencing effects.
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
torch::lazy::Value ir_value = XLAGraphExecutor::Get()->GetDeviceDataIrValue(
0.0, xla::PrimitiveType::F32, device);
return std::make_shared<torch::lazy::Value>(std::move(ir_value));
}

at::Tensor GetXlaTensorDimensionSize(const at::Tensor& tensor, int64_t dim) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
return bridge::AtenFromXlaTensor(
Expand Down Expand Up @@ -1011,8 +998,6 @@ void InitXlaModuleBindings(py::module m) {

py::class_<torch::lazy::Value, std::shared_ptr<torch::lazy::Value>>(
m, "IrValue");
m.def("_xla_create_token",
[](const std::string& device) { return CreateToken(device); });
m.def(
"_xla_all_reduce_inplace",
[](const std::string& reduce_type, const std::vector<at::Tensor>& tensors,
Expand Down Expand Up @@ -1616,6 +1601,17 @@ void InitXlaModuleBindings(py::module m) {
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
});
m.def("_get_all_reduce_token",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to call it _get_cc_token or _get_xla_token, although it is currently only for all_reduce. We can also do this after we convert second cc op to use cpp token

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just following the traditional in the python layer where it's named as all_reduce_token. haha.

[](const std::string& device_str) -> const torch::lazy::Value& {
auto device = GetDeviceOrCurrent(device_str);
return GetAllReduceToken(device);
});
m.def("_set_all_reduce_token",
[](const std::string& device_str,
const std::shared_ptr<torch::lazy::Value>& token) {
auto device = GetDeviceOrCurrent(device_str);
SetAllReduceToken(device, token);
});

/* The distributed runtime service is used by the PjRt GPU client. */
py::class_<xla::DistributedRuntimeService,
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "torch_xla/csrc/tensor.h"

#include <algorithm>
#include <atomic>
#include <cmath>
#include <condition_variable>
#include <exception>
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <Python.h>

#include <algorithm>
#include <atomic>
#include <cmath>
#include <condition_variable>
#include <exception>
Expand Down