Skip to content

Commit 450c22c

Browse files
msaroufimpytorchmergebot
authored andcommitted
mypy index propagation (pytorch#105622)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#105622 Approved by: https://github.com/eellison
1 parent fe7187b commit 450c22c

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

torch/_inductor/index_propagation.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"""
2222
import itertools
2323
from dataclasses import dataclass
24-
from typing import Any
24+
from typing import Any, Callable, Dict, Optional, Tuple, Union
2525

2626
import sympy
2727

@@ -47,11 +47,11 @@ class SymPyOps:
4747
"""
4848

4949
@staticmethod
50-
def identity(value):
50+
def identity(value: Any) -> Any:
5151
return value
5252

5353
@staticmethod
54-
def constant(value, dtype):
54+
def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr:
5555
if is_boolean_dtype(dtype):
5656
expr = sympy.Integer(bool(value))
5757
elif is_integer_dtype(dtype):
@@ -61,13 +61,13 @@ def constant(value, dtype):
6161
return TypedExpr(expr, dtype)
6262

6363
@staticmethod
64-
def index_expr(value, dtype):
64+
def index_expr(value: sympy.Expr, dtype: torch.dtype) -> Union[int, TypedExpr]:
6565
if isinstance(value, int):
6666
value = sympy.Integer(value)
6767
return TypedExpr(value, dtype)
6868

6969
@staticmethod
70-
def to_dtype(value, dtype):
70+
def to_dtype(value: Any, dtype: torch.dtype) -> Union[int, TypedExpr]:
7171
if isinstance(value.expr, (sympy.Integer, sympy.Float)):
7272
return SymPyOps.constant(value.expr, dtype)
7373
elif is_integer_dtype(dtype) and is_integer_dtype(value.dtype):
@@ -77,38 +77,38 @@ def to_dtype(value, dtype):
7777
return NotImplemented
7878

7979
@staticmethod
80-
def square(x):
80+
def square(x: TypedExpr) -> TypedExpr:
8181
return TypedExpr(x.expr * x.expr, x.dtype)
8282

8383
@staticmethod
84-
def add(x, y):
84+
def add(x: TypedExpr, y: TypedExpr) -> TypedExpr:
8585
result_type = torch.promote_types(x.dtype, y.dtype)
8686
return TypedExpr(x.expr + y.expr, result_type)
8787

8888
@staticmethod
89-
def sub(x, y):
89+
def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr:
9090
result_type = torch.promote_types(x.dtype, y.dtype)
9191
return TypedExpr(x.expr - y.expr, result_type)
9292

9393
@staticmethod
94-
def mul(x, y):
94+
def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr:
9595
result_type = torch.promote_types(x.dtype, y.dtype)
9696
return TypedExpr(x.expr * y.expr, result_type)
9797

9898
@staticmethod
99-
def neg(x):
99+
def neg(x: TypedExpr) -> TypedExpr:
100100
return TypedExpr(-x.expr, x.dtype)
101101

102102
@staticmethod
103-
def floordiv(x, y):
103+
def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr:
104104
result_type = torch.promote_types(x.dtype, y.dtype)
105105
if not is_integer_dtype(result_type):
106106
return NotImplemented
107107

108108
return TypedExpr(FloorDiv(x.expr, y.expr), result_type)
109109

110110
@staticmethod
111-
def remainder(x, y):
111+
def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
112112
result_type = torch.promote_types(x.dtype, y.dtype)
113113
if not is_integer_dtype(result_type):
114114
return NotImplemented
@@ -117,12 +117,12 @@ def remainder(x, y):
117117
return TypedExpr(result_expr, result_type)
118118

119119
@staticmethod
120-
def minimum(x, y):
120+
def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
121121
result_type = torch.promote_types(x.dtype, y.dtype)
122122
return TypedExpr(sympy.Min(x.expr, y.expr), result_type)
123123

124124
@staticmethod
125-
def maximum(x, y):
125+
def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
126126
result_type = torch.promote_types(x.dtype, y.dtype)
127127
return TypedExpr(sympy.Max(x.expr, y.expr), result_type)
128128

@@ -150,18 +150,18 @@ class IndexPropagation:
150150
151151
"""
152152

153-
def __init__(self, inner):
153+
def __init__(self, inner: Any):
154154
self._inner = inner
155155

156-
def materialize_expr(self, expr, dtype):
156+
def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any:
157157
# Construct a new constant/index_expr from the SymPy expression
158158
if isinstance(expr, sympy.Integer):
159159
return self._inner.constant(int(expr), dtype)
160160
elif not expr.free_symbols:
161161
return self._inner.constant(float(expr), dtype)
162162
return self._inner.index_expr(expr, dtype)
163163

164-
def unwrap(self, a):
164+
def unwrap(self, a: Union[Any, IndexPropVar]) -> Any:
165165
if not isinstance(a, IndexPropVar):
166166
return a
167167

@@ -171,15 +171,17 @@ def unwrap(self, a):
171171

172172
return a.value
173173

174-
def fallback(self, name, args, kwargs):
174+
def fallback(self, name: str, args: Tuple, kwargs: Dict[str, Any]) -> IndexPropVar:
175175
# Fallback to the wrapped handler
176176
new_args = [self.unwrap(a) for a in args]
177177
new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()}
178178
return IndexPropVar(getattr(self._inner, name)(*new_args, **new_kwargs))
179179

180-
def propagate_sympy(self, name, args, kwargs):
180+
def propagate_sympy(
181+
self, name: str, args: Tuple, kwargs: Dict[str, Any]
182+
) -> Union[TypedExpr, IndexPropVar]:
181183
# Build a new SymPy expression from this ops call
182-
def unwrap(a):
184+
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
183185
if not isinstance(a, IndexPropVar):
184186
return a
185187
return a.value
@@ -191,8 +193,8 @@ def unwrap(a):
191193
return self.fallback(name, args, kwargs)
192194
return IndexPropVar.new_symbolic(new_expr)
193195

194-
def __getattr__(self, name):
195-
def inner(*args, **kwargs):
196+
def __getattr__(self, name: str) -> Callable[..., Union[Any, IndexPropVar]]:
197+
def inner(*args: Any, **kwargs: Any) -> Union[Any, IndexPropVar]:
196198
if not hasattr(SymPyOps, name):
197199
return self.fallback(name, args, kwargs)
198200

@@ -208,7 +210,9 @@ def inner(*args, **kwargs):
208210

209211
return inner
210212

211-
def indirect_indexing(self, index, size, check=True):
213+
def indirect_indexing(
214+
self, index: Union[Any, IndexPropVar], size: Any, check: bool = True
215+
) -> Any:
212216
# indirect_indexing returns a sympy value, so no need to wrap in IndexPropVar here
213217
if isinstance(index, IndexPropVar) and index.is_symbolic:
214218
return index.value.expr

0 commit comments

Comments
 (0)