@@ -30,7 +30,6 @@ limitations under the License. */
3030#include " paddle/fluid/operators/math/fc.h"
3131#include " paddle/fluid/operators/math/functors.h"
3232#include " paddle/fluid/operators/math/math_function.h"
33- #include " paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
3433#include " paddle/fluid/operators/reduce_ops/reduce_op.h"
3534#include " paddle/fluid/operators/transpose_op.h"
3635#include " paddle/fluid/operators/unique_op.h"
@@ -66,14 +65,22 @@ using LoDTensor = framework::LoDTensor;
6665 dev_ctx); \
6766 cast_functor.template apply<dtype>()
6867
68+ template <typename T>
69+ struct MaxFunctor {
70+ template <typename DeviceContext, typename X, typename Y, typename Dim>
71+ void operator ()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
72+ y->device (place) = x->maximum (dim);
73+ }
74+ };
75+
6976template <typename DeviceContext, typename T>
7077inline void MAX_FUNC (const framework::ExecutionContext& ctx,
7178 const Tensor* input, Tensor* output,
7279 const std::vector<int >& dims) {
7380 auto cast_out_dtype =
7481 static_cast <framework::proto::VarType::Type>(output->type ());
7582 framework::VisitDataType (cast_out_dtype,
76- ReduceKernelFunctor<DeviceContext, T, MaxFunctor>(
83+ ReduceKernelFunctor<DeviceContext, T, MaxFunctor<T> >(
7784 input, output, dims, false , false , ctx));
7885}
7986
0 commit comments