|
19 | 19 | import os |
20 | 20 | import pickle |
21 | 21 | import random |
| 22 | +import re |
22 | 23 | import sys |
23 | 24 | import tempfile |
24 | 25 | import threading |
@@ -5635,6 +5636,115 @@ def f2(a, b): |
5635 | 5636 | self.assertTrue(same(res11, res12)) |
5636 | 5637 | self.assertTrue(same(res21, res22)) |
5637 | 5638 |
|
| 5639 | + def test_replay_side_effects_config(self): |
| 5640 | + # Test that replay_side_effects config controls mutation replay |
| 5641 | + def fn(x, lst): |
| 5642 | + lst.append(x + 1) |
| 5643 | + return x * 2 |
| 5644 | + |
| 5645 | + x = torch.tensor([5.0]) |
| 5646 | + |
| 5647 | + # Test with replay enabled (default) |
| 5648 | + lst_with_replay = [] |
| 5649 | + opt_fn_with_replay = torch.compile(fn, backend="eager") |
| 5650 | + result1 = opt_fn_with_replay(x, lst_with_replay) |
| 5651 | + self.assertEqual(len(lst_with_replay), 1) # Mutation should be replayed |
| 5652 | + self.assertTrue(same(result1, x * 2)) |
| 5653 | + |
| 5654 | + torch._dynamo.reset() |
| 5655 | + |
| 5656 | + # Test with replay disabled |
| 5657 | + lst_without_replay = [] |
| 5658 | + with torch._dynamo.config.patch( |
| 5659 | + replay_side_effects=False, side_effect_replay_policy="warn" |
| 5660 | + ): |
| 5661 | + opt_fn_without_replay = torch.compile(fn, backend="eager") |
| 5662 | + result2 = opt_fn_without_replay(x, lst_without_replay) |
| 5663 | + self.assertEqual( |
| 5664 | + len(lst_without_replay), 0 |
| 5665 | + ) # Mutation should NOT be replayed |
| 5666 | + self.assertTrue(same(result2, x * 2)) |
| 5667 | + |
| 5668 | + torch._dynamo.reset() |
| 5669 | + lst_without_replay = [] |
| 5670 | + with torch._dynamo.config.patch( |
| 5671 | + replay_side_effects=False, side_effect_replay_policy="error" |
| 5672 | + ): |
| 5673 | + opt_fn_without_replay = torch.compile(fn, backend="eager") |
| 5674 | + with self.assertRaisesRegex( |
| 5675 | + RuntimeError, |
| 5676 | + re.escape( |
| 5677 | + "While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['lst']\"]" |
| 5678 | + ), |
| 5679 | + ): |
| 5680 | + _ = opt_fn_without_replay(x, lst_without_replay) |
| 5681 | + |
| 5682 | + def test_replay_side_effects_model_attr(self): |
| 5683 | + class Bar(torch.nn.Module): |
| 5684 | + def __init__(self): |
| 5685 | + super().__init__() |
| 5686 | + self.const = 4 |
| 5687 | + |
| 5688 | + def forward(self, x): |
| 5689 | + return x.cos() |
| 5690 | + |
| 5691 | + class Foo(torch.nn.Module): |
| 5692 | + def __init__(self): |
| 5693 | + super().__init__() |
| 5694 | + self.const = 4 |
| 5695 | + self.tensor = None |
| 5696 | + self.bar = Bar() |
| 5697 | + |
| 5698 | + def forward(self, x): |
| 5699 | + self.const = 5 |
| 5700 | + self.tensor = x.sin() |
| 5701 | + res = self.bar(x) |
| 5702 | + return x.cos() + res.sum() + self.tensor |
| 5703 | + |
| 5704 | + with torch._dynamo.config.patch( |
| 5705 | + replay_side_effects=False, side_effect_replay_policy="error" |
| 5706 | + ): |
| 5707 | + foo = Foo() |
| 5708 | + with self.assertRaisesRegex( |
| 5709 | + RuntimeError, |
| 5710 | + re.escape( |
| 5711 | + "While compiling, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: [\"L['self']\"]" |
| 5712 | + ), |
| 5713 | + ): |
| 5714 | + torch.compile(foo, fullgraph=True)(torch.randn(4, 4)) |
| 5715 | + |
| 5716 | + with torch._dynamo.config.patch( |
| 5717 | + replay_side_effects=False, side_effect_replay_policy="silent" |
| 5718 | + ): |
| 5719 | + foo_v2_compile = Foo() |
| 5720 | + foo_v2_eager = Foo() |
| 5721 | + inp = torch.randn(4, 4) |
| 5722 | + res = torch.compile(foo_v2_compile, fullgraph=True)(torch.randn(4, 4)) |
| 5723 | + self.assertEqual(foo_v2_compile.tensor, None) |
| 5724 | + self.assertEqual(foo_v2_compile.const, 4) |
| 5725 | + self.assertEqual(foo_v2_compile.bar.const, 4) |
| 5726 | + same(res, foo_v2_eager(inp)) |
| 5727 | + |
| 5728 | + def test_replay_side_effects_input_mut(self): |
| 5729 | + class Foo(torch.nn.Module): |
| 5730 | + def __init__(self): |
| 5731 | + super().__init__() |
| 5732 | + self.const = 4 |
| 5733 | + self.tensor = None |
| 5734 | + |
| 5735 | + def forward(self, x): |
| 5736 | + x.add_(5) |
| 5737 | + return x.cos() |
| 5738 | + |
| 5739 | + # This is ok because we actually capture the graph which |
| 5740 | + # has mutation. In export, we never retrace the actual |
| 5741 | + # gm so we won't see any mutation applied to inputs |
| 5742 | + with torch._dynamo.config.patch( |
| 5743 | + replay_side_effects=False, side_effect_replay_policy="error" |
| 5744 | + ): |
| 5745 | + foo = Foo() |
| 5746 | + torch.compile(foo, fullgraph=True)(torch.randn(4, 4)) |
| 5747 | + |
5638 | 5748 | def test_list_append_return_none(self): |
5639 | 5749 | def fn(x): |
5640 | 5750 | alist = [] |
|
0 commit comments