Skip to content

Commit 1db9cd4

Browse files
authored
fix xpu op test, *test=kunlun (PaddlePaddle#40862)
1 parent d43e843 commit 1db9cd4

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

python/paddle/fluid/tests/unittests/op_test_xpu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from white_list import op_accuracy_white_list, check_shape_white_list, compile_vs_runtime_white_list, no_check_set_white_list
3939
from white_list import op_threshold_white_list, no_grad_set_white_list
4040
from op_test import OpTest, _set_use_system_allocator, get_numeric_gradient
41+
from xpu.get_test_cover_info import is_empty_grad_op_type
4142

4243

4344
class XPUOpTest(OpTest):
@@ -108,6 +109,13 @@ def check_grad_with_place(self,
108109
check_dygraph=True,
109110
numeric_place=None,
110111
check_eager=False):
112+
if hasattr(self, 'op_type_need_check_grad'):
113+
xpu_version = core.get_xpu_device_version(0)
114+
if is_empty_grad_op_type(xpu_version, self.op_type,
115+
self.in_type_str):
116+
self._check_grad_helper()
117+
return
118+
111119
if place == None:
112120
place = paddle.XPUPlace(0)
113121

python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def is_empty_grad_op_type(xpu_version, op, test_type):
208208
if grad_op not in xpu_op_list.keys():
209209
return True
210210

211-
grad_op_types = xpu_op_list[op]
211+
grad_op_types = xpu_op_list[grad_op]
212212
paddle_test_type = type_dict_str_to_paddle[test_type]
213213
if paddle_test_type not in grad_op_types:
214214
return True
@@ -239,9 +239,11 @@ def create_test_class(func_globals,
239239
continue
240240
class_obj = test_class[1]
241241
cls_name = "{0}_{1}".format(test_class[0], str(test_type))
242-
func_globals[cls_name] = type(
243-
cls_name, (class_obj, ),
244-
{'in_type': type_dict_str_to_numpy[test_type]})
242+
func_globals[cls_name] = type(cls_name, (class_obj, ), {
243+
'in_type': type_dict_str_to_numpy[test_type],
244+
'in_type_str': test_type,
245+
'op_type_need_check_grad': True
246+
})
245247

246248
if hasattr(test_class_obj, 'use_dynamic_create_class'
247249
) and test_class_obj.use_dynamic_create_class:
@@ -250,6 +252,8 @@ def create_test_class(func_globals,
250252
cls_name = "{0}_{1}".format(dy_class[0], str(test_type))
251253
attr_dict = dy_class[1]
252254
attr_dict['in_type'] = type_dict_str_to_numpy[test_type]
255+
attr_dict['in_type_str'] = test_type
256+
attr_dict['op_type_need_check_grad'] = True
253257
func_globals[cls_name] = type(cls_name, (base_class, ), attr_dict)
254258

255259
record_op_test(op_name, test_type)

0 commit comments

Comments
 (0)