Skip to content

Commit 1343aaa

Browse files
committed
support ernie quant model with interleaved
1 parent 3b53b92 commit 1343aaa

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ PDNode* MultiHeadMatmulPattern::operator()() {
314314
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
315315
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
316316
->assert_is_op_output("reshape2");
317+
reshape2_qkv_out_var->assert_is_op_input("mul");
317318

318319
// Second path to matmul
319320
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul");
@@ -499,7 +500,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
499500
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
500501
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
501502
->assert_is_op_output("reshape2");
502-
503+
reshape2_qkv_out_var->assert_is_ops_input(matmul_ops);
503504
// Second path to matmul
504505
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(matmul_ops);
505506
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())

0 commit comments

Comments
 (0)