Skip to content

Commit 99c6f44

Browse files
committed
follow comments
1 parent dcb3da5 commit 99c6f44

File tree

8 files changed

+90
-104
lines changed

8 files changed

+90
-104
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
128128
op_library(sum_op DEPS net_op)
129129
op_library(pool_op DEPS pooling)
130130
op_library(pool_with_index_op DEPS pooling)
131-
op_library(sequence_conv_op DEPS sequence_project)
131+
op_library(sequence_conv_op DEPS context_project)
132132
op_library(lstm_op DEPS sequence2batch lstm_compute)
133133

134134
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})

paddle/operators/math/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ if(WITH_GPU)
99
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
1010
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
1111
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
12-
nv_library(sequence_project SRCS sequence_project.cc sequence_project.cu DEPS device_context)
12+
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
1313
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
1414
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
1515
else()
@@ -19,7 +19,7 @@ else()
1919
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
2020
cc_library(pooling SRCS pooling.cc DEPS device_context)
2121
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
22-
cc_library(sequence_project SRCS sequence_project.cc DEPS device_context)
22+
cc_library(context_project SRCS context_project.cc DEPS device_context)
2323
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
2424
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
2525
endif()

paddle/operators/math/sequence_project.cc renamed to paddle/operators/math/context_project.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/operators/math/sequence_project.h"
15+
#include "paddle/operators/math/context_project.h"
1616

