Skip to content

Commit 2b8b16d

Browse files
authored
[NPU] add reduce_min (#39019)
[NPU] add reduce_min
1 parent 35b03e1 commit 2b8b16d

File tree

2 files changed

+418
-0
lines changed

2 files changed

+418
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
16+
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
template <typename DeviceContext, typename T>
23+
class ReduceMinNPUKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
auto* x = ctx.Input<Tensor>("X");
27+
auto* out = ctx.Output<Tensor>("Out");
28+
auto dims = ctx.Attr<std::vector<int>>("dim");
29+
bool keep_dim = ctx.Attr<bool>("keep_dim");
30+
bool reduce_all = ctx.Attr<bool>("reduce_all");
31+
int out_dtype = ctx.Attr<int>("out_dtype");
32+
33+
auto place = ctx.GetPlace();
34+
35+
framework::Tensor cast_out(x->type());
36+
cast_out.Resize(out->dims());
37+
cast_out.mutable_data<T>(place);
38+
39+
auto cast_out_dtype = x->type();
40+
if (out_dtype != -1) {
41+
cast_out_dtype = static_cast<framework::proto::VarType::Type>(out_dtype);
42+
}
43+
44+
if (x->type() != cast_out_dtype) {
45+
if (cast_out_dtype == framework::proto::VarType::FP32) {
46+
out->mutable_data<float>(place);
47+
} else if (cast_out_dtype == framework::proto::VarType::FP16) {
48+
out->mutable_data<paddle::platform::float16>(place);
49+
} else if (cast_out_dtype == framework::proto::VarType::INT16) {
50+
out->mutable_data<int16_t>(place);
51+
} else if (cast_out_dtype == framework::proto::VarType::INT32) {
52+
out->mutable_data<int32_t>(place);
53+
} else if (cast_out_dtype == framework::proto::VarType::INT64) {
54+
out->mutable_data<int64_t>(place);
55+
} else if (cast_out_dtype == framework::proto::VarType::FP64) {
56+
out->mutable_data<double>(place);
57+
} else if (cast_out_dtype == framework::proto::VarType::BOOL) {
58+
out->mutable_data<bool>(place);
59+
}
60+
} else {
61+
out->ShareDataWith(cast_out);
62+
}
63+
64+
framework::NPUAttributeMap attr_input = {{"axes", dims},
65+
{"keep_dims", keep_dim}};
66+
67+
if (reduce_all) {
68+
std::vector<int> dim_vec;
69+
for (int i = 0; i < x->dims().size(); i++) {
70+
dim_vec.push_back(i);
71+
}
72+
73+
attr_input = {{"axes", dim_vec}, {"keep_dims", keep_dim}};
74+
}
75+
76+
const auto& dev_ctx =
77+
ctx.template device_context<paddle::platform::NPUDeviceContext>();
78+
if (x->type() == framework::proto::VarType::INT64) {
79+
auto op_func = [](const std::vector<Tensor>& inputs,
80+
const std::vector<Tensor>& outputs,
81+
const NPUAttributeMap& attrs,
82+
const platform::NPUDeviceContext& dev_ctx) {
83+
const auto& runner =
84+
NpuOpRunner("ReduceMinD", {inputs[0]}, {outputs[0]}, attrs);
85+
runner.Run(dev_ctx.stream());
86+
};
87+
88+
NpuOpRunner::TypeAdapter({*x}, {cast_out}, attr_input, dev_ctx, op_func,
89+
{framework::proto::VarType::INT32},
90+
{framework::proto::VarType::INT32});
91+
} else {
92+
const auto& runner =
93+
NpuOpRunner("ReduceMinD", {*x}, {cast_out}, attr_input);
94+
runner.Run(dev_ctx.stream());
95+
}
96+
97+
if (x->type() != cast_out_dtype) {
98+
auto dst_dtype = ConvertToNpuDtype(cast_out_dtype);
99+
const auto& runner_cast =
100+
NpuOpRunner("Cast", {cast_out}, {*out},
101+
{{"dst_type", static_cast<int>(dst_dtype)}});
102+
runner_cast.Run(dev_ctx.stream());
103+
}
104+
}
105+
};
106+
107+
} // namespace operators
108+
} // namespace paddle
109+
110+
namespace ops = paddle::operators;
111+
namespace plat = paddle::platform;
112+
REGISTER_OP_NPU_KERNEL(
113+
reduce_min, ops::ReduceMinNPUKernel<plat::NPUDeviceContext, float>,
114+
ops::ReduceMinNPUKernel<plat::NPUDeviceContext, plat::float16>,
115+
#ifdef PADDLE_WITH_ASCEND_INT64
116+
ops::ReduceMinNPUKernel<plat::NPUDeviceContext, int64_t>,
117+
#endif
118+
ops::ReduceMinNPUKernel<plat::NPUDeviceContext, int>);

0 commit comments

Comments
 (0)