Skip to content

Conversation

@ZHUI
Copy link
Collaborator

@ZHUI ZHUI commented Mar 2, 2022

PR types

Others

PR changes

OPs

Describe

Move segment_pool to phi.

@paddle-bot-old
Copy link

paddle-bot-old bot commented Mar 2, 2022

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h"
#include "paddle/phi/kernels/segment_pool_grad_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

segment_pool_grad_kernel.h建议放在开头

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/segment_pool_kernel_impl.h"
#include "paddle/phi/kernels/segment_pool_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));

auto index_type = paddle::framework::TransToProtoVarType(segment_ids.type());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要转成proto::VarType了,直接用DataType类型做判断就行


namespace phi {

using Tensor = DenseTensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using Tensor可以移除

length.Resize(phi::make_ddim({1}));
IndexT* length_data = dev_ctx.template HostAlloc<IndexT>(&length);

// IndexT* length_data = length.data<IndexT>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释可以去掉

const std::string& pooltype,
DenseTensor* out,
DenseTensor* summed_ids) {
auto index_type = paddle::framework::TransToProtoVarType(segment_ids.dtype());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用转ProtoVarType

Comment on lines 19 to 23
KernelSignature SegmentPoolOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"segment_pool", {"X", "SegmentIds"}, {"pooltype"}, {"Out", "SummedIds"});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里前向ArgumentMapping感觉可以不写,用默认的应该也能work

Copy link
Contributor

@Avin0323 Avin0323 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for unity_build_rule.cmake

@ZHUI ZHUI requested a review from XiaoguangHu01 March 10, 2022 02:32
Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for PADDLE_THROW

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZHUI ZHUI merged commit a07f19e into PaddlePaddle:develop Mar 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

6 participants