@@ -208,7 +208,7 @@ def is_empty_grad_op_type(xpu_version, op, test_type):
208208 if grad_op not in xpu_op_list .keys ():
209209 return True
210210
211- grad_op_types = xpu_op_list [op ]
211+ grad_op_types = xpu_op_list [grad_op ]
212212 paddle_test_type = type_dict_str_to_paddle [test_type ]
213213 if paddle_test_type not in grad_op_types :
214214 return True
@@ -239,9 +239,11 @@ def create_test_class(func_globals,
239239 continue
240240 class_obj = test_class [1 ]
241241 cls_name = "{0}_{1}" .format (test_class [0 ], str (test_type ))
242- func_globals [cls_name ] = type (
243- cls_name , (class_obj , ),
244- {'in_type' : type_dict_str_to_numpy [test_type ]})
242+ func_globals [cls_name ] = type (cls_name , (class_obj , ), {
243+ 'in_type' : type_dict_str_to_numpy [test_type ],
244+ 'in_type_str' : test_type ,
245+ 'op_type_need_check_grad' : True
246+ })
245247
246248 if hasattr (test_class_obj , 'use_dynamic_create_class'
247249 ) and test_class_obj .use_dynamic_create_class :
@@ -250,6 +252,8 @@ def create_test_class(func_globals,
250252 cls_name = "{0}_{1}" .format (dy_class [0 ], str (test_type ))
251253 attr_dict = dy_class [1 ]
252254 attr_dict ['in_type' ] = type_dict_str_to_numpy [test_type ]
255+ attr_dict ['in_type_str' ] = test_type
256+ attr_dict ['op_type_need_check_grad' ] = True
253257 func_globals [cls_name ] = type (cls_name , (base_class , ), attr_dict )
254258
255259 record_op_test (op_name , test_type )
0 commit comments