Skip to content

Commit 4cff8b5

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Add option to disable applying side effects in dynamo (pytorch#167239)
There are two motivating use cases for this change: 1) export (when we trace pytree calls into a graph, we don't want to accidentally trace the side effect bytecode which will pollute the initial state) -> We want to warn about side effects and don't want to actually apply them 2) VLLM -> They want to detect side effects and error out. We implement this with two configs where one config controls whether we want to apply side effects (by default yes) and the warning level for side effects (warning for export and error for VLLM). We intentionally ignore input side effects, because they are captured in the graph and export would never trace the actual dynamo graph module when tracing the pytree calls). Pull Request resolved: pytorch#167239 Approved by: https://github.com/williamwen42, https://github.com/anijain2305
1 parent 4714eb7 commit 4cff8b5

File tree

6 files changed

+159
-7
lines changed

6 files changed

+159
-7
lines changed

test/dynamo/test_misc.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import pickle
2121
import random
22+
import re
2223
import sys
2324
import tempfile
2425
import threading
@@ -5635,6 +5636,115 @@ def f2(a, b):
56355636
self.assertTrue(same(res11, res12))
56365637
self.assertTrue(same(res21, res22))
56375638

5639+
def test_replay_side_effects_config(self):
5640+
# Test that replay_side_effects config controls mutation replay
5641+
def fn(x, lst):
5642+
lst.append(x + 1)
5643+
return x * 2
5644+
5645+
x = torch.tensor([5.0])
5646+
5647+
# Test with replay enabled (default)
5648+
lst_with_replay = []
5649+
opt_fn_with_replay = torch.compile(fn, backend="eager")
5650+
result1 = opt_fn_with_replay(x, lst_with_replay)
5651+
self.assertEqual(len(lst_with_replay), 1) # Mutation should be replayed
5652+
self.assertTrue(same(result1, x * 2))
5653+
5654+
torch._dynamo.reset()
5655+
5656+
# Test with replay disabled
5657+
lst_without_replay = []
5658+
with torch._dynamo.config.patch(
5659+
replay_side_effects=False, side_effect_replay_policy="warn"
5660+
):
5661+
opt_fn_without_replay = torch.compile(fn, backend="eager")
5662+
result2 = opt_fn_without_replay(x, lst_without_replay)
5663+
self.assertEqual(
5664+
len(lst_without_replay), 0
5665+
) # Mutation should NOT be replayed
5666+
self.assertTrue(same(result2, x * 2))
5667+
5668+
torch._dynamo.reset()
5669+
lst_without_replay = []
5670+
with torch._dynamo.config.patch(
5671+
replay_side_effects=False, side_effect_replay_policy="error"
5672+
):
5673+
opt_fn_without_replay = torch.compile(fn, backend="eager")
5674+
with self.assertRaisesRegex(
5675+
RuntimeError,
5676+
re.escape(
5677+
"While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['lst']\"]"
5678+
),
5679+
):
5680+
_ = opt_fn_without_replay(x, lst_without_replay)
5681+
5682+
def test_replay_side_effects_model_attr(self):
5683+
class Bar(torch.nn.Module):
5684+
def __init__(self):
5685+
super().__init__()
5686+
self.const = 4
5687+
5688+
def forward(self, x):
5689+
return x.cos()
5690+
5691+
class Foo(torch.nn.Module):
5692+
def __init__(self):
5693+
super().__init__()
5694+
self.const = 4
5695+
self.tensor = None
5696+
self.bar = Bar()
5697+
5698+
def forward(self, x):
5699+
self.const = 5
5700+
self.tensor = x.sin()
5701+
res = self.bar(x)
5702+
return x.cos() + res.sum() + self.tensor
5703+
5704+
with torch._dynamo.config.patch(
5705+
replay_side_effects=False, side_effect_replay_policy="error"
5706+
):
5707+
foo = Foo()
5708+
with self.assertRaisesRegex(
5709+
RuntimeError,
5710+
re.escape(
5711+
"While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['self']\"]"
5712+
),
5713+
):
5714+
torch.compile(foo, fullgraph=True)(torch.randn(4, 4))
5715+
5716+
with torch._dynamo.config.patch(
5717+
replay_side_effects=False, side_effect_replay_policy="silent"
5718+
):
5719+
foo_v2_compile = Foo()
5720+
foo_v2_eager = Foo()
5721+
inp = torch.randn(4, 4)
5722+
res = torch.compile(foo_v2_compile, fullgraph=True)(torch.randn(4, 4))
5723+
self.assertEqual(foo_v2_compile.tensor, None)
5724+
self.assertEqual(foo_v2_compile.const, 4)
5725+
self.assertEqual(foo_v2_compile.bar.const, 4)
5726+
same(res, foo_v2_eager(inp))
5727+
5728+
def test_replay_side_effects_input_mut(self):
5729+
class Foo(torch.nn.Module):
5730+
def __init__(self):
5731+
super().__init__()
5732+
self.const = 4
5733+
self.tensor = None
5734+
5735+
def forward(self, x):
5736+
x.add_(5)
5737+
return x.cos()
5738+
5739+
# This is ok because we actually capture the graph which
5740+
# has mutation. In export, we never retrace the actual
5741+
# gm so we won't see any mutation applied to inputs
5742+
with torch._dynamo.config.patch(
5743+
replay_side_effects=False, side_effect_replay_policy="error"
5744+
):
5745+
foo = Foo()
5746+
torch.compile(foo, fullgraph=True)(torch.randn(4, 4))
5747+
56385748
def test_list_append_return_none(self):
56395749
def fn(x):
56405750
alist = []

