File tree Expand file tree Collapse file tree 1 file changed +2
-1
lines changed
paddle/fluid/framework/ir Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Original file line number Diff line number Diff line change @@ -314,6 +314,7 @@ PDNode* MultiHeadMatmulPattern::operator()() {
314314 pattern->NewNode (reshape2_qkv_repr ())->assert_is_op (" reshape2" );
315315 auto * reshape2_qkv_out_var = pattern->NewNode (reshape2_qkv_out_repr ())
316316 ->assert_is_op_output (" reshape2" );
317+ reshape2_qkv_out_var->assert_is_op_input (" mul" );
317318
318319 // Second path to matmul
319320 auto * mul1 = pattern->NewNode (mul1_repr ())->assert_is_op (" mul" );
@@ -499,7 +500,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
499500 pattern->NewNode (reshape2_qkv_repr ())->assert_is_op (" reshape2" );
500501 auto * reshape2_qkv_out_var = pattern->NewNode (reshape2_qkv_out_repr ())
501502 ->assert_is_op_output (" reshape2" );
502-
503+ reshape2_qkv_out_var-> assert_is_ops_input (matmul_ops);
503504 // Second path to matmul
504505 auto * mul1 = pattern->NewNode (mul1_repr ())->assert_is_ops (matmul_ops);
505506 auto * mul1_w_var = pattern->NewNode (mul1_w_repr ())
You can’t perform that action at this time.
0 commit comments