Skip to content

Commit 6df91b5

Browse files
pianpwkpytorchmergebot
authored andcommitted
real tensor prop for composite ops (pytorch#135717)
Fixes pytorch#135632 Adds real tensor propagation for decompositions, checking any symbols on their outputs Pull Request resolved: pytorch#135717 Approved by: https://github.com/ezyang
1 parent 0cdc6a8 commit 6df91b5

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

test/export/test_export.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,25 @@ def forward(self, x):
890890
torch.allclose(ep.module()(torch.zeros(2, 3)), torch.ones(2, 3) * 21)
891891
)
892892

893+
@testing.expectedFailureTrainingIRToRunDecompNonStrict # TODO(pianpwk): user_output signature
894+
def test_real_tensor_for_max_op(self):
895+
class Foo(torch.nn.Module):
896+
def forward(self, x, y):
897+
x = x[x > 0]
898+
y = y[y > 0]
899+
return max(x.shape[0], y.shape[0])
900+
901+
model = Foo()
902+
inputs = (torch.randn(64), torch.randn(64))
903+
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
904+
ep = export(model, inputs)
905+
906+
self.assertEqual(ep.module()(*inputs), model(*inputs))
907+
x = torch.zeros(64)
908+
y = torch.ones(64)
909+
self.assertEqual(ep.module()(x, x), model(x, x))
910+
self.assertEqual(ep.module()(x, y), model(x, y))
911+
893912
def test_export_script_module(self):
894913
class Foo(torch.nn.Module):
895914
def forward(self, rv: torch.Tensor, t: torch.Tensor):

torch/_subclasses/fake_tensor.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,6 +1877,7 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
18771877
for a in flat_args
18781878
)
18791879
):
1880+
log.debug("propagate_real_tensors %s", func)
18801881
real_flat_args = [maybe_to_real_tensor(a) for a in flat_args]
18811882
real_args, real_kwargs = pytree.tree_unflatten(real_flat_args, args_spec)
18821883
real_out = func(*real_args, **real_kwargs)
@@ -1888,7 +1889,7 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
18881889
# However, if there's a bug in the condition above, this condition
18891890
# will also trigger.
18901891
log.debug(
1891-
"propagate_real_tensors skipped %s(%s, %s) %s",
1892+
"SKIPPED propagate_real_tensors %s(%s, %s) %s",
18921893
func,
18931894
flat_arg_fake_tensors,
18941895
flat_args,
@@ -1898,17 +1899,40 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
18981899
def maybe_propagate_real_tensors(fake_out: T) -> T:
18991900
import sympy
19001901

1902+
log.debug("maybe_propagate_real_tensors %s", func)
1903+
19011904
def go(t: object, real_t: Tensor) -> None:
19021905
if isinstance(t, FakeTensor):
19031906
# NB: unconditionally overwrite
1907+
log.debug(
1908+
"maybe_propagate_real_tensors %s -> %s", id(t), id(real_t)
1909+
)
19041910
t.real_tensor = real_t
1911+
for s, real_s in zip(t.size(), real_t.size()):
1912+
go(s, real_s) # type: ignore[arg-type]
1913+
for s, real_s in zip(t.stride(), real_t.stride()):
1914+
go(s, real_s) # type: ignore[arg-type]
1915+
go(t.storage_offset(), real_t.storage_offset()) # type: ignore[arg-type]
19051916
elif isinstance(t, py_sym_types) and free_unbacked_symbols(t):
19061917
if isinstance(t.node.expr, sympy.Symbol):
19071918
assert self.shape_env is not None
19081919
self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t)
19091920

19101921
if real_out is not nil:
1911-
tree_map_(go, fake_out, real_out)
1922+
if (
1923+
not isinstance(fake_out, Tensor)
1924+
and not isinstance(real_out, Tensor)
1925+
and type(fake_out) != type(real_out)
1926+
):
1927+
# This can happen when decompositions have different return types,
1928+
# e.g. namedtuple vs. tuple vs. list.
1929+
tree_map_(
1930+
go,
1931+
tuple(pytree.tree_flatten(fake_out)),
1932+
tuple(pytree.tree_flatten(real_out)),
1933+
)
1934+
else:
1935+
tree_map_(go, fake_out, real_out)
19121936

19131937
# If a data-dependent op is used in a decomposition, we
19141938
# may need to get the unbacked settings "early"
@@ -1940,13 +1964,15 @@ def go(t: object, real_t: Tensor) -> None:
19401964
)
19411965
):
19421966
with self:
1943-
return decomposition_table[func](*args, **kwargs)
1967+
return maybe_propagate_real_tensors(
1968+
decomposition_table[func](*args, **kwargs)
1969+
)
19441970

19451971
with self:
19461972
# Decomposes CompositeImplicitAutograd ops
19471973
r = func.decompose(*args, **kwargs)
19481974
if r is not NotImplemented:
1949-
return r
1975+
return maybe_propagate_real_tensors(r)
19501976

19511977
# prims already wrap FakeTensor inputs to FakeTensor outputs
19521978
# and do device logic, we dont need do anything but run them

torch/fx/experimental/symbolic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2726,6 +2726,7 @@ def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr):
27262726
def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None:
27272727
"""Used only when propagate_real_tensors; registers a value for an
27282728
unbacked symbol, which can be used last resort to resolve hints."""
2729+
log.info("set_unbacked_var_to_val %s = %s", k, v)
27292730
self.unbacked_var_to_val[k] = sympy.sympify(v)
27302731

27312732
# Unlike set_replacement, this records a shapeenv event

0 commit comments

Comments
 (0)