@@ -58,7 +58,13 @@ def _is_on_tpu():
5858 return  'XRT_TPU_CONFIG'  in  os .environ  or  xr .device_type () ==  'TPU' 
5959
6060
61+ def  _is_on_eager_debug_mode ():
62+  return  xu .getenv_as ('XLA_USE_EAGER_DEBUG_MODE' , bool , defval = False )
63+ 
64+ 
6165skipOnTpu  =  unittest .skipIf (_is_on_tpu (), 'Not supported on TPU' )
66+ skipOnEagerDebug  =  unittest .skipIf (_is_on_eager_debug_mode (),
67+  'skip on eager debug mode' )
6268
6369
6470def  _gen_tensor (* args , ** kwargs ):
@@ -1649,6 +1655,42 @@ def test_cached_addcdiv(self):
16491655 xm .mark_step ()
16501656 self .assertEqual (met .metric_data ("TransferToServerTime" )[0 ], 4 )
16511657
1658+  @skipOnEagerDebug  
1659+  def  test_print_executation (self ):
1660+  xla_device  =  xm .xla_device ()
1661+  xm .mark_step ()
1662+  xm .wait_device_ops ()
1663+  met .clear_all ()
1664+ 
1665+  # case 1 mark_step 
1666+  t1  =  torch .randn (1 , 4 , device = xla_device )
1667+  xm .mark_step ()
1668+  xm .wait_device_ops ()
1669+  self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 1 )
1670+  for  _  in  range (3 ):
1671+  print (t1 )
1672+  self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 1 )
1673+  self .assertIn ('xla::device_data' ,
1674+  torch_xla ._XLAC ._get_xla_tensors_text ([t1 ]))
1675+ 
1676+  # case 2 no mark_step, directly print 
1677+  met .clear_all ()
1678+  t1  =  torch .randn (1 , 4 , device = xla_device )
1679+  for  _  in  range (3 ):
1680+  print (t1 )
1681+  self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 1 )
1682+  self .assertIn ('xla::device_data' ,
1683+  torch_xla ._XLAC ._get_xla_tensors_text ([t1 ]))
1684+ 
1685+  # case 2 no mark_step, print with .cpu 
1686+  met .clear_all ()
1687+  t1  =  torch .randn (1 , 4 , device = xla_device )
1688+  for  _  in  range (3 ):
1689+  print (t1 .cpu ())
1690+  self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 1 )
1691+  self .assertIn ('xla::device_data' ,
1692+  torch_xla ._XLAC ._get_xla_tensors_text ([t1 ]))
1693+ 
16521694 def  test_index_types (self ):
16531695
16541696 def  test_fn (* indices ):
0 commit comments