Skip to content

Commit 47a73fe

Browse files
committed
fixed concat tests
1 parent 9361040 commit 47a73fe

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

python/paddle/fluid/tests/unittests/mkldnn/test_concat_bf16_mkldnn_op.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,14 @@ def setUp(self):
4040
'mkldnn_data_type': self.mkldnn_data_type
4141
}
4242

43+
self.sections = [self.x0.shape[self.axis]] * 2
44+
self.sections[1] += self.x1.shape[self.axis]
45+
4346
self.output = np.concatenate(
4447
(self.x0, self.x1, self.x2), axis=self.axis).astype(np.uint16)
4548
self.outputs = {'Out': self.output}
4649

47-
def calculate_grads(self):
50+
def calculate_grads(self):
4851
self.dout = self.outputs['Out']
4952
self.dxs = np.split(self.dout, self.sections, self.axis)
5053

@@ -73,9 +76,9 @@ def init_axis(self):
7376
self.axis = 0
7477

7578
def init_shape(self):
76-
self.x0_shape = [2, 2, 1, 2]
77-
self.x1_shape = [1, 2, 1, 2]
78-
self.x2_shape = [3, 2, 1, 2]
79+
self.x0_shape = [6, 2, 4, 3]
80+
self.x1_shape = [7, 2, 4, 3]
81+
self.x2_shape = [8, 2, 4, 3]
7982

8083

8184
# --------------------test concat bf16 in with axis 1--------------------
@@ -86,9 +89,9 @@ def init_axis(self):
8689
self.axis = 1
8790

8891
def init_shape(self):
89-
self.x0_shape = [1, 1, 5, 5]
90-
self.x1_shape = [1, 2, 5, 5]
91-
self.x2_shape = [1, 3, 5, 5]
92+
self.x0_shape = [1, 4, 5, 5]
93+
self.x1_shape = [1, 8, 5, 5]
94+
self.x2_shape = [1, 6, 5, 5]
9295

9396

9497
# --------------------test concat bf16 in with axis 2--------------------

python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ def setUp(self):
4141
self.output = np.concatenate(
4242
(self.x0, self.x1, self.x2), axis=self.axis).astype(self.dtype)
4343

44-
self.sections = [self.x0.shape[self.axis]] * 2
45-
self.sections[1] += self.x1.shape[self.axis]
46-
4744
self.outputs = {'Out': self.output}
4845

4946
def configure_datatype(self):

0 commit comments

Comments
 (0)