2121"""
2222import itertools
2323from dataclasses import dataclass
24- from typing import Any
24+ from typing import Any , Callable , Dict , Optional , Tuple , Union
2525
2626import sympy
2727
@@ -47,11 +47,11 @@ class SymPyOps:
4747 """
4848
4949 @staticmethod
50- def identity (value ) :
50+ def identity (value : Any ) -> Any :
5151 return value
5252
5353 @staticmethod
54- def constant (value , dtype ) :
54+ def constant (value : Union [ int , float , bool ], dtype : torch . dtype ) -> TypedExpr :
5555 if is_boolean_dtype (dtype ):
5656 expr = sympy .Integer (bool (value ))
5757 elif is_integer_dtype (dtype ):
@@ -61,13 +61,13 @@ def constant(value, dtype):
6161 return TypedExpr (expr , dtype )
6262
6363 @staticmethod
64- def index_expr (value , dtype ) :
64+ def index_expr (value : sympy . Expr , dtype : torch . dtype ) -> Union [ int , TypedExpr ] :
6565 if isinstance (value , int ):
6666 value = sympy .Integer (value )
6767 return TypedExpr (value , dtype )
6868
6969 @staticmethod
70- def to_dtype (value , dtype ) :
70+ def to_dtype (value : Any , dtype : torch . dtype ) -> Union [ int , TypedExpr ] :
7171 if isinstance (value .expr , (sympy .Integer , sympy .Float )):
7272 return SymPyOps .constant (value .expr , dtype )
7373 elif is_integer_dtype (dtype ) and is_integer_dtype (value .dtype ):
@@ -77,38 +77,38 @@ def to_dtype(value, dtype):
7777 return NotImplemented
7878
7979 @staticmethod
80- def square (x ) :
80+ def square (x : TypedExpr ) -> TypedExpr :
8181 return TypedExpr (x .expr * x .expr , x .dtype )
8282
8383 @staticmethod
84- def add (x , y ) :
84+ def add (x : TypedExpr , y : TypedExpr ) -> TypedExpr :
8585 result_type = torch .promote_types (x .dtype , y .dtype )
8686 return TypedExpr (x .expr + y .expr , result_type )
8787
8888 @staticmethod
89- def sub (x , y ) :
89+ def sub (x : TypedExpr , y : TypedExpr ) -> TypedExpr :
9090 result_type = torch .promote_types (x .dtype , y .dtype )
9191 return TypedExpr (x .expr - y .expr , result_type )
9292
9393 @staticmethod
94- def mul (x , y ) :
94+ def mul (x : TypedExpr , y : TypedExpr ) -> TypedExpr :
9595 result_type = torch .promote_types (x .dtype , y .dtype )
9696 return TypedExpr (x .expr * y .expr , result_type )
9797
9898 @staticmethod
99- def neg (x ) :
99+ def neg (x : TypedExpr ) -> TypedExpr :
100100 return TypedExpr (- x .expr , x .dtype )
101101
102102 @staticmethod
103- def floordiv (x , y ) :
103+ def floordiv (x : TypedExpr , y : TypedExpr ) -> TypedExpr :
104104 result_type = torch .promote_types (x .dtype , y .dtype )
105105 if not is_integer_dtype (result_type ):
106106 return NotImplemented
107107
108108 return TypedExpr (FloorDiv (x .expr , y .expr ), result_type )
109109
110110 @staticmethod
111- def remainder (x , y ) :
111+ def remainder (x : TypedExpr , y : TypedExpr ) -> Optional [ TypedExpr ] :
112112 result_type = torch .promote_types (x .dtype , y .dtype )
113113 if not is_integer_dtype (result_type ):
114114 return NotImplemented
@@ -117,12 +117,12 @@ def remainder(x, y):
117117 return TypedExpr (result_expr , result_type )
118118
119119 @staticmethod
120- def minimum (x , y ) :
120+ def minimum (x : TypedExpr , y : TypedExpr ) -> TypedExpr :
121121 result_type = torch .promote_types (x .dtype , y .dtype )
122122 return TypedExpr (sympy .Min (x .expr , y .expr ), result_type )
123123
124124 @staticmethod
125- def maximum (x , y ) :
125+ def maximum (x : TypedExpr , y : TypedExpr ) -> TypedExpr :
126126 result_type = torch .promote_types (x .dtype , y .dtype )
127127 return TypedExpr (sympy .Max (x .expr , y .expr ), result_type )
128128
@@ -150,18 +150,18 @@ class IndexPropagation:
150150
151151 """
152152
153- def __init__ (self , inner ):
153+ def __init__ (self , inner : Any ):
154154 self ._inner = inner
155155
156- def materialize_expr (self , expr , dtype ) :
156+ def materialize_expr (self , expr : sympy . Expr , dtype : torch . dtype ) -> Any :
157157 # Construct a new constant/index_expr from the SymPy expression
158158 if isinstance (expr , sympy .Integer ):
159159 return self ._inner .constant (int (expr ), dtype )
160160 elif not expr .free_symbols :
161161 return self ._inner .constant (float (expr ), dtype )
162162 return self ._inner .index_expr (expr , dtype )
163163
164- def unwrap (self , a ) :
164+ def unwrap (self , a : Union [ Any , IndexPropVar ]) -> Any :
165165 if not isinstance (a , IndexPropVar ):
166166 return a
167167
@@ -171,15 +171,17 @@ def unwrap(self, a):
171171
172172 return a .value
173173
174- def fallback (self , name , args , kwargs ) :
174+ def fallback (self , name : str , args : Tuple , kwargs : Dict [ str , Any ]) -> IndexPropVar :
175175 # Fallback to the wrapped handler
176176 new_args = [self .unwrap (a ) for a in args ]
177177 new_kwargs = {k : self .unwrap (v ) for k , v in kwargs .items ()}
178178 return IndexPropVar (getattr (self ._inner , name )(* new_args , ** new_kwargs ))
179179
180- def propagate_sympy (self , name , args , kwargs ):
180+ def propagate_sympy (
181+ self , name : str , args : Tuple , kwargs : Dict [str , Any ]
182+ ) -> Union [TypedExpr , IndexPropVar ]:
181183 # Build a new SymPy expression from this ops call
182- def unwrap (a ) :
184+ def unwrap (a : Union [ Any , IndexPropVar ]) -> Any :
183185 if not isinstance (a , IndexPropVar ):
184186 return a
185187 return a .value
@@ -191,8 +193,8 @@ def unwrap(a):
191193 return self .fallback (name , args , kwargs )
192194 return IndexPropVar .new_symbolic (new_expr )
193195
194- def __getattr__ (self , name ) :
195- def inner (* args , ** kwargs ) :
196+ def __getattr__ (self , name : str ) -> Callable [..., Union [ Any , IndexPropVar ]] :
197+ def inner (* args : Any , ** kwargs : Any ) -> Union [ Any , IndexPropVar ] :
196198 if not hasattr (SymPyOps , name ):
197199 return self .fallback (name , args , kwargs )
198200
@@ -208,7 +210,9 @@ def inner(*args, **kwargs):
208210
209211 return inner
210212
211- def indirect_indexing (self , index , size , check = True ):
213+ def indirect_indexing (
214+ self , index : Union [Any , IndexPropVar ], size : Any , check : bool = True
215+ ) -> Any :
212216 # indirect_indexing returns a sympy value, so no need to wrap in IndexPropVar here
213217 if isinstance (index , IndexPropVar ) and index .is_symbolic :
214218 return index .value .expr
0 commit comments