Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 3f2decb

Browse files
authored
Move aot_cudagraphs backend here (#757)
* Move aot_cudagraphs backend here Previously it was in pytorch/pytorch but it depends on torchdynamo code more closely, so this seems like the logical place. Previously at pytorch/pytorch#80566 Signed-off-by: Edward Z. Yang <ezyang@fb.com>
1 parent 8c7dd9e commit 3f2decb

File tree

2 files changed

+343
-1
lines changed

2 files changed

+343
-1
lines changed

tests/test_aot_cudagraphs.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Owner(s): ["module: cuda graphs"]
2+
3+
import functools
4+
import unittest
5+
from unittest.mock import patch
6+
7+
import torch
8+
9+
import torchdynamo
10+
import torchdynamo.testing
11+
from torchdynamo.testing import same
12+
13+
14+
def composed(*decs):
15+
def deco(f):
16+
for dec in reversed(decs):
17+
f = dec(f)
18+
return f
19+
20+
return deco
21+
22+
23+
def assert_aot_autograd_counter(ok=True):
24+
def deco(f):
25+
@functools.wraps(f)
26+
def wrap(self, *args, **kwargs):
27+
torchdynamo.utils.counters.clear()
28+
r = f(self, *args, **kwargs)
29+
c_ok = torchdynamo.utils.counters["aot_autograd"]["ok"]
30+
c_not_ok = torchdynamo.utils.counters["aot_autograd"]["not_ok"]
31+
if ok:
32+
self.assertGreater(c_ok, 0)
33+
self.assertEqual(c_not_ok, 0)
34+
else:
35+
self.assertEqual(c_ok, 0)
36+
self.assertGreater(c_not_ok, 0)
37+
return r
38+
39+
return wrap
40+
41+
return deco
42+
43+
44+
def patch_all(ok=True):
45+
return composed(
46+
patch("torchdynamo.config.verify_correctness", True),
47+
assert_aot_autograd_counter(ok),
48+
)
49+
50+
51+
N_ITERS = 5
52+
53+
54+
@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda")
55+
class TestAotCudagraphs(torchdynamo.testing.TestCase):
56+
@patch_all()
57+
def test_basic(self):
58+
def model(x, y):
59+
return (x + y) * y
60+
61+
with torchdynamo.optimize("aot_cudagraphs"):
62+
for i in range(N_ITERS):
63+
x = torch.randn(3, device="cuda", requires_grad=True)
64+
y = torch.randn(3, device="cuda")
65+
loss = model(x, y).sum()
66+
loss.backward()
67+
68+
@patch_all()
69+
def test_dtoh(self):
70+
def model(x, y):
71+
a = x + y
72+
b = a.cpu() * 3
73+
return b
74+
75+
with torchdynamo.optimize("aot_cudagraphs"):
76+
for i in range(N_ITERS):
77+
x = torch.randn(3, device="cuda", requires_grad=True)
78+
y = torch.randn(3, device="cuda")
79+
loss = model(x, y).sum()
80+
loss.backward()
81+
82+
@patch_all()
83+
def test_htod(self):
84+
def model(x, y):
85+
a = x + y
86+
return a * 3
87+
88+
with torchdynamo.optimize("aot_cudagraphs"):
89+
for i in range(N_ITERS):
90+
x = torch.randn(3, device="cuda", requires_grad=True)
91+
y = torch.randn((), device="cpu")
92+
loss = model(x, y).sum()
93+
loss.backward()
94+
95+
@patch("functorch._src.config.use_functionalize", True)
96+
@patch_all(ok=False) # input mutation not supported yet
97+
def test_mutate_input(self):
98+
def model(x, y):
99+
y.add_(3)
100+
return x * y
101+
102+
with torchdynamo.optimize("aot_cudagraphs"):
103+
for i in range(N_ITERS):
104+
with self.subTest(i):
105+
x = torch.randn(3, device="cuda", requires_grad=True)
106+
y = torch.randn(3, device="cuda")
107+
y_orig = y.clone()
108+
loss = model(x, y).sum()
109+
self.assertTrue(same(y, y_orig + 3))
110+
loss.backward()
111+
112+
@patch_all()
113+
def test_mutate_constant(self):
114+
def model(x, y):
115+
c = torch.tensor(1)
116+
c.add_(2)
117+
return x * y * 0 + c
118+
119+
with torchdynamo.optimize("aot_cudagraphs"):
120+
for i in range(N_ITERS):
121+
with self.subTest(i):
122+
x = torch.randn(1, device="cuda", requires_grad=True)
123+
y = torch.randn(1, device="cuda")
124+
loss = model(x, y).sum()
125+
self.assertTrue(same(loss, torch.tensor(3.0, device="cuda")))
126+
loss.backward()
127+
128+
@patch_all()
129+
def test_factory(self):
130+
def model(y):
131+
x = torch.zeros(3, device="cuda:0")
132+
x.add_(3)
133+
return x * y
134+
135+
with torchdynamo.optimize("aot_cudagraphs"):
136+
for i in range(N_ITERS):
137+
with self.subTest(i):
138+
y = torch.randn(3, device="cuda:0", requires_grad=True)
139+
loss = model(y).sum()
140+
loss.backward()
141+
142+
# Internal resize_ inside models appear to be broken right now
143+
@unittest.expectedFailure
144+
@patch("functorch._src.config.use_functionalize", True)
145+
@patch_all()
146+
def test_mutated_metadata(self):
147+
# more tortured example at
148+
# https://github.com/pytorch/pytorch/issues/81385
149+
def model(x):
150+
x = x.clone()
151+
x.resize_(20)
152+
x.fill_(2)
153+
return x
154+
155+
with torchdynamo.optimize("aot_cudagraphs"):
156+
for i in range(N_ITERS):
157+
with self.subTest(i):
158+
x = torch.empty(0, device="cuda:0")
159+
rx = model(x)
160+
self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
161+
162+
@patch("functorch._src.config.use_functionalize", True)
163+
@patch_all()
164+
def test_dead_fill(self):
165+
def model(x):
166+
x = x.clone()
167+
y = x[0:0]
168+
x.fill_(2)
169+
y.fill_(3)
170+
return x, y
171+
172+
with torchdynamo.optimize("aot_cudagraphs"):
173+
for i in range(N_ITERS):
174+
with self.subTest(i):
175+
x = torch.empty(20, device="cuda:0")
176+
rx, ry = model(x)
177+
self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
178+
self.assertTrue(same(ry, torch.empty(0, device="cuda:0")))
179+
180+
181+
if __name__ == "__main__":
182+
unittest.main()

torchdynamo/optimizations/training.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
import logging
2+
import operator
3+
from collections import defaultdict
4+
from typing import Set
25

36
import torch
7+
from torch.fx import GraphModule
8+
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
9+
from torch.multiprocessing.reductions import StorageWeakRef
10+
from torch.nn import Module
11+
from torch.utils._pytree import tree_map
412

13+
import torchdynamo
514
from torchdynamo import config
615
from torchdynamo.utils import clone_inputs
716
from torchdynamo.utils import count_calls
@@ -59,7 +68,7 @@ def __init__(self, gm: torch.fx.GraphModule, example_inputs):
5968
# - data mutation of inputs (fixed when we stop recording the
6069
# copy_ directly into the graph)
6170
# - metadata mutation of inputs (fixed if we do an extra partition
62-
# to avoid AOTAutograd on the mutated inputs, or if we some how
71+
# to avoid AotAutograd on the mutated inputs, or if we some how
6372
# get custom autograd function to reflect metadata changes to the
6473
# original tensor)
6574
mutated = has_mutation(self.gm, self.example_inputs, inputs_only=True)
@@ -249,6 +258,153 @@ def candidate(self):
249258
aot_prims_nvfuser = AotPrimsNvfuser.compile_fn
250259

251260

261+
def cloner(t):
262+
if isinstance(t, torch.Tensor):
263+
return t.clone()
264+
else:
265+
return t
266+
267+
268+
class CudaGraphModule(Module):
269+
gm: GraphModule
270+
mutated_inputs: Set[int]
271+
272+
def __init__(self, gm, mutated_inputs):
273+
super().__init__()
274+
self.gm = gm
275+
self.mutated_inputs = mutated_inputs
276+
277+
warmed_up = False
278+
279+
# these are all None or all filled
280+
graph = None
281+
static_inputs = None
282+
static_outputs = None
283+
284+
# NB: we override __call__ as we don't need any nn.Module machinery
285+
# and to reduce overhead
286+
def __call__(self, *args):
287+
# TODO: once we've recorded here, we'd like to replace the __call__
288+
# implementation with compiled bytecode that copies into static, replays
289+
# the cuda graph, then copies out. First condition is the hotpath,
290+
# needs optimizing
291+
if self.graph is not None:
292+
assert len(args) == len(self.static_inputs)
293+
for dst, src in zip(self.static_inputs, args):
294+
dst.copy_(src)
295+
self.graph.replay()
296+
for i in self.mutated_inputs:
297+
args[i].copy_(self.static_inputs[i])
298+
return tree_map(cloner, self.static_outputs)
299+
300+
elif self.warmed_up:
301+
# record
302+
self.static_inputs = [x.clone() for x in args]
303+
self.graph = torch.cuda.CUDAGraph()
304+
with torch.cuda.graph(self.graph):
305+
self.static_outputs = self.gm(*self.static_inputs)
306+
# NB: recording doesn't actually run the operations, so
307+
# now we immediately replay the graph to serve up the result
308+
self.graph.replay()
309+
for i in self.mutated_inputs:
310+
args[i].copy_(self.static_inputs[i])
311+
return tree_map(cloner, self.static_outputs)
312+
313+
else:
314+
# warmup
315+
stream = torch.cuda.Stream()
316+
stream.wait_stream(torch.cuda.current_stream())
317+
with torch.cuda.stream(stream):
318+
r = self.gm(*args)
319+
torch.cuda.current_stream().wait_stream(stream)
320+
self.warmed_up = True
321+
return r
322+
323+
324+
# Interpreter versions of these passes can be found at
325+
# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
326+
327+
328+
def find_input_mutations(g):
329+
FK = "fake_result"
330+
inputs = defaultdict(set)
331+
input_idx = 0
332+
mutated_inputs = set()
333+
for n in g.nodes:
334+
if n.op == "placeholder":
335+
inputs[StorageWeakRef(n.meta[FK].storage())].add(input_idx)
336+
input_idx += 1
337+
elif n.op == "call_function":
338+
if n.target is operator.getitem:
339+
continue
340+
schema = n.target._schema
341+
for i, arg in enumerate(schema.arguments):
342+
if i < len(n.args):
343+
argument = n.args[i]
344+
else:
345+
if arg.name not in n.kwargs:
346+
continue
347+
argument = n.kwargs[arg.name]
348+
mut_arg = False
349+
if arg.alias_info:
350+
if arg.alias_info.is_write:
351+
mut_arg = True
352+
if mut_arg:
353+
# TODO: not correct for args that contain tensors in a struct
354+
# like list
355+
mutated_inputs |= inputs[
356+
StorageWeakRef(argument.meta[FK].storage())
357+
]
358+
# TODO: error on unrecognized nodes
359+
return mutated_inputs
360+
361+
362+
# Mutates input graph
363+
def apply_cuda_graphs(gm):
364+
for n in gm.graph.nodes:
365+
if n.op == "call_module":
366+
assert not n.kwargs
367+
submod = gm.get_submodule(n.target)
368+
gm.delete_submodule(n.target)
369+
mutated_inputs = find_input_mutations(submod.graph)
370+
gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
371+
# NB: we didn't actually change the graph, no need for recompile
372+
373+
374+
def cudagraphs(model, inputs):
375+
model = partition_cudagraphs(model, inputs)
376+
apply_cuda_graphs(model)
377+
return model
378+
379+
380+
def raw_aot_autograd_cudagraphs(model, inputs):
381+
kwargs = {
382+
# these are taken from memory_efficient_fusion()
383+
"fw_compiler": cudagraphs,
384+
"bw_compiler": cudagraphs,
385+
"hasher_type": "StaticShapeHasher",
386+
}
387+
388+
def _wrapped_bw_compiler(*args, **kwargs):
389+
# stop TorchDynamo from trying to compile our generated backwards pass
390+
return torchdynamo.disable(bw_compiler(*args, **kwargs)) # type: ignore[operator]
391+
392+
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
393+
kwargs["bw_compiler"] = _wrapped_bw_compiler
394+
395+
from functorch.compile import aot_module_simplified # type: ignore[import]
396+
397+
return aot_module_simplified(model, **kwargs)
398+
399+
400+
class AotAutogradCudaGraphs(AotAutogradStrategy):
401+
def candidate(self):
402+
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)
403+
404+
405+
aot_cudagraphs = AotAutogradCudaGraphs.compile_fn
406+
407+
252408
def create_aot_backends():
253409
"""
254410
Register aliases for the AOT backends
@@ -280,3 +436,7 @@ def create_aot_backends():
280436
# without worrying about the impact of decomposisitons. More details at
281437
# https://github.com/pytorch/torchdynamo/issues/611
282438
BACKENDS["aot_nvfuser_nodecomps"] = aot_mem_efficient_fusion_no_decomp
439+
440+
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
441+
# for debugging and can serve as a perf baseline.
442+
BACKENDS["aot_cudagraphs"] = aot_cudagraphs

0 commit comments

Comments
 (0)