@@ -38,7 +38,9 @@ namespace cub = hipcub;
3838#include " paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
3939#include " paddle/fluid/platform/device/gpu/gpu_device_function.h"
4040#include " paddle/fluid/platform/device/gpu/gpu_info.h"
41+ #include " paddle/fluid/platform/enforce.h"
4142#include " paddle/fluid/platform/fast_divmod.h"
43+ #include " paddle/fluid/string/string_helper.h"
4244
4345// Reduce split or not, Whether to use ReduceHigherDim
4446#define REDUCE_SPLIT_BOUNDARY 512
@@ -814,11 +816,42 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
814816 }
815817}
816818
819+ template <typename Tx, typename Ty, template <typename > class ReduceOp ,
820+ typename TransformOp>
821+ static typename std::enable_if<!std::is_same<Tx, platform::float16>::value,
822+ void >::type
823+ CubTensorReduceFunctorImpl (const Tx* x_data, Ty* y_data,
824+ const TransformOp& transform, int reduce_num,
825+ const platform::Place& place, gpuStream_t stream) {
826+ auto reducer = ReduceOp<Ty>();
827+ cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x (x_data,
828+ transform);
829+ size_t temp_storage_bytes = 0 ;
830+ cub::DeviceReduce::Reduce (nullptr , temp_storage_bytes, trans_x, y_data,
831+ reduce_num, reducer, reducer.initial (), stream);
832+ framework::Tensor tmp;
833+ auto * temp_storage = tmp.mutable_data <uint8_t >(
834+ framework::make_ddim ({static_cast <int64_t >(temp_storage_bytes)}), place);
835+ cub::DeviceReduce::Reduce (temp_storage, temp_storage_bytes, trans_x, y_data,
836+ reduce_num, reducer, reducer.initial (), stream);
837+ }
838+
839+ template <typename Tx, typename Ty, template <typename > class ReduceOp ,
840+ typename TransformOp>
841+ static typename std::enable_if<std::is_same<Tx, platform::float16>::value,
842+ void >::type
843+ CubTensorReduceFunctorImpl (const Tx* x_data, Ty* y_data,
844+ const TransformOp& transform, int reduce_num,
845+ const platform::Place& place, gpuStream_t stream) {
846+ PADDLE_THROW (platform::errors::InvalidArgument (
847+ " Tx should not be float16 when using cub::DeviceReduce::Reduce()." ));
848+ }
849+
817850template <typename Tx, typename Ty, template <typename > class ReduceOp ,
818851 typename TransformOp>
819852void TensorReduceFunctorImpl (const framework::Tensor& x, framework::Tensor* y,
820853 const TransformOp& transform,
821- std::vector<int > origin_reduce_dims,
854+ const std::vector<int >& origin_reduce_dims,
822855 gpuStream_t stream) {
823856 auto x_dim = framework::vectorize<int >(x.dims ());
824857 auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
@@ -848,25 +881,11 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
848881 }
849882
850883 config.SetOutputData (y_data, x.place (), &tmp);
851- bool use_cub_reduce = (config. reduce_num == numel) &&
852- (!std::is_same<Tx, paddle::platform::float16>::value) ;
884+ constexpr bool kIsTxFP16 = std::is_same<Tx, paddle::platform::float16>::value;
885+ bool use_cub_reduce = config. reduce_num == numel && ! kIsTxFP16 ;
853886 if (use_cub_reduce) {
854- // launch CUB::Reduce
855- auto reducer = ReduceOp<Ty>();
856- cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x (x_data,
857- transform);
858- size_t temp_storage_bytes = 0 ;
859- cub::DeviceReduce::Reduce (nullptr , temp_storage_bytes, trans_x, y_data,
860- config.reduce_num , reducer, reducer.initial (),
861- stream);
862- framework::Tensor tmp;
863- auto * temp_storage = tmp.mutable_data <uint8_t >(
864- framework::make_ddim ({static_cast <int64_t >(temp_storage_bytes)}),
865- x.place ());
866- cub::DeviceReduce::Reduce (temp_storage, temp_storage_bytes, trans_x, y_data,
867- config.reduce_num , reducer, reducer.initial (),
868- stream);
869-
887+ CubTensorReduceFunctorImpl<Tx, Ty, ReduceOp, TransformOp>(
888+ x_data, y_data, transform, config.reduce_num , x.place (), stream);
870889 return ;
871890 }
872891
0 commit comments