Skip to content

Commit aa05988

Browse files
author
heliqi
authored
Add unittest for flatten2_matmul squeeze2_matmul reshape2_matmul pass (#37644)
* add flatten2_matmul squeeze2_matmul reshape2_matmul test case * modify skip func to ignore_pass_case func * rebuild CI * add test_xx_matmul_fuse_pass timeout * add test_map_xx_pass timeout * add max_duration of test cast * add trt skip * add timeout * del commented code
1 parent bb38b6a commit aa05988

File tree

5 files changed

+554
-8
lines changed

5 files changed

+554
-8
lines changed

paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,14 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
169169
.AddInput("X")
170170
.IsTensor()
171171
.End()
172-
.AddInput("Y")
173-
.IsTensor()
174-
.End()
175172
.AddOutput("Out")
176173
.IsTensor()
177174
.End()
178175
.AddOutput("XShape")
179176
.IsTensor()
180177
.End()
181178
.AddAttr("axis")
182-
.IsNumGE(0)
179+
.IsNumEQ(1)
183180
.End();
184181

185182
AddOpCompat(OpCompat("mul"))
@@ -222,7 +219,7 @@ Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
222219
.IsBoolEQ(false)
223220
.End();
224221

225-
AddOpCompat(OpCompat("Squeeze2"))
222+
AddOpCompat(OpCompat("squeeze2"))
226223
.AddInput("X")
227224
.IsTensor()
228225
.End()
@@ -593,10 +590,10 @@ Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
593590
.IsNumLT(1.00001f)
594591
.End()
595592
.AddAttr("transpose_X")
596-
.IsBoolEQ("False")
593+
.IsBoolEQ(false)
597594
.End()
598595
.AddAttr("transpose_Y")
599-
.IsBoolEQ("False")
596+
.IsBoolEQ(false)
600597
.End();
601598

602599
AddOpCompat(OpCompat("mul"))

python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU)
8383
set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120)
8484
set_tests_properties(test_conv_elementwise_add2_act_fuse_pass PROPERTIES TIMEOUT 120)
8585
set_tests_properties(test_conv_elementwise_add_act_fuse_pass PROPERTIES TIMEOUT 120)
86-
set_tests_properties(test_conv_elementwise_add_act_fuse_pass PROPERTIES TIMEOUT 90)
8786
set_tests_properties(test_matmul_scale_fuse_pass PROPERTIES TIMEOUT 60)
8887
set_tests_properties(test_matmul_v2_scale_fuse_pass PROPERTIES TIMEOUT 60)
88+
set_tests_properties(test_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
89+
set_tests_properties(test_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
90+
set_tests_properties(test_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
8991
endif()
9092

9193
if (WITH_MKLDNN)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from auto_scan_test import PassAutoScanTest, IgnoreReasons
16+
from program_config import TensorConfig, ProgramConfig, OpConfig
17+
import numpy as np
18+
import paddle.inference as paddle_infer
19+
from functools import partial
20+
from typing import Optional, List, Callable, Dict, Any, Set
21+
import unittest
22+
23+
import hypothesis
24+
from hypothesis import given, settings, seed, example, assume, reproduce_failure
25+
import hypothesis.strategies as st
26+
27+
28+
class TestFlatten2MatmulFusePass(PassAutoScanTest):
29+
"""
30+
x_var
31+
|
32+
flatten2
33+
\
34+
flatten2_out_var y_var
35+
\ /
36+
matmul bias_var
37+
\ /
38+
elementwise_add
39+
"""
40+
41+
def sample_predictor_configs(self, program_config):
42+
# TRT
43+
# config = self.create_trt_inference_config()
44+
# config.enable_tensorrt_engine(
45+
# max_batch_size=10,
46+
# workspace_size=102400,
47+
# min_subgraph_size=0,
48+
# precision_mode=paddle_infer.PrecisionType.Float32,
49+
# use_static=False,
50+
# use_calib_mode=False)
51+
# yield config, ['mul', 'elementwise_add'], (1e-5, 1e-5)
52+
53+
# cpu
54+
config = self.create_inference_config(use_gpu=False)
55+
yield config, ["mul", "elementwise_add"], (1e-5, 1e-5)
56+
57+
# for gpu
58+
config = self.create_inference_config(use_gpu=True)
59+
yield config, ["mul", "elementwise_add"], (1e-5, 1e-5)
60+
61+
def add_ignore_pass_case(self):
62+
# Here we put some skip rules to avoid known bugs
63+
def teller1(program_config, predictor_config):
64+
if predictor_config.tensorrt_engine_enabled():
65+
# On 3080, the results of MatMul and Mul are different
66+
# When the input Y is weight
67+
return True
68+
69+
# On TRT when the input Y is weight, Mul is converted to FC
70+
if "matmul_y" not in program_config.weights \
71+
or "bias" not in program_config.weights:
72+
return True
73+
74+
y_shape = list(program_config.weights["matmul_y"].shape)
75+
bias_shape = program_config.weights["bias"].shape
76+
axis = program_config.ops[2].attrs["axis"]
77+
# bias should be [mul_y_shape[-1]]
78+
if axis == 0 or bias_shape[0] != y_shape[1] or len(
79+
bias_shape) != 1:
80+
return True
81+
return False
82+
83+
self.add_ignore_check_case(
84+
teller1,
85+
IgnoreReasons.PASS_ACCURACY_ERROR,
86+
"The pass error on TRT while shape of bias is not [out_size].", )
87+
88+
def sample_program_config(self, draw):
89+
# 1. Generate shape and attr of flatten2
90+
x_shape = draw(
91+
st.lists(
92+
st.integers(
93+
min_value=1, max_value=10), min_size=4, max_size=4))
94+
# [a, b, c, d] => [a, b*c*d]
95+
flatten_axis = 1
96+
flatten_shape = [x_shape[0], x_shape[1] * x_shape[2] * x_shape[3]]
97+
98+
# 2. Generate attr:transpose_X/transpose_Y/alpha of matmul
99+
alpha = 1.0
100+
transpose_X = False
101+
transpose_Y = False
102+
103+
# 3. Generate legal shape of input:Y of matmul
104+
y_shape = draw(
105+
st.lists(
106+
st.integers(
107+
min_value=1, max_value=8), min_size=2, max_size=2))
108+
y_shape[0] = flatten_shape[1]
109+
110+
# 4. Generate legal attr:axis of elementwise_add
111+
axis = draw(st.integers(min_value=-1, max_value=1))
112+
if axis == 0:
113+
bias_shape = [flatten_shape[0], ]
114+
elif axis == 1:
115+
bias_shape = [y_shape[1]]
116+
else:
117+
bias_shape = [flatten_shape[0], y_shape[1]]
118+
if draw(st.booleans()):
119+
bias_shape[1] = 1
120+
121+
flatten2_op = OpConfig(
122+
"flatten2",
123+
inputs={"X": ["flatten2_x"], },
124+
axis=flatten_axis,
125+
outputs={"Out": ["flatten2_out"],
126+
"XShape": ["xshape"]}, )
127+
matmul_op = OpConfig(
128+
"matmul",
129+
inputs={"X": ["flatten2_out"],
130+
"Y": ["matmul_y"]},
131+
outputs={"Out": ["matmul_out"]},
132+
alpha=alpha,
133+
transpose_X=transpose_X,
134+
transpose_Y=transpose_Y,
135+
fused_reshape_X=[],
136+
fused_reshape_Y=[],
137+
fused_transpose_X=[],
138+
fused_transpose_Y=[],
139+
fused_reshape_Out=[],
140+
fused_transpose_Out=[], )
141+
142+
add_op = OpConfig(
143+
"elementwise_add",
144+
inputs={"X": ["matmul_out"],
145+
"Y": ["bias"]},
146+
outputs={"Out": ["add_out"]},
147+
axis=axis, )
148+
149+
ops = [flatten2_op, matmul_op, add_op]
150+
151+
if draw(st.integers(min_value=1, max_value=10)) <= 8:
152+
program_config = ProgramConfig(
153+
ops=ops,
154+
weights={
155+
"matmul_y": TensorConfig(shape=y_shape),
156+
"bias": TensorConfig(shape=bias_shape),
157+
},
158+
inputs={"flatten2_x": TensorConfig(shape=x_shape), },
159+
outputs=ops[-1].outputs["Out"], )
160+
else:
161+
program_config = ProgramConfig(
162+
ops=ops,
163+
weights={},
164+
inputs={
165+
"flatten2_x": TensorConfig(shape=x_shape),
166+
"matmul_y": TensorConfig(shape=y_shape),
167+
"bias": TensorConfig(shape=bias_shape),
168+
},
169+
outputs=ops[-1].outputs["Out"], )
170+
return program_config
171+
172+
def test(self):
173+
self.run_and_statis(
174+
quant=False,
175+
max_examples=50,
176+
max_duration=1000,
177+
passes=["flatten2_matmul_fuse_pass"])
178+
179+
180+
if __name__ == "__main__":
181+
unittest.main()

0 commit comments

Comments
 (0)