33import torch .fx
44import torch .fx .traceback
55from torch ._dynamo .graph_utils import _get_flat_args
6+ from torch ._dynamo .variables .streams import get_current_stream , new_event
67
78
89Node : TypeAlias = torch .fx .Node
@@ -12,6 +13,14 @@ def is_gradient_acc(node: Node) -> bool:
1213 return node .meta .get ("is_gradient_acc" , False )
1314
1415
16+ def is_bwd_node (node : Node ) -> bool :
17+ return node .meta .get ("partitioner_tag" ) == "is_backward"
18+
19+
20+ def get_device (node : Node ) -> torch .device :
21+ return node .meta ["val" ].device
22+
23+
1524def get_stream (node : Node ) -> Optional [int ]:
1625 maybe_annotation = node .meta .get ("custom" , None )
1726 if maybe_annotation is not None :
@@ -20,13 +29,50 @@ def get_stream(node: Node) -> Optional[int]:
2029 return None
2130
2231
32+ def get_stream_or_current_stream (node : Node ) -> int :
33+ ind = get_stream (node )
34+ if ind is None :
35+ ind = get_current_stream (get_device (node ))
36+ return ind
37+
38+
2339def set_stream (node : Node , ind : int ) -> None :
2440 if "custom" in node .meta :
2541 node .meta ["custom" ].update ({"stream" : ind })
2642 else :
2743 node .meta ["custom" ] = {"stream" : ind }
2844
2945
46+ def insert_sync (
47+ graph : torch .fx .Graph ,
48+ consumer : Node ,
49+ producer : Node ,
50+ node_to_wait_event_ind : dict [Node , int ],
51+ ) -> None :
52+ if producer not in node_to_wait_event_ind :
53+ node_to_wait_event_ind [producer ] = new_event ()
54+
55+ with graph .inserting_after (producer ):
56+ node = graph .call_function (
57+ torch .ops .streams .record_event .default ,
58+ (
59+ node_to_wait_event_ind [producer ],
60+ get_stream_or_current_stream (producer ),
61+ ),
62+ )
63+ node .meta ["partitioner_tag" ] = "must_be_in_backward"
64+
65+ with graph .inserting_before (consumer ):
66+ node = graph .call_function (
67+ torch .ops .streams .wait_event .default ,
68+ (
69+ node_to_wait_event_ind [producer ],
70+ get_stream_or_current_stream (consumer ),
71+ ),
72+ )
73+ node .meta ["partitioner_tag" ] = "must_be_in_backward"
74+
75+
3076def assign_backward_streams (gm : torch .fx .GraphModule ) -> None :
3177 """Assigns backward streams to gradient accumulation nodes"""
3278
@@ -51,3 +97,18 @@ def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
5197 if ind is not None :
5298 set_stream (node , ind )
5399 break
100+
101+
102+ def insert_backward_syncs (gm : torch .fx .GraphModule ) -> None :
103+ """Inserts stream syncs for backward nodes if consumer and producer are on different streams"""
104+ node_to_wait_event_ind = {}
105+ for node in gm .graph .nodes :
106+ if is_bwd_node (node ):
107+ flat_args = _get_flat_args (node , {})
108+ cur_node_stream = get_stream (node )
109+
110+ for arg in flat_args :
111+ if is_bwd_node (arg ):
112+ arg_stream = get_stream (arg )
113+ if arg_stream != cur_node_stream and get_device (arg ).type != "cpu" :
114+ insert_sync (gm .graph , node , arg , node_to_wait_event_ind )
0 commit comments