Skip to content

Commit 8d3f377

Browse files
committed
mv erf op to phi
1 parent b798fb0 commit 8d3f377

File tree

13 files changed

+273
-101
lines changed

13 files changed

+273
-101
lines changed

paddle/fluid/operators/erf_op.cc

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ limitations under the License. */
1616
#include <string>
1717
#include <unordered_map>
1818

19-
#include "paddle/fluid/operators/erf_op.h"
19+
#include "paddle/fluid/framework/infershape_utils.h"
20+
#include "paddle/fluid/framework/op_registry.h"
2021
#include "paddle/fluid/platform/float16.h"
22+
#include "paddle/phi/core/infermeta_utils.h"
23+
#include "paddle/phi/infermeta/unary.h"
2124

2225
namespace paddle {
2326
namespace operators {
@@ -29,18 +32,6 @@ class ErfOp : public framework::OperatorWithKernel {
2932
const framework::AttributeMap &attrs)
3033
: OperatorWithKernel(type, inputs, outputs, attrs) {}
3134

32-
void InferShape(framework::InferShapeContext *ctx) const override {
33-
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
34-
platform::errors::InvalidArgument(
35-
"Input(%s) of ErfOp should not be null.", "X"));
36-
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
37-
platform::errors::InvalidArgument(
38-
"Output(%s) of ErfOp should not be null.", "Out"));
39-
40-
ctx->ShareDim("X", /*->*/ "Out");
41-
ctx->ShareLoD("X", /*->*/ "Out");
42-
}
43-
4435
protected:
4536
framework::OpKernelType GetExpectedKernelType(
4637
const framework::ExecutionContext &ctx) const override {
@@ -116,28 +107,10 @@ class ErfGradOpMaker : public framework::SingleGradOpMaker<T> {
116107

117108
namespace ops = paddle::operators;
118109

110+
DECLARE_INFER_SHAPE_FUNCTOR(erf, ErfInferShapeFunctor,
111+
PD_INFER_META(phi::ErfInferMeta));
119112
REGISTER_OPERATOR(erf, ops::ErfOp, ops::ErfOpMaker,
120113
ops::ErfGradOpMaker<paddle::framework::OpDesc>,
121-
ops::ErfGradOpMaker<paddle::imperative::OpBase>);
114+
ops::ErfGradOpMaker<paddle::imperative::OpBase>,
115+
ErfInferShapeFunctor);
122116
REGISTER_OPERATOR(erf_grad, ops::ErfGradOp);
123-
REGISTER_OP_CPU_KERNEL(
124-
erf, ops::ErfKernel<paddle::platform::CPUDeviceContext, float>,
125-
ops::ErfKernel<paddle::platform::CPUDeviceContext, double>,
126-
ops::ErfKernel<paddle::platform::CPUDeviceContext,
127-
paddle::platform::float16>);
128-
REGISTER_OP_CPU_KERNEL(
129-
erf_grad, ops::ErfGradKernel<paddle::platform::CPUDeviceContext, float>,
130-
ops::ErfGradKernel<paddle::platform::CPUDeviceContext, double>,
131-
ops::ErfGradKernel<paddle::platform::CPUDeviceContext,
132-
paddle::platform::float16>);
133-
134-
REGISTER_OP_CUDA_KERNEL(
135-
erf, ops::ErfKernel<paddle::platform::CUDADeviceContext, float>,
136-
ops::ErfKernel<paddle::platform::CUDADeviceContext, double>,
137-
ops::ErfKernel<paddle::platform::CUDADeviceContext,
138-
paddle::platform::float16>);
139-
REGISTER_OP_CUDA_KERNEL(
140-
erf_grad, ops::ErfGradKernel<paddle::platform::CUDADeviceContext, float>,
141-
ops::ErfGradKernel<paddle::platform::CUDADeviceContext, double>,
142-
ops::ErfGradKernel<paddle::platform::CUDADeviceContext,
143-
paddle::platform::float16>);

paddle/fluid/operators/erf_op.h

Lines changed: 0 additions & 66 deletions
This file was deleted.

paddle/phi/infermeta/unary.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,12 @@ void TransposeInferMeta(const MetaTensor& x,
10881088
out->set_dtype(x.dtype());
10891089
}
10901090

1091+
void ErfInferMeta(const MetaTensor& x, MetaTensor* out) {
1092+
out->set_dims(x.dims());
1093+
out->share_lod(x);
1094+
out->set_dtype(x.dtype());
1095+
}
1096+
10911097
} // namespace phi
10921098

10931099
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);

paddle/phi/infermeta/unary.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,5 @@ void TransposeInferMeta(const MetaTensor& x,
154154
const std::vector<int>& axis,
155155
MetaTensor* out);
156156

157+
void ErfInferMeta(const MetaTensor& x, MetaTensor* out);
157158
} // namespace phi
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/kernels/erf_grad_kernel.h"
16+
#include "paddle/phi/backends/cpu/cpu_context.h"
17+
#include "paddle/phi/common/float16.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/impl/erf_grad_kernel_impl.h"
20+
21+
PD_REGISTER_KERNEL(erf_grad,
22+
CPU,
23+
ALL_LAYOUT,
24+
phi::ErfGradKernel,
25+
float,
26+
double,
27+
phi::dtype::float16) {}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/kernels/erf_kernel.h"
16+
#include "paddle/phi/backends/cpu/cpu_context.h"
17+
#include "paddle/phi/common/float16.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/impl/erf_kernel_impl.h"
20+
21+
PD_REGISTER_KERNEL(
22+
erf, CPU, ALL_LAYOUT, phi::ErfKernel, float, double, phi::dtype::float16) {}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename Context>
22+
void ErfGradKernel(const Context& dev_ctx,
23+
const DenseTensor& x,
24+
const DenseTensor& out_grad,
25+
DenseTensor* x_grad);
26+
27+
} // namespace phi

paddle/phi/kernels/erf_kernel.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename Context>
22+
void ErfKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
23+
24+
} // namespace phi
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/common/float16.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/erf_grad_kernel.h"
19+
#include "paddle/phi/kernels/impl/erf_grad_kernel_impl.h"
20+
21+
PD_REGISTER_KERNEL(erf_grad,
22+
GPU,
23+
ALL_LAYOUT,
24+
phi::ErfGradKernel,
25+
float,
26+
double,
27+
phi::dtype::float16) {}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/common/float16.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/erf_kernel.h"
19+
#include "paddle/phi/kernels/impl/erf_kernel_impl.h"
20+
21+
PD_REGISTER_KERNEL(
22+
erf, GPU, ALL_LAYOUT, phi::ErfKernel, float, double, phi::dtype::float16) {}

0 commit comments

Comments
 (0)