Skip to content

Commit 55d6b87

Browse files
authored
sum op (#39165)
1 parent b75507d commit 55d6b87

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/sum_op.h"
16+
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
23+
template <typename DeviceContext, typename T>
24+
class SumMLUKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext &ctx) const override {
27+
auto out_var = ctx.OutputVar("Out");
28+
if (out_var->IsType<framework::LoDTensor>()) {
29+
// init
30+
auto *out = out_var->GetMutable<framework::LoDTensor>();
31+
auto ins = ctx.MultiInput<Tensor>("X");
32+
out->mutable_data<T>(ctx.GetPlace());
33+
auto place = ctx.GetPlace();
34+
int ins_size = static_cast<int>(ins.size());
35+
if (ins_size == 1) {
36+
TensorCopy(*ins[0], place, out);
37+
return;
38+
}
39+
40+
// MLU shoul do sth
41+
std::vector<const void *> inputs;
42+
std::vector<MLUCnnlTensorDesc> input_descs;
43+
std::vector<cnnlTensorDescriptor_t> desc_vector;
44+
for (int i = 0; i < ins_size; i++) {
45+
input_descs.emplace_back(MLUCnnlTensorDesc(
46+
*ins[i], CNNL_LAYOUT_ARRAY, ToCnnlDataType(ins[i]->type())));
47+
desc_vector.push_back(input_descs.back().get());
48+
inputs.push_back(GetBasePtr(ins[i]));
49+
}
50+
// init out tensors
51+
MLUCnnlTensorDesc output_desc(*out, CNNL_LAYOUT_ARRAY,
52+
ToCnnlDataType(out->type()));
53+
uint32_t ins_size_t = static_cast<uint32_t>(ins_size);
54+
MLUCnnl::AddN(ctx, ins_size_t, desc_vector.data(), inputs.data(),
55+
output_desc.get(), GetBasePtr(out));
56+
57+
} else {
58+
PADDLE_THROW(platform::errors::InvalidArgument(
59+
"Expected type of Output(out) must be Tensor or But got "
60+
"unsupport type: %s.",
61+
framework::ToTypeName(out_var->Type())));
62+
}
63+
}
64+
};
65+
66+
} // namespace operators
67+
} // namespace paddle
68+
69+
namespace ops = paddle::operators;
70+
71+
REGISTER_OP_MLU_KERNEL(
72+
sum, ops::SumMLUKernel<paddle::platform::MLUDeviceContext, float>,
73+
ops::SumMLUKernel<paddle::platform::MLUDeviceContext,
74+
paddle::platform::float16>);
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
import paddle.fluid.core as core
25+
26+
paddle.enable_static()
27+
SEED = 2021
28+
29+
30+
class TestSum1(OpTest):
31+
def setUp(self):
32+
self.set_mlu()
33+
self.init_dtype()
34+
self.op_type = "sum"
35+
self.place = paddle.MLUPlace(0)
36+
37+
x0 = np.random.random((3, 40)).astype(self.dtype)
38+
x1 = np.random.random((3, 40)).astype(self.dtype)
39+
x2 = np.random.random((3, 40)).astype(self.dtype)
40+
self.inputs = {'X': [("x0", x0), ("x1", x1), ("x2", x2)]}
41+
y = x0 + x1 + x2
42+
self.outputs = {'Out': y}
43+
44+
self.attrs = {'use_mkldnn': False}
45+
46+
def init_dtype(self):
47+
self.dtype = np.float32
48+
49+
def set_mlu(self):
50+
self.__class__.use_mlu = True
51+
52+
def test_check_output(self):
53+
self.check_output_with_place(self.place)
54+
55+
56+
class TestSum2(OpTest):
57+
def setUp(self):
58+
self.set_mlu()
59+
self.init_dtype()
60+
self.op_type = "sum"
61+
self.place = paddle.MLUPlace(0)
62+
63+
x0 = np.random.random((3, 3)).astype(self.dtype)
64+
x1 = np.random.random((3, 3)).astype(self.dtype)
65+
x2 = np.random.random((3, 3)).astype(self.dtype)
66+
x3 = np.random.random((3, 3)).astype(self.dtype)
67+
self.inputs = {'X': [("x0", x0), ("x1", x1), ("x2", x2), ("x3", x3)]}
68+
# There will be a problem if just using `y=x0+x1+x2+x3` to calculate the
69+
# summation result as the reference standard result. The reason is that
70+
# numpy's fp16 data has precision loss when doing `add` operation.
71+
# For example, the results of `x0+x1+x2+x3` is different from that of
72+
# `x3+x2+x1+x0` if the dtype is fp16.
73+
# Therefore, converting the input to fp32 for calculation.
74+
y = (x0.astype(np.float32) + x1.astype(np.float32) +
75+
x2.astype(np.float32) + x3.astype(np.float32)).astype(self.dtype)
76+
self.outputs = {'Out': y}
77+
78+
self.attrs = {'use_mkldnn': False}
79+
80+
def init_dtype(self):
81+
self.dtype = np.float16
82+
83+
def set_mlu(self):
84+
self.__class__.use_mlu = True
85+
86+
def test_check_output(self):
87+
self.check_output_with_place(self.place)
88+
89+
90+
class TestSum3(OpTest):
91+
def setUp(self):
92+
self.set_mlu()
93+
self.init_dtype()
94+
self.op_type = "sum"
95+
self.place = paddle.MLUPlace(0)
96+
97+
x0 = np.random.random((3, 3)).astype(self.dtype)
98+
99+
self.inputs = {'X': [("x0", x0)]}
100+
y = x0
101+
self.outputs = {'Out': y}
102+
103+
self.attrs = {'use_mkldnn': False}
104+
105+
def init_dtype(self):
106+
self.dtype = np.float16
107+
108+
def set_mlu(self):
109+
self.__class__.use_mlu = True
110+
111+
def test_check_output(self):
112+
self.check_output_with_place(self.place)
113+
114+
115+
if __name__ == '__main__':
116+
unittest.main()

0 commit comments

Comments
 (0)