Skip to content

Conversation

@ZzSean
Copy link
Contributor

@ZzSean ZzSean commented Mar 10, 2022

PR types

Others

PR changes

Others

Describe

[Phi]Move kron kernel to phi

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

// limitations under the License.

#include "paddle/phi/kernels/impl/kron_grad_kernel_impl.h"
#include "paddle/phi/kernels/kron_grad_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

头文件放到第一行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

// limitations under the License.

#include "paddle/phi/kernels/impl/kron_kernel_impl.h"
#include "paddle/phi/kernels/kron_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

namespace phi {

namespace ops = paddle::operators;
namespace plat = paddle::platform;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以用phi::dtype 命名空间下的complex ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Comment on lines 21 to 22
namespace ops = paddle::operators;
namespace plat = paddle::platform;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在phi下最好还是不使用paddle::xxx相关namespace的别名,会增加后续替换的难度

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

const plat::complex<T>* dout_;
const plat::complex<T>* A_;
const plat::complex<T>* B_;
plat::complex<T>* dout_a_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plat->phi::dtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

p_dout_y = dout_y.data<T>();
}

plat::ForRange<Context> for_range(dev_ctx, numel);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用phi下的ForRange

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Comment on lines 239 to 244
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*ctx, dout_x, dx, kps::IdentityFunctor<T>(), {1}, stream);
}
if (dy) {
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1}, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorReduceImpl可以使用phi下的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Comment on lines 20 to 21
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_registry.h这里应该不需要了
for_range.h使用phi下的

#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce_op.cu.h可以使用paddle/phi/kernels/funcs/reduce_function.h代替

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Comment on lines 19 to 21
KernelSignature KronOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("kron", {"X", "Y"}, {}, {"Out"});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的前向ArgumentMapping看上去没有特殊case,感觉可以不写,试试直接使用默认的参数映射能不能work?

auto stream = dev_ctx.stream(); // it is a cuda device_context
auto* ctx = reinterpret_cast<const plat::CUDADeviceContext*>(&dev_ctx);
if (dx) {
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops::TensorReduceImpl 已迁移,这里可用 phi::funcs::ReduceKernel

*ctx, dout_x, dx, kps::IdentityFunctor<T>(), {1}, stream);
}
if (dy) {
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

#include <algorithm>
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi目录下用不到原来的op 注册头文件

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

引用这个 paddle/phi/kernels/funcs/for_range.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops::TensorReduceImpl 已迁移,这里可用 phi::funcs::ReduceKernel , 此头文件可以不用了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

p_shape_y = dim_y.Get();
#endif

paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用Phi下的ForRange替代

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Copy link
Contributor

@MingMingShangTian MingMingShangTian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZzSean ZzSean merged commit f181d47 into PaddlePaddle:develop Mar 15, 2022
@ZzSean ZzSean deleted the move_kron branch April 14, 2022 09:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants