Skip to content

Commit 4e4f8cc

Browse files
committed
add matmul_broadcast_unitest
1 parent 5a7bd4a commit 4e4f8cc

File tree

3 files changed

+60
-22
lines changed

3 files changed

+60
-22
lines changed

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
198198
// "embedding_fc_lstm_fuse_pass", //
199199
// TODO(wilber): fix correctness problem.
200200
// "fc_lstm_fuse_pass", //
201-
"mul_lstm_fuse_pass", //
202-
"fc_gru_fuse_pass", //
203-
"mul_gru_fuse_pass", //
204-
"seq_concat_fc_fuse_pass", //
205-
"squeeze2_matmul_fuse_pass", //
206-
"reshape2_matmul_fuse_pass", //
207-
"flatten2_matmul_fuse_pass", //
208-
"map_matmul_v2_to_mul_pass", //
209-
"map_matmul_v2_to_matmul_pass", //
201+
"mul_lstm_fuse_pass", //
202+
"fc_gru_fuse_pass", //
203+
"mul_gru_fuse_pass", //
204+
"seq_concat_fc_fuse_pass", //
205+
"squeeze2_matmul_fuse_pass", //
206+
"reshape2_matmul_fuse_pass", //
207+
"flatten2_matmul_fuse_pass", //
208+
"map_matmul_v2_to_mul_pass", //
209+
// "map_matmul_v2_to_matmul_pass", //
210210
"map_matmul_to_mul_pass", //
211211
"fc_fuse_pass", //
212212
"repeated_fc_relu_fuse_pass", //

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -340,19 +340,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
340340
return false;
341341
}
342342

343-
for (auto& param_name : desc.Inputs()) {
344-
for (auto& var_name : param_name.second) {
345-
auto* var_desc = block->FindVar(var_name);
346-
const auto shape = var_desc->GetShape();
347-
if (shape.size() < 3) {
348-
VLOG(3)
349-
<< "matmul op dims < 3 not supported in tensorrt, but got dims "
350-
<< shape.size() << ", so jump it.";
351-
return false;
352-
}
353-
}
354-
}
355-
356343
// not support broadcast
357344
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
358345
auto* y_var_desc = block->FindVar(desc.Input("Y")[0]);
@@ -371,6 +358,19 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
371358
return false;
372359
}
373360
}
361+
362+
for (auto& param_name : desc.Inputs()) {
363+
for (auto& var_name : param_name.second) {
364+
auto* var_desc = block->FindVar(var_name);
365+
const auto shape = var_desc->GetShape();
366+
if (shape.size() < 3) {
367+
VLOG(3)
368+
<< "matmul op dims < 3 not supported in tensorrt, but got dims "
369+
<< shape.size() << ", so jump it.";
370+
return false;
371+
}
372+
}
373+
}
374374
}
375375
if (op_type == "softmax") {
376376
auto* block = desc.Block();

python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,43 @@ def set_params(self):
107107
self.alpha = 2.0
108108

109109

110+
class TensorRTMatMulBroadcastTest(InferencePassTest):
111+
def setUp(self):
112+
self.set_params()
113+
place = fluid.CPUPlace()
114+
with fluid.program_guard(self.main_program, self.startup_program):
115+
data_x = fluid.data(
116+
name="data_x", shape=[-1, 6, 24], dtype="float32")
117+
data_y = fluid.data(name="data_y", shape=[24, 16], dtype="float32")
118+
matmul_out = fluid.layers.matmul(
119+
x=data_x,
120+
y=data_y,
121+
transpose_x=self.transpose_x,
122+
transpose_y=self.transpose_y,
123+
alpha=self.alpha)
124+
out = fluid.layers.batch_norm(matmul_out, is_test=True)
125+
126+
self.feeds = {
127+
"data_x": np.ones([2, 6, 24]).astype("float32"),
128+
"data_y": np.ones([24, 16]).astype("float32")
129+
}
130+
self.enable_trt = True
131+
self.trt_parameters = TensorRTMatMulBroadcastTest.TensorRTParam(
132+
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
133+
self.fetch_list = [out]
134+
135+
def set_params(self):
136+
self.transpose_x = False
137+
self.transpose_y = False
138+
self.alpha = 1.0
139+
140+
def test_check_output(self):
141+
if core.is_compiled_with_cuda():
142+
use_gpu = True
143+
self.check_output_with_option(use_gpu)
144+
self.assertTrue(
145+
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
146+
147+
110148
if __name__ == "__main__":
111149
unittest.main()

0 commit comments

Comments
 (0)