Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def is_empty_grad_op(op_type):
all_op_kernels = core._get_all_register_op_kernels()
grad_op = op_type + '_grad'
if grad_op in all_op_kernels.keys():
if is_mkldnn_op_test():
if is_onednn_op_test():
grad_op_kernels = all_op_kernels[grad_op]
for grad_op_kernel in grad_op_kernels:
if 'MKLDNN' in grad_op_kernel:
Expand All @@ -489,8 +489,10 @@ def is_empty_grad_op(op_type):
def is_xpu_op_test():
return hasattr(cls, "use_xpu") and cls.use_xpu

def is_mkldnn_op_test():
return hasattr(cls, "use_mkldnn") and cls.use_mkldnn
def is_onednn_op_test():
return (hasattr(cls, "use_mkldnn") and cls.use_mkldnn) or (
hasattr(cls, "use_onednn") and cls.use_onednn
)

def is_rocm_op_test():
return core.is_compiled_with_rocm()
Expand Down Expand Up @@ -534,7 +536,7 @@ def is_complex_test():
not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST
and not hasattr(cls, 'exist_fp64_check_grad')
and not is_xpu_op_test()
and not is_mkldnn_op_test()
and not is_onednn_op_test()
and not is_rocm_op_test()
and not is_custom_device_op_test()
and not cls.check_prim
Expand Down Expand Up @@ -578,6 +580,10 @@ def is_bfloat16_op(self):
and 'mkldnn_data_type' in self.attrs
and self.attrs['mkldnn_data_type'] == 'bfloat16'
)
or (
hasattr(self, 'onednn_data_type')
and self.onednn_data_type == "bfloat16"
)
)

def is_float16_op(self):
Expand All @@ -599,13 +605,21 @@ def is_float16_op(self):
and 'mkldnn_data_type' in self.attrs
and self.attrs['mkldnn_data_type'] == 'float16'
)
or (
hasattr(self, 'onednn_data_type')
and self.onednn_data_type == "float16"
)
)

def is_mkldnn_op(self):
return (hasattr(self, "use_mkldnn") and self.use_mkldnn) or (
hasattr(self, "attrs")
and "use_mkldnn" in self.attrs
and self.attrs["use_mkldnn"]
def is_onednn_op(self):
return (
(hasattr(self, "use_mkldnn") and self.use_mkldnn)
or (hasattr(self, "use_onednn") and self.use_onednn)
or (
hasattr(self, "attrs")
and "use_mkldnn" in self.attrs
and self.attrs["use_mkldnn"]
)
)

def is_xpu_op(self):
Expand Down Expand Up @@ -867,7 +881,7 @@ def _append_ops(self, block):
self.__class__.op_type = (
self.op_type
) # for ci check, please not delete it for now
if self.is_mkldnn_op():
if self.is_onednn_op():
self.__class__.use_mkldnn = True

if self.is_xpu_op():
Expand Down Expand Up @@ -2204,7 +2218,7 @@ def check_output_with_place(
):
core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False)
if not self.is_mkldnn_op():
if not self.is_onednn_op():
set_flags({"FLAGS_use_mkldnn": False})

if hasattr(self, "use_custom_device") and self.use_custom_device:
Expand Down Expand Up @@ -2683,7 +2697,7 @@ def infer_and_compare_symbol(self):
atol = 0

if self.is_bfloat16_op():
if self.is_mkldnn_op():
if self.is_onednn_op():
check_dygraph = False

if (
Expand Down Expand Up @@ -2940,7 +2954,7 @@ def check_output(
check_symbol_infer=True,
):
self.__class__.op_type = self.op_type
if self.is_mkldnn_op():
if self.is_onednn_op():
self.__class__.use_mkldnn = True

if self.is_xpu_op():
Expand Down Expand Up @@ -3273,7 +3287,7 @@ def check_grad_with_place(
if hasattr(self, "use_custom_device") and self.use_custom_device:
check_dygraph = False

if not self.is_mkldnn_op():
if not self.is_onednn_op():
set_flags({"FLAGS_use_mkldnn": False})

core._set_prim_all_enabled(False)
Expand Down Expand Up @@ -3382,7 +3396,7 @@ def check_grad_with_place(
op_attrs = self.attrs if hasattr(self, "attrs") else {}
self._check_grad_helper()
if self.is_bfloat16_op():
if self.is_mkldnn_op():
if self.is_onednn_op():
check_dygraph = False
atol = max(atol, 0.01)

Expand Down
18 changes: 9 additions & 9 deletions test/mkldnn/test_conv3d_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,55 +27,55 @@

class TestMKLDNN(TestConv3DOp):
def init_kernel_type(self):
self.use_mkldnn = True
self.use_onednn = True
self.data_format = "NCHW"
self.dtype = np.float32
self.check_pir_onednn = True


class TestMKLDNNCase1(TestCase1):
def init_kernel_type(self):
self.use_mkldnn = True
self.use_onednn = True
self.data_format = "NCHW"
self.dtype = np.float32
self.check_pir_onednn = True


class TestMKLDNNGroup1(TestWithGroup1):
def init_kernel_type(self):
self.use_mkldnn = True
self.use_onednn = True
self.data_format = "NCHW"
self.dtype = np.float32
self.check_pir_onednn = True


class TestMKLDNNGroup2(TestWithGroup2):
def init_kernel_type(self):
self.use_mkldnn = True
self.use_onednn = True
self.data_format = "NCHW"
self.dtype = np.float32
self.check_pir_onednn = True


class TestMKLDNNWith1x1(TestWith1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.use_onednn = True
self.data_format = "NCHW"
self.dtype = np.float32
self.check_pir_onednn = True


class TestMKLDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.use_onednn = True
self.data_format = "NCHW"
self.dtype = np.float32
self.check_pir_onednn = True


class TestConv3DOp_AsyPadding_MKLDNN(TestConv3DOp):
def init_kernel_type(self):
self.use_mkldnn = True
self.use_onednn = True
self.data_format = "NCHW"
self.dtype = np.float32
self.check_pir_onednn = True
Expand All @@ -91,7 +91,7 @@ def init_paddings(self):
self.padding_algorithm = "SAME"

def init_kernel_type(self):
self.use_mkldnn = True
self.use_onednn = True
self.data_format = "NCHW"
self.dtype = np.float32
self.check_pir_onednn = True
Expand All @@ -103,7 +103,7 @@ def init_paddings(self):
self.padding_algorithm = "VALID"

def init_kernel_type(self):
self.use_mkldnn = True
self.use_onednn = True
self.data_format = "NCHW"
self.dtype = np.float32
self.check_pir_onednn = True
Expand Down
Loading