@@ -40,11 +40,14 @@ def setUp(self):
4040 'mkldnn_data_type' : self .mkldnn_data_type
4141 }
4242
43+ self .sections = [self .x0 .shape [self .axis ]] * 2
44+ self .sections [1 ] += self .x1 .shape [self .axis ]
45+
4346 self .output = np .concatenate (
4447 (self .x0 , self .x1 , self .x2 ), axis = self .axis ).astype (np .uint16 )
4548 self .outputs = {'Out' : self .output }
4649
47- def calculate_grads (self ):
50+ def calculate_grads (self ):
4851 self .dout = self .outputs ['Out' ]
4952 self .dxs = np .split (self .dout , self .sections , self .axis )
5053
@@ -73,9 +76,9 @@ def init_axis(self):
7376 self .axis = 0
7477
7578 def init_shape (self ):
76- self .x0_shape = [2 , 2 , 1 , 2 ]
77- self .x1_shape = [1 , 2 , 1 , 2 ]
78- self .x2_shape = [3 , 2 , 1 , 2 ]
79+ self .x0_shape = [6 , 2 , 4 , 3 ]
80+ self .x1_shape = [7 , 2 , 4 , 3 ]
81+ self .x2_shape = [8 , 2 , 4 , 3 ]
7982
8083
8184# --------------------test concat bf16 in with axis 1--------------------
@@ -86,9 +89,9 @@ def init_axis(self):
8689 self .axis = 1
8790
8891 def init_shape (self ):
89- self .x0_shape = [1 , 1 , 5 , 5 ]
90- self .x1_shape = [1 , 2 , 5 , 5 ]
91- self .x2_shape = [1 , 3 , 5 , 5 ]
92+ self .x0_shape = [1 , 4 , 5 , 5 ]
93+ self .x1_shape = [1 , 8 , 5 , 5 ]
94+ self .x2_shape = [1 , 6 , 5 , 5 ]
9295
9396
9497# --------------------test concat bf16 in with axis 2--------------------
0 commit comments