Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,15 @@ def _restore_origin_opcode(self, stack_vars, store_var_info, instr_idx):
self.pycode_gen.gen_enable_eval_frame()

name_gen = NameGenerator("___graph_fn_saved_orig_")
stored_var_ids = set()

# here is not update changed values, it just give names to stack vars
# and want keep same interface as _build_compile_fn_with_name_store
for var in stack_vars[::-1]:
if var.id in stored_var_ids:
self.pycode_gen.gen_pop_top()
continue
stored_var_ids.add(var.id)
if not store_var_info.get(var.id, []):
name = name_gen.next()
store_var_info.setdefault(var.id, [])
Expand Down
14 changes: 14 additions & 0 deletions test/sot/test_min_graph_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ def get_arg_from_kwargs(x, **kwargs):
return x, y


def add_with_breakgraph(x, y):
sot.psdb.breakgraph()
return x + y


def restore_same_arg_when_fallback(x):
return add_with_breakgraph(x, x)


class TestMinGraphSize(TestCaseBase):
@min_graph_size_guard(10)
def test_cases(self):
Expand Down Expand Up @@ -116,6 +125,11 @@ def test_get_arg_from_kwargs(self):
self.assert_results(get_arg_from_kwargs, None)
self.assert_results(get_arg_from_kwargs, None, y=1)

@min_graph_size_guard(10)
def test_restore_same_arg_when_fallback(self):
x = paddle.to_tensor(1)
self.assert_results(restore_same_arg_when_fallback, x)


if __name__ == "__main__":
unittest.main()