Skip to content

Commit e7ad09e

Browse files
authored
Add ToString method for both PjrtData and PjrtShardedData (#5265)
* Add ToString method for both PjrtData and PjrtShardedData * on cpu same config will become replicated, dont't check actual op sharding type
1 parent 6f9cb61 commit e7ad09e

File tree

4 files changed

+42
-4
lines changed

4 files changed

+42
-4
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,20 @@ def test_xla_sharded_tensor(self):
3737
# TODO(244003536) add more tests for XLAShardedTensror.
3838
self.assertTrue(isinstance(xst1, XLAShardedTensor))
3939

40+
def test_sharded_tensor_debug_info(self):
41+
partition_spec = (0, 1)
42+
xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
43+
dtype=torch.float,
44+
device=xm.xla_device())
45+
xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)),
46+
partition_spec)
47+
48+
debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(xst1.global_tensor)
49+
self.assertIn('XLAShardedData', debug_info)
50+
self.assertIn('Data Device: SPMD:0', debug_info)
51+
self.assertIn('OpSharding:', debug_info)
52+
self.assertIn('NumShards: %s' % (self.n_devices), debug_info)
53+
4054
def test_xla_shards(self):
4155
num_element = self.n_devices
4256
mesh = self._get_mesh((self.n_devices,))

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,13 +378,11 @@ std::string GetXLATensorDebugInfo(const at::Tensor& tensor) {
378378
}
379379

380380
torch::lazy::BackendDataPtr handle = xtensor->CurrentDataHandle();
381-
ss << "XLAData: ";
382381
if (handle) {
383382
auto data = UnwrapXlaData(handle);
384-
ss << "\n Data Device: " << data->device() << "\n";
385-
ss << " Data Shape: " << data->shape().ToString() << "\n";
383+
ss << data->ToString();
386384
} else {
387-
ss << "None\n";
385+
ss << "XLAData: None\n";
388386
}
389387

390388
auto at_tensor = xtensor->CurrentTensorData();

torch_xla/csrc/runtime/computation_client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class ComputationClient {
5151

5252
virtual bool HasValue() const = 0;
5353

54+
virtual std::string ToString() const = 0;
55+
5456
private:
5557
std::string device_;
5658
xla::Shape shape_;

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,20 @@ class PjRtComputationClient : public ComputationClient {
173173
return buffer != nullptr && !buffer->IsDeleted();
174174
};
175175

176+
std::string ToString() const override {
177+
std::stringstream ss;
178+
ss << "XLAData: \n";
179+
ss << " Data Device: " << device() << "\n";
180+
ss << " Data Shape: " << shape().ToString() << "\n";
181+
ss << " Data Handle: ";
182+
if (HasValue()) {
183+
ss << reinterpret_cast<std::uintptr_t>(buffer.get()) << "\n";
184+
} else {
185+
ss << "None\n";
186+
}
187+
return ss.str();
188+
}
189+
176190
std::shared_ptr<xla::PjRtBuffer> buffer;
177191
};
178192

@@ -212,6 +226,16 @@ class PjRtComputationClient : public ComputationClient {
212226
return true;
213227
}
214228

229+
std::string ToString() const override {
230+
std::stringstream ss;
231+
ss << "XLAShardedData: \n";
232+
ss << " Data Device: " << device() << "\n";
233+
ss << " Data Shape: " << shape().ToString() << "\n";
234+
ss << " OpSharding: " << sharding.type() << "\n";
235+
ss << " NumShards: " << shards.size() << "\n";
236+
return ss.str();
237+
}
238+
215239
xla::OpSharding GetSharding() { return sharding; }
216240

217241
std::vector<std::shared_ptr<PjRtData>> shards;

0 commit comments

Comments
 (0)