Skip to content

Commit 6ddc7d4

Browse files
committed
add MaxFunctor
1 parent a0777ff commit 6ddc7d4

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

paddle/fluid/operators/viterbi_decode_op.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ limitations under the License. */
3030
#include "paddle/fluid/operators/math/fc.h"
3131
#include "paddle/fluid/operators/math/functors.h"
3232
#include "paddle/fluid/operators/math/math_function.h"
33-
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
3433
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
3534
#include "paddle/fluid/operators/transpose_op.h"
3635
#include "paddle/fluid/operators/unique_op.h"
@@ -66,14 +65,22 @@ using LoDTensor = framework::LoDTensor;
6665
dev_ctx); \
6766
cast_functor.template apply<dtype>()
6867

68+
template <typename T>
69+
struct MaxFunctor {
70+
template <typename DeviceContext, typename X, typename Y, typename Dim>
71+
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
72+
y->device(place) = x->maximum(dim);
73+
}
74+
};
75+
6976
template <typename DeviceContext, typename T>
7077
inline void MAX_FUNC(const framework::ExecutionContext& ctx,
7178
const Tensor* input, Tensor* output,
7279
const std::vector<int>& dims) {
7380
auto cast_out_dtype =
7481
static_cast<framework::proto::VarType::Type>(output->type());
7582
framework::VisitDataType(cast_out_dtype,
76-
ReduceKernelFunctor<DeviceContext, T, MaxFunctor>(
83+
ReduceKernelFunctor<DeviceContext, T, MaxFunctor<T>>(
7784
input, output, dims, false, false, ctx));
7885
}
7986

0 commit comments

Comments
 (0)