Skip to content

Commit 1ad310b

Browse files
authored
Make sure printing XLA tensor only execute the HLO once (#5721)
* Add test to make sure print tensor only execute graph once * also check hlo * fix test racing * skip if on eager debug mode
1 parent 43cd6b5 commit 1ad310b

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

test/test_operations.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,13 @@ def _is_on_tpu():
5858
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'
5959

6060

61+
def _is_on_eager_debug_mode():
62+
return xu.getenv_as('XLA_USE_EAGER_DEBUG_MODE', bool, defval=False)
63+
64+
6165
skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU')
66+
skipOnEagerDebug = unittest.skipIf(_is_on_eager_debug_mode(),
67+
'skip on eager debug mode')
6268

6369

6470
def _gen_tensor(*args, **kwargs):
@@ -1649,6 +1655,42 @@ def test_cached_addcdiv(self):
16491655
xm.mark_step()
16501656
self.assertEqual(met.metric_data("TransferToServerTime")[0], 4)
16511657

1658+
@skipOnEagerDebug
1659+
def test_print_executation(self):
1660+
xla_device = xm.xla_device()
1661+
xm.mark_step()
1662+
xm.wait_device_ops()
1663+
met.clear_all()
1664+
1665+
# case 1 mark_step
1666+
t1 = torch.randn(1, 4, device=xla_device)
1667+
xm.mark_step()
1668+
xm.wait_device_ops()
1669+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
1670+
for _ in range(3):
1671+
print(t1)
1672+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
1673+
self.assertIn('xla::device_data',
1674+
torch_xla._XLAC._get_xla_tensors_text([t1]))
1675+
1676+
# case 2 no mark_step, directly print
1677+
met.clear_all()
1678+
t1 = torch.randn(1, 4, device=xla_device)
1679+
for _ in range(3):
1680+
print(t1)
1681+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
1682+
self.assertIn('xla::device_data',
1683+
torch_xla._XLAC._get_xla_tensors_text([t1]))
1684+
1685+
# case 2 no mark_step, print with .cpu
1686+
met.clear_all()
1687+
t1 = torch.randn(1, 4, device=xla_device)
1688+
for _ in range(3):
1689+
print(t1.cpu())
1690+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
1691+
self.assertIn('xla::device_data',
1692+
torch_xla._XLAC._get_xla_tensors_text([t1]))
1693+
16521694
def test_index_types(self):
16531695

16541696
def test_fn(*indices):

0 commit comments

Comments
 (0)