Skip to content

Commit 4d88aff

Browse files
Chilleepytorchmergebot
authored andcommitted
Ported proxy tensor tests over to core (pytorch#78890)
Will fill out later Pull Request resolved: pytorch#78890 Approved by: https://github.com/ezyang, https://github.com/zou3519
1 parent e806d13 commit 4d88aff

File tree

3 files changed

+215
-42
lines changed

3 files changed

+215
-42
lines changed

test/test_fx_experimental.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from torch.fx.experimental.rewriter import RewritingTracer
3030
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
3131
import torch.fx.experimental.meta_tracer
32-
from torch.fx.experimental.proxy_tensor import make_fx
3332
from torch.fx.graph_module import GraphModule
3433
from torch.fx.node import Node
3534
from torch.fx.operator_schemas import (
@@ -701,46 +700,6 @@ def forward(self, x):
701700

702701
torch.testing.assert_close(loaded(x), mttm(x))
703702

704-
def test_proxy_tensor(self):
705-
def f_grad(x):
706-
val = x.cos().cos().sum()
707-
return torch.autograd.grad(val, x)
708-
709-
def f_backward(x):
710-
val = x.cos().cos().sum()
711-
val.backward()
712-
return x.grad
713-
714-
for f in [f_grad, f_backward]:
715-
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
716-
inp = torch.randn(3, requires_grad=True)
717-
traced_graph_out = traced_graph(inp)
718-
assert inp.grad is None
719-
torch.testing.assert_close(traced_graph_out, f(inp))
720-
721-
def test_mode_tracing_factory_function(self):
722-
def f(x):
723-
return x + torch.randn(x.shape)
724-
725-
traced = make_fx(f, trace_factory_functions=True)(torch.randn(3))
726-
self.assertTrue(
727-
any(
728-
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
729-
for node in traced.graph.nodes
730-
)
731-
)
732-
733-
def test_mode_tracing_factory_function_default_behavior(self):
734-
def f(x):
735-
return x + torch.randn(x.shape)
736-
737-
traced = make_fx(f)(torch.randn(3)) # default behavior should not trace factory functions
738-
self.assertFalse(
739-
any(
740-
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
741-
for node in traced.graph.nodes
742-
)
743-
)
744703

745704
def test_call_to_assert_with_msg(self):
746705
class M(torch.nn.Module):

test/test_proxy_tensor.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Owner(s): ["oncall: fx"]
2+
3+
from torch.testing._internal.common_utils import TestCase, run_tests
4+
import torch
5+
import unittest
6+
import warnings
7+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
8+
from torch.testing._internal.common_methods_invocations import DecorateInfo
9+
from torch.testing._internal.common_methods_invocations import op_db
10+
11+
from torch.testing._internal.common_device_type import ops
12+
from torch.fx.experimental.proxy_tensor import make_fx
13+
14+
# Copied from functorch
15+
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
16+
return (op_name, variant_name, device_type, dtypes, True)
17+
18+
19+
def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
20+
return (op_name, variant_name, device_type, dtypes, False)
21+
22+
23+
def skipOps(test_case_name, base_test_name, to_skip):
24+
all_opinfos = op_db
25+
for xfail in to_skip:
26+
op_name, variant_name, device_type, dtypes, expected_failure = xfail
27+
matching_opinfos = [o for o in all_opinfos
28+
if o.name == op_name and o.variant_test_name == variant_name]
29+
assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
30+
for opinfo in matching_opinfos:
31+
decorators = list(opinfo.decorators)
32+
if expected_failure:
33+
decorator = DecorateInfo(unittest.expectedFailure,
34+
test_case_name, base_test_name,
35+
device_type=device_type, dtypes=dtypes)
36+
decorators.append(decorator)
37+
else:
38+
decorator = DecorateInfo(unittest.skip("Skipped!"),
39+
test_case_name, base_test_name,
40+
device_type=device_type, dtypes=dtypes)
41+
decorators.append(decorator)
42+
opinfo.decorators = tuple(decorators)
43+
44+
# This decorator doesn't modify fn in any way
45+
def wrapped(fn):
46+
return fn
47+
return wrapped
48+
49+
50+
USE_TORCHVISION = False
51+
try:
52+
import torchvision
53+
USE_TORCHVISION = True
54+
except ImportError:
55+
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
56+
"to install it with commands from pytorch.org, post-fixed with "
57+
"`--no-deps` to avoid overwriting the pytorch installation",
58+
UserWarning)
59+
60+
61+
class TestProxyTensor(TestCase):
62+
def test_make_fx(self, device):
63+
def f(x):
64+
return torch.sin(x)
65+
inp = torch.randn(3)
66+
fx_f = make_fx(f)(inp)
67+
68+
new_inp = torch.randn(3)
69+
self.assertEqual(fx_f(new_inp), f(new_inp))
70+
71+
def test_scalar_device(self, device):
72+
def f(a, b):
73+
return a + b
74+
inps = [torch.randn(3, device=device), torch.tensor(5)]
75+
fx_f = make_fx(f)(*inps)
76+
self.assertEqual(fx_f(*inps), f(*inps))
77+
78+
79+
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
80+
def test_resnet18_backward_trace(self, device):
81+
mod = torchvision.models.resnet18()
82+
83+
def f(x):
84+
out = mod(x)
85+
out.sum().backward()
86+
return [a.grad for a in mod.parameters()]
87+
88+
inp = torch.randn(3, 3, 250, 250, requires_grad=True)
89+
grads = f(inp)
90+
91+
mod.zero_grad()
92+
mod(inp).sum().backward()
93+
grads2 = [a.grad for a in mod.parameters()]
94+
self.assertEqual(grads, grads2)
95+
96+
def test_proxy_tensor(self):
97+
def f_grad(x):
98+
val = x.cos().cos().sum()
99+
return torch.autograd.grad(val, x)
100+
101+
def f_backward(x):
102+
val = x.cos().cos().sum()
103+
val.backward()
104+
return x.grad
105+
106+
for f in [f_grad, f_backward]:
107+
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
108+
inp = torch.randn(3, requires_grad=True)
109+
traced_graph_out = traced_graph(inp)
110+
assert inp.grad is None
111+
torch.testing.assert_close(traced_graph_out, f(inp))
112+
113+
def test_mode_tracing_factory_function(self):
114+
def f(x):
115+
return x + torch.randn(x.shape)
116+
117+
traced = make_fx(f, trace_factory_functions=True)(torch.randn(3))
118+
self.assertTrue(
119+
any(
120+
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
121+
for node in traced.graph.nodes
122+
)
123+
)
124+
125+
def test_mode_tracing_factory_function_default_behavior(self):
126+
def f(x):
127+
return x + torch.randn(x.shape)
128+
129+
traced = make_fx(f)(torch.randn(3)) # default behavior should not trace factory functions
130+
self.assertFalse(
131+
any(
132+
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
133+
for node in traced.graph.nodes
134+
)
135+
)
136+
137+
make_fx_failures = {
138+
xfail('allclose'),
139+
xfail('nn.functional.dropout'),
140+
xfail('linalg.eigvals'),
141+
xfail('nn.functional.max_pool1d', device_type='cpu'), # precision problems?
142+
xfail('randn_like'), # randomness
143+
xfail('rand_like'), # randomness
144+
xfail('randint_like'), # randomness
145+
skip('new_empty'), # nondeterministic
146+
skip('empty_like'), # nondeterministic
147+
skip('linalg.lstsq', 'grad_oriented'), # flaky
148+
xfail('normal', '', device_type='cpu'),
149+
xfail('normal', 'number_mean', device_type='cpu'),
150+
xfail('multinomial', device_type='cpu'),
151+
xfail('nn.functional.feature_alpha_dropout', 'with_train', device_type='cpu'),
152+
xfail('bernoulli', device_type='cpu'),
153+
xfail('nn.functional.dropout2d', device_type='cpu'),
154+
skip('nn.functional.max_unpool1d', '', device_type='cpu'), # flaky
155+
skip('nn.functional.max_unpool2d', '', device_type='cpu'), # flaky
156+
skip('nn.functional.max_unpool3d', '', device_type='cpu'), # flaky
157+
skip('empty'), # nondeterministic
158+
skip('linalg.lstsq'), # flaky, probably just a precision issue
159+
xfail('histogram'),
160+
xfail('scatter'),
161+
# data-dependent control flow
162+
xfail('cov'),
163+
xfail('istft'),
164+
xfail('nanquantile'),
165+
xfail('nn.functional.gaussian_nll_loss'),
166+
xfail('quantile'),
167+
xfail('tensor_split'),
168+
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
169+
xfail('sparse.sampled_addmm'),
170+
}
171+
172+
173+
class TestProxyTensorOpInfo(TestCase):
174+
@ops(op_db, allowed_dtypes=(torch.float,))
175+
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures
176+
)
177+
def test_make_fx_exhaustive(self, device, dtype, op):
178+
179+
def f(args, kwargs):
180+
return op.op(*args, **kwargs)
181+
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
182+
new_f = None
183+
for sample_input in sample_inputs_itr:
184+
args = [sample_input.input] + list(sample_input.args)
185+
kwargs = sample_input.kwargs
186+
187+
new_f = make_fx(f)(args, kwargs)
188+
for arg in args:
189+
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
190+
arg.uniform_(0, 1)
191+
try:
192+
old_out = f(args, kwargs)
193+
except Exception:
194+
continue
195+
new_out = new_f(args, kwargs)
196+
self.assertEqual(new_out, old_out)
197+
198+
199+
200+
only_for = ("cpu")
201+
instantiate_device_type_tests(
202+
TestProxyTensor,
203+
globals(),
204+
only_for=only_for,
205+
)
206+
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
207+
208+
209+
if __name__ == '__main__':
210+
run_tests()

torch/fx/experimental/proxy_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class ProxyTensor(torch.Tensor):
9494
def __new__(cls, elem, proxy, *, requires_grad=None):
9595
# Hack to deal with super().__new__ not working for sparse tensors
9696
if elem.is_sparse or requires_grad is not None:
97+
if requires_grad is None:
98+
requires_grad = False
9799
r = torch.Tensor._make_subclass(cls, elem, requires_grad)
98100
else:
99101
r = super().__new__(cls, elem) # type: ignore[call-arg]
@@ -177,7 +179,9 @@ def wrapped(*args):
177179
for idx, arg in enumerate(flat_args):
178180
if isinstance(flat_inps[idx], torch.Tensor):
179181
with no_dispatch():
180-
flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=flat_inps[idx].is_leaf)
182+
flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=(
183+
flat_inps[idx].is_leaf and flat_inps[idx].requires_grad
184+
))
181185
else:
182186
flat_args[idx] = flat_inps[idx]
183187

0 commit comments

Comments
 (0)