Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 49 additions & 44 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch_xla.core.xla_env_vars as xenv
from torch_xla import runtime as xr
import torch_xla.debug.profiler as xp
from torch_xla._dynamo import dynamo_backend2
import torch.optim as optim
import torch.nn as nn
import torch._dynamo as dynamo
Expand Down Expand Up @@ -38,31 +39,33 @@ def _is_on_neuron():
skipOnNeuron = unittest.skipIf(_is_on_neuron(), 'Not supported on NEURON')


class DynamoInPlaceTest(unittest.TestCase):
class DynamoInPlaceTest(parameterized.TestCase):

def inplace_update(self, a):
a += 1
return a

def test_inplace_update_correctness(self):
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
def test_inplace_update_correctness(self, backend):
dynamo_inplace = torch.compile(
self.inplace_update, backend="openxla", fullgraph=True)
self.inplace_update, backend=backend, fullgraph=True)
t = torch.tensor([0, 1, 2], device=xm.xla_device())
for i in range(10):
t = dynamo_inplace(t)
self.assertTrue(torch.all(torch.eq(t.cpu(), torch.tensor([10, 11, 12]))))


class DynamRandomOpTest(unittest.TestCase):
class DynamRandomOpTest(parameterized.TestCase):

def random_op(self, a):
return torch.randn(5, 5, device=a.device) + a

def test_random_op_different_result_each_run(self):
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
def test_random_op_different_result_each_run(self, backend):
xm.wait_device_ops()
met.clear_all()
dynamo_random_op = torch.compile(
self.random_op, backend="openxla", fullgraph=True)
self.random_op, backend=backend, fullgraph=True)
t = torch.randn(5, 5).to(xm.xla_device())
dynamo_res_1 = dynamo_random_op(t)
dynamo_res_2 = dynamo_random_op(t)
Expand All @@ -75,7 +78,7 @@ def test_random_op_different_result_each_run(self):
self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3))


class DynamoLTCInteractionTest(unittest.TestCase):
class DynamoLTCInteractionTest(parameterized.TestCase):

def index_copy_inplace(self, cache, update_indices, xk):
cache.index_copy_(0, update_indices, xk)
Expand Down Expand Up @@ -104,21 +107,22 @@ def test_mark_step_after_dynamo(self):
xm.wait_device_ops()
self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0])

def test_copy_op(self):
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
def test_copy_op(self, backend):

def copy_a_to_b(a):
res = a.cos()
copy = torch.ops.aten.copy.default(a, res)
copy = torch.ops.aten.copy_.default(a, res)
return copy

device = torch_xla.device()
compiled_copy = torch.compile(copy_a_to_b, backend="openxla")
compiled_copy = torch.compile(copy_a_to_b, backend=backend)
a = torch.randn(2, 9).to(device)
res = compiled_copy(a)
self.assertTrue(torch.allclose(res, a))


class DynamoProfilerTest(unittest.TestCase):
class DynamoProfilerTest(parameterized.TestCase):

def dummy_fn(self, a):
return torch.sin(a) + a
Expand Down Expand Up @@ -253,11 +257,10 @@ def fn_without_input(device):
res_xla_dynamo = compiled_fn(device)
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))

@parameterized.parameters(
True,
False,
)
def test_simple_model_with_in_place_ops(self, initialize_on_cuda):
@parameterized.product(
initialize_on_cuda=[True, False],
backend=['openxla', dynamo_backend2.dynamo_backend])
def test_simple_model_with_in_place_ops(self, initialize_on_cuda, backend):

class TestModel(nn.Module):

Expand Down Expand Up @@ -286,7 +289,7 @@ def forward(self, index, copy_tensor, input_tensor, op_name):

cpu_model = TestModel()
device_model = TestModel(device).to(device)
compiled_model = torch.compile(device_model, backend='openxla')
compiled_model = torch.compile(device_model, backend=backend)

input_tensor = torch.ones(3)
copy_tensor = torch.rand(5, 3)
Expand All @@ -306,11 +309,10 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
op_name=in_place_op)
self.assertTrue(torch.allclose(res_cpu, res_device_dynamo.cpu()))

