Skip to content

Commit ea4d08d

Browse files
author
xutianbing
committed
update interface of context projection functions, Tensor -> Matrix/Vector
1 parent 2c37ad7 commit ea4d08d

File tree

6 files changed

+207
-181
lines changed

6 files changed

+207
-181
lines changed

paddle/function/ContextProjectionOp.cpp

Lines changed: 110 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,17 @@ limitations under the License. */
1919
namespace paddle {
2020

2121
template <>
22-
void ContextProjectionForward<DEVICE_TYPE_CPU>(Tensor& output,
23-
const Tensor& input,
24-
const Tensor& weight,
25-
const Tensor& sequence,
22+
void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix* out_mat,
23+
const CpuMatrix* input_mat,
24+
const CpuMatrix* weight_mat,
25+
const CpuIVector& seq_vec,
2626
size_t context_length,
2727
int context_start,
28-
size_t begin_pad,
29-
bool is_padding) {
30-
CHECK(output.getData() && input.getData() && sequence.getData());
31-
CHECK_EQ(output.dims_.size(), 2);
32-
CHECK_EQ(input.dims_.size(), 2);
33-
CHECK_EQ(weight.dims_.size(), 2);
34-
CHECK_EQ(sequence.dims_.size(), 1);
35-
36-
auto out_mat = std::make_shared<CpuMatrix>(
37-
output.getData(), output.dims_[0], output.dims_[1]);
38-
const auto in_mat = std::make_shared<CpuMatrix>(
39-
input.getData(), input.dims_[0], input.dims_[1]);
40-
const auto weight_mat =
41-
!weight.getData()
42-
? nullptr
43-
: std::make_shared<CpuMatrix>(
44-
weight.getData(), weight.dims_[0], weight.dims_[1]);
45-
CpuIVector seq_vec(sequence.dims_[0],
46-
reinterpret_cast<int*>(sequence.getData()));
47-
CHECK_EQ(out_mat->getWidth(), in_mat->getWidth() * context_length);
48-
28+
size_t begin_pad) {
4929
const int* starts = seq_vec.getData();
5030
const size_t num_sequences = seq_vec.getSize() - 1;
31+
auto w_mat = const_cast<CpuMatrix*>(weight_mat);
32+
auto in_mat = const_cast<CpuMatrix*>(input_mat);
5133
for (size_t i = 0; i < num_sequences; ++i) {
5234
for (size_t j = 0; j < context_length; ++j) {
5335
int begin = starts[i] + context_start + j;
@@ -58,8 +40,8 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(Tensor& output,
5840
int64_t pad_size =
5941
std::min(starts[i] - begin, starts[i + 1] - starts[i]);
6042
MatrixPtr mat = out_mat->subMatrix(starts[i], pad_size);
61-
if (is_padding && weight_mat) {
62-
MatrixPtr sub = weight_mat->subMatrix(j, pad_size);
43+
if (w_mat) {
44+
MatrixPtr sub = w_mat->subMatrix(j, pad_size);
6345
mat->addAtOffset(*sub, j * in_mat->getWidth());
6446
}
6547
dst_begin = starts[i] + pad_size;
@@ -69,8 +51,8 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(Tensor& output,
6951
int64_t pad_size =
7052
std::min(end - starts[i + 1], starts[i + 1] - starts[i]);
7153
MatrixPtr mat = out_mat->subMatrix(starts[i + 1] - pad_size, pad_size);
72-
if (is_padding && weight_mat) {
73-
MatrixPtr sub = weight_mat->subMatrix(
54+
if (w_mat) {
55+
MatrixPtr sub = w_mat->subMatrix(
7456
begin_pad + context_start + j - pad_size, pad_size);
7557
mat->addAtOffset(*sub, j * in_mat->getWidth());
7658
}
@@ -98,7 +80,6 @@ class ContextProjectionForwardFunc : public FunctionBase {
9880
context_length_ = config.get<size_t>("context_length");
9981
context_start_ = config.get<int>("context_start");
10082
begin_pad_ = config.get<size_t>("begin_pad");
101-
is_padding_ = config.get<bool>("is_padding");
10283
}
10384

10485
void calc(const Arguments& inputs,
@@ -108,59 +89,58 @@ class ContextProjectionForwardFunc : public FunctionBase {
10889
CHECK_EQ(1, outputs.size());
10990
CHECK_EQ(0, inouts.size());
11091

111-
ContextProjectionForward<Device>((Tensor&)outputs[0],
112-
inputs[0],
113-
inputs[1],
114-
inputs[2],
92+
CHECK(outputs[0].getData() && inputs[0].getData() && inputs[2].getData());
93+
CHECK_EQ(outputs[0].dims_.size(), 2);
94+
CHECK_EQ(inputs[0].dims_.size(), 2);
95+
CHECK_EQ(inputs[1].dims_.size(), 2);
96+
CHECK_EQ(inputs[2].dims_.size(), 1);
97+
/// dim of output = dim of input * context_length
98+
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
99+
/// dim of input == dim of weight
100+
CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]);
101+
/// input and output has the same batch_size
102+
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
103+
104+
auto out_mat = std::make_shared<typename MatrixT<Device>::type>(
105+
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
106+
const auto in_mat = std::make_shared<typename MatrixT<Device>::type>(
107+
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
108+
const auto w_mat =
109+
!inputs[1].getData()
110+
? nullptr
111+
: std::make_shared<typename MatrixT<Device>::type>(
112+
inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]);
113+
typename SequenceT<Device>::type seq_vec(
114+
inputs[2].dims_[0], reinterpret_cast<int*>(inputs[2].getData()));
115+
116+
ContextProjectionForward<Device>(out_mat.get(),
117+
in_mat.get(),
118+
w_mat.get(),
119+
seq_vec,
115120
context_length_,
116121
context_start_,
117-
begin_pad_,
118-
is_padding_);
122+
begin_pad_);
119123
}
120124

121125
private:
122126
size_t context_length_;
123127
int context_start_;
124128
size_t begin_pad_;
125-
bool is_padding_;
126129
};
127130

128131
template <>
129-
void ContextProjectionBackward<DEVICE_TYPE_CPU>(Tensor& out_grad,
130-
Tensor& in_grad,
131-
Tensor& w_grad,
132-
const Tensor& sequence,
132+
void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
133+
CpuMatrix* in_grad_mat,
134+
CpuMatrix* w_grad_mat,
135+
const CpuIVector& seq_vec,
133136
size_t context_length,
134137
int context_start,
135138
size_t begin_pad,
136139
bool is_padding,
137140
size_t total_pad) {
138-
CHECK(out_grad.getData() && sequence.getData());
139-
CHECK_EQ(out_grad.dims_.size(), 2);
140-
CHECK_EQ(in_grad.dims_.size(), 2);
141-
CHECK_EQ(w_grad.dims_.size(), 2);
142-
CHECK_EQ(sequence.dims_.size(), 1);
143-
144-
auto out_grad_mat = std::make_shared<CpuMatrix>(
145-
out_grad.getData(), out_grad.dims_[0], out_grad.dims_[1]);
146-
const auto in_grad_mat =
147-
!in_grad.getData()
148-
? nullptr
149-
: std::make_shared<CpuMatrix>(
150-
in_grad.getData(), in_grad.dims_[0], in_grad.dims_[1]);
151-
const auto w_grad_mat =
152-
!w_grad.getData()
153-
? nullptr
154-
: std::make_shared<CpuMatrix>(
155-
w_grad.getData(), w_grad.dims_[0], w_grad.dims_[1]);
156-
CpuIVector seq_vec(sequence.dims_[0],
157-
reinterpret_cast<int*>(sequence.getData()));
158-
CHECK_EQ(out_grad_mat->getWidth(), in_grad_mat->getWidth() * context_length);
159-
141+
CHECK(out_grad_mat);
160142
size_t input_dim = in_grad_mat ? in_grad_mat->getWidth()
161143
: w_grad_mat ? w_grad_mat->getWidth() : 0;
162-
CHECK_EQ(out_grad_mat->getWidth(), input_dim * context_length);
163-
164144
const int* starts = seq_vec.getData();
165145
size_t num_sequences = seq_vec.getSize() - 1;
166146
for (size_t i = 0; i < num_sequences; ++i) {
@@ -226,10 +206,38 @@ class ContextProjectionBackwardFunc : public FunctionBase {
226206
CHECK_EQ(1, outputs.size());
227207
CHECK_EQ(0, inouts.size());
228208

229-
ContextProjectionBackward<Device>((Tensor&)outputs[0],
230-
(Tensor&)inputs[0],
231-
(Tensor&)inputs[1],
232-
inputs[2],
209+
CHECK(outputs[0].getData() && inputs[2].getData());
210+
CHECK_EQ(outputs[0].dims_.size(), 2);
211+
CHECK_EQ(inputs[0].dims_.size(), 2);
212+
CHECK_EQ(inputs[1].dims_.size(), 2);
213+
CHECK_EQ(inputs[2].dims_.size(), 1);
214+
215+
/// dim of input == dim of weight
216+
CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]);
217+
/// input and output has the same batch_size
218+
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
219+
/// dim of output = dim of input * context_length
220+
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
221+
222+
auto out_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
223+
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
224+
auto in_grad_mat =
225+
!inputs[0].getData()
226+
? nullptr
227+
: std::make_shared<typename MatrixT<Device>::type>(
228+
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
229+
auto w_grad_mat =
230+
!inputs[1].getData()
231+
? nullptr
232+
: std::make_shared<typename MatrixT<Device>::type>(
233+
inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]);
234+
typename SequenceT<Device>::type seq_vec(
235+
inputs[2].dims_[0], reinterpret_cast<int*>(inputs[2].getData()));
236+
237+
ContextProjectionBackward<Device>(out_grad_mat.get(),
238+
in_grad_mat ? in_grad_mat.get() : nullptr,
239+
w_grad_mat ? w_grad_mat.get() : nullptr,
240+
seq_vec,
233241
context_length_,
234242
context_start_,
235243
begin_pad_,
@@ -264,10 +272,24 @@ class ContextProjectionBackwardDataFunc : public FunctionBase {
264272
CHECK_EQ(2, inputs.size());
265273
CHECK_EQ(1, outputs.size());
266274
CHECK_EQ(0, inouts.size());
275+
CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData());
276+
CHECK_EQ(outputs[0].dims_.size(), 2);
277+
CHECK_EQ(inputs[0].dims_.size(), 2);
278+
CHECK_EQ(inputs[1].dims_.size(), 1);
279+
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
280+
/// input and output has the same batch_size
281+
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
267282

268-
ContextProjectionBackwardData<Device>((Tensor&)outputs[0],
269-
(Tensor&)inputs[0],
270-
inputs[1],
283+
auto out_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
284+
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
285+
const auto in_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
286+
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
287+
typename SequenceT<Device>::type seq_vec(
288+
inputs[1].dims_[0], reinterpret_cast<int*>(inputs[1].getData()));
289+
290+
ContextProjectionBackwardData<Device>(out_grad_mat.get(),
291+
in_grad_mat.get(),
292+
seq_vec,
271293
context_length_,
272294
context_start_);
273295
}
@@ -299,9 +321,22 @@ class ContextProjectionBackwardWeightFunc : public FunctionBase {
299321
CHECK_EQ(1, outputs.size());
300322
CHECK_EQ(0, inouts.size());
301323

302-
ContextProjectionBackwardWeight<Device>((Tensor&)outputs[0],
303-
(Tensor&)inputs[0],
304-
inputs[1],
324+
CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData());
325+
CHECK_EQ(outputs[0].dims_.size(), 2);
326+
CHECK_EQ(inputs[0].dims_.size(), 2);
327+
CHECK_EQ(inputs[1].dims_.size(), 1);
328+
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
329+
330+
auto out_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
331+
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
332+
auto w_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
333+
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
334+
typename SequenceT<Device>::type seq_vec(
335+
inputs[1].dims_[0], reinterpret_cast<int*>(inputs[1].getData()));
336+
337+
ContextProjectionBackwardWeight<Device>(out_grad_mat.get(),
338+
w_grad_mat.get(),
339+
seq_vec,
305340
context_length_,
306341
context_start_,
307342
total_pad_,

paddle/function/ContextProjectionOp.h

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,13 @@ namespace paddle {
3232
*
3333
*/
3434
template <DeviceType Device>
35-
void ContextProjectionForward(Tensor& output,
36-
const Tensor& input,
37-
const Tensor& weight,
38-
const Tensor& sequence,
35+
void ContextProjectionForward(typename MatrixT<Device>::type* output,
36+
const typename MatrixT<Device>::type* input,
37+
const typename MatrixT<Device>::type* weight,
38+
const typename SequenceT<Device>::type& sequence,
3939
size_t context_length,
4040
int context_start,
41-
size_t begin_pad,
42-
bool is_padding);
41+
size_t begin_pad);
4342

4443
/**
4544
* \brief Context Projection Backward.
@@ -55,30 +54,32 @@ void ContextProjectionForward(Tensor& output,
5554
*
5655
*/
5756
template <DeviceType Device>
58-
void ContextProjectionBackward(Tensor& out_grad,
59-
Tensor& in_grad,
60-
Tensor& w_grad,
61-
const Tensor& sequence,
57+
void ContextProjectionBackward(typename MatrixT<Device>::type* out_grad,
58+
typename MatrixT<Device>::type* in_grad,
59+
typename MatrixT<Device>::type* w_grad,
60+
const typename SequenceT<Device>::type& seq_vec,
6261
size_t context_length,
6362
int context_start,
6463
size_t begin_pad,
6564
bool is_padding,
6665
size_t total_pad);
6766

6867
template <DeviceType Device>
69-
void ContextProjectionBackwardData(Tensor& out_grad,
70-
Tensor& in_grad,
71-
const Tensor& sequence,
72-
size_t context_length,
73-
int context_start);
68+
void ContextProjectionBackwardData(
69+
typename MatrixT<Device>::type* out_grad,
70+
typename MatrixT<Device>::type* in_grad,
71+
const typename SequenceT<Device>::type& sequence,
72+
size_t context_length,
73+
int context_start);
7474

7575
template <DeviceType Device>
76-
void ContextProjectionBackwardWeight(Tensor& out_grad,
77-
Tensor& w_grad,
78-
const Tensor& sequence,
79-
size_t context_length,
80-
int context_start,
81-
size_t total_pad,
82-
size_t begin_pad);
76+
void ContextProjectionBackwardWeight(
77+
typename MatrixT<Device>::type* out_grad,
78+
typename MatrixT<Device>::type* w_grad,
79+
const typename SequenceT<Device>::type& seq_vec,
80+
size_t context_length,
81+
int context_start,
82+
size_t total_pad,
83+
size_t begin_pad);
8384

8485
} // namespace paddle

0 commit comments

Comments
 (0)