Skip to content

Commit e5a766e

Browse files
mlazospytorchmergebot
authored andcommitted
[user-streams] Insert backward syncs (pytorch#167747)
Pull Request resolved: pytorch#167747 Approved by: https://github.com/soulitzer
1 parent a5f36a8 commit e5a766e

File tree

3 files changed

+68
-1
lines changed

3 files changed

+68
-1
lines changed

test/dynamo/test_streams.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,10 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
585585
# Annotation: {'stream': 1}
586586
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
587587
588+
# No stacktrace found for following nodes
589+
record_event_default = torch.ops.streams.record_event.default(2, 1); record_event_default = None
590+
wait_event_default = torch.ops.streams.wait_event.default(2, 0); wait_event_default = None
591+
588592
# Annotation: {'stream': 0}
589593
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
590594
return (add_3, add_2)

torch/_functorch/_aot_autograd/graph_capture.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
handle_effect_tokens_fn,
3434
)
3535
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
36-
from .streams import assign_backward_streams
36+
from .streams import assign_backward_streams, insert_backward_syncs
3737
from .utils import (
3838
call_and_expect_output_descs,
3939
copy_fwd_metadata_to_bw_nodes,
@@ -477,6 +477,8 @@ def aot_dispatch_autograd_graph(
477477
# After copying metadata, assign streams to gradient accumulation nodes
478478
assign_backward_streams(fx_g)
479479

480+
insert_backward_syncs(fx_g)
481+
480482
fx_g.graph.eliminate_dead_code()
481483
if not aot_config.disable_functionalization:
482484
# There should be *NO* mutating ops in the graph at this point.

torch/_functorch/_aot_autograd/streams.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.fx
44
import torch.fx.traceback
55
from torch._dynamo.graph_utils import _get_flat_args
6+
from torch._dynamo.variables.streams import get_current_stream, new_event
67

78

89
Node: TypeAlias = torch.fx.Node
@@ -12,6 +13,14 @@ def is_gradient_acc(node: Node) -> bool:
1213
return node.meta.get("is_gradient_acc", False)
1314

1415

16+
def is_bwd_node(node: Node) -> bool:
17+
return node.meta.get("partitioner_tag") == "is_backward"
18+
19+
20+
def get_device(node: Node) -> torch.device:
21+
return node.meta["val"].device
22+
23+
1524
def get_stream(node: Node) -> Optional[int]:
1625
maybe_annotation = node.meta.get("custom", None)
1726
if maybe_annotation is not None:
@@ -20,13 +29,50 @@ def get_stream(node: Node) -> Optional[int]:
2029
return None
2130

2231

32+
def get_stream_or_current_stream(node: Node) -> int:
33+
ind = get_stream(node)
34+
if ind is None:
35+
ind = get_current_stream(get_device(node))
36+
return ind
37+
38+
2339
def set_stream(node: Node, ind: int) -> None:
2440
if "custom" in node.meta:
2541
node.meta["custom"].update({"stream": ind})
2642
else:
2743
node.meta["custom"] = {"stream": ind}
2844

2945

46+
def insert_sync(
47+
graph: torch.fx.Graph,
48+
consumer: Node,
49+
producer: Node,
50+
node_to_wait_event_ind: dict[Node, int],
51+
) -> None:
52+
if producer not in node_to_wait_event_ind:
53+
node_to_wait_event_ind[producer] = new_event()
54+
55+
with graph.inserting_after(producer):
56+
node = graph.call_function(
57+
torch.ops.streams.record_event.default,
58+
(
59+
node_to_wait_event_ind[producer],
60+
get_stream_or_current_stream(producer),
61+
),
62+
)
63+
node.meta["partitioner_tag"] = "must_be_in_backward"
64+
65+
with graph.inserting_before(consumer):
66+
node = graph.call_function(
67+
torch.ops.streams.wait_event.default,
68+
(
69+
node_to_wait_event_ind[producer],
70+
get_stream_or_current_stream(consumer),
71+
),
72+
)
73+
node.meta["partitioner_tag"] = "must_be_in_backward"
74+
75+
3076
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
3177
"""Assigns backward streams to gradient accumulation nodes"""
3278

@@ -51,3 +97,18 @@ def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
5197
if ind is not None:
5298
set_stream(node, ind)
5399
break
100+
101+
102+
def insert_backward_syncs(gm: torch.fx.GraphModule) -> None:
103+
"""Inserts stream syncs for backward nodes if consumer and producer are on different streams"""
104+
node_to_wait_event_ind = {}
105+
for node in gm.graph.nodes:
106+
if is_bwd_node(node):
107+
flat_args = _get_flat_args(node, {})
108+
cur_node_stream = get_stream(node)
109+
110+
for arg in flat_args:
111+
if is_bwd_node(arg):
112+
arg_stream = get_stream(arg)
113+
if arg_stream != cur_node_stream and get_device(arg).type != "cpu":
114+
insert_sync(gm.graph, node, arg, node_to_wait_event_ind)

0 commit comments

Comments
 (0)