Skip to content

Conversation

@sneaxiy
Copy link
Collaborator

@sneaxiy sneaxiy commented Dec 20, 2021

PR types

Others

PR changes

OPs

Describe

Support FP16 for mean and reduce_mean ops.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

};

template <typename T>
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
Copy link
Contributor

@AnnaTrainingG AnnaTrainingG Dec 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前mean 和reducemean调用的都是pten里面的reduce,由chen tian yu进行修改。可以确认一下修改是否起作用了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已打log确认,确实能调到,因为PTen没有FP16的注册。后续PR会在PTen里添加注册。

template <typename Tx, typename Ty = Tx>
struct CustomMean {
using Transformer = kps::DivideFunctor<Tx>;
using Transformer = kps::DivideFunctor<Tx, Ty>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在最新的reduce_op.cu.h中是没用到CustomMean的,本来计划删除,但是pten用到了,你只修改这里是不会对pten中reduce实现有修改的,不过他们好像在移动代码

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已打log确认,确实能调到,因为PTen没有FP16的注册。后续PR会在PTen里添加注册。

reduce_mean,
ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
kps::DivideFunctor>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chentianyu说对于mean和sum pten内部是添加了注册的,请确认修改这里是否真正能够调用fp16

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已打log确认,确实能调到,因为PTen没有FP16的注册。后续PR会在PTen里添加注册。

Copy link
Contributor

@AnnaTrainingG AnnaTrainingG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sneaxiy sneaxiy merged commit 643a268 into PaddlePaddle:develop Dec 21, 2021
@sneaxiy sneaxiy deleted the mean_fp16 branch December 21, 2021 06:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants