Skip to content

Commit 0601515

Browse files
add x_num_col_dims and y_num_col_dims
1 parent 02d5fd7 commit 0601515

File tree

3 files changed

+380
-201
lines changed

3 files changed

+380
-201
lines changed

paddle/operators/matmul_op.cc

Lines changed: 166 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -37,118 +37,156 @@ class MatMulOp : public framework::OperatorWithKernel {
3737
bool transpose_x = context->Attrs().Get<bool>("transpose_X");
3838
bool transpose_y = context->Attrs().Get<bool>("transpose_Y");
3939

40+
int x_num_col_dims = context->Attrs().Get<int>("x_num_col_dims");
41+
int y_num_col_dims = context->Attrs().Get<int>("y_num_col_dims");
42+
4043
PADDLE_ENFORCE_GE(dim_x.size(), 1,
4144
"Input tensor X must be at least 1-dimensional.");
4245
PADDLE_ENFORCE_GE(dim_y.size(), 1,
4346
"Input tensor Y must be at least 1-dimensional.");
4447

45-
std::vector<int64_t> out_dim;
46-
int64_t batch_count = 1;
47-
if (dim_x.size() > 3) {
48-
PADDLE_ENFORCE_EQ(
49-
dim_y.size(), dim_x.size(),
50-
"The dimensions of X and Y must be the same, and both of "
51-
"them should be %d-dimensional.",
52-
dim_x.size());
53-
54-
// The first rank-2 dimensions are accumulated on the batch_count, and the
55-
// last two dimensions are used for matrix multiplication.
56-
for (int j = 0; j < dim_x.size() - 2; ++j) {
57-
PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j],
58-
"The %d-th dimension of X and Y must be the same.",
59-
j);
60-
out_dim.push_back(dim_x[j]);
61-
batch_count *= dim_x[j];
48+
std::vector<int64_t> dim_out;
49+
if (x_num_col_dims == 0 && x_num_col_dims == 0) {
50+
std::vector<int64_t> out_dim;
51+
int64_t batch_count = 1;
52+
if (dim_x.size() > 3) {
53+
PADDLE_ENFORCE_EQ(
54+
dim_y.size(), dim_x.size(),
55+
"The dimensions of X and Y must be the same, and both of "
56+
"them should be %d-dimensional.",
57+
dim_x.size());
58+
59+
// The first rank-2 dimensions are accumulated on the batch_count,
60+
// and the last two dimensions are used for matrix multiplication.
61+
for (int j = 0; j < dim_x.size() - 2; ++j) {
62+
PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j],
63+
"The %d-th dimension of X and Y must be the same.",
64+
j);
65+
out_dim.push_back(dim_x[j]);
66+
batch_count *= dim_x[j];
67+
}
6268
}
63-
}
6469

65-
int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0;
66-
bool remove_initial_dim = false, remove_final_dim = false;
70+
int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0;
71+
bool remove_initial_dim = false, remove_final_dim = false;
6772

68-
switch (dim_x.size()) {
69-
case 1:
70-
if (transpose_x) {
71-
M = dim_x[0];
72-
KX = 1;
73-
} else {
74-
M = 1;
75-
KX = dim_x[0];
76-
remove_initial_dim = true;
77-
}
78-
break;
79-
case 2:
80-
M = transpose_x ? dim_x[1] : dim_x[0];
81-
KX = transpose_x ? dim_x[0] : dim_x[1];
82-
break;
83-
case 3:
84-
batchCountX = dim_x[0];
85-
M = transpose_x ? dim_x[2] : dim_x[1];
86-
KX = transpose_x ? dim_x[1] : dim_x[2];
87-
break;
88-
default:
89-
batchCountX = batch_count;
90-
size_t mat_s = dim_x.size() - 2;
91-
M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s];
92-
KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1];
93-
break;
94-
}
73+
switch (dim_x.size()) {
74+
case 1:
75+
if (transpose_x) {
76+
M = dim_x[0];
77+
KX = 1;
78+
} else {
79+
M = 1;
80+
KX = dim_x[0];
81+
remove_initial_dim = true;
82+
}
83+
break;
84+
case 2:
85+
M = transpose_x ? dim_x[1] : dim_x[0];
86+
KX = transpose_x ? dim_x[0] : dim_x[1];
87+
break;
88+
case 3:
89+
batchCountX = dim_x[0];
90+
M = transpose_x ? dim_x[2] : dim_x[1];
91+
KX = transpose_x ? dim_x[1] : dim_x[2];
92+
break;
93+
default:
94+
batchCountX = batch_count;
95+
size_t mat_s = dim_x.size() - 2;
96+
M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s];
97+
KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1];
98+
break;
99+
}
95100

96-
switch (dim_y.size()) {
97-
case 1:
98-
if (transpose_y) {
99-
N = dim_y[0];
100-
KY = 1;
101+
switch (dim_y.size()) {
102+
case 1:
103+
if (transpose_y) {
104+
N = dim_y[0];
105+
KY = 1;
106+
} else {
107+
N = 1;
108+
KY = dim_y[0];
109+
remove_final_dim = true;
110+
}
111+
break;
112+
case 2:
113+
KY = transpose_y ? dim_y[1] : dim_y[0];
114+
N = transpose_y ? dim_y[0] : dim_y[1];
115+
break;
116+
case 3:
117+
batchCountY = dim_y[0];
118+
KY = transpose_y ? dim_y[2] : dim_y[1];
119+
N = transpose_y ? dim_y[1] : dim_y[2];
120+
break;
121+
default:
122+
batchCountY = batch_count;
123+
size_t mat_s = dim_y.size() - 2;
124+
KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s];
125+
N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1];
126+
}
127+
128+
PADDLE_ENFORCE_EQ(
129+
KX, KY,
130+
"First matrix's width must be equal with second matrix's height.");
131+
if (batchCountX && batchCountY) {
132+
PADDLE_ENFORCE_EQ(
133+
batchCountX, batchCountY,
134+
"When Input(X) and Input(Y) are both three dimensional, they "
135+
"must have the same batch dimension.");
136+
}
137+
int batchCount = std::max(batchCountX, batchCountY);
138+
139+
if (batchCount) {
140+
if (dim_x.size() > 3) {
141+
dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end());
101142
} else {
102-
N = 1;
103-
KY = dim_y[0];
104-
remove_final_dim = true;
143+
dim_out.push_back(batchCount);
105144
}
106-
break;
107-
case 2:
108-
KY = transpose_y ? dim_y[1] : dim_y[0];
109-
N = transpose_y ? dim_y[0] : dim_y[1];
110-
break;
111-
case 3:
112-
batchCountY = dim_y[0];
113-
KY = transpose_y ? dim_y[2] : dim_y[1];
114-
N = transpose_y ? dim_y[1] : dim_y[2];
115-
break;
116-
default:
117-
batchCountY = batch_count;
118-
size_t mat_s = dim_y.size() - 2;
119-
KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s];
120-
N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1];
121-
}
145+
}
146+
if (!remove_initial_dim) {
147+
dim_out.push_back(M);
148+
}
149+
if (!remove_final_dim) {
150+
dim_out.push_back(N);
151+
}
152+
if (dim_out.size() == 0) {
153+
// We don't support 0-dimensional Tensors (scalars), so instead
154+
// treat the output as a Tensor of shape (1, ) in this case.
155+
dim_out.push_back(1);
156+
}
157+
} else {
158+
if (x_num_col_dims == 0) {
159+
x_num_col_dims = 1;
160+
}
161+
if (y_num_col_dims == 0) {
162+
y_num_col_dims = 1;
163+
}
164+
PADDLE_ENFORCE_GT(
165+
dim_x.size(), x_num_col_dims,
166+
"The input tensor X's rank of MulOp should be larger than "
167+
"x_num_col_dims.");
168+
PADDLE_ENFORCE_GT(
169+
dim_x.size(), y_num_col_dims,
170+
"The input tensor Y's rank of MulOp should be larger than "
171+
"y_num_col_dims.");
172+
173+
auto x_mat_dims = framework::flatten_to_2d(dim_x, x_num_col_dims);
174+
auto y_mat_dims = framework::flatten_to_2d(dim_y, y_num_col_dims);
122175

123-
PADDLE_ENFORCE_EQ(
124-
KX, KY,
125-
"First matrix's width must be equal with second matrix's height.");
126-
if (batchCountX && batchCountY) {
127176
PADDLE_ENFORCE_EQ(
128-
batchCountX, batchCountY,
129-
"When Input(X) and Input(Y) are both three dimensional, they "
130-
"must have the same batch dimension.");
131-
}
132-
int batchCount = std::max(batchCountX, batchCountY);
177+
x_mat_dims[1], y_mat_dims[0],
178+
"First matrix's width must be equal with second matrix's height.");
133179

