@@ -162,6 +162,34 @@ def elementwise_pow(name : str, x, y, axis, in_dtype):
162162
163163 return outs [0 ]
164164
165+
166+ def elementwise_floordiv (name : str , x , y , axis , in_dtype ):
167+ import paddle
168+ paddle .enable_static ()
169+
170+ with paddle .static .program_guard (paddle .static .Program (), paddle .static .Program ()):
171+ node_x = paddle .static .data (name = 'x' , shape = x .shape , dtype = in_dtype )
172+ node_y = paddle .static .data (name = 'y' , shape = y .shape , dtype = in_dtype )
173+ if paddle .__version__ == "1.8" :
174+ out = paddle .fluid .layers .nn .elementwise_floordiv (node_x , node_y , axis = axis )
175+ else :
176+ if axis != - 1 :
177+ pass
178+ out = paddle .floor_divide (node_x , node_y )
179+
180+ cpu = paddle .static .cpu_places (1 )
181+ exe = paddle .static .Executor (cpu [0 ])
182+
183+ # startup program will call initializer to initialize the parameters.
184+ exe .run (paddle .static .default_startup_program ())
185+ outs = exe .run (
186+ feed = {'x' : x , 'y' : y },
187+ fetch_list = [out ])
188+ saveModel (name , exe , feedkeys = ['x' , 'y' ], fetchlist = [out ], inputs = [x , y ], outputs = [outs [0 ]], target_dir = sys .argv [1 ])
189+
190+ return outs [0 ]
191+
192+
165193def elementwise_ops (name : str , data_x , data_y , axis , in_dtype ):
166194 elementwise_add ("elementwise_add" + name , data_x , data_y , axis , in_dtype )
167195 elementwise_sub ("elementwise_sub" + name , data_x , data_y , axis , in_dtype )
@@ -193,5 +221,29 @@ def main():
193221 axis = 0
194222 elementwise_ops ("4" , data_x , data_y , axis , in_dtype )
195223
224+ # test for elementwise_floordiv, support int and int64
225+ # paddle1.8 support axis = [0, x_last_dims]
226+ # paddle2.x only support axis = -1
227+ floordiv_support_dtype = ['int64' , 'int32' ]
228+ data_x = np .array ([- 4 , 0 , - 8 ])
229+
230+ data_y = np .array ([3 , 5 , 3 ])
231+ axis = - 1
232+ for dtype in floordiv_support_dtype :
233+ elementwise_floordiv ("elementwise_floordiv_" + dtype + "_1" ,
234+ data_x .astype (dtype ), data_y .astype (dtype ), axis , dtype )
235+
236+ data_x = np .random .randint (- 10 , 10 , [2 , 5 , 3 , 4 ])
237+ data_y = np .random .randint (1 , 5 , [3 , 4 ])
238+ for dtype in floordiv_support_dtype :
239+ elementwise_floordiv ("elementwise_floordiv_" + dtype + "_2" ,
240+ data_x .astype (dtype ), data_y .astype (dtype ), axis , dtype )
241+
242+ data_y = np .random .randint (1 , 5 , [5 , 3 , 4 ])
243+ for dtype in floordiv_support_dtype :
244+ elementwise_floordiv ("elementwise_floordiv_" + dtype + "_3" ,
245+ data_x .astype (dtype ), data_y .astype (dtype ), axis , dtype )
246+
247+
196248if __name__ == "__main__" :
197249 main ()
0 commit comments