Skip to content

Commit 0c2c199

Browse files
committed
mean first version
1 parent 0e9597d commit 0c2c199

File tree

9 files changed

+200
-62
lines changed

9 files changed

+200
-62
lines changed

paddle/fluid/operators/kernel_primitives/functor_primitives.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
#pragma once
1616

17+
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
1718
#include "paddle/fluid/platform/eigen_ext.h"
19+
#include "paddle/fluid/platform/enforce.h"
20+
#include "paddle/fluid/platform/float16.h"
1821

1922
namespace paddle {
2023
namespace operators {
@@ -74,16 +77,20 @@ struct IdentityFunctor {
7477
*/
7578
template <typename Tx, typename Ty = Tx>
7679
struct DivideFunctor {
77-
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<Tx>(1.0f); }
80+
private:
81+
using MPType = typename ::paddle::operators::details::MPTypeTrait<Tx>::Type;
82+
83+
public:
84+
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<MPType>(1.0f); }
7885

79-
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((Tx)(1.0 / n)) {}
86+
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {}
8087

8188
HOSTDEVICE inline Ty operator()(const Tx& x) const {
82-
return static_cast<Ty>(x * n_inv);
89+
return static_cast<Ty>(static_cast<MPType>(x) * n_inv);
8390
}
8491

8592
private:
86-
Tx n_inv;
93+
MPType n_inv;
8794
};
8895

8996
/**

paddle/fluid/operators/mean_op.cu

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,23 @@ limitations under the License. */
1818
#include <hipcub/hipcub.hpp>
1919
namespace cub = hipcub;
2020
#endif
21+
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
22+
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
2123
#include "paddle/fluid/operators/mean_op.h"
24+
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
2225
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
2326
#include "paddle/fluid/platform/float16.h"
2427

2528
namespace paddle {
2629
namespace operators {
2730

28-
template <typename T>
29-
struct DivideFunctor {
30-
HOSTDEVICE explicit inline DivideFunctor(int n)
31-
: n_inv(static_cast<T>(1.0 / n)) {}
32-
33-
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
34-
35-
private:
36-
T n_inv;
37-
};
38-
3931
template <typename T>
4032
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
33+
using MT = typename details::MPTypeTrait<T>::Type;
4134
int idx = blockDim.x * blockIdx.x + threadIdx.x;
42-
T data = in_data[0];
35+
auto data = static_cast<MT>(in_data[0]);
4336
for (; idx < N; idx += blockDim.x * gridDim.x) {
44-
out_data[idx] = data / (static_cast<T>(N));
37+
out_data[idx] = static_cast<T>(data / (static_cast<MT>(N)));
4538
}
4639
}
4740

@@ -53,26 +46,23 @@ class MeanCUDAKernel : public framework::OpKernel<T> {
5346
auto* output = context.Output<Tensor>("Out");
5447

5548
output->mutable_data<T>(context.GetPlace());
56-
auto size_prob = input->numel();
49+
auto numel = input->numel();
50+
auto rank = input->dims().size();
51+
if (rank == 0) return;
52+
5753
const T* in_data = input->data<T>();
5854
T* out_data = output->mutable_data<T>(context.GetPlace());
5955
auto stream = context.cuda_device_context().stream();
6056

61-
DivideFunctor<T> transformer(size_prob);
62-
cub::TransformInputIterator<T, DivideFunctor<T>, const T*> trans_x(
63-
in_data, transformer);
64-
size_t temp_storage_bytes = 0;
65-
66-
auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x,
67-
out_data, size_prob, stream);
68-
PADDLE_ENFORCE_GPU_SUCCESS(err);
69-
framework::Tensor tmp;
70-
auto* temp_storage = tmp.mutable_data<uint8_t>(
71-
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
72-
context.GetPlace());
73-
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x,
74-
out_data, size_prob, stream);
75-
PADDLE_ENFORCE_GPU_SUCCESS(err);
57+
using MT = typename details::MPTypeTrait<T>::Type;
58+
using Div = kernel_primitives::DivideFunctor<T, MT>;
59+
std::vector<int> reduce_dims;
60+
reduce_dims.reserve(rank);
61+
for (decltype(rank) i = 0; i < rank; ++i) {
62+
reduce_dims.push_back(i);
63+
}
64+
TensorReduceFunctorImpl<T, T, kernel_primitives::AddFunctor, Div>(
65+
*input, output, Div(numel), reduce_dims, stream);
7666
}
7767
};
7868

@@ -91,6 +81,8 @@ class MeanCUDAGradKernel : public framework::OpKernel<T> {
9181

9282
auto in_data = OG->data<T>();
9383
auto size_prob = IG->numel();
84+
if (IG->dims().size() == 0) return;
85+
9486
auto out_data = IG->data<T>();
9587
int threads = 512;
9688
int grid = (size_prob + threads - 1) / threads;

paddle/fluid/operators/reduce_ops/reduce_functor_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ struct CustomSub {
7777

7878
template <typename Tx, typename Ty = Tx>
7979
struct CustomMean {
80-
using Transformer = kps::DivideFunctor<Tx>;
80+
using Transformer = kps::DivideFunctor<Tx, Ty>;
8181

8282
inline Ty initial() { return static_cast<Ty>(0.0f); }
8383

paddle/fluid/operators/reduce_ops/reduce_mean_op.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,7 @@
1919
REGISTER_OP_CUDA_KERNEL(
2020
reduce_mean,
2121
ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::DivideFunctor>,
22+
ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
23+
kps::DivideFunctor>,
2224
ops::ReduceCudaKernel<float, kps::AddFunctor, kps::DivideFunctor>,
2325
ops::ReduceCudaKernel<double, kps::AddFunctor, kps::DivideFunctor>);

paddle/fluid/operators/reduce_ops/reduce_mean_op.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,18 @@ struct MeanGradFunctor {
3535
}
3636
};
3737

38+
// TODO(zengjinle): Should refine the numeric stability of FP16 reduce_mean
39+
// and reduce_mean_grad later.
40+
struct FP16MeanGradFunctor {
41+
template <typename DeviceContext, typename X, typename Y, typename DX,
42+
typename DY, typename Dim>
43+
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
44+
const Dim& dim, int size) {
45+
dx->device(place) = (dy->template cast<float>().broadcast(dim) /
46+
dx->template cast<float>().constant(size))
47+
.template cast<platform::float16>();
48+
}
49+
};
50+
3851
} // namespace operators
3952
} // namespace paddle

paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ using CUDAReduceMeanGradKernel =
2020
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
2121
ops::MeanGradFunctor, true>;
2222

23+
using FP16CUDAReduceMeanGradKernel =
24+
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
25+
paddle::platform::float16, ops::FP16MeanGradFunctor,
26+
true>;
27+
2328
REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
29+
FP16CUDAReduceMeanGradKernel,
2430
CUDAReduceMeanGradKernel<float>,
2531
CUDAReduceMeanGradKernel<double>);

paddle/fluid/operators/reduce_ops/reduce_op.cu.h

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ namespace cub = hipcub;
3838
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
3939
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
4040
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
41+
#include "paddle/fluid/platform/enforce.h"
4142
#include "paddle/fluid/platform/fast_divmod.h"
43+
#include "paddle/fluid/string/string_helper.h"
4244

4345
// Reduce split or not, Whether to use ReduceHigherDim
4446
#define REDUCE_SPLIT_BOUNDARY 512
@@ -814,11 +816,42 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
814816
}
815817
}
816818

819+
template <typename Tx, typename Ty, template <typename> class ReduceOp,
820+
typename TransformOp>
821+
static typename std::enable_if<!std::is_same<Tx, platform::float16>::value,
822+
void>::type
823+
CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data,
824+
const TransformOp& transform, int reduce_num,
825+
const platform::Place& place, gpuStream_t stream) {
826+
auto reducer = ReduceOp<Ty>();
827+
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
828+
transform);
829+
size_t temp_storage_bytes = 0;
830+
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
831+
reduce_num, reducer, reducer.initial(), stream);
832+
framework::Tensor tmp;
833+
auto* temp_storage = tmp.mutable_data<uint8_t>(
834+
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}), place);
835+
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
836+
reduce_num, reducer, reducer.initial(), stream);
837+
}
838+
839+
template <typename Tx, typename Ty, template <typename> class ReduceOp,
840+
typename TransformOp>
841+
static typename std::enable_if<std::is_same<Tx, platform::float16>::value,
842+
void>::type
843+
CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data,
844+
const TransformOp& transform, int reduce_num,
845+
const platform::Place& place, gpuStream_t stream) {
846+
PADDLE_THROW(platform::errors::InvalidArgument(
847+
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
848+
}
849+
817850
template <typename Tx, typename Ty, template <typename> class ReduceOp,
818851
typename TransformOp>
819852
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
820853
const TransformOp& transform,
821-
std::vector<int> origin_reduce_dims,
854+
const std::vector<int>& origin_reduce_dims,
822855
gpuStream_t stream) {
823856
auto x_dim = framework::vectorize<int>(x.dims());
824857
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
@@ -848,25 +881,11 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
848881
}
849882

850883
config.SetOutputData(y_data, x.place(), &tmp);
851-
bool use_cub_reduce = (config.reduce_num == numel) &&
852-
(!std::is_same<Tx, paddle::platform::float16>::value);
884+
constexpr bool kIsTxFP16 = std::is_same<Tx, paddle::platform::float16>::value;
885+
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16;
853886
if (use_cub_reduce) {
854-
// launch CUB::Reduce
855-
auto reducer = ReduceOp<Ty>();
856-
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
857-
transform);
858-
size_t temp_storage_bytes = 0;
859-
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
860-
config.reduce_num, reducer, reducer.initial(),
861-
stream);
862-
framework::Tensor tmp;
863-
auto* temp_storage = tmp.mutable_data<uint8_t>(
864-
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
865-
x.place());
866-
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
867-
config.reduce_num, reducer, reducer.initial(),
868-
stream);
869-
887+
CubTensorReduceFunctorImpl<Tx, Ty, ReduceOp, TransformOp>(
888+
x_data, y_data, transform, config.reduce_num, x.place(), stream);
870889
return;
871890
}
872891

paddle/fluid/operators/reduce_ops/reduce_op.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
703703
std::vector<int> reduce_dims =
704704
GetReduceDim(dims, input->dims().size(), reduce_all);
705705
int reduce_num = 1;
706-
for (int i = 0; i < input->dims().size(); i++) {
706+
for (auto i : reduce_dims) {
707707
reduce_num *= (input->dims())[i];
708708
}
709709
gpuStream_t stream = context.cuda_device_context().stream();
@@ -713,8 +713,10 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
713713
TensorReduceFunc<T, ReduceOp, TransformOp>(
714714
*input, output, reduce_dims, reduce_num, stream));
715715
} else {
716-
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, T>>(
717-
*input, output, TransformOp<T, T>(reduce_num), reduce_dims, stream);
716+
using MPType = typename details::MPTypeTrait<T>::Type;
717+
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, MPType>>(
718+
*input, output, TransformOp<T, MPType>(reduce_num), reduce_dims,
719+
stream);
718720
}
719721
}
720722
};

0 commit comments

Comments
 (0)