Skip to content

Commit 69cf71c

Browse files
authored
[SOT] Skip restore same stack arg when fallback (#72367)
1 parent 9ba9a69 commit 69cf71c

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,10 +390,15 @@ def _restore_origin_opcode(self, stack_vars, store_var_info, instr_idx):
390390
self.pycode_gen.gen_enable_eval_frame()
391391

392392
name_gen = NameGenerator("___graph_fn_saved_orig_")
393+
stored_var_ids = set()
393394

394395
# here is not update changed values, it just give names to stack vars
395396
# and want keep same interface as _build_compile_fn_with_name_store
396397
for var in stack_vars[::-1]:
398+
if var.id in stored_var_ids:
399+
self.pycode_gen.gen_pop_top()
400+
continue
401+
stored_var_ids.add(var.id)
397402
if not store_var_info.get(var.id, []):
398403
name = name_gen.next()
399404
store_var_info.setdefault(var.id, [])

test/sot/test_min_graph_size.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ def get_arg_from_kwargs(x, **kwargs):
8484
return x, y
8585

8686

87+
def add_with_breakgraph(x, y):
88+
sot.psdb.breakgraph()
89+
return x + y
90+
91+
92+
def restore_same_arg_when_fallback(x):
93+
return add_with_breakgraph(x, x)
94+
95+
8796
class TestMinGraphSize(TestCaseBase):
8897
@min_graph_size_guard(10)
8998
def test_cases(self):
@@ -116,6 +125,11 @@ def test_get_arg_from_kwargs(self):
116125
self.assert_results(get_arg_from_kwargs, None)
117126
self.assert_results(get_arg_from_kwargs, None, y=1)
118127

128+
@min_graph_size_guard(10)
129+
def test_restore_same_arg_when_fallback(self):
130+
x = paddle.to_tensor(1)
131+
self.assert_results(restore_same_arg_when_fallback, x)
132+
119133

120134
if __name__ == "__main__":
121135
unittest.main()

0 commit comments

Comments
 (0)