@parameterized.parameters(
True,
False,
)
def test_einsum(self, initialize_on_cuda):
@parameterized.product(
initialize_on_cuda=[True, False],
backend=['openxla', dynamo_backend2.dynamo_backend])
def test_einsum(self, initialize_on_cuda, backend):
# einsum currently does not have meta function to compute the shape hence
# will fallback to XLA with FakeTensor as input to infer the output shape.
def einsum_mm(a, b):
Expand All @@ -321,7 +323,7 @@ def einsum_mm(a, b):
b = torch.randn(4, 4, 4, 4).to(device)
xm.mark_step()

dynamo_einsum_mm = torch.compile(einsum_mm, backend="openxla")
dynamo_einsum_mm = torch.compile(einsum_mm, backend=backend)
res_device_dynamo = dynamo_einsum_mm(a, b)
res_device_non_dynamo = einsum_mm(a, b)
self.assertTrue(
Expand Down Expand Up @@ -368,11 +370,10 @@ def get_loader(self, device, sample_count, batch_size=4):

@skipOnTpu
@skipOnNeuron
@parameterized.parameters(
True,
False,
)
def test_resnet18(self, initialize_on_cuda):
@parameterized.product(
initialize_on_cuda=[True, False],
backend=['openxla', dynamo_backend2.dynamo_backend])
def test_resnet18(self, initialize_on_cuda, backend):
device = self._choose_proper_device(initialize_on_cuda)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
loader = self.get_loader(device, sample_count, batch_size=4)
Expand All @@ -386,19 +387,21 @@ def test_resnet18(self, initialize_on_cuda):
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
dynamo_resnet18 = torch.compile(device_resnet18, backend=backend)
for data, _ in loader:
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05))
# We only expect one graph for the resnet18 inference.
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count)
self.assertEqual(
met.metric_data('RunCachedGraphInputData')[0], sample_count)
self.assertEqual(
met.metric_data('RunCachedGraphOutputData')[0], sample_count)
if backend == 'openxla':
# backend2 doesnt populate metrics
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count)
self.assertEqual(
met.metric_data('RunCachedGraphInputData')[0], sample_count)
self.assertEqual(
met.metric_data('RunCachedGraphOutputData')[0], sample_count)

@skipOnNeuron
def test_resnet18_lazy_vs_dynamo(self):
Expand Down Expand Up @@ -428,7 +431,7 @@ def test_resnet18_lazy_vs_dynamo(self):
# mess up the counter check.


class DynamoCpuFallbackTest(unittest.TestCase):
class DynamoCpuFallbackTest(parameterized.TestCase):

def test_operator_fallback(self):

Expand Down Expand Up @@ -509,7 +512,7 @@ def fn_fallback(t):
self.assertEqual(met.metric_data('ExecuteTime')[0], 3)


class DynamoTrainingBasicTest(unittest.TestCase):
class DynamoTrainingBasicTest(parameterized.TestCase):

@classmethod
def setUpClass(self):
Expand Down Expand Up @@ -613,7 +616,7 @@ def test_resnet18(self):
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 2)


class DynamoTrainingOptimizerTest(unittest.TestCase):
class DynamoTrainingOptimizerTest(parameterized.TestCase):

@classmethod
def setUpClass(self):
Expand Down Expand Up @@ -719,7 +722,7 @@ def test_resnet18(self):
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3)


class DynamoErrorMessageTest(unittest.TestCase):
class DynamoErrorMessageTest(parameterized.TestCase):

def test_mixed_cpu_tensor(self):
device = xm.xla_device()
Expand Down Expand Up @@ -758,17 +761,18 @@ def test_all_cpu_tensor(self):
self.assertLessEqual(len(met.counter_names()), 1)


class DynamoOperationsTests(test_utils.XlaTestCase):
class DynamoOperationsTest(test_utils.XlaTestCase, parameterized.TestCase):

def test_new_with_sizes(self):
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
def test_new_with_sizes(self, backend):

# The addition operation is needed here, since the error only occurs when FakeTensorMode
# checks the device of the arguments of some operation. If there's no operation using the
# result of Tensor.new, this comparison never occurs.
def foo(x):
return x.new(*x.size()) + x

optfoo = torch.compile(backend="openxla")(foo)
optfoo = torch.compile(backend=backend)(foo)

t = torch.arange(9)
Xt = t.to(xm.xla_device())
Expand All @@ -782,12 +786,13 @@ def foo(x):
self.assertEqual(expected.dtype, actual.dtype)
self.assertEqual(expected.device, actual.device)

