Skip to content

Commit 76f8703

Browse files
authored
[Phi] Move allclose op kernel into phi (#40469)
* move allclose kernel * remove allclose op kernel * fix coverage failed
1 parent 39de9b8 commit 76f8703

File tree

9 files changed

+276
-215
lines changed

9 files changed

+276
-215
lines changed

paddle/fluid/operators/allclose_op.cc

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/operators/allclose_op.h"
1615
#include <cmath>
1716
#include <string>
17+
1818
#include "paddle/fluid/framework/op_registry.h"
1919
#include "paddle/fluid/framework/op_version_registry.h"
2020
#include "paddle/fluid/framework/operator.h"
@@ -23,41 +23,6 @@
2323
namespace paddle {
2424
namespace operators {
2525

26-
template <typename T>
27-
struct GetTensorValue<platform::CPUDeviceContext, T> {
28-
T operator()(const platform::CPUDeviceContext& dev_ctx,
29-
const framework::Tensor& tensor) const {
30-
return *(tensor.data<T>());
31-
}
32-
};
33-
34-
template <typename T>
35-
struct AllcloseFunctor<platform::CPUDeviceContext, T> {
36-
void operator()(const platform::CPUDeviceContext& ctx,
37-
const framework::Tensor& in, const framework::Tensor& other,
38-
const double rtol, const double atol, bool equal_nan,
39-
framework::Tensor* output) {
40-
auto* in_a = in.data<T>();
41-
auto* in_b = other.data<T>();
42-
auto* out_data = output->mutable_data<bool>(ctx.GetPlace());
43-
auto num = in.numel();
44-
*out_data = true;
45-
for (int i = 0; i < num; i++) {
46-
const T a = in_a[i], b = in_b[i];
47-
bool val;
48-
if (std::isnan(a) || std::isnan(b)) {
49-
val = equal_nan && std::isnan(a) == std::isnan(b);
50-
} else {
51-
T left = (a > b ? a - b : b - a);
52-
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
53-
T diff = (left > right ? left - right : right - left);
54-
val = a == b || left <= right || diff <= 1e-15;
55-
}
56-
*out_data &= val;
57-
}
58-
}
59-
};
60-
6126
class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker {
6227
public:
6328
void Make() override {
@@ -157,8 +122,6 @@ REGISTER_OPERATOR(
157122
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
158123
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
159124
ops::AllcloseOpVarTypeInference);
160-
REGISTER_OP_CPU_KERNEL(allclose, ops::AllcloseKernel<CPU, float>,
161-
ops::AllcloseKernel<CPU, double>);
162125

163126
/* ========================== register checkpoint ===========================*/
164127
REGISTER_OP_VERSION(allclose)

paddle/fluid/operators/allclose_op.cu

Lines changed: 0 additions & 84 deletions
This file was deleted.

paddle/fluid/operators/allclose_op.h

Lines changed: 0 additions & 93 deletions
This file was deleted.

paddle/phi/api/lib/utils/tensor_utils.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ phi::Scalar MakePhiScalarFromVar(const framework::Variable& variable) {
4040
auto expected_place = phi::TransToPhiPlace(phi::Backend::CPU);
4141
if (variable.IsType<framework::LoDTensor>()) {
4242
const auto& tensor = variable.Get<framework::LoDTensor>();
43+
PADDLE_ENFORCE_EQ(
44+
tensor.numel(),
45+
1UL,
46+
platform::errors::InvalidArgument("The DenseTensor used to construct "
47+
"the Scalar contains more than 1 "
48+
"value, it contains `%d` values.",
49+
tensor.numel()));
4350
if (!platform::is_same_place(tensor.place(), expected_place)) {
4451
framework::LoDTensor tmp_tensor;
4552
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/common/scalar.h"
18+
#include "paddle/phi/core/dense_tensor.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void AllCloseKernel(const Context& dev_ctx,
24+
const DenseTensor& x,
25+
const DenseTensor& y,
26+
const Scalar& rtol,
27+
const Scalar& atol,
28+
bool equal_nan,
29+
DenseTensor* out);
30+
31+
} // namespace phi
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/allclose_kernel.h"
16+
17+
#include <cmath>
18+
19+
#include "paddle/phi/core/enforce.h"
20+
#include "paddle/phi/core/kernel_registry.h"
21+
22+
namespace phi {
23+
24+
template <typename T, typename Context>
25+
void AllCloseKernel(const Context& dev_ctx,
26+
const DenseTensor& x,
27+
const DenseTensor& y,
28+
const Scalar& rtol,
29+
const Scalar& atol,
30+
bool equal_nan,
31+
DenseTensor* out) {
32+
PADDLE_ENFORCE_EQ(
33+
rtol.dtype(),
34+
DataType::FLOAT64,
35+
phi::errors::InvalidArgument(
36+
"Input (Rtol) type must be double, but get %s.", rtol.dtype()));
37+
PADDLE_ENFORCE_EQ(
38+
atol.dtype(),
39+
DataType::FLOAT64,
40+
phi::errors::InvalidArgument(
41+
"Input (Atol) type must be double, but get %s.", atol.dtype()));
42+
43+
auto* in_a = x.data<T>();
44+
auto* in_b = y.data<T>();
45+
auto rtol_v = rtol.to<double>();
46+
auto atol_v = atol.to<double>();
47+
auto* out_data = dev_ctx.template Alloc<bool>(out);
48+
*out_data = true;
49+
50+
auto num = x.numel();
51+
for (int64_t i = 0; i < num; ++i) {
52+
const T a = in_a[i], b = in_b[i];
53+
bool val;
54+
if (std::isnan(a) || std::isnan(b)) {
55+
val = equal_nan && std::isnan(a) == std::isnan(b);
56+
} else {
57+
T left = (a > b ? a - b : b - a);
58+
T right = atol_v + (b > 0 ? rtol_v * b : (-rtol_v) * b);
59+
T diff = (left > right ? left - right : right - left);
60+
val = a == b || left <= right || diff <= 1e-15;
61+
}
62+
*out_data &= val;
63+
}
64+
}
65+
66+
} // namespace phi
67+
68+
PD_REGISTER_KERNEL(
69+
allclose, CPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) {
70+
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
71+
}

0 commit comments

Comments
 (0)