@@ -207,9 +207,11 @@ def forward(self, x, w1, w2):
207207def replace_pattern_with_filters (
208208 gm : GraphModule ,
209209 pattern : Union [Callable , Graph , GraphModule ],
210- replacement : Union [Callable , Graph , GraphModule ] ,
210+ replacement : Union [Callable , Graph , GraphModule , None ] = None ,
211211 match_filters : Optional [List [Callable [["InternalMatch" , Graph , Graph ], bool ]]] = None ,
212212 ignore_literals : bool = False ,
213+ # Placed at the end to avoid breaking backward compatibility
214+ replacement_callback : Optional [Callable [["InternalMatch" , Graph , Graph ], Graph ]] = None ,
213215) -> List [ReplacedPatterns ]:
214216 """
215217 See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
@@ -219,17 +221,22 @@ def replace_pattern_with_filters(
219221 (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
220222 whether the match satisfies the condition.
221223 See matcher_utils.py for definition of InternalMatch.
224+ ``replacement_callback``: A function that takes in a match and returns a
225+ Graph to be used as the replacement. This allows you to construct a
226+ replacement graph based on the match.
222227 """
223228
224- return _replace_pattern (gm , pattern , replacement , match_filters , ignore_literals )
229+ return _replace_pattern (gm , pattern , replacement , match_filters , ignore_literals , replacement_callback )
225230
226231
227232def _replace_pattern (
228233 gm : GraphModule ,
229234 pattern : Union [Callable , Graph , GraphModule ],
230- replacement : Union [Callable , Graph , GraphModule ] ,
235+ replacement : Union [Callable , Graph , GraphModule , None ] = None ,
231236 match_filters : Optional [List [Callable [["InternalMatch" , Graph , Graph ], bool ]]] = None ,
232237 ignore_literals : bool = False ,
238+ # Placed at the end to avoid breaking backward compatibility
239+ replacement_callback : Optional [Callable [["InternalMatch" , Graph , Graph ], Graph ]] = None ,
233240) -> List [ReplacedPatterns ]:
234241
235242 from torch .fx .passes .utils .matcher_utils import SubgraphMatcher , InternalMatch
@@ -247,13 +254,6 @@ def _replace_pattern(
247254 else :
248255 pattern_graph = symbolic_trace (pattern ).graph
249256
250- if isinstance (replacement , GraphModule ):
251- replacement_graph = replacement .graph
252- elif isinstance (replacement , Graph ):
253- replacement_graph = replacement
254- else :
255- replacement_graph = symbolic_trace (replacement ).graph
256-
257257 matcher = SubgraphMatcher (pattern_graph , match_output = False , match_placeholder = False ,
258258 remove_overlapping_matches = True , ignore_literals = ignore_literals )
259259 _matches : List [InternalMatch ] = matcher .match (original_graph )
@@ -265,13 +265,27 @@ def _replace_pattern(
265265 for match_filter in match_filters )
266266 ]
267267
268- replacement_placeholders = [n for n in replacement_graph .nodes if n .op == "placeholder" ]
268+ if isinstance (replacement , GraphModule ):
269+ common_replacement_graph = replacement .graph
270+ elif isinstance (replacement , Graph ):
271+ common_replacement_graph = replacement
272+ elif callable (replacement ):
273+ common_replacement_graph = symbolic_trace (replacement ).graph
274+ else :
275+ assert replacement_callback is not None , "Must provide either a replacement GraphModule or a replacement callback"
276+ common_replacement_graph = None
269277
270278 # As we progressively replace nodes, we'll need to keep track of how the match results should change
271279 match_changed_node : Dict [Node , Node ] = {}
272280
273281 match_and_replacements = []
274- for match in _matches :
282+ for i , match in enumerate (_matches ):
283+ if replacement_callback is not None :
284+ replacement_graph = replacement_callback (match , original_graph , pattern_graph )
285+ else :
286+ assert common_replacement_graph is not None , "Must provide either a replacement GraphModule or a replacement callback"
287+ replacement_graph = common_replacement_graph
288+ replacement_placeholders = [n for n in replacement_graph .nodes if n .op == "placeholder" ]
275289
276290 # Build connecting between replacement graph's input and original graph input producer node
277291
0 commit comments