Skip to content

Commit b6a03d9

Browse files
authored
Add tooling to explain why a graph execution happens (#5723)
* Initial commit for debugging tool * minor format tweak * Only master process should print the execution frame info * add execution cause * handle dynamo and everything else * add test * linter * add test to the script
1 parent 595ebcf commit b6a03d9

File tree

5 files changed

+215
-0
lines changed

5 files changed

+215
-0
lines changed

test/run_tests.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ function run_save_tensor_hlo {
108108
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@"
109109
}
110110

111+
function run_pt_xla_debug {
112+
echo "Running in save tensor file mode: $@"
113+
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
114+
}
115+
111116
function run_stablehlo_compile {
112117
echo "Running in StableHlo Compile mode: $@"
113118
XLA_STABLEHLO_COMPILE=1 run_test "$@"
@@ -156,6 +161,7 @@ function run_xla_op_tests1 {
156161
run_test "$CDIR/test_grad_checkpoint.py"
157162
run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
158163
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
164+
run_pt_xla_debug "$CDIR/test_pt_xla_debug.py"
159165
run_test "$CDIR/test_async_closures.py"
160166
run_test "$CDIR/test_profiler.py"
161167
run_test "$CDIR/pjrt/test_runtime.py"

test/test_pt_xla_debug.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import os
2+
3+
import torch
4+
import torch_xla
5+
import torch_xla.core.xla_model as xm
6+
import torch_xla.utils.utils as xu
7+
import torch_xla.debug.profiler as xp
8+
import torch_xla.utils.utils as xu
9+
import torch_xla.distributed.parallel_loader as pl
10+
import unittest
11+
12+
13+
def check_env_flag(name, default=''):
14+
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
15+
16+
17+
def extract_execution_cause(lines):
18+
causes = []
19+
for i in range(len(lines)):
20+
if 'Execution Cause' in lines[i].decode():
21+
causes.append(lines[i + 1].decode())
22+
return causes
23+
24+
25+
class PtXLADebugTest(unittest.TestCase):
26+
27+
@classmethod
28+
def setUpClass(cls):
29+
if not check_env_flag('PT_XLA_DEBUG'):
30+
assert False, "This test should be run with PT_XLA_DEBUG"
31+
cls.debug_file_name = os.getenv('PT_XLA_DEBUG_FILE')
32+
if not cls.debug_file_name:
33+
assert False, "This test should be run with PT_XLA_DEBUG_FILE"
34+
open(cls.debug_file_name, 'w').close()
35+
36+
def test_user_mark_step(self):
37+
device = xm.xla_device()
38+
t1 = torch.randn(2, 2, device=device)
39+
xm.mark_step()
40+
with open(self.debug_file_name, 'rb') as f:
41+
lines = f.readlines()
42+
causes = extract_execution_cause(lines)
43+
self.assertEqual(len(causes), 1)
44+
self.assertIn('user mark_step', causes[0])
45+
open(self.debug_file_name, 'w').close()
46+
47+
def test_step_trace(self):
48+
device = xm.xla_device()
49+
with xp.StepTrace('train_pt_xla_debug'):
50+
t1 = torch.randn(2, 2, device=device)
51+
with open(self.debug_file_name, 'rb') as f:
52+
lines = f.readlines()
53+
causes = extract_execution_cause(lines)
54+
self.assertEqual(len(causes), 1)
55+
self.assertIn('mark_step when exiting a profiler StepTrace region',
56+
causes[0])
57+
open(self.debug_file_name, 'w').close()
58+
59+
def test_dynamo(self):
60+
device = xm.xla_device()
61+
t1 = torch.randn(2, 2, device=device)
62+
63+
def toy_program(t1):
64+
return t1 + t1
65+
66+
compiled = torch.compile(toy_program, backend="openxla")
67+
res = compiled(t1)
68+
with open(self.debug_file_name, 'rb') as f:
69+
lines = f.readlines()
70+
causes = extract_execution_cause(lines)
71+
self.assertEqual(len(causes), 3)
72+
self.assertIn('mark_step when dynamo processing input graphs', causes[0])
73+
self.assertIn('mark_step when dynamo processing input graphs', causes[1])
74+
self.assertIn('dynamo compiles FX graph to HLO', causes[2])
75+
open(self.debug_file_name, 'w').close()
76+
77+
def test_parallel_loader(self):
78+
device = xm.xla_device()
79+
80+
train_dataset_len = 100
81+
batch_size = 10
82+
train_loader = xu.SampleGenerator(
83+
data=(torch.zeros(batch_size, 3, 128,
84+
128), torch.zeros(batch_size, dtype=torch.int64)),
85+
sample_count=train_dataset_len // 10)
86+
87+
train_device_loader = pl.MpDeviceLoader(
88+
train_loader,
89+
device,
90+
loader_prefetch_size=8,
91+
device_prefetch_size=4,
92+
host_to_device_transfer_threads=1)
93+
94+
for step, (data, target) in enumerate(train_device_loader):
95+
pass
96+
97+
with open(self.debug_file_name, 'rb') as f:
98+
lines = f.readlines()
99+
causes = extract_execution_cause(lines)
100+
self.assertEqual(len(causes), batch_size + 2)
101+
for cause in causes:
102+
self.assertIn('mark_step in parallel loader at step end', cause)
103+
open(self.debug_file_name, 'w').close()
104+
105+
def test_print(self):
106+
device = xm.xla_device()
107+
t1 = torch.randn(2, 2, device=device)
108+
print(t1)
109+
with open(self.debug_file_name, 'rb') as f:
110+
lines = f.readlines()
111+
causes = extract_execution_cause(lines)
112+
self.assertEqual(len(causes), 1)
113+
self.assertIn('user code trying to access tensor value', causes[0])
114+
open(self.debug_file_name, 'w').close()
115+
116+
117+
if __name__ == '__main__':
118+
test = unittest.main()
119+
sys.exit(0 if test.result.wasSuccessful() else 1)

torch_xla/csrc/debug_util.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/csrc/lazy/python/python_util.h>
55

66
#include <fstream>
7+
#include <iostream>
78
#include <mutex>
89
#include <sstream>
910
#include <unordered_set>
@@ -209,4 +210,86 @@ bool DebugUtil::ExperimentEnabled(const std::string& name) {
209210
return xset->find(name) != xset->end();
210211
}
211212

213+
// helper function until we move to C++ 20
214+
static bool endsWith(const std::string& str, const std::string& suffix) {
215+
return str.size() >= suffix.size() &&
216+
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
217+
}
218+
219+
void DebugUtil::analyze_graph_execution_python_frame() {
220+
static bool is_master_process =
221+
(runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0);
222+
static std::string debug_file_name =
223+
runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", "");
224+
static std::string debug_output_prefix = "Execution Analysis: ";
225+
// TODO: Make this configurable.
226+
if (!is_master_process) {
227+
return;
228+
}
229+
std::vector<torch::lazy::SourceLocation> frames =
230+
torch::lazy::GetPythonFrames();
231+
// python frame must be > 1
232+
XLA_CHECK_GE(frames.size(), 1);
233+
std::stringstream ss;
234+
ss << "\n"
235+
<< debug_output_prefix
236+
<< "======================================================================"
237+
"=========="
238+
<< "\n";
239+
ss << debug_output_prefix << "Execution Cause\n";
240+
if (frames[0].function == "mark_step") {
241+
if (frames[1].function == "next" &&
242+
endsWith(frames[1].file, "parallel_loader.py")) {
243+
ss << debug_output_prefix
244+
<< " mark_step in parallel loader at step end\n";
245+
} else if (frames[1].function == "__exit__" &&
246+
endsWith(frames[1].file, "profiler.py")) {
247+
ss << debug_output_prefix
248+
<< " mark_step when exiting a profiler StepTrace region\n";
249+
} else if ((frames[1].function == "extract_compiled_graph" ||
250+
frames[1].function == "extract_internal") &&
251+
endsWith(frames[1].file, "dynamo_bridge.py")) {
252+
ss << debug_output_prefix
253+
<< " mark_step when dynamo processing input graphs\n";
254+
} else {
255+
ss << debug_output_prefix << " user mark_step\n";
256+
}
257+
} else if (frames[0].function == "extract_graph_helper" &&
258+
endsWith(frames[0].file, "dynamo_bridge.py")) {
259+
ss << debug_output_prefix << " dynamo compiles FX graph to HLO\n";
260+
} else {
261+
// TODO(JackCaoG): be more specific about exeuction caused by printing
262+
// tensor or fallback or some weird indexing.
263+
ss << debug_output_prefix
264+
<< " most likely user code trying to access tensor value before "
265+
"mark_step\n";
266+
}
267+
268+
// TODO(JackCaoG): make number of frames printed configurable
269+
ss << debug_output_prefix << "Python Frame Triggered Execution: \n";
270+
for (auto& location : frames) {
271+
ss << debug_output_prefix << " " << location.function << " ("
272+
<< location.file << ":" << location.line << ")\n";
273+
}
274+
ss << debug_output_prefix
275+
<< "----------------------------------------------------------------------"
276+
"----------"
277+
<< "\n";
278+
ss << debug_output_prefix
279+
<< "======================================================================"
280+
"=========="
281+
<< "\n";
282+
283+
// TODO(JackCaoG): print more information about the graph that is about to get
284+
// executed.
285+
if (debug_file_name == "") {
286+
// print to stderr by default
287+
std::cerr << ss.str();
288+
} else {
289+
std::ofstream outFile;
290+
outFile.open(debug_file_name, std::ios_base::app);
291+
outFile << ss.rdbuf();
292+
}
293+
}
294+
212295
} // namespace torch_xla

torch_xla/csrc/debug_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class DebugUtil {
4646
absl::Span<const size_t> indices);
4747

4848
static bool ExperimentEnabled(const std::string& name);
49+
50+
// warning, this function should only be called when a graph execution is
51+
// about to happen.
52+
static void analyze_graph_execution_python_frame();
4953
};
5054

5155
} // namespace torch_xla

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,9 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
13221322
const SyncTensorsConfig& config, bool warm_up_cache_only) {
13231323
tsl::profiler::TraceMe activity("SyncTensorsGraphInternal",
13241324
tsl::profiler::TraceMeLevel::kInfo);
1325+
if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) {
1326+
DebugUtil::analyze_graph_execution_python_frame();
1327+
}
13251328
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
13261329
if (coll.indices.empty()) {
13271330
// Enure previous execution is complete before exiting this

0 commit comments

Comments
 (0)