|
1 | 1 | import logging |
| 2 | +import operator |
| 3 | +from collections import defaultdict |
| 4 | +from typing import Set |
2 | 5 |
|
3 | 6 | import torch |
| 7 | +from torch.fx import GraphModule |
| 8 | +from torch.fx.passes.backends.cudagraphs import partition_cudagraphs |
| 9 | +from torch.multiprocessing.reductions import StorageWeakRef |
| 10 | +from torch.nn import Module |
| 11 | +from torch.utils._pytree import tree_map |
4 | 12 |
|
| 13 | +import torchdynamo |
5 | 14 | from torchdynamo import config |
6 | 15 | from torchdynamo.utils import clone_inputs |
7 | 16 | from torchdynamo.utils import count_calls |
@@ -59,7 +68,7 @@ def __init__(self, gm: torch.fx.GraphModule, example_inputs): |
59 | 68 | # - data mutation of inputs (fixed when we stop recording the |
60 | 69 | # copy_ directly into the graph) |
61 | 70 | # - metadata mutation of inputs (fixed if we do an extra partition |
62 | | - # to avoid AOTAutograd on the mutated inputs, or if we some how |
| 71 | + # to avoid AotAutograd on the mutated inputs, or if we some how |
63 | 72 | # get custom autograd function to reflect metadata changes to the |
64 | 73 | # original tensor) |
65 | 74 | mutated = has_mutation(self.gm, self.example_inputs, inputs_only=True) |
@@ -249,6 +258,153 @@ def candidate(self): |
249 | 258 | aot_prims_nvfuser = AotPrimsNvfuser.compile_fn |
250 | 259 |
|
251 | 260 |
|
| 261 | +def cloner(t): |
| 262 | + if isinstance(t, torch.Tensor): |
| 263 | + return t.clone() |
| 264 | + else: |
| 265 | + return t |
| 266 | + |
| 267 | + |
| 268 | +class CudaGraphModule(Module): |
| 269 | + gm: GraphModule |
| 270 | + mutated_inputs: Set[int] |
| 271 | + |
| 272 | + def __init__(self, gm, mutated_inputs): |
| 273 | + super().__init__() |
| 274 | + self.gm = gm |
| 275 | + self.mutated_inputs = mutated_inputs |
| 276 | + |
| 277 | + warmed_up = False |
| 278 | + |
| 279 | + # these are all None or all filled |
| 280 | + graph = None |
| 281 | + static_inputs = None |
| 282 | + static_outputs = None |
| 283 | + |
| 284 | + # NB: we override __call__ as we don't need any nn.Module machinery |
| 285 | + # and to reduce overhead |
| 286 | + def __call__(self, *args): |
| 287 | + # TODO: once we've recorded here, we'd like to replace the __call__ |
| 288 | + # implementation with compiled bytecode that copies into static, replays |
| 289 | + # the cuda graph, then copies out. First condition is the hotpath, |
| 290 | + # needs optimizing |
| 291 | + if self.graph is not None: |
| 292 | + assert len(args) == len(self.static_inputs) |
| 293 | + for dst, src in zip(self.static_inputs, args): |
| 294 | + dst.copy_(src) |
| 295 | + self.graph.replay() |
| 296 | + for i in self.mutated_inputs: |
| 297 | + args[i].copy_(self.static_inputs[i]) |
| 298 | + return tree_map(cloner, self.static_outputs) |
| 299 | + |
| 300 | + elif self.warmed_up: |
| 301 | + # record |
| 302 | + self.static_inputs = [x.clone() for x in args] |
| 303 | + self.graph = torch.cuda.CUDAGraph() |
| 304 | + with torch.cuda.graph(self.graph): |
| 305 | + self.static_outputs = self.gm(*self.static_inputs) |
| 306 | + # NB: recording doesn't actually run the operations, so |
| 307 | + # now we immediately replay the graph to serve up the result |
| 308 | + self.graph.replay() |
| 309 | + for i in self.mutated_inputs: |
| 310 | + args[i].copy_(self.static_inputs[i]) |
| 311 | + return tree_map(cloner, self.static_outputs) |
| 312 | + |
| 313 | + else: |
| 314 | + # warmup |
| 315 | + stream = torch.cuda.Stream() |
| 316 | + stream.wait_stream(torch.cuda.current_stream()) |
| 317 | + with torch.cuda.stream(stream): |
| 318 | + r = self.gm(*args) |
| 319 | + torch.cuda.current_stream().wait_stream(stream) |
| 320 | + self.warmed_up = True |
| 321 | + return r |
| 322 | + |
| 323 | + |
| 324 | +# Interpreter versions of these passes can be found at |
| 325 | +# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23 |
| 326 | + |
| 327 | + |
| 328 | +def find_input_mutations(g): |
| 329 | + FK = "fake_result" |
| 330 | + inputs = defaultdict(set) |
| 331 | + input_idx = 0 |
| 332 | + mutated_inputs = set() |
| 333 | + for n in g.nodes: |
| 334 | + if n.op == "placeholder": |
| 335 | + inputs[StorageWeakRef(n.meta[FK].storage())].add(input_idx) |
| 336 | + input_idx += 1 |
| 337 | + elif n.op == "call_function": |
| 338 | + if n.target is operator.getitem: |
| 339 | + continue |
| 340 | + schema = n.target._schema |
| 341 | + for i, arg in enumerate(schema.arguments): |
| 342 | + if i < len(n.args): |
| 343 | + argument = n.args[i] |
| 344 | + else: |
| 345 | + if arg.name not in n.kwargs: |
| 346 | + continue |
| 347 | + argument = n.kwargs[arg.name] |
| 348 | + mut_arg = False |
| 349 | + if arg.alias_info: |
| 350 | + if arg.alias_info.is_write: |
| 351 | + mut_arg = True |
| 352 | + if mut_arg: |
| 353 | + # TODO: not correct for args that contain tensors in a struct |
| 354 | + # like list |
| 355 | + mutated_inputs |= inputs[ |
| 356 | + StorageWeakRef(argument.meta[FK].storage()) |
| 357 | + ] |
| 358 | + # TODO: error on unrecognized nodes |
| 359 | + return mutated_inputs |
| 360 | + |
| 361 | + |
| 362 | +# Mutates input graph |
| 363 | +def apply_cuda_graphs(gm): |
| 364 | + for n in gm.graph.nodes: |
| 365 | + if n.op == "call_module": |
| 366 | + assert not n.kwargs |
| 367 | + submod = gm.get_submodule(n.target) |
| 368 | + gm.delete_submodule(n.target) |
| 369 | + mutated_inputs = find_input_mutations(submod.graph) |
| 370 | + gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs)) |
| 371 | + # NB: we didn't actually change the graph, no need for recompile |
| 372 | + |
| 373 | + |
| 374 | +def cudagraphs(model, inputs): |
| 375 | + model = partition_cudagraphs(model, inputs) |
| 376 | + apply_cuda_graphs(model) |
| 377 | + return model |
| 378 | + |
| 379 | + |
| 380 | +def raw_aot_autograd_cudagraphs(model, inputs): |
| 381 | + kwargs = { |
| 382 | + # these are taken from memory_efficient_fusion() |
| 383 | + "fw_compiler": cudagraphs, |
| 384 | + "bw_compiler": cudagraphs, |
| 385 | + "hasher_type": "StaticShapeHasher", |
| 386 | + } |
| 387 | + |
| 388 | + def _wrapped_bw_compiler(*args, **kwargs): |
| 389 | + # stop TorchDynamo from trying to compile our generated backwards pass |
| 390 | + return torchdynamo.disable(bw_compiler(*args, **kwargs)) # type: ignore[operator] |
| 391 | + |
| 392 | + bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] |
| 393 | + kwargs["bw_compiler"] = _wrapped_bw_compiler |
| 394 | + |
| 395 | + from functorch.compile import aot_module_simplified # type: ignore[import] |
| 396 | + |
| 397 | + return aot_module_simplified(model, **kwargs) |
| 398 | + |
| 399 | + |
| 400 | +class AotAutogradCudaGraphs(AotAutogradStrategy): |
| 401 | + def candidate(self): |
| 402 | + return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs) |
| 403 | + |
| 404 | + |
| 405 | +aot_cudagraphs = AotAutogradCudaGraphs.compile_fn |
| 406 | + |
| 407 | + |
252 | 408 | def create_aot_backends(): |
253 | 409 | """ |
254 | 410 | Register aliases for the AOT backends |
@@ -280,3 +436,7 @@ def create_aot_backends(): |
280 | 436 | # without worrying about the impact of decomposisitons. More details at |
281 | 437 | # https://github.com/pytorch/torchdynamo/issues/611 |
282 | 438 | BACKENDS["aot_nvfuser_nodecomps"] = aot_mem_efficient_fusion_no_decomp |
| 439 | + |
| 440 | + # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful |
| 441 | + # for debugging and can serve as a perf baseline. |
| 442 | + BACKENDS["aot_cudagraphs"] = aot_cudagraphs |
0 commit comments