@@ -1877,6 +1877,7 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
18771877 for a in flat_args
18781878 )
18791879 ):
1880+ log .debug ("propagate_real_tensors %s" , func )
18801881 real_flat_args = [maybe_to_real_tensor (a ) for a in flat_args ]
18811882 real_args , real_kwargs = pytree .tree_unflatten (real_flat_args , args_spec )
18821883 real_out = func (* real_args , ** real_kwargs )
@@ -1888,7 +1889,7 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
18881889 # However, if there's a bug in the condition above, this condition
18891890 # will also trigger.
18901891 log .debug (
1891- "propagate_real_tensors skipped %s(%s, %s) %s" ,
1892+ "SKIPPED propagate_real_tensors %s(%s, %s) %s" ,
18921893 func ,
18931894 flat_arg_fake_tensors ,
18941895 flat_args ,
@@ -1898,17 +1899,40 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
18981899 def maybe_propagate_real_tensors (fake_out : T ) -> T :
18991900 import sympy
19001901
1902+ log .debug ("maybe_propagate_real_tensors %s" , func )
1903+
19011904 def go (t : object , real_t : Tensor ) -> None :
19021905 if isinstance (t , FakeTensor ):
19031906 # NB: unconditionally overwrite
1907+ log .debug (
1908+ "maybe_propagate_real_tensors %s -> %s" , id (t ), id (real_t )
1909+ )
19041910 t .real_tensor = real_t
1911+ for s , real_s in zip (t .size (), real_t .size ()):
1912+ go (s , real_s ) # type: ignore[arg-type]
1913+ for s , real_s in zip (t .stride (), real_t .stride ()):
1914+ go (s , real_s ) # type: ignore[arg-type]
1915+ go (t .storage_offset (), real_t .storage_offset ()) # type: ignore[arg-type]
19051916 elif isinstance (t , py_sym_types ) and free_unbacked_symbols (t ):
19061917 if isinstance (t .node .expr , sympy .Symbol ):
19071918 assert self .shape_env is not None
19081919 self .shape_env .set_unbacked_var_to_val (t .node .expr , real_t )
19091920
19101921 if real_out is not nil :
1911- tree_map_ (go , fake_out , real_out )
1922+ if (
1923+ not isinstance (fake_out , Tensor )
1924+ and not isinstance (real_out , Tensor )
1925+ and type (fake_out ) != type (real_out )
1926+ ):
1927+ # This can happen when decompositions have different return types,
1928+ # e.g. namedtuple vs. tuple vs. list.
1929+ tree_map_ (
1930+ go ,
1931+ tuple (pytree .tree_flatten (fake_out )),
1932+ tuple (pytree .tree_flatten (real_out )),
1933+ )
1934+ else :
1935+ tree_map_ (go , fake_out , real_out )
19121936
19131937 # If a data-dependent op is used in a decomposition, we
19141938 # may need to get the unbacked settings "early"
@@ -1940,13 +1964,15 @@ def go(t: object, real_t: Tensor) -> None:
19401964 )
19411965 ):
19421966 with self :
1943- return decomposition_table [func ](* args , ** kwargs )
1967+ return maybe_propagate_real_tensors (
1968+ decomposition_table [func ](* args , ** kwargs )
1969+ )
19441970
19451971 with self :
19461972 # Decomposes CompositeImplicitAutograd ops
19471973 r = func .decompose (* args , ** kwargs )
19481974 if r is not NotImplemented :
1949- return r
1975+ return maybe_propagate_real_tensors ( r )
19501976
19511977 # prims already wrap FakeTensor inputs to FakeTensor outputs
19521978 # and do device logic, we dont need do anything but run them
0 commit comments