Skip to content

Commit 8b8b652

Browse files
committed
[LoweringContext] Support an optimized parameter mapping for SPMD
1 parent 39e67b5 commit 8b8b652

File tree

4 files changed

+96
-5
lines changed

4 files changed

+96
-5
lines changed

test/run_tests.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ function run_xla_op_tests1 {
203203
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
204204
run_save_tensor_ir "$CDIR/spmd/test_spmd_graph_dump.py"
205205
run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py"
206+
run_save_tensor_hlo "$CDIR/spmd/test_spmd_lowering_context.py"
206207
}
207208

208209
function run_xla_op_tests2 {
@@ -247,6 +248,7 @@ function run_xla_op_tests3 {
247248
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
248249
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
249250
run_test "$CDIR/spmd/test_mp_input_sharding.py"
251+
run_test "$CDIR/spmd/test_spmd_lowering_context.py"
250252
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
251253
run_test "$CDIR/test_input_output_aliases.py"
252254
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import sys
2+
3+
import unittest
4+
5+
import test_xla_sharding_base
6+
7+
import torch
8+
import torch_xla
9+
import torch_xla.debug.metrics as met
10+
import torch_xla.distributed.spmd as xs
11+
import torch_xla.core.xla_model as xm
12+
import contextlib
13+
14+
15+
class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest):
16+
17+
@classmethod
18+
def setUpClass(cls):
19+
super().setUpClass()
20+
21+
def test_device_parameter_id_tensor_mapping(self):
22+
met.clear_all()
23+
24+
model_axis = min(8, self.n_devices)
25+
data_axis = self.n_devices // model_axis
26+
mesh_shape = (data_axis, model_axis)
27+
spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))
28+
29+
device = xm.xla_device()
30+
a = torch.randn([32, 2048]).to(device)
31+
xs.mark_sharding(a, spmd_mesh, ('x', 'y'))
32+
b = torch.ones(2048).to(device)
33+
xs.mark_sharding(b, spmd_mesh, ('x',))
34+
35+
def fn(a, b):
36+
return a + b
37+
38+
result = fn(a, b)
39+
ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
40+
ctx.build([result])
41+
torch_xla.sync()
42+
43+
mapping = ctx.device_parameter_id_tensor_mapping()
44+
num_params = len(mapping)
45+
self.assertEqual(num_params, 2)
46+
self.assertNotEqual(ctx.tensor_parameter_id(a), -1)
47+
self.assertNotEqual(ctx.tensor_parameter_id(b), -1)
48+
self.assertEqual(met.counter_value("VirtualDeviceUsage"), num_params)
49+
50+
# Ensure that the parameter mapping does not require transferring data
51+
# from the device to the host when sharded.
52+
self.assertFalse(met.metric_data("TransferFromDeviceTime"))
53+
self.assertFalse(met.counter_value("ReplicateShardedData"))
54+
55+
56+
if __name__ == '__main__':
57+
test = unittest.main()
58+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_operations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2642,14 +2642,14 @@ def test_api(self):
26422642

26432643
ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
26442644
ctx.build([result])
2645-
hlo = ctx.hlo()
26462645
hlo_text = ctx.hlo_text()
26472646
self.assertIn('MyCustomName', hlo_text)
2648-
self.assertIn('opcode: "parameter"', hlo_text)
2649-
self.assertIn('opcode: "parameter"', hlo_text)
2647+
self.assertTrue(hlo_text.count('opcode: "parameter"'), 2)
26502648
self.assertIn('opcode: "add"', hlo_text)
26512649
mapping = ctx.parameter_id_tensor_mapping()
2652-
self.assertEqual(len(mapping), 2)
2650+
num_params = len(mapping)
2651+
self.assertEqual(num_params, 2)
2652+
self.assertTrue(met.metric_data("TransferFromDeviceTime"))
26532653

26542654
def test_get_parameters_scalar(self):
26552655
"""Scalar tensors parameters may be shared in the HLO graph if their

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1096,10 +1096,39 @@ class PyLoweringContext {
10961096
return results;
10971097
}
10981098

1099+
// Returns a mapping from HLO parameter IDs to their corresponding
1100+
// device-backed Tensors. This version only returns parameters that were
1101+
// explicitly allocated on device data, accessible via GetTensorParameterId().
1102+
// Unlike GetParameterIdTensorMapping(), it avoids transferring data from
1103+
// device to host, making it more efficient especially for SPMD scenarios
1104+
// where data may be sharded.
1105+
std::unordered_map<int64_t, at::Tensor> GetDeviceParameterIdTensorMapping() {
1106+
// Find parameters in the lowering
1107+
const std::vector<torch::lazy::BackendDataPtr>& device_data =
1108+
lowering_ctx.GetParametersData();
1109+
1110+
// Create a mapping from parameter id to the tensor data
1111+
std::unordered_map<int64_t, at::Tensor> param_to_tensor;
1112+
param_to_tensor.reserve(device_data.size());
1113+
1114+
for (const auto& data : device_data) {
1115+
std::optional<int64_t> param_id = lowering_ctx.GetParameterId(data);
1116+
XLA_CHECK(param_id.has_value())
1117+
<< "Parameter ID must exist for device data";
1118+
1119+
at::Tensor tensor =
1120+
bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create(data));
1121+
param_to_tensor.emplace(param_id.value(), std::move(tensor));
1122+
}
1123+
return param_to_tensor;
1124+
}
1125+
10991126
// Get the parameter identifier of a given tensor. If the tensor is not a
11001127
// parameter this will always return -1. This is useful in conjunction with
11011128
// GetParameterIdTensorMapping to identify which values can be baked into
1102-
// the graph and which values must remain parameters.
1129+
// the graph and which values must remain parameters. Note that in
1130+
// conjunction with GetDeviceParameterIdTensorMapping, all tensors are
1131+
// parameters with a valid parameter id.
11031132
int64_t GetTensorParameterId(at::Tensor tensor) {
11041133
// Convert tensor into the backing lazy node
11051134
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
@@ -1201,6 +1230,8 @@ void BuildLoweringContextSubmodule(py::module* m) {
12011230
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
12021231
.def("parameter_id_tensor_mapping",
12031232
&PyLoweringContext::GetParameterIdTensorMapping)
1233+
.def("device_parameter_id_tensor_mapping",
1234+
&PyLoweringContext::GetDeviceParameterIdTensorMapping)
12041235
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId)
12051236
.def("set_name_string", &PyLoweringContext::SetNameString)
12061237
.def("get_name_string", &PyLoweringContext::GetNameString);

0 commit comments

Comments
 (0)