@@ -37,118 +37,156 @@ class MatMulOp : public framework::OperatorWithKernel {
3737 bool  transpose_x = context->Attrs ().Get <bool >(" transpose_X" 
3838 bool  transpose_y = context->Attrs ().Get <bool >(" transpose_Y" 
3939
40+  int  x_num_col_dims = context->Attrs ().Get <int >(" x_num_col_dims" 
41+  int  y_num_col_dims = context->Attrs ().Get <int >(" y_num_col_dims" 
42+ 
4043 PADDLE_ENFORCE_GE (dim_x.size (), 1 ,
4144 " Input tensor X must be at least 1-dimensional." 
4245 PADDLE_ENFORCE_GE (dim_y.size (), 1 ,
4346 " Input tensor Y must be at least 1-dimensional." 
4447
45-  std::vector<int64_t > out_dim;
46-  int64_t  batch_count = 1 ;
47-  if  (dim_x.size () > 3 ) {
48-  PADDLE_ENFORCE_EQ (
49-  dim_y.size (), dim_x.size (),
50-  " The dimensions of X and Y must be the same, and both of " 
51-  " them should be %d-dimensional." 
52-  dim_x.size ());
53- 
54-  //  The first rank-2 dimensions are accumulated on the batch_count, and the
55-  //  last two dimensions are used for matrix multiplication.
56-  for  (int  j = 0 ; j < dim_x.size () - 2 ; ++j) {
57-  PADDLE_ENFORCE_EQ (dim_y[j], dim_x[j],
58-  " The %d-th dimension of X and Y must be the same." 
59-  j);
60-  out_dim.push_back (dim_x[j]);
61-  batch_count *= dim_x[j];
48+  std::vector<int64_t > dim_out;
49+  if  (x_num_col_dims == 0  && x_num_col_dims == 0 ) {
50+  std::vector<int64_t > out_dim;
51+  int64_t  batch_count = 1 ;
52+  if  (dim_x.size () > 3 ) {
53+  PADDLE_ENFORCE_EQ (
54+  dim_y.size (), dim_x.size (),
55+  " The dimensions of X and Y must be the same, and both of " 
56+  " them should be %d-dimensional." 
57+  dim_x.size ());
58+ 
59+  //  The first rank-2 dimensions are accumulated on the batch_count,
60+  //  and the last two dimensions are used for matrix multiplication.
61+  for  (int  j = 0 ; j < dim_x.size () - 2 ; ++j) {
62+  PADDLE_ENFORCE_EQ (dim_y[j], dim_x[j],
63+  " The %d-th dimension of X and Y must be the same." 
64+  j);
65+  out_dim.push_back (dim_x[j]);
66+  batch_count *= dim_x[j];
67+  }
6268 }
63-  }
6469
65-  int  M = 0 , N = 0 , KX = 0 , KY = 0 , batchCountX = 0 , batchCountY = 0 ;
66-  bool  remove_initial_dim = false , remove_final_dim = false ;
70+    int  M = 0 , N = 0 , KX = 0 , KY = 0 , batchCountX = 0 , batchCountY = 0 ;
71+    bool  remove_initial_dim = false , remove_final_dim = false ;
6772
68-  switch  (dim_x.size ()) {
69-  case  1 :
70-  if  (transpose_x) {
71-  M = dim_x[0 ];
72-  KX = 1 ;
73-  } else  {
74-  M = 1 ;
75-  KX = dim_x[0 ];
76-  remove_initial_dim = true ;
77-  }
78-  break ;
79-  case  2 :
80-  M = transpose_x ? dim_x[1 ] : dim_x[0 ];
81-  KX = transpose_x ? dim_x[0 ] : dim_x[1 ];
82-  break ;
83-  case  3 :
84-  batchCountX = dim_x[0 ];
85-  M = transpose_x ? dim_x[2 ] : dim_x[1 ];
86-  KX = transpose_x ? dim_x[1 ] : dim_x[2 ];
87-  break ;
88-  default :
89-  batchCountX = batch_count;
90-  size_t  mat_s = dim_x.size () - 2 ;
91-  M = transpose_x ? dim_x[mat_s + 1 ] : dim_x[mat_s];
92-  KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1 ];
93-  break ;
94-  }
73+    switch  (dim_x.size ()) {
74+    case  1 :
75+    if  (transpose_x) {
76+    M = dim_x[0 ];
77+    KX = 1 ;
78+    } else  {
79+    M = 1 ;
80+    KX = dim_x[0 ];
81+    remove_initial_dim = true ;
82+    }
83+    break ;
84+    case  2 :
85+    M = transpose_x ? dim_x[1 ] : dim_x[0 ];
86+    KX = transpose_x ? dim_x[0 ] : dim_x[1 ];
87+    break ;
88+    case  3 :
89+    batchCountX = dim_x[0 ];
90+    M = transpose_x ? dim_x[2 ] : dim_x[1 ];
91+    KX = transpose_x ? dim_x[1 ] : dim_x[2 ];
92+    break ;
93+    default :
94+    batchCountX = batch_count;
95+    size_t  mat_s = dim_x.size () - 2 ;
96+    M = transpose_x ? dim_x[mat_s + 1 ] : dim_x[mat_s];
97+    KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1 ];
98+    break ;
99+    }
95100
96-  switch  (dim_y.size ()) {
97-  case  1 :
98-  if  (transpose_y) {
99-  N = dim_y[0 ];
100-  KY = 1 ;
101+  switch  (dim_y.size ()) {
102+  case  1 :
103+  if  (transpose_y) {
104+  N = dim_y[0 ];
105+  KY = 1 ;
106+  } else  {
107+  N = 1 ;
108+  KY = dim_y[0 ];
109+  remove_final_dim = true ;
110+  }
111+  break ;
112+  case  2 :
113+  KY = transpose_y ? dim_y[1 ] : dim_y[0 ];
114+  N = transpose_y ? dim_y[0 ] : dim_y[1 ];
115+  break ;
116+  case  3 :
117+  batchCountY = dim_y[0 ];
118+  KY = transpose_y ? dim_y[2 ] : dim_y[1 ];
119+  N = transpose_y ? dim_y[1 ] : dim_y[2 ];
120+  break ;
121+  default :
122+  batchCountY = batch_count;
123+  size_t  mat_s = dim_y.size () - 2 ;
124+  KY = transpose_y ? dim_y[mat_s + 1 ] : dim_y[mat_s];
125+  N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1 ];
126+  }
127+ 
128+  PADDLE_ENFORCE_EQ (
129+  KX, KY,
130+  " First matrix's width must be equal with second matrix's height." 
131+  if  (batchCountX && batchCountY) {
132+  PADDLE_ENFORCE_EQ (
133+  batchCountX, batchCountY,
134+  " When Input(X) and Input(Y) are both three dimensional, they " 
135+  " must have the same batch dimension." 
136+  }
137+  int  batchCount = std::max (batchCountX, batchCountY);
138+ 
139+  if  (batchCount) {
140+  if  (dim_x.size () > 3 ) {
141+  dim_out.insert (dim_out.begin (), out_dim.begin (), out_dim.end ());
101142 } else  {
102-  N = 1 ;
103-  KY = dim_y[0 ];
104-  remove_final_dim = true ;
143+  dim_out.push_back (batchCount);
105144 }
106-  break ;
107-  case  2 :
108-  KY = transpose_y ? dim_y[1 ] : dim_y[0 ];
109-  N = transpose_y ? dim_y[0 ] : dim_y[1 ];
110-  break ;
111-  case  3 :
112-  batchCountY = dim_y[0 ];
113-  KY = transpose_y ? dim_y[2 ] : dim_y[1 ];
114-  N = transpose_y ? dim_y[1 ] : dim_y[2 ];
115-  break ;
116-  default :
117-  batchCountY = batch_count;
118-  size_t  mat_s = dim_y.size () - 2 ;
119-  KY = transpose_y ? dim_y[mat_s + 1 ] : dim_y[mat_s];
120-  N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1 ];
121-  }
145+  }
146+  if  (!remove_initial_dim) {
147+  dim_out.push_back (M);
148+  }
149+  if  (!remove_final_dim) {
150+  dim_out.push_back (N);
151+  }
152+  if  (dim_out.size () == 0 ) {
153+  //  We don't support 0-dimensional Tensors (scalars), so instead
154+  //  treat the output as a Tensor of shape (1, ) in this case.
155+  dim_out.push_back (1 );
156+  }
157+  } else  {
158+  if  (x_num_col_dims == 0 ) {
159+  x_num_col_dims = 1 ;
160+  }
161+  if  (y_num_col_dims == 0 ) {
162+  y_num_col_dims = 1 ;
163+  }
164+  PADDLE_ENFORCE_GT (
165+  dim_x.size (), x_num_col_dims,
166+  " The input tensor X's rank of MulOp should be larger than " 
167+  " x_num_col_dims." 
168+  PADDLE_ENFORCE_GT (
169+  dim_x.size (), y_num_col_dims,
170+  " The input tensor Y's rank of MulOp should be larger than " 
171+  " y_num_col_dims." 
172+ 
173+  auto  x_mat_dims = framework::flatten_to_2d (dim_x, x_num_col_dims);
174+  auto  y_mat_dims = framework::flatten_to_2d (dim_y, y_num_col_dims);
122175
123-  PADDLE_ENFORCE_EQ (
124-  KX, KY,
125-  " First matrix's width must be equal with second matrix's height." 
126-  if  (batchCountX && batchCountY) {
127176 PADDLE_ENFORCE_EQ (
128-  batchCountX, batchCountY,
129-  " When Input(X) and Input(Y) are both three dimensional, they " 
130-  " must have the same batch dimension." 
131-  }
132-  int  batchCount = std::max (batchCountX, batchCountY);
177+  x_mat_dims[1 ], y_mat_dims[0 ],
178+  " First matrix's width must be equal with second matrix's height." 
133179
134-  std::vector<int64_t > dim_out;
135-  if  (batchCount) {
136-  if  (dim_x.size () > 3 ) {
137-  dim_out.insert (dim_out.begin (), out_dim.begin (), out_dim.end ());
138-  } else  {
139-  dim_out.push_back (batchCount);
180+  dim_out.reserve (
181+  static_cast <size_t >(x_num_col_dims + dim_y.size () - y_num_col_dims));
182+ 
183+  for  (int  i = 0 ; i < x_num_col_dims; ++i) {
184+  dim_out.push_back (dim_x[i]);
185+  }
186+ 
187+  for  (int  i = y_num_col_dims; i < dim_y.size (); ++i) {
188+  dim_out.push_back (dim_y[i]);
140189 }
141-  }
142-  if  (!remove_initial_dim) {
143-  dim_out.push_back (M);
144-  }
145-  if  (!remove_final_dim) {
146-  dim_out.push_back (N);
147-  }
148-  if  (dim_out.size () == 0 ) {
149-  //  We don't support 0-dimensional Tensors (scalars), so instead
150-  //  treat the output as a Tensor of shape (1, ) in this case.
151-  dim_out.push_back (1 );
152190 }
153191 context->SetOutputDim (" Out" framework::make_ddim (dim_out));
154192 context->ShareLoD (" X" /* ->*/ " Out" 
@@ -162,6 +200,37 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
162200 AddInput (" X" " The first input of MatMul op" 
163201 AddInput (" Y" " The second input of MatMul op" 
164202 AddOutput (" Out" " The output of MatMul op" 
203+  AddAttr<int >(
204+  " x_num_col_dims" 
205+  R"DOC( (int, default 0), The matmul_op can take tensors with more than two
206+  dimensions as its inputs. If the input $X$ is a tensor with more 
207+  than two dimensions, $X$ will be flattened into a two-dimensional 
208+  matrix first. The flattening rule is: the first `num_col_dims` 
209+  will be flattened to form the first dimension of the final matrix 
210+  (the height of the matrix), and the rest `rank(X) - num_col_dims` 
211+  dimensions are flattened to form the second dimension of the final 
212+  matrix (the width of the matrix). As a result, height of the 
213+  flattened matrix is equal to the product of $X$'s first 
214+  `x_num_col_dims` dimensions' sizes, and width of the flattened 
215+  matrix is equal to the product of $X$'s last `rank(x) - num_col_dims` 
216+  dimensions' size. For example, suppose $X$ is a 6-dimensional 
217+  tensor with the shape [2, 3, 4, 5, 6], and `x_num_col_dims` = 3. 
218+  Thus, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = 
219+  [24, 30]. The default value 0 indicates the input is a 2-D Matrix. 
220+  )DOC"  )
221+  .SetDefault (0 )
222+  .EqualGreaterThan (0 );
223+  AddAttr<int >(
224+  " y_num_col_dims" 
225+  R"DOC( (int, default 0), The matmul_op can take tensors with more than
226+  two, dimensions as its inputs. If the input $Y$ is a tensor with 
227+  more than two dimensions, $Y$ will be flattened into a 
228+  two-dimensional matrix first. The attribute `y_num_col_dims` 
229+  determines how $Y$ is flattened. 
230+  See comments of `x_num_col_dims` for more details. 
231+  )DOC"  )
232+  .SetDefault (0 )
233+  .EqualGreaterThan (0 );
165234 AddAttr<bool >(" transpose_X" 
166235 R"DOC( If true, use the transpose of `X`.
167236 )DOC"  )
0 commit comments