Skip to content

Commit 0bff2f4

Browse files
author
chenfeiyu
committed
rename op and api, remove inplace and view machanism for these 2 operators
1 parent 5516c39 commit 0bff2f4

File tree

9 files changed

+62
-303
lines changed

9 files changed

+62
-303
lines changed

paddle/fluid/operators/complex_view_op.cc

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
namespace paddle {
2525
namespace operators {
2626

27-
class ViewAsComplexOp : public framework::OperatorWithKernel {
27+
class AsComplexOp : public framework::OperatorWithKernel {
2828
public:
2929
using framework::OperatorWithKernel::OperatorWithKernel;
3030

3131
void InferShape(framework::InferShapeContext *ctx) const override {
32-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "view_as_complex");
33-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "view_as_complex");
32+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "as_complex");
33+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "as_complex");
3434

3535
auto in_dims = ctx->GetInputDim("X");
3636
const int input_rank = in_dims.size();
@@ -57,15 +57,15 @@ class ViewAsComplexOp : public framework::OperatorWithKernel {
5757
}
5858
};
5959

60-
class ViewAsComplexOpMaker : public framework::OpProtoAndCheckerMaker {
60+
class AsComplexOpMaker : public framework::OpProtoAndCheckerMaker {
6161
public:
6262
void Make() override {
6363
AddInput("X", "(Tensor), The input tensor of view_as_complex op.");
6464
AddOutput("Out", "(Tensor), The output tensor of view_as_complex op.");
6565
AddComment(R"DOC(
66-
View_as_complex Operator.
66+
As_complex Operator.
6767
68-
This operator is used to return a complex tensor view represented
68+
This operator is used to return a complex tensor represented
6969
by an old-fashioned real tensor. The size of the last dimension of
7070
the input tensor should be 2, which corresponds to 'real' and
7171
'complex', respectively.
@@ -75,25 +75,25 @@ the input tensor should be 2, which corresponds to 'real' and
7575
};
7676

7777
template <typename T>
78-
class ViewAsComplexGradMaker : public framework::SingleGradOpMaker<T> {
78+
class AsComplexGradMaker : public framework::SingleGradOpMaker<T> {
7979
public:
8080
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
8181

8282
void Apply(GradOpPtr<T> retv) const override {
83-
retv->SetType("view_as_real");
83+
retv->SetType("as_real");
8484
retv->SetInput("X", this->OutputGrad("Out"));
8585
retv->SetAttrMap(this->Attrs());
8686
retv->SetOutput("Out", this->InputGrad("X"));
8787
}
8888
};
8989

90-
class ViewAsRealOp : public framework::OperatorWithKernel {
90+
class AsRealOp : public framework::OperatorWithKernel {
9191
public:
9292
using framework::OperatorWithKernel::OperatorWithKernel;
9393

9494
void InferShape(framework::InferShapeContext *ctx) const override {
95-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "view_as_real");
96-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "view_as_real");
95+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "as_real");
96+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "as_real");
9797

9898
auto out_dims_v = framework::vectorize(ctx->GetInputDim("X"));
9999
out_dims_v.push_back(2);
@@ -112,15 +112,15 @@ class ViewAsRealOp : public framework::OperatorWithKernel {
112112
}
113113
};
114114

115-
class ViewAsRealOpMaker : public framework::OpProtoAndCheckerMaker {
115+
class AsRealOpMaker : public framework::OpProtoAndCheckerMaker {
116116
public:
117117
void Make() override {
118-
AddInput("X", "(Tensor), The input tensor of view_as_real op.");
119-
AddOutput("Out", "(Tensor), The output tensor of view_as_real op.");
118+
AddInput("X", "(Tensor), The input tensor of as_real op.");
119+
AddOutput("Out", "(Tensor), The output tensor of as_real op.");
120120
AddComment(R"DOC(
121-
View_as_real Operator.
121+
AsReal Operator.
122122
123-
This operator is used to return an old-fashioned real tensor view of a
123+
This operator is used to return an old-fashioned real tensor from a
124124
complex tensor. The size of the last dimension of the output tensor is 2,
125125
which corresponds to 'real' and 'complex', respectively.
126126
@@ -129,12 +129,12 @@ which corresponds to 'real' and 'complex', respectively.
129129
};
130130

131131
template <typename T>
132-
class ViewAsRealGradMaker : public framework::SingleGradOpMaker<T> {
132+
class AsRealGradMaker : public framework::SingleGradOpMaker<T> {
133133
public:
134134
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
135135

136136
void Apply(GradOpPtr<T> retv) const override {
137-
retv->SetType("view_as_complex");
137+
retv->SetType("as_complex");
138138
retv->SetInput("X", this->OutputGrad("Out"));
139139
retv->SetAttrMap(this->Attrs());
140140
retv->SetOutput("Out", this->InputGrad("X"));
@@ -146,23 +146,18 @@ class ViewAsRealGradMaker : public framework::SingleGradOpMaker<T> {
146146

147147
namespace ops = paddle::operators;
148148

149-
REGISTER_OPERATOR(view_as_complex, ops::ViewAsComplexOp,
150-
ops::ViewAsComplexOpMaker,
151-
ops::ViewAsComplexGradMaker<paddle::framework::OpDesc>,
152-
ops::ViewAsComplexGradMaker<paddle::imperative::OpBase>,
153-
ops::ViewAsComplexOpInplaceInferer);
149+
REGISTER_OPERATOR(as_complex, ops::AsComplexOp, ops::AsComplexOpMaker,
150+
ops::AsComplexGradMaker<paddle::framework::OpDesc>,
151+
ops::AsComplexGradMaker<paddle::imperative::OpBase>);
154152

155-
REGISTER_OPERATOR(view_as_real, ops::ViewAsRealOp, ops::ViewAsRealOpMaker,
156-
ops::ViewAsRealGradMaker<paddle::framework::OpDesc>,
157-
ops::ViewAsRealGradMaker<paddle::imperative::OpBase>,
158-
ops::ViewAsRealOpInplaceInferer);
153+
REGISTER_OPERATOR(as_real, ops::AsRealOp, ops::AsRealOpMaker,
154+
ops::AsRealGradMaker<paddle::framework::OpDesc>,
155+
ops::AsRealGradMaker<paddle::imperative::OpBase>);
159156

160157
REGISTER_OP_CPU_KERNEL(
161-
view_as_complex,
162-
ops::ViewAsComplexKernel<paddle::platform::CPUDeviceContext, float>,
163-
ops::ViewAsComplexKernel<paddle::platform::CPUDeviceContext, double>);
158+
as_complex, ops::AsComplexKernel<paddle::platform::CPUDeviceContext, float>,
159+
ops::AsComplexKernel<paddle::platform::CPUDeviceContext, double>);
164160

165161
REGISTER_OP_CPU_KERNEL(
166-
view_as_real,
167-
ops::ViewAsRealKernel<paddle::platform::CPUDeviceContext, float>,
168-
ops::ViewAsRealKernel<paddle::platform::CPUDeviceContext, double>);
162+
as_real, ops::AsRealKernel<paddle::platform::CPUDeviceContext, float>,
163+
ops::AsRealKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/complex_view_op.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
namespace ops = paddle::operators;
2121

2222
REGISTER_OP_CUDA_KERNEL(
23-
view_as_complex,
24-
ops::ViewAsComplexKernel<paddle::platform::CUDADeviceContext, float>,
25-
ops::ViewAsComplexKernel<paddle::platform::CUDADeviceContext, double>);
23+
as_complex,
24+
ops::AsComplexKernel<paddle::platform::CUDADeviceContext, float>,
25+
ops::AsComplexKernel<paddle::platform::CUDADeviceContext, double>);
2626

2727
REGISTER_OP_CUDA_KERNEL(
28-
view_as_real,
29-
ops::ViewAsRealKernel<paddle::platform::CUDADeviceContext, float>,
30-
ops::ViewAsRealKernel<paddle::platform::CUDADeviceContext, double>);
28+
as_real, ops::AsRealKernel<paddle::platform::CUDADeviceContext, float>,
29+
ops::AsRealKernel<paddle::platform::CUDADeviceContext, double>);

paddle/fluid/operators/complex_view_op.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace paddle {
2525
namespace operators {
2626

2727
template <typename DeviceContext, typename T>
28-
class ViewAsComplexKernel : public framework::OpKernel<T> {
28+
class AsComplexKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& context) const override {
3131
const auto* x = context.Input<framework::LoDTensor>("X");
@@ -42,7 +42,7 @@ class ViewAsComplexKernel : public framework::OpKernel<T> {
4242
};
4343

4444
template <typename DeviceContext, typename T>
45-
class ViewAsRealKernel : public framework::OpKernel<T> {
45+
class AsRealKernel : public framework::OpKernel<T> {
4646
public:
4747
void Compute(const framework::ExecutionContext& context) const override {
4848
const auto* x = context.Input<framework::LoDTensor>("X");
@@ -56,8 +56,5 @@ class ViewAsRealKernel : public framework::OpKernel<T> {
5656
}
5757
};
5858

59-
DECLARE_INPLACE_OP_INFERER(ViewAsComplexOpInplaceInferer, {"X", "Out"});
60-
DECLARE_INPLACE_OP_INFERER(ViewAsRealOpInplaceInferer, {"X", "Out"});
61-
6259
} // namespace operators
6360
} // namespace paddle

paddle/fluid/pybind/op_function_generator.cc

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -32,78 +32,6 @@
3232
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
3333
#endif
3434

35-
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
36-
// generated in C++ automatically.
37-
// However, some OPs need to pass the outputs from Python instead of generating
38-
// them in C++. There are mainly 2 reasons for that,
39-
// (1) Optimizer OPs need to update the input param in-place, like sgd.
40-
// So they need to pass the output which is same as input param.
41-
// (2) Very few python APIs has out in their arguments, like fill_constant.
42-
// So they need to pass the python output to C++.
43-
// Actually, this is not a good design, since it may break the SSA graph,
44-
// especially in declarative mode.
45-
// For those OPs, we need to manually specify the outs need to pass in this map.
46-
std::map<std::string, std::set<std::string>> op_passing_outs_map = {
47-
{"sgd", {"ParamOut"}},
48-
{"adam",
49-
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
50-
"MasterParamOut"}},
51-
{"adamw",
52-
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
53-
"MasterParamOut"}},
54-
{"average_accumulates",
55-
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
56-
"out_old_num_accumulates", "out_num_updates"}},
57-
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
58-
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
59-
{"batch_norm", {"MeanOut", "VarianceOut"}},
60-
{"sync_batch_norm", {"MeanOut", "VarianceOut"}},
61-
{"accuracy", {"Correct", "Total"}},
62-
{"fill_constant", {"Out"}},
63-
{"recv_v2", {"Out"}},
64-
{"partial_recv", {"Out"}},
65-
{"matmul", {"Out"}},
66-
{"c_broadcast", {"Out"}},
67-
{"c_sync_calc_stream", {"Out"}},
68-
{"c_sync_comm_stream", {"Out"}},
69-
{"c_reduce_sum", {"Out"}},
70-
{"c_reduce_max", {"Out"}},
71-
{"c_reduce_min", {"Out"}},
72-
{"c_reduce_prod", {"Out"}},
73-
{"c_reduce", {"Out"}},
74-
{"c_scatter", {"Out"}},
75-
{"barrier", {"Out"}},
76-
{"fake_quantize_dequantize_moving_average_abs_max",
77-
{"Out", "OutScale", "OutAccum", "OutState"}},
78-
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
79-
{"fake_channel_wise_quantize_dequantize_abs_max", {"Out", "OutScale"}},
80-
{"check_finite_and_unscale", {"Out", "FoundInfinite"}},
81-
{"update_loss_scaling",
82-
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
83-
{"moving_average_abs_max_scale",
84-
{"Out", "OutScale", "OutAccum", "OutState"}},
85-
{"lamb",
86-
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
87-
{"rnn", {"DropoutState"}},
88-
{"run_program", {"Out", "DOut", "OutScope"}},
89-
{"clear_float_status", {"FloatStatusOut"}},
90-
{"get_float_status", {"FloatStatusOut"}},
91-
};
92-
93-
// NOTE(pangyoki): Tensor View Strategy.
94-
// In this case, a new output varbase will be created, and this varbase will
95-
// reuse the input varbase's allocation.
96-
// It's a map. The key of outer map is the view op name, the value is
97-
// a pair which implies the mapping relationship between the input and
98-
// output varbase.
99-
std::map<std::string, std::pair<std::string, std::string>> view_op_map = {
100-
{"squeeze2", {"X", "Out"}}, // "X" -> "Out"
101-
{"unsqueeze2", {"X", "Out"}},
102-
{"reshape2", {"X", "Out"}},
103-
{"flatten_contiguous_range", {"X", "Out"}},
104-
{"view_as_complex", {"X", "Out"}},
105-
{"view_as_real", {"X", "Out"}}};
106-
10735
// NOTE(pangyoki): Inplace OP with duplicable input.
10836
// The set includes inplace ops that have duplicable input.
10937
// The first Varbase in input needs to be specified for the inplace strategy

python/paddle/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@
155155
from .tensor.manipulation import chunk # noqa: F401
156156
from .tensor.manipulation import tolist # noqa: F401
157157
from .tensor.manipulation import tensordot # noqa: F401
158-
from .tensor.manipulation import view_as_complex # noqa: F401
159-
from .tensor.manipulation import view_as_real # noqa: F401
158+
from .tensor.manipulation import as_complex # noqa: F401
159+
from .tensor.manipulation import as_real # noqa: F401
160160

161161
from .tensor.math import abs # noqa: F401
162162
from .tensor.math import acos # noqa: F401
@@ -541,8 +541,8 @@
541541
'einsum',
542542
'set_flags',
543543
'get_flags',
544-
'view_as_complex',
545-
'view_as_real',
544+
'as_complex',
545+
'as_real',
546546
'diff',
547547
'angle',
548548
]

python/paddle/fluid/tests/unittests/test_complex_view_op.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def ref_view_as_real(x):
3535

3636
class TestViewAsComplexOp(OpTest):
3737
def setUp(self):
38-
self.op_type = "view_as_complex"
38+
self.op_type = "as_complex"
3939
x = np.random.randn(10, 10, 2).astype("float64")
4040
out_ref = ref_view_as_complex(x)
4141
self.out_grad = np.ones(
@@ -57,7 +57,7 @@ def test_check_grad(self):
5757

5858
class TestViewAsRealOp(OpTest):
5959
def setUp(self):
60-
self.op_type = "view_as_real"
60+
self.op_type = "as_real"
6161
real = np.random.randn(10, 10).astype("float64")
6262
imag = np.random.randn(10, 10).astype("float64")
6363
x = real + 1j * imag
@@ -85,14 +85,14 @@ def setUp(self):
8585
def test_dygraph(self):
8686
with dygraph.guard():
8787
x = paddle.to_tensor(self.x)
88-
out_np = paddle.view_as_complex(x).numpy()
88+
out_np = paddle.as_complex(x).numpy()
8989
self.assertTrue(np.allclose(self.out, out_np))
9090

9191
def test_static(self):
9292
mp, sp = static.Program(), static.Program()
9393
with static.program_guard(mp, sp):
9494
x = static.data("x", shape=[10, 10, 2], dtype="float64")
95-
out = paddle.view_as_complex(x)
95+
out = paddle.as_complex(x)
9696

9797
exe = static.Executor()
9898
exe.run(sp)
@@ -108,14 +108,14 @@ def setUp(self):
108108
def test_dygraph(self):
109109
with dygraph.guard():
110110
x = paddle.to_tensor(self.x)
111-
out_np = paddle.view_as_real(x).numpy()
111+
out_np = paddle.as_real(x).numpy()
112112
self.assertTrue(np.allclose(self.out, out_np))
113113

114114
def test_static(self):
115115
mp, sp = static.Program(), static.Program()
116116
with static.program_guard(mp, sp):
117117
x = static.data("x", shape=[10, 10], dtype="complex128")
118-
out = paddle.view_as_real(x)
118+
out = paddle.as_real(x)
119119

120120
exe = static.Executor()
121121
exe.run(sp)

0 commit comments

Comments
 (0)