Skip to content

Commit 4a99b6f

Browse files
committed
fix for matmul_v2 6D x 2D
1 parent 71cb3ff commit 4a99b6f

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ class MatMulV2MKLDNNKernel
148148
if (x_dims.size() == 1) {
149149
x_bd_dims[x_bd_dims.size() - 1] = x_dims[0];
150150
} else if (x_dims.size() == 2) {
151-
x_bd_dims[2] = x_dims[1];
152-
x_bd_dims[1] = x_dims[0];
151+
x_bd_dims[x_bd_dims.size() - 1] = x_dims[1];
152+
x_bd_dims[x_bd_dims.size() - 2] = x_dims[0];
153153
} else {
154154
for (size_t i = 0; i < x_dims.size(); ++i) {
155155
x_bd_dims[i] = x_dims[i];
@@ -158,8 +158,8 @@ class MatMulV2MKLDNNKernel
158158
if (y_dims.size() == 1) {
159159
y_bd_dims[x_bd_dims.size() - 2] = y_dims[0];
160160
} else if (y_dims.size() == 2) {
161-
y_bd_dims[2] = y_dims[1];
162-
y_bd_dims[1] = y_dims[0];
161+
y_bd_dims[y_bd_dims.size() - 1] = y_dims[1];
162+
y_bd_dims[y_bd_dims.size() - 2] = y_dims[0];
163163
} else {
164164
for (size_t i = 0; i < y_dims.size(); ++i) {
165165
y_bd_dims[i] = y_dims[i];

python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,22 @@ def config(self):
235235
self.trans_y = True
236236

237237

238+
class TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
239+
def config(self):
240+
self.x_shape = (1, 1, 2, 1, 8, 9)
241+
self.y_shape = (9, 12)
242+
self.trans_x = False
243+
self.trans_y = False
244+
245+
246+
class TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
247+
def config(self):
248+
self.x_shape = (20, 5)
249+
self.y_shape = (1, 2, 1, 5, 11)
250+
self.trans_x = False
251+
self.trans_y = False
252+
253+
238254
# BF16 TESTS
239255
def create_bf16_test_class(parent):
240256
@OpTestTool.skip_if_not_cpu_bf16()
@@ -274,7 +290,8 @@ def calculate_grads(self):
274290
2: [1, 0],
275291
3: [0, 2, 1],
276292
4: [0, 1, 3, 2],
277-
5: [0, 1, 2, 4, 3]
293+
5: [0, 1, 2, 4, 3],
294+
6: [0, 1, 2, 3, 5, 4]
278295
}
279296

280297
# expand vector so it will be a valid matrix for multiplication
@@ -370,6 +387,8 @@ def calculate_grads(self):
370387
create_bf16_test_class(TestMatMulV2MatrixXMatrixTransposeXTransposeYOneDNNOp)
371388
create_bf16_test_class(TestMatMulV2MatrixXMatrixTransposeY2OneDNNOp)
372389
create_bf16_test_class(TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp)
390+
create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp)
391+
create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp)
373392

374393
if __name__ == "__main__":
375394
paddle.enable_static()

0 commit comments

Comments
 (0)