134-
std::vector<int64_t> dim_out;
135-
if (batchCount) {
136-
if (dim_x.size() > 3) {
137-
dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end());
138-
} else {
139-
dim_out.push_back(batchCount);
180+
dim_out.reserve(
181+
static_cast<size_t>(x_num_col_dims + dim_y.size() - y_num_col_dims));
182+
183+
for (int i = 0; i < x_num_col_dims; ++i) {
184+
dim_out.push_back(dim_x[i]);
185+
}
186+
187+
for (int i = y_num_col_dims; i < dim_y.size(); ++i) {
188+
dim_out.push_back(dim_y[i]);
140189
}
141-
}
142-
if (!remove_initial_dim) {
143-
dim_out.push_back(M);
144-
}
145-
if (!remove_final_dim) {
146-
dim_out.push_back(N);
147-
}
148-
if (dim_out.size() == 0) {
149-
// We don't support 0-dimensional Tensors (scalars), so instead
150-
// treat the output as a Tensor of shape (1, ) in this case.
151-
dim_out.push_back(1);
152190
}
153191
context->SetOutputDim("Out", framework::make_ddim(dim_out));
154192
context->ShareLoD("X", /*->*/ "Out");
@@ -162,6 +200,37 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
162200
AddInput("X", "The first input of MatMul op");
163201
AddInput("Y", "The second input of MatMul op");
164202
AddOutput("Out", "The output of MatMul op");
203+
AddAttr<int>(
204+
"x_num_col_dims",
205+
R"DOC((int, default 0), The matmul_op can take tensors with more than two
206+
dimensions as its inputs. If the input $X$ is a tensor with more
207+
than two dimensions, $X$ will be flattened into a two-dimensional
208+
matrix first. The flattening rule is: the first `num_col_dims`
209+
will be flattened to form the first dimension of the final matrix
210+
(the height of the matrix), and the rest `rank(X) - num_col_dims`
211+
dimensions are flattened to form the second dimension of the final
212+
matrix (the width of the matrix). As a result, height of the
213+
flattened matrix is equal to the product of $X$'s first
214+
`x_num_col_dims` dimensions' sizes, and width of the flattened
215+
matrix is equal to the product of $X$'s last `rank(x) - num_col_dims`
216+
dimensions' size. For example, suppose $X$ is a 6-dimensional
217+
tensor with the shape [2, 3, 4, 5, 6], and `x_num_col_dims` = 3.
218+
Thus, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] =
219+
[24, 30]. The default value 0 indicates the input is a 2-D Matrix.
220+
)DOC")
221+
.SetDefault(0)
222+
.EqualGreaterThan(0);
223+
AddAttr<int>(
224+
"y_num_col_dims",
225+
R"DOC((int, default 0), The matmul_op can take tensors with more than
226+
two, dimensions as its inputs. If the input $Y$ is a tensor with
227+
more than two dimensions, $Y$ will be flattened into a
228+
two-dimensional matrix first. The attribute `y_num_col_dims`
229+
determines how $Y$ is flattened.
230+
See comments of `x_num_col_dims` for more details.
231+
)DOC")
232+
.SetDefault(0)
233+
.EqualGreaterThan(0);
165234
AddAttr<bool>("transpose_X",
166235
R"DOC(If true, use the transpose of `X`.
167236
)DOC")

0 commit comments

Comments
 (0)