Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 74 additions & 19 deletions torchax/test/test_interop.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,40 @@
import functools
import torch
import unittest
import torchax
from torchax import interop
import torchax

class M1(torch.nn.Module):

def __init__(self):
super().__init__()
self.x = torch.ones(10, 10)
class InteropTest(unittest.TestCase):

class M(torch.nn.Module):
def setUp(self):
torchax.enable_globally()

def __init__(self):
super().__init__()
self.a = torch.nn.Linear(100, 100)
self.b = torch.nn.Parameter(
torch.ones(10, 10)
)
c = torch.ones(10, 10)
self.register_buffer('c', c)
self.register_buffer('c2', c, persistent=False)
self.d = torch.ones(10, 10)
self.m1 = M1()

def test_mod_attr(self):

class InteropTest(unittest.TestCase):
class Child(torch.nn.Module):

def __init__(self):
super().__init__()
self.x = torch.ones(10, 10)

def test_mod_attr(self):
m = M()
class ModuleWithUnregisteredTensor(torch.nn.Module):

def __init__(self):
super().__init__()
self.a = torch.nn.Linear(100, 100)
self.b = torch.nn.Parameter(
torch.ones(10, 10)
)
c = torch.ones(10, 10)
self.register_buffer('c', c)
self.register_buffer('c2', c, persistent=False)
self.d = torch.ones(10, 10)
self.m1 = Child()

m = ModuleWithUnregisteredTensor()
params, buffers = interop.extract_all_buffers(m)
self.assertEqual(
set(params.keys()), {'a.weight', 'a.bias', 'b'}
Expand Down Expand Up @@ -75,6 +81,55 @@ def fn(x):
expected = torch.ones(2, 2) * 2
torch.testing.assert_close(x.grad, expected, check_device=False)

def test_module_with_shared_weights(self):

# arrange
class ModuleWithSharedWeights(torch.nn.Module):

def __init__(self):
super().__init__()
self.a = torch.nn.Linear(10, 10)
self.b = self.a

def forward(self, x):
return self.a(self.b(x))

m = ModuleWithSharedWeights().to('jax')

m_jitted = interop.JittableModule(m, dedup_parameters=True)

# a's weights and bias and b's weights and bias
self.assertEqual(len(m.state_dict()), 4)

# b's weights and bias are deduped
self.assertEqual(len(m_jitted.params), 2)
x = torch.randn(10, 10).to('jax')
expected = m(x)

# act
actual = m_jitted(x)

# assert
torch.testing.assert_allclose(actual, expected)

# arrange
# make sure buffer donation works
functional_forward = interop.jax_jit(
functools.partial(m_jitted.functional_call, 'forward'),
kwargs_for_jax_jit={
'donate_argnums': (0, )
}
)

# act
actual = functional_forward(m_jitted.params, m_jitted.buffers, x)
# assert
torch.testing.assert_allclose(actual, expected)






if __name__ == '__main__':
unittest.main()
24 changes: 22 additions & 2 deletions torchax/torchax/interop.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import copy
import functools
import torch
Expand Down Expand Up @@ -51,14 +52,27 @@ def set_one(module, prefix):

class JittableModule(torch.nn.Module):

def __init__(self, m: torch.nn.Module, extra_jit_args={}):
def __init__(self, m: torch.nn.Module, extra_jit_args={}, dedup_parameters=True):
super().__init__()
self.params, self.buffers = extract_all_buffers(m)
self._model = m
self._jitted = {}

self._extra_jit_args = extra_jit_args

self._extra_dumped_weights = {}

if dedup_parameters:
temp = collections.defaultdict(list)
for k, v in self.params.items():
temp[id(v)].append(k)

for v in temp.values():
if len(v) > 1:
# duplicated weights with different name
self._extra_dumped_weights[v[0]] = v[1:]
for extra_keys in v[1:]:
del self.params[extra_keys]

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
Expand All @@ -69,6 +83,10 @@ def functional_call(
kwargs = kwargs or {}
params_copy = copy.copy(params)
params_copy.update(buffers)
# reinflate the state dict so there are not any missing keys
for k, v in self._extra_dumped_weights.items():
for new_key in v:
params_copy[new_key] = params_copy[k]
with torch_stateless._reparametrize_module(self._model, params_copy):
res = getattr(self._model, method_name)(*args, **kwargs)
return res
Expand Down Expand Up @@ -285,11 +303,13 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
return torch_view(jitted)


def jax_jit(torch_function, kwargs_for_jax_jit=None):
def jax_jit(torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False):
return wrap_jax_jit(torch_function, jax_jit_func=jax.jit,
kwargs_for_jax=kwargs_for_jax_jit)




def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
return wrap_jax_jit(torch_function, jax_jit_func=shard_map,
kwargs_for_jax=kwargs_for_jax_shard_map)
Expand Down
2 changes: 2 additions & 0 deletions torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,12 +1590,14 @@ def _aten_bitwise_not(self):


# aten.bitwise_left_shift
@op(torch.ops.aten.__lshift__)
@op(torch.ops.aten.bitwise_left_shift)
def _aten_bitwise_left_shift(input, other):
return jnp.left_shift(input, other)


# aten.bitwise_right_shift
@op(torch.ops.aten.__rshift__)
@op(torch.ops.aten.bitwise_right_shift)
def _aten_bitwise_right_shift(input, other):
return jnp.right_shift(input, other)
Expand Down
Loading