1717
namespace paddle {
1818
namespace operators {
1919
namespace math {
2020

21-
template class SequenceProjectFunctor<platform::CPUPlace, float>;
22-
template class SequenceProjectFunctor<platform::CPUPlace, double>;
21+
template class ContextProjectFunctor<platform::CPUPlace, float>;
22+
template class ContextProjectFunctor<platform::CPUPlace, double>;
2323

2424
} // namespace math
2525
} // namespace operators

paddle/operators/math/sequence_project.cu renamed to paddle/operators/math/context_project.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ limitations under the License. */
1414

1515
#define EIGEN_USE_GPU
1616

17-
#include "paddle/operators/math/sequence_project.h"
17+
#include "paddle/operators/math/context_project.h"
1818

1919
namespace paddle {
2020
namespace operators {
2121
namespace math {
2222

23-
template class SequenceProjectFunctor<platform::GPUPlace, float>;
24-
template class SequenceProjectFunctor<platform::GPUPlace, double>;
23+
template class ContextProjectFunctor<platform::GPUPlace, float>;
24+
template class ContextProjectFunctor<platform::GPUPlace, double>;
2525

2626
} // namespace math
2727
} // namespace operators

paddle/operators/math/sequence_project.h renamed to paddle/operators/math/context_project.h

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,29 @@ namespace paddle {
2323
namespace operators {
2424
namespace math {
2525

26-
// template <typename T, int MajorType = Eigen::RowMajor,
27-
// typename IndexType = Eigen::DenseIndex>
28-
// using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
29-
3026
template <typename T, int MajorType = Eigen::RowMajor,
3127
typename IndexType = Eigen::DenseIndex>
3228
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
3329
/*
34-
* \brief SequenceProject projects features of context_length time-steps of each
35-
* instance.
36-
*
30+
* \brief Context projection concatenate features in adjacent time steps in
31+
* a sequence. The i-th row of the output is the concatenation of
32+
* context_length rows of the input. The context_length rows are the
33+
* consecutive rows from the i+shift_start row.
34+
3735
* \param in Input data.
38-
* \param inShape The shape of Input data,
36+
* \param Shape The shape of Input data,
3937
* [minibatch, number_of_input_features].
40-
* \param inShape A float LoDTensor.
38+
* \param type A float LoDTensor.
4139
*
4240
* \param padding_data Padding data.
43-
* \param inShape The shape of Padding data,
41+
* \param Shape The shape of Padding data,
4442
* [up_pad + down_pad, number_of_input_features].
45-
* \param inShape A float LoDTensor.
43+
* \param type A float Tensor.
4644
*
4745
* \param col Col data.
48-
* \param inShape The shape of Col data,
49-
* [minibatch, 1].
50-
* \param inShape A float LoDTensor.
46+
* \param Shape The shape of Col data,
47+
* [minibatch, context_length * number_of_input_features].
48+
* \param type A float Tensor.
5149
*
5250
* For a mini-batch of 2 variable lengths sentences, containing 3, and 1
5351
* time-steps:
@@ -87,7 +85,7 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
8785
*/
8886

8987
template <typename Place, typename T>
90-
class SequenceProjectFunctor {
88+
class ContextProjectFunctor {
9189
public:
9290
void operator()(const platform::DeviceContext& context,
9391
framework::LoDTensor& in, framework::Tensor& padding_data,
@@ -147,8 +145,7 @@ class SequenceProjectFunctor {
147145
/*stride_height*/ context_stride, /*stride_width*/ 1,
148146
up_pad, down_pad, 0, 0);
149147
}
150-
out_t.Resize(framework::make_ddim(
151-
{sequence_height, context_length * sequence_width}));
148+
out_t.Resize({sequence_height, context_length * sequence_width});
152149
}
153150
}
154151
}
@@ -162,8 +159,7 @@ class SequenceProjectFunctor {
162159
sequence_height = static_cast<int>(out_t.dims()[0]);
163160

164161
// add up trainable data
165-
out_t.Resize(framework::make_ddim(
166-
{sequence_height * context_length, sequence_width}));
162+
out_t.Resize({sequence_height * context_length, sequence_width});
167163

168164
if (up_pad > 0) { // add up pad
169165
int padding_rows = std::min(
@@ -223,8 +219,7 @@ class SequenceProjectFunctor {
223219
}
224220
}
225221
}
226-
out_t.Resize(framework::make_ddim(
227-
{sequence_height, context_length * sequence_width}));
222+
out_t.Resize({sequence_height, context_length * sequence_width});
228223
}
229224
}
230225
}

paddle/operators/sequence_conv_op.cc

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ class SequenceConvOp : public framework::OperatorWithKernel {
3838
auto filter_dims = ctx->GetInputDim("Filter");
3939
PADDLE_ENFORCE(in_dims.size() == 2 && filter_dims.size() == 2,
4040
"Input(X, Filter) should be 2-D tensor.");
41-
PADDLE_ENFORCE(
42-
filter_dims[0] == context_length && filter_dims[1] == in_dims[1],
43-
"Filter's shape should be (context_length x "
44-
"number_of_input_features).");
41+
PADDLE_ENFORCE(filter_dims[0] == context_length * in_dims[1],
42+
"Filter's height should be context_length * "
43+
"number_of_input_features .");
4544

4645
if (padding_trainable) {
4746
PADDLE_ENFORCE(
@@ -66,8 +65,9 @@ class SequenceConvOp : public framework::OperatorWithKernel {
6665
"and 'context_length'.");
6766
}
6867

69-
in_dims[1] = 1;
68+
in_dims[1] = filter_dims[1];
7069
ctx->SetOutputDim("Out", in_dims);
70+
ctx->ShareLoD("X", "Out");
7171
}
7272
};
7373

@@ -101,35 +101,51 @@ class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
101101
SequenceConvOpMaker(framework::OpProto* proto,
102102
framework::OpAttrChecker* op_checker)
103103
: OpProtoAndCheckerMaker(proto, op_checker) {
104-
AddInput("X",
105-
"(A float LoDTensor) the input of SequenceConvOp, a vector of "
106-
"2-D matrix of size (minibatch, number_of_input_features).");
104+
AddInput(
105+
"X",
106+
"(LoDTensor) the input(X) is a LodTensor, which support "
107+
"variable-time length input sequence. The underlying tensor in "
108+
"this LoDTensor is a matrix with shape (T, D), where, T is the "
109+
"total time steps in this mini-batch, D is the input feature size.");
107110
AddInput("PaddingData",
108-
"(Tensor) the input of SequenceConvOp, a vector of "
109-
"2-D matrix of size (up_pad + down_pad, "
110-
"number_of_input_features). ")
111+
"(Tensor, optional) the input(PaddingData) is an optional "
112+
"parameter, and it is learnable. "
113+
"This is a tensor with shape (N, D), where N is the "
114+
"top_pad + bottom_pad, D is the input feature size. In order to "
115+
"ensure the equal length of sequence before and after "
116+
"convolution, it is necessary to fill the top and bottom of each "
117+
"sequence according to context_length, context_stride and "
118+
"context_start")
111119
.AsDispensable();
112120
AddInput("Filter",
113-
"(Tensor) the input of SequenceConvOp, a vector of "
114-
"2-D matrix of size (context_length x number_of_input_features).");
115-
AddOutput("Out",
116-
"(A float LoDTensor) the output of SequenceConvOp, a vector "
117-
"of 2-D matrix of size (minibatch, 1).");
121+
"(Tensor) the input(Filter) is an learnable parameter."
122+
"This is a tensor with shape (N, D), where N is the "
123+
"context_length, D is the output feature size.");
124+
AddOutput(
125+
"Out",
126+
"(LoDTensor) the output(Out) is a LodTensor, which support "
127+
"variable-time length output sequence. The underlying tensor in "
128+
"this LoDTensor is a matrix with shape (T, D), where, T is the "
129+
"total time steps in this mini-batch, D is the output feature size.");
118130

119131
AddAttr<bool>("padding_trainable",
120132
"(bool, default false) the padding data of SequenceConvOp "
121133
"is trainable or not.")
122134
.SetDefault(false);
123135
AddAttr<int>("context_length",
124-
"(int, default 3) the context_length of SequenceConvOp.")
136+
"(int, default 3) the context_length of SequenceConvOp is the "
137+
"height of the convolution kernel.")
125138
.SetDefault(3)
126139
.GreaterThan(0);
127140
AddAttr<int>("context_start",
128-
"(int, default 0) the context_start of SequenceConvOp.")
141+
"(int, default 0) the context_start of SequenceConvOp "
142+
"represents the beginning of the convolution of the number of "
143+
"rows of sequence, which can be negative.")
129144
.SetDefault(0);
130145
AddAttr<int>("context_stride",
131-
"(int, default 1) the context_stride of SequenceConvOp. "
132-
"Currently, sequence_project_op only support "
146+
"(int, default 1) the context_stride of SequenceConvOp "
147+
"represents the step length of convolution. "
148+
"Currently, SequenceConvOp only supports"
133149
"context_stride=1.")
134150
.SetDefault(1)
135151
.GreaterThan(0);
@@ -139,14 +155,10 @@ class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
139155
context_length time-steps of each instance.
140156
The convolution operation calculates the output based on the input, filter
141157
and strides, paddings parameters. The size of each dimension of the
142-
parameters is checked in the infer-shape.
143-
144-
Example:
145-
Input:
146-
X shape: (minibatch, number_of_input_features)
147-
Filter shape: (context_length, number_of_input_features)
148-
Output:
149-
Out shape: (minibatch, 1)
158+
parameters is checked in the infer-shape. In order to ensure the equal
159+
length of sequence before and after convolution, it is necessary to fill
160+
the top and bottom of each sequence according to context_length,
161+
context_stride and context_start.
150162
)DOC");
151163
}
152164
};

0 commit comments

Comments
 (0)