Skip to content

Commit 6f54f6f

Browse files
committed
support paddle elementwise_floordiv
1 parent 2bc32a1 commit 6f54f6f

File tree

4 files changed

+75
-0
lines changed

4 files changed

+75
-0
lines changed

src/core/tests/frontend/paddle/op_fuzzy.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ static const std::vector<std::string> models{
126126
std::string("elementwise_mul4"),
127127
std::string("elementwise_pow4"),
128128
std::string("elementwise_sub4"),
129+
std::string("elementwise_floordiv_int32_1"),
130+
std::string("elementwise_floordiv_int32_2"),
131+
std::string("elementwise_floordiv_int32_3"),
132+
std::string("elementwise_floordiv_int64_1"),
133+
std::string("elementwise_floordiv_int64_2"),
134+
std::string("elementwise_floordiv_int64_3"),
129135
std::string("embedding_0"),
130136
std::string("embedding_sparse"),
131137
std::string("embedding_none_weight"),

src/core/tests/frontend/paddle/test_models/gen_scripts/generate_elementwise_ops.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,34 @@ def elementwise_pow(name : str, x, y, axis, in_dtype):
162162

163163
return outs[0]
164164

165+
166+
def elementwise_floordiv(name : str, x, y, axis, in_dtype):
167+
import paddle
168+
paddle.enable_static()
169+
170+
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
171+
node_x = paddle.static.data(name = 'x', shape = x.shape, dtype = in_dtype)
172+
node_y = paddle.static.data(name = 'y', shape = y.shape, dtype = in_dtype)
173+
if paddle.__version__ == "1.8":
174+
out = paddle.fluid.layers.nn.elementwise_floordiv(node_x, node_y, axis=axis)
175+
else:
176+
if axis != -1:
177+
pass
178+
out = paddle.floor_divide(node_x, node_y)
179+
180+
cpu = paddle.static.cpu_places(1)
181+
exe = paddle.static.Executor(cpu[0])
182+
183+
# startup program will call initializer to initialize the parameters.
184+
exe.run(paddle.static.default_startup_program())
185+
outs = exe.run(
186+
feed={'x': x, 'y': y},
187+
fetch_list=[out])
188+
saveModel(name, exe, feedkeys=['x', 'y'], fetchlist=[out], inputs=[x, y], outputs=[outs[0]], target_dir=sys.argv[1])
189+
190+
return outs[0]
191+
192+
165193
def elementwise_ops(name : str, data_x, data_y, axis, in_dtype):
166194
elementwise_add("elementwise_add" + name, data_x, data_y, axis, in_dtype)
167195
elementwise_sub("elementwise_sub" + name, data_x, data_y, axis, in_dtype)
@@ -193,5 +221,29 @@ def main():
193221
axis = 0
194222
elementwise_ops("4", data_x, data_y, axis, in_dtype)
195223

224+
# test for elementwise_floordiv, support int and int64
225+
# paddle1.8 support axis = [0, x_last_dims]
226+
# paddle2.x only support axis = -1
227+
floordiv_support_dtype = ['int64', 'int32']
228+
data_x = np.array([-4, 0, -8])
229+
230+
data_y = np.array([3, 5, 3])
231+
axis = -1
232+
for dtype in floordiv_support_dtype:
233+
elementwise_floordiv("elementwise_floordiv_" + dtype + "_1",
234+
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)
235+
236+
data_x = np.random.randint(-10, 10, [2, 5, 3, 4])
237+
data_y = np.random.randint(1, 5, [3, 4])
238+
for dtype in floordiv_support_dtype:
239+
elementwise_floordiv("elementwise_floordiv_" + dtype + "_2",
240+
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)
241+
242+
data_y = np.random.randint(1, 5, [5, 3, 4])
243+
for dtype in floordiv_support_dtype:
244+
elementwise_floordiv("elementwise_floordiv_" + dtype + "_3",
245+
data_x.astype(dtype), data_y.astype(dtype), axis, dtype)
246+
247+
196248
if __name__ == "__main__":
197249
main()

src/frontends/paddle/src/op/elementwise_ops.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ NamedOutputs elementwise_greater_equal(const NodeContext& node_context) {
4646
return elementwise_ops<default_opset::GreaterEqual>(node_context);
4747
}
4848

49+
NamedOutputs elementwise_floordiv(const NodeContext& node_context) {
50+
auto x = node_context.get_input("X");
51+
auto y = node_context.get_input("Y");
52+
auto axis = -1;
53+
if (node_context.has_attribute("axis")) {
54+
axis = node_context.get_attribute<int>("axis");
55+
}
56+
return node_context.default_single_output_mapping(
57+
{std::make_shared<default_opset::Divide>(x,
58+
y,
59+
false,
60+
ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::PDPD, axis))},
61+
{"Out"});
62+
}
63+
4964
} // namespace op
5065
} // namespace paddle
5166
} // namespace frontend

src/frontends/paddle/src/op_table.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ OP_CONVERTER(dropout);
2525
OP_CONVERTER(elementwise_add);
2626
OP_CONVERTER(elementwise_div);
2727
OP_CONVERTER(elementwise_equal);
28+
OP_CONVERTER(elementwise_floordiv);
2829
OP_CONVERTER(elementwise_greater_equal);
2930
OP_CONVERTER(elementwise_max);
3031
OP_CONVERTER(elementwise_min);
@@ -123,6 +124,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
123124
{"dropout", op::dropout},
124125
{"elementwise_add", op::elementwise_add},
125126
{"elementwise_div", op::elementwise_div},
127+
{"elementwise_floordiv", op::elementwise_floordiv},
126128
{"elementwise_max", op::elementwise_max},
127129
{"elementwise_min", op::elementwise_min},
128130
{"elementwise_mul", op::elementwise_mul},

0 commit comments

Comments
 (0)