Skip to content

Commit f576960

Browse files
isurufpytorchmergebot
authored andcommitted
do not expand in replace/simplify if no changes (pytorch#135863)
Pull Request resolved: pytorch#135863 Approved by: https://github.com/ezyang
1 parent 1aba224 commit f576960

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

torch/fx/experimental/symbolic_shapes.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4592,7 +4592,10 @@ def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
45924592
# assumption queries if expr has a relational node.
45934593
if not r.is_Symbol or r != s:
45944594
replacements[s] = r
4595-
return safe_expand(expr.xreplace(replacements))
4595+
if replacements:
4596+
return safe_expand(expr.xreplace(replacements))
4597+
else:
4598+
return expr
45964599

45974600
@_lru_cache
45984601
def _update_divisible(self):
@@ -4626,8 +4629,9 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
46264629
if self.replace(Mod(base, divisor)) in self.divisible and \
46274630
base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible:
46284631
div_replacements[atom] = divisor1
4629-
expr = expr.xreplace(div_replacements)
4630-
expr = safe_expand(expr)
4632+
if div_replacements:
4633+
expr = expr.xreplace(div_replacements)
4634+
expr = safe_expand(expr)
46314635
if expr.has(FloorDiv):
46324636
div_replacements = {}
46334637
pows = expr.atoms(sympy.Pow)
@@ -4636,13 +4640,14 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
46364640
base, divisor = fd.args
46374641
if self.replace(Mod(base, divisor)) in self.divisible:
46384642
div_replacements[fd] = CleanDiv(base, divisor)
4639-
new_expr = expr.xreplace(div_replacements)
4640-
new_expr = safe_expand(new_expr)
4641-
new_pows = new_expr.atoms(sympy.Pow)
4642-
new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer))
4643-
# divisions simplified away
4644-
if new_pows.issubset(pows) and new_rationals.issubset(rationals):
4645-
expr = new_expr
4643+
if div_replacements:
4644+
new_expr = expr.xreplace(div_replacements)
4645+
new_expr = safe_expand(new_expr)
4646+
new_pows = new_expr.atoms(sympy.Pow)
4647+
new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer))
4648+
# divisions simplified away
4649+
if new_pows.issubset(pows) and new_rationals.issubset(rationals):
4650+
expr = new_expr
46464651
return expr
46474652

46484653
@lru_cache(256)

0 commit comments

Comments
 (0)