Skip to content

Commit f3022df

Browse files
authored
add elementwise sub and elementwise div in tensorrt op teller (#40806)
* add elementwise sub and elementwise div in tensorrt op teller * add unittest of elementwise mul, sub and div
1 parent c544a18 commit f3022df

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ struct SimpleOpTypeSetTeller : public Teller {
7474
"tanh",
7575
"pad",
7676
"elementwise_add",
77+
"elementwise_sub",
7778
"elementwise_mul",
79+
"elementwise_div",
7880
"dropout",
7981
"prelu",
8082
"conv2d_transpose",
@@ -133,7 +135,9 @@ struct SimpleOpTypeSetTeller : public Teller {
133135
"tanh",
134136
"pad",
135137
"elementwise_add",
138+
"elementwise_sub",
136139
"elementwise_mul",
140+
"elementwise_div",
137141
"dropout",
138142
"prelu",
139143
"conv2d_transpose",
@@ -958,7 +962,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
958962
}
959963
}
960964

961-
if (op_type == "elementwise_add" || op_type == "elementwise_mul") {
965+
if (op_type == "elementwise_add" || op_type == "elementwise_mul" ||
966+
op_type == "elementwise_sub" || op_type == "elementwise_div") {
962967
if (desc.Input("X").size() != 1) {
963968
VLOG(3) << "The input op's Input(\"X\").size() "
964969
"should equal to 1, but received Input(\"X\").size() = "

python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ def generate_input(shape):
150150
return np.random.random(shape).astype(np.float32)
151151

152152
for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]:
153-
for op_type in ["elementwise_add", "elementwise_mul"]:
153+
for op_type in [
154+
"elementwise_add", "elementwise_mul", "elementwise_sub",
155+
"elementwise_div"
156+
]:
154157
for axis in [0, -1]:
155158
self.dims = len(shape)
156159
dics = [{"axis": axis}]
@@ -306,7 +309,10 @@ def generate_input(shape):
306309
input1_shape = input1_shape_list[i]
307310
for j in range(6):
308311
input2_shape = input2_shape_list[j][i]
309-
for op_type in ["elementwise_add", "elementwise_mul"]:
312+
for op_type in [
313+
"elementwise_add", "elementwise_mul", "elementwise_sub",
314+
"elementwise_div"
315+
]:
310316
for axis in axis_list[j][i]:
311317
self.shape1 = input1_shape
312318
self.shape2 = input2_shape

python/paddle/fluid/tests/unittests/ir/inference/test_trt_elementwise_op.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,23 @@ def test_check_output(self):
5656
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
5757

5858

59+
class TensorRTSubgraphPassElementwiseBroadcastTest1(
60+
TensorRTSubgraphPassElementwiseBroadcastTest):
61+
def append_eltwise(self, data1, data2):
62+
return fluid.layers.elementwise_sub(x=data1, y=data2, axis=0)
63+
64+
65+
class TensorRTSubgraphPassElementwiseBroadcastTest2(
66+
TensorRTSubgraphPassElementwiseBroadcastTest):
67+
def append_eltwise(self, data1, data2):
68+
return fluid.layers.elementwise_mul(x=data1, y=data2, axis=0)
69+
70+
71+
class TensorRTSubgraphPassElementwiseBroadcastTest3(
72+
TensorRTSubgraphPassElementwiseBroadcastTest):
73+
def append_eltwise(self, data1, data2):
74+
return fluid.layers.elementwise_div(x=data1, y=data2, axis=0)
75+
76+
5977
if __name__ == "__main__":
6078
unittest.main()

0 commit comments

Comments
 (0)