def test_return_expand(self):
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
def test_return_expand(self, backend):

def foo(x):
return x.expand(2, -1)

optfoo = torch.compile(backend="openxla")(foo)
optfoo = torch.compile(backend=backend)(foo)

t = torch.arange(10)
Xt = t.to(xm.xla_device())
Expand Down
64 changes: 64 additions & 0 deletions torch_xla/_dynamo/dynamo_backend2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import functools
from typing import Any
import torch
from torch.utils import _pytree as pytree
from torch_xla.core import xla_builder as xb
import torch_xla

from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func


def _dynamo_backend(model: torch.fx.GraphModule, sample_args: Any):
"""A dynamo backend that compiles a FX graph to HLO using JAX and torchax.

It takes FX graph as input and returns a compiled PyTorch function. The FX graph
is traced into a JAX function using torchax, and the JAX function is lowered to HLO.

Args:
model: the graph to be compiled
sample_args: a tuple or list of sample inputs. I.e. model(*sample_args) produces
the model output

Returns:
Another callable f such that f(*sample_inputs) computes the same thing as model.
"""

try:
import torchax.interop
from torchax.export import JaxInterpreter
import jax
except ImportError:
print('To use this dynamo backend, please install torchax')
raise

jax.config.update("jax_enable_x64", True)
env = torchax.default_env()
xla_device = torch_xla.device()

def run_jax(*args, initial_rng_key):
args_t = torchax.interop.torch_view(args)
env.manual_seed(initial_rng_key)
with env:
res = model(*args_t)
return torchax.interop.jax_view(res)

initial_rng_key = torch.tensor(0, device=xla_device, dtype=torch.uint32)
computation = xb.jax_func_to_xla_computation(
run_jax, sample_args, {'initial_rng_key': initial_rng_key}, 'dynamo_jax')

def equivalent(*args, **kwargs):
kwargs['initial_rng_key'] = torch.randint(
0, 2**32, (), dtype=torch.uint32, device=xla_device)
flattened, _ = pytree.tree_flatten((args, kwargs))
res = computation(flattened)
if not isinstance(res, (list, tuple)):
return (res,)
return res

return make_boxed_func(equivalent)


def dynamo_backend(fx, args):
from functorch.compile import aot_function
return aot_function(fx, fw_compiler=_dynamo_backend)
4 changes: 2 additions & 2 deletions torch_xla/core/xla_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,8 +911,8 @@ def get_hlo():
import torch_xla.debug.profiler as xp
# If we see this trace span in the profiler, we'll know that there's a cache miss.
with xp.Trace('jax_to_hlo'):
hlo_ir = jax.jit(
fn, keep_unused=True).lower(*sample_tensor_args).compiler_ir('hlo')
lowered = jax.jit(fn, keep_unused=True).lower(*sample_tensor_args)
hlo_ir = lowered.compiler_ir('hlo')

# Get a protobuf representation of the HLO. `as_serialized_hlo_module_proto` is
# mentioned at https://github.com/jax-ml/jax/discussions/22266
Expand Down
8 changes: 4 additions & 4 deletions torchax/test/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ def test_mode_decorator(self):

def test_same_manual_seed(self):
with xla_env:
torch.manual_seed(1234)
xla_env.manual_seed(1234)
x = torch.randn((3, 3))
self.assertIsInstance(x, tensor.Tensor)

torch.manual_seed(1234)
xla_env.manual_seed(1234)
y = torch.randn((3, 3))
self.assertIsInstance(y, tensor.Tensor)

self.assertTrue(torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem)))

def test_different_manual_seed(self):
with xla_env:
torch.manual_seed(1234)
xla_env.manual_seed(1234)
x = torch.randn((3, 3))
self.assertIsInstance(x, tensor.Tensor)

torch.manual_seed(12345)
xla_env.manual_seed(12345)
y = torch.randn((3, 3))
self.assertIsInstance(y, tensor.Tensor)

Expand Down
7 changes: 7 additions & 0 deletions torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,13 @@ def reduce_fn(a, b):

return y, indices

try:
@op(torch.ops.xla.max_pool2d_forward)
def _xla_max_pool2d_foward(*args, **kwargs):
return _aten_max_pool2d_with_indices(*args, **kwargs)[0]
except AttributeError:
pass


# TODO add more ops

Expand Down
Loading
Loading