- Notifications
You must be signed in to change notification settings - Fork 5.9k
Add sequence_conv_op and sequence_projection functor #4814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sequence_conv_op and sequence_projection functor #4814
Conversation
1faad45 to 4de6294 Compare 4de6294 to 1e60c9b Compare 40688d2 to 4b0ec8f Compare 4b0ec8f to 834b82f Compare 8d6f296 to 6d375e5 Compare bf2feb2 to b0092ea Compare dd4a738 to 5cd8a9a Compare 5cd8a9a to ce96057 Compare | Because seq_project is only used in seq_conv, seq_project should be written in functor form. |
f2da6c2 to c2eb73e Compare 8d63828 to 6ce31f6 Compare 932e0f7 to 4c6bccb Compare | I think we can merge it first and review the code the same time. @chengduoZH Please continue to polish the code based on the comments. And, please split PR into small ones. Such a big PR will take a long time to review. |
| @dzhwinter Ok! |
bcdaae5 to dcb3da5 Compare | * \param col Col data. | ||
| * \param inShape The shape of Col data, | ||
| * [minibatch, 1]. | ||
| * \param inShape A float LoDTensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are so many inShape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
| * \param inShape A float LoDTensor. | ||
| * | ||
| * For a mini-batch of 2 variable lengths sentences, containing 3, and 1 | ||
| * time-steps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 34 says this function is used for one sequence, but the example here has variable lengths sentences. Please to keep consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| sequence_width}); // output_height, output_width, | ||
| // input_channels, filter_height, filter_width | ||
| | ||
| out_t.Resize(framework::make_ddim(output_shape)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove the framework::make_ddim, since the std::vector can be automatically converted to DDim, the same below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/sequence_conv_op.cc Outdated
| PADDLE_ENFORCE( | ||
| filter_dims[0] == context_length && filter_dims[1] == in_dims[1], | ||
| "Filter's shape should be (context_length x " | ||
| "number_of_input_features)."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The filter shape is not right.
假如:context_length = 3, 输入hidden size = D, 输出的hidden size = H
Filter: [3D, H]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| } | ||
| | ||
| in_dims[1] = 1; | ||
| ctx->SetOutputDim("Out", in_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output shape is not right.
依据上面假设输出dims[1] = H。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also should set LoD for output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/sequence_conv_op.h Outdated
| // Because if padding_trainable is false, padding data should be zeros. | ||
| auto temp = framework::EigenVector<T>::Flatten(col); | ||
| temp.device(context.GetEigenDevice<Place>()) = | ||
| temp.constant(static_cast<T>(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调用math:: SetConstant置零: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/math/math_function.h#L97
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/sequence_conv_op.h Outdated
| | ||
| filter.Resize(framework::make_ddim({context_length * sequence_width, 1})); | ||
| math::matmul<Place, T>(context.device_context(), col, false, filter, false, | ||
| T(1.0), out, T(0.0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
T(1.0) -> static_cast<T>(1.0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/sequence_conv_op.h Outdated
| // Because if padding_trainable is false, padding data should be zeros. | ||
| auto temp = framework::EigenVector<T>::Flatten(col); | ||
| temp.device(context.GetEigenDevice<Place>()) = | ||
| temp.constant(static_cast<T>(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调用math:: SetConstant置零: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/math/math_function.h#L97
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/sequence_conv_op.h Outdated
| functor(context.device_context(), filter_g, 0); | ||
| | ||
| Tensor filter_grad_ = *filter_g; | ||
| LoDTensor out_grad_ = *out_g; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_grad_ -> out_grad
| output_dim = self.outputs['Out'].shape | ||
| filter.shape = filter_dim[0] * filter_dim[1] | ||
| self.outputs['Out'].shape = (output_dim[0], ) | ||
| np.dot(out, filter, out=self.outputs['Out']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python单测forward实现,觉得避免和C++ Code一致,避免采用先展开后矩阵乘的形式,可以是Conv原本实现形式。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python单测是根据之前paddle改写过来的,context_project_functor是先经过im2col再通过矩阵乘得到的,这两种方式并不太一样
5e60d24 to 4ff4f0f Compare 4ff4f0f to 99c6f44 Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the Python API needs this op, approve it. But still need to modify later.
| framework::Tensor& col, bool padding_trainable, | ||
| int context_start, int context_length, int context_stride, | ||
| int up_pad, int down_pad, bool gradient, bool input_grad, | ||
| bool pad_grad) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
觉得将projection和un-projection的过程混合在一起,代码逻辑不够清晰。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
分开写也是可以的,不过显得代码有点冗余,我再想想办法
| * \param in Input data. | ||
| * \param Shape The shape of Input data, | ||
| * [minibatch, number_of_input_features]. | ||
| * \param type A float LoDTensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the type, there is no meaning here.
The argument type in the following function is clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. #5130
| | ||
| * \param in Input data. | ||
| * \param Shape The shape of Input data, | ||
| * [minibatch, number_of_input_features]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
number_of_input_features -> input_hidden_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. #5130
| "this LoDTensor is a matrix with shape (T, D), where, T is the " | ||
| "total time steps in this mini-batch, D is the output feature size."); | ||
| | ||
| AddAttr<bool>("padding_trainable", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddingTrainable, please to see our name convention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. #5130
| "(bool, default false) the padding data of SequenceConvOp " | ||
| "is trainable or not.") | ||
| .SetDefault(false); | ||
| AddAttr<int>("context_length", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contextLength
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. #5130
| "height of the convolution kernel.") | ||
| .SetDefault(3) | ||
| .GreaterThan(0); | ||
| AddAttr<int>("context_start", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contextStart
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. #5130
| "represents the beginning of the convolution of the number of " | ||
| "rows of sequence, which can be negative.") | ||
| .SetDefault(0); | ||
| AddAttr<int>("context_stride", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contextStride
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. #5130
| del idx[0] | ||
| self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() + | ||
| [self.input_size[0]]] | ||
| self.output_represention = 8 # output feature size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need unit testing for the case self.context_stride > 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, seq_conv_op only supports self.context_stride = 1.
fix #4899
fix #5045