test/export/test_experimental.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,18 @@ def generate(self, *, input_tensor, input_tensor2):
349349
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
350350
self.assertTrue(torch.allclose(res, res2))
351351

352+
def test_side_effect(self):
353+
global_env = []
354+
355+
class Foo(torch.nn.Module):
356+
def forward(self, x):
357+
global_env.append(x)
358+
return x.sin()
359+
360+
with torch._dynamo.config.patch(replay_side_effects=False):
361+
_ = dynamo_graph_capture_for_export(Foo())(torch.randn(4, 4))
362+
self.assertEqual(len(global_env), 0)
363+
352364
def test_export_add_in_out_info(self):
353365
class Foo(torch.nn.Module):
354366
def forward(self, dct, lst, bleh):

torch/_dynamo/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@
4444
# turn on/off DCE pass (deprecated: always true)
4545
dead_code_elimination = True
4646

47+
# Enable or disable side effect replay after graph execution.
48+
# When False, mutations to Python objects (lists, dicts, attributes) won't be
49+
# replayed after the compiled graph runs. This can cause correctness issues
50+
# if your code depends on these mutations being visible. This should probably
51+
# never be False by default. At the moment, only export will need it.
52+
replay_side_effects = True
53+
54+
# Configure side effect warning level
55+
# If `silent`, we silently allow side effects
56+
# If `warn`, we warn side effects
57+
# If `error`, we error on side effects
58+
side_effect_replay_policy = "silent"
59+
4760
# disable (for a function) when cache reaches this size
4861

4962
# controls the maximum number of cache entries with a guard on same ID_MATCH'd

torch/_dynamo/output_graph.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,7 +1845,7 @@ def compile_subgraph(
18451845
[create_instruction("DELETE_FAST", argval=graph_output_var)]
18461846
)
18471847

1848-
if self.export:
1848+
if torch._dynamo.config.side_effect_replay_policy in ["warn", "error"]:
18491849
from torch.export._trace import _ExportModuleSpecTrackerDict
18501850

18511851
potential_side_effects = []
@@ -1881,10 +1881,16 @@ def compile_subgraph(
18811881
]
18821882

18831883
if side_effect_refs:
1884-
warnings.warn(
1885-
f"While exporting, we found certain side effects happened in the model.forward. "
1886-
f"Here are the list of potential sources you can double check: {side_effect_refs}"
1887-
)
1884+
if torch._dynamo.config.side_effect_replay_policy == "warn":
1885+
warnings.warn(
1886+
f"While compiling, we found certain side effects happened in the model.forward. "
1887+
f"Here are the list of potential sources you can double check: {side_effect_refs}"
1888+
)
1889+
else:
1890+
raise RuntimeError(
1891+
f"While compiling, we found certain side effects happened in the model.forward. "
1892+
f"Here are the list of potential sources you can double check: {side_effect_refs}"
1893+
)
18881894

18891895
return all_stack_locals_metas
18901896

@@ -1930,7 +1936,8 @@ def codegen_suffix(
19301936
assert self.backward_state_var is not None
19311937
cg.append_output(cg.create_load(self.backward_state_var))
19321938
cg.store_attr(name)
1933-
self.side_effects.codegen_hooks(cg)
1939+
if config.replay_side_effects:
1940+
self.side_effects.codegen_hooks(cg)
19341941

19351942
# TODO get debug_locals working for nested graph breaks
19361943
# Return variables used for logging at the end
@@ -1945,7 +1952,8 @@ def codegen_suffix(
19451952
self.codegen_cells(tx, cg)
19461953

19471954
cg.restore_stack(stack_values, value_from_source=not tx.export)
1948-
self.side_effects.codegen_update_mutated(cg)
1955+
if config.replay_side_effects:
1956+
self.side_effects.codegen_update_mutated(cg)
19491957

19501958
def cleanup_graph(self) -> None:
19511959
"""

torch/_dynamo/variables/builder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,13 @@ class GraphArg:
359359
# stash a strong reference too.
360360
example_strong_ref: Optional[torch.Tensor] = None
361361

362+
def __setattr__(self, name, value):
363+
# Use object.__setattr__ to bypass Dynamo's STORE_ATTR interception.
364+
# This is needed because when PYTORCH_TEST_WITH_DYNAMO=1, even internal
365+
# GraphArg creation can be traced, and with replay_side_effects=False,
366+
# normal STORE_ATTR bytecode only records mutations without applying them.
367+
object.__setattr__(self, name, value)
368+
362369
@property
363370
def example(self):
364371
if isinstance(self._example, TensorWeakRef):

torch/export/_trace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ class ExportDynamoConfig:
140140
capture_dynamic_output_shape_ops: bool = True
141141
capture_scalar_outputs: bool = True
142142
prefer_deferred_runtime_asserts_over_guards: bool = False
143+
replay_side_effects: bool = False
144+
side_effect_replay_policy: str = "warn"
143145

144146

145147
@dataclasses.dataclass

0 commit comments

Comments
 (0)