Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,14 @@ function run_eager_debug {
XLA_USE_EAGER_DEBUG_MODE=1 run_test "$@"
}

function run_save_tensor_file {
function run_save_tensor_ir {
echo "Running in save tensor file mode: $@"
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" run_test "$@"
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="text" run_test "$@"
}

function run_save_tensor_hlo {
echo "Running in save tensor file mode: $@"
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@"
}

function run_stablehlo_compile {
Expand Down Expand Up @@ -148,7 +153,7 @@ function run_xla_op_tests {
run_test "$CDIR/dynamo/test_dynamo.py"
run_test "$CDIR/dynamo/test_bridge.py"
run_test "$CDIR/dynamo/test_num_output.py"
run_save_tensor_file "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_downcast_bf16 "$CDIR/test_data_type.py"
run_use_bf16 "$CDIR/test_data_type.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
Expand All @@ -167,6 +172,8 @@ function run_xla_op_tests {
run_test "$CDIR/spmd/test_xla_virtual_device.py"
run_test "$CDIR/spmd/test_dynamo_spmd.py"
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_save_tensor_ir "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
53 changes: 53 additions & 0 deletions test/spmd/test_spmd_graph_dump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import sys

import unittest
from unittest.mock import patch
import math
import numpy as np
import os

import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import test_xla_sharding_base


class BasicShardingTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
os.environ["XLA_USE_SPMD"] = "1"
super().setUpClass()

def test_dump_with_output_sharding(self):
save_file = os.getenv('XLA_SAVE_TENSORS_FILE')
save_format = os.getenv('XLA_SAVE_TENSORS_FMT')
if not save_file:
assert False, "This test should be run with XLA_SAVE_TENSORS_FILE"
should_dump_output_sharding = (save_format == 'hlo')
save_file += '.0'
device = xm.xla_device()
xla_x = torch.randn(8, 32).to(device)
xla_y = torch.randn(8, 32).to(device)
# shard one of the input tensor
partition_spec = (0, 1)
xla_sharded_x = xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)),
partition_spec)
xla_res = xla_x + xla_y
with open(save_file, 'rb') as f:
current_line = sum(1 for line in f)
with open(save_file, 'rb') as f:
xm.mark_step()
lines = f.readlines()
self.assertGreater(len(lines), current_line)
if should_dump_output_sharding:
self.assertIn('OUTPUT_SHARDING_END', str(lines[-2]))
else:
self.assertIn('END_GRAPH', str(lines[-3]))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
28 changes: 28 additions & 0 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,34 @@ void DebugUtil::SaveTensorsGraphInfo(const char* name,
}
}

void DebugUtil::SaveOutputShardingInfo(std::vector<XLATensorPtr>* tensors,
absl::Span<const size_t> indices) {
thread_local const std::string save_file =
runtime::sys_util::GetEnvOrdinalPath("XLA_SAVE_TENSORS_FILE", "",
GetCurrentDevice().ordinal());
std::string fmt_str =
runtime::sys_util::GetEnvString("XLA_SAVE_TENSORS_FMT", "text");
if (save_file.empty() || fmt_str != "hlo") {
return;
}
std::stringstream ss;
for (int i = 0; i < indices.size(); ++i) {
auto xtensor = (*tensors)[indices[i]];
ss << xtensor->shape().get().ToString() << " ";
if (xtensor->sharding_spec()) {
ss << xla::HloSharding::FromProto(xtensor->sharding_spec()->sharding)
->ToString();
} else {
ss << xla::HloSharding::FromProto(xla::HloSharding::Replicate().ToProto())
->ToString();
}
ss << "\n";
}
std::ofstream graph_file(save_file, std::ios_base::app);
graph_file << "\n#OUTPUT_SHARDING_BEGIN\n\n"
<< ss.str() << "\n#OUTPUT_SHARDING_END\n\n";
}

bool DebugUtil::ExperimentEnabled(const std::string& name) {
static const std::unordered_set<std::string>* xset = LoadExperiments();
return xset->find(name) != xset->end();
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/debug_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class DebugUtil {
const std::vector<size_t>* indices,
GraphFormat format = GetDefaultGraphFormat());

static void SaveOutputShardingInfo(std::vector<XLATensorPtr>* tensors,
absl::Span<const size_t> indices);

static bool ExperimentEnabled(const std::string& name);
};

Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,7 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
ShardingUtil::PrepareOutputShardingPropagation(
tensors, coll->indices, cached_computation->computation, &tensors_data,
&sharding_specs);
DebugUtil::SaveOutputShardingInfo(tensors, coll->indices);
}

return ScheduleSyncTensorsGraph(
Expand Down