Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions test/eager/test_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest
import sys

import torch
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm


class Eager(unittest.TestCase):

@classmethod
def setUpClass(cls):
torch_xla.experimental.eager_mode(True)

def test_eager_basic(self):
met.clear_all()
self.assertTrue(torch_xla.experimental.is_eager_mode())
device = torch_xla.device()

# For some reason randn will also trigger an execution of
# size [5, 5] full of 0.
t1 = torch.randn(5, 5, device=device)
xm.wait_device_ops()
self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 2)
self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 2)

t1 *= 5
xm.wait_device_ops()
self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 3)
self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 3)

def test_eager_recompile(self):
self.assertTrue(torch_xla.experimental.is_eager_mode())
device = torch_xla.device()

t1 = torch.randn(5, 5, device=device)
xm.wait_device_ops()
met.clear_all()

t2 = torch.logsumexp(t1, 0)
xm.wait_device_ops()
self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 1)
self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 1)

t3 = torch.logsumexp(t1, 0)
xm.wait_device_ops()
# make sure no recompilation
self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 1)
self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 2)

def test_eager_in_place(self):
self.assertTrue(torch_xla.experimental.is_eager_mode())
device = torch_xla.device()

t1 = torch.randn(5, 5, device=device)
xm.wait_device_ops()
met.clear_all()
xm.optimization_barrier_([t1])
xm.wait_device_ops()
self.assertEqual(met.metric_data("EagerOpCompileTime")[0], 1)
self.assertEqual(met.metric_data("EagerOpExecuteTime")[0], 1)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
run_test "$CDIR/eager/test_eager_with_torch_compile.py"
}
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,13 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) {
torch::lazy::MakeNode<Cast>(ir_value, xla_shape.get().element_type());
}
SetIrValue(std::move(ir_value), /*inplace=*/true);
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();

// in place update should also be triggered eagerly if configured
if (graph_executor->UseEagerMode()) {
std::vector<XLATensorPtr> xtensors({c10::make_intrusive<XLATensor>(*this)});
graph_executor->ApplyEagerSync(xtensors);
}
}

void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const {
Expand Down