Skip to content

Commit 1716324

Browse files
authored
Add paddle.lerp API to do a linear interpolation (#37253)
* save temp * add unittest, test=develop * fix ci error, test=develop * fix grad accuracy error, test=develop * fix unused error, test=develop * fix compilation error on Windows, test=develop * add unittest, test=develop * modify by review comment and add lerp_ * fix inplace api, test=develop * fix inplace api, test=develop * fix coverage error, test=develop
1 parent 46212b8 commit 1716324

File tree

7 files changed

+629
-0
lines changed

7 files changed

+629
-0
lines changed

paddle/fluid/operators/lerp_op.cc

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/lerp_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class LerpOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp");
26+
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp");
27+
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp");
28+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp");
29+
30+
auto x_dims = ctx->GetInputDim("X");
31+
auto y_dims = ctx->GetInputDim("Y");
32+
auto w_dims = ctx->GetInputDim("Weight");
33+
framework::DDim out_dims;
34+
out_dims = GetOutputDims(x_dims, y_dims);
35+
if (w_dims.size() > 1 || w_dims[0] != 1) {
36+
out_dims = GetOutputDims(out_dims, w_dims);
37+
}
38+
39+
ctx->SetOutputDim("Out", out_dims);
40+
ctx->ShareLoD("X", /*->*/ "Out");
41+
}
42+
43+
private:
44+
framework::DDim GetOutputDims(const framework::DDim& s_dims,
45+
const framework::DDim& l_dims) const {
46+
if (s_dims.size() > l_dims.size()) {
47+
return GetOutputDims(l_dims, s_dims);
48+
}
49+
std::vector<int64_t> shapes = framework::vectorize<int64_t>(l_dims);
50+
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
51+
int64_t s = s_dims[i];
52+
int64_t l = l_dims[j];
53+
if (s != l) {
54+
if (l == 1) {
55+
shapes[j] = s;
56+
} else if (s != 1) {
57+
PADDLE_THROW(platform::errors::InvalidArgument(
58+
"The shape of tensor a %s:%d must match shape of tensor b "
59+
"%s:%d.",
60+
s_dims.to_str(), i, l_dims.to_str(), j));
61+
}
62+
}
63+
}
64+
return framework::make_ddim(shapes);
65+
}
66+
};
67+
68+
class LerpOpMaker : public framework::OpProtoAndCheckerMaker {
69+
public:
70+
void Make() override {
71+
AddInput("X", "(Tensor), The input tensor of lerp op.");
72+
AddInput("Y", "(Tensor), The input tensor of lerp op.");
73+
AddInput("Weight", "(Tensor, optional), The input tensor of lerp op.");
74+
AddOutput("Out", "(Tensor), The output tensor of lerp op.");
75+
AddComment(R"DOC(
76+
Lerp Operator.
77+
78+
This operator is used to do a linear interpolation of input $X$ and $Y$ with $Weight$.
79+
80+
The equation is:
81+
82+
$$Out = X + Weight * (Y - X)$$
83+
84+
Both the input $X$ and $Y$ can carry the LoD (Level of Details) information,
85+
or not. But the output only shares the LoD information with input $X$.
86+
87+
)DOC");
88+
}
89+
};
90+
91+
class LerpGradOp : public framework::OperatorWithKernel {
92+
public:
93+
using framework::OperatorWithKernel::OperatorWithKernel;
94+
95+
void InferShape(framework::InferShapeContext* ctx) const override {
96+
if (ctx->HasOutput(framework::GradVarName("X"))) {
97+
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
98+
}
99+
if (ctx->HasOutput(framework::GradVarName("Y"))) {
100+
ctx->SetOutputDim(framework::GradVarName("Y"), ctx->GetInputDim("Y"));
101+
}
102+
}
103+
};
104+
105+
template <typename T>
106+
class LerpOpGradMaker : public framework::SingleGradOpMaker<T> {
107+
public:
108+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
109+
110+
void Apply(GradOpPtr<T> op) const override {
111+
op->SetType("lerp_grad");
112+
op->SetInput("X", this->Input("X"));
113+
op->SetInput("Y", this->Input("Y"));
114+
op->SetInput("Weight", this->Input("Weight"));
115+
op->SetInput("Out", this->Output("Out"));
116+
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
117+
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
118+
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
119+
op->SetAttrMap(this->Attrs());
120+
}
121+
};
122+
123+
DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"});
124+
125+
} // namespace operators
126+
} // namespace paddle
127+
128+
REGISTER_OPERATOR(
129+
lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker,
130+
paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>,
131+
paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>,
132+
paddle::operators::LerpInplaceInferer);
133+
134+
REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);
135+
136+
REGISTER_OP_CPU_KERNEL(
137+
lerp,
138+
paddle::operators::LerpKernel<paddle::platform::CPUDeviceContext, float>,
139+
paddle::operators::LerpKernel<paddle::platform::CPUDeviceContext, double>);
140+
141+
REGISTER_OP_CPU_KERNEL(
142+
lerp_grad,
143+
paddle::operators::LerpGradKernel<paddle::platform::CPUDeviceContext,
144+
float>,
145+
paddle::operators::LerpGradKernel<paddle::platform::CPUDeviceContext,
146+
double>);

paddle/fluid/operators/lerp_op.cu

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/lerp_op.h"
16+
17+
REGISTER_OP_CUDA_KERNEL(
18+
lerp,
19+
paddle::operators::LerpKernel<paddle::platform::CUDADeviceContext, float>,
20+
paddle::operators::LerpKernel<paddle::platform::CUDADeviceContext, double>);
21+
22+
REGISTER_OP_CUDA_KERNEL(
23+
lerp_grad,
24+
paddle::operators::LerpGradKernel<paddle::platform::CUDADeviceContext,
25+
float>,
26+
paddle::operators::LerpGradKernel<paddle::platform::CUDADeviceContext,
27+
double>);

paddle/fluid/operators/lerp_op.h

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/framework/operator.h"
18+
19+
#ifdef _WIN32
20+
#ifndef NOMINMAX
21+
#define NOMINMAX // msvc max/min macro conflict with std::min/max
22+
#endif
23+
#endif
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
static framework::DDim ExtendDims2Rank(const framework::DDim& in_dims,
29+
int rank) {
30+
if (in_dims.size() == rank) {
31+
return in_dims;
32+
}
33+
std::vector<int64_t> shapes(rank, 1);
34+
for (int i = in_dims.size() - 1, j = rank - 1; i >= 0; --i, --j) {
35+
shapes[j] = in_dims[i];
36+
}
37+
return framework::make_ddim(shapes);
38+
}
39+
40+
template <size_t D>
41+
static void GetBroadcastDims(const framework::DDim& in_dims,
42+
const framework::DDim& out_dims,
43+
Eigen::DSizes<int, D>* bcast_dims) {
44+
for (size_t i = 0; i < D; ++i) {
45+
if (in_dims[i] == out_dims[i]) {
46+
(*bcast_dims)[i] = 1;
47+
} else {
48+
(*bcast_dims)[i] = std::max(in_dims[i], out_dims[i]);
49+
}
50+
}
51+
}
52+
53+
template <typename DeviceContext, typename T, size_t D>
54+
static void LerpFunction(const framework::ExecutionContext& ctx) {
55+
auto x = ctx.Input<framework::Tensor>("X");
56+
auto y = ctx.Input<framework::Tensor>("Y");
57+
auto w = ctx.Input<framework::Tensor>("Weight");
58+
auto out = ctx.Output<framework::Tensor>("Out");
59+
out->mutable_data<T>(ctx.GetPlace());
60+
61+
auto out_dims = out->dims();
62+
auto x_dims = ExtendDims2Rank(x->dims(), D);
63+
auto y_dims = ExtendDims2Rank(y->dims(), D);
64+
auto w_dims = ExtendDims2Rank(w->dims(), D);
65+
Eigen::DSizes<int, D> x_bcast_dims;
66+
Eigen::DSizes<int, D> y_bcast_dims;
67+
Eigen::DSizes<int, D> w_bcast_dims;
68+
GetBroadcastDims<D>(x_dims, out_dims, &x_bcast_dims);
69+
GetBroadcastDims<D>(y_dims, out_dims, &y_bcast_dims);
70+
GetBroadcastDims<D>(w_dims, out_dims, &w_bcast_dims);
71+
72+
auto eigen_x = framework::EigenTensor<T, D>::From(*x, x_dims);
73+
auto eigen_y = framework::EigenTensor<T, D>::From(*y, y_dims);
74+
auto eigen_w = framework::EigenTensor<T, D>::From(*w, w_dims);
75+
auto eigen_out = framework::EigenTensor<T, D>::From(*out);
76+
77+
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
78+
eigen_out.device(place) =
79+
eigen_x.broadcast(x_bcast_dims) +
80+
eigen_w.broadcast(w_bcast_dims) *
81+
(eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims));
82+
}
83+
84+
template <typename DeviceContext, typename T, size_t D>
85+
static void LerpGradFunction(const framework::ExecutionContext& ctx) {
86+
auto w = ctx.Input<framework::Tensor>("Weight");
87+
auto dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
88+
auto dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
89+
auto dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
90+
91+
auto dout_dims = dout->dims();
92+
auto dx_dims = ExtendDims2Rank(dx->dims(), D);
93+
auto dy_dims = ExtendDims2Rank(dy->dims(), D);
94+
auto w_dims = ExtendDims2Rank(w->dims(), D);
95+
Eigen::DSizes<int, D> dx_bcast_dims;
96+
Eigen::DSizes<int, D> dy_bcast_dims;
97+
Eigen::DSizes<int, D> w_bcast_dims;
98+
GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
99+
GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims);
100+
GetBroadcastDims<D>(w_dims, dout_dims, &w_bcast_dims);
101+
102+
auto eigen_w = framework::EigenTensor<T, D>::From(*w, w_dims);
103+
auto eigen_dout = framework::EigenTensor<T, D>::From(*dout);
104+
105+
Eigen::DSizes<int, D * 2> dx_reshape_dims;
106+
Eigen::DSizes<int, D * 2> dy_reshape_dims;
107+
Eigen::DSizes<int, D> reduce_dims;
108+
for (int i = 0; i < dout_dims.size(); ++i) {
109+
dx_reshape_dims[2 * i] = dx_bcast_dims[i];
110+
dx_reshape_dims[2 * i + 1] = dx_dims[i];
111+
dy_reshape_dims[2 * i] = dy_bcast_dims[i];
112+
dy_reshape_dims[2 * i + 1] = dy_dims[i];
113+
reduce_dims[i] = 2 * i;
114+
}
115+
116+
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
117+
118+
if (dx) {
119+
dx->mutable_data<T>(ctx.GetPlace());
120+
auto eigen_dx = framework::EigenTensor<T, D>::From(*dx, dx_dims);
121+
auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout;
122+
eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims)
123+
.sum(reduce_dims)
124+
.reshape(eigen_dx.dimensions());
125+
}
126+
if (dy) {
127+
dy->mutable_data<T>(ctx.GetPlace());
128+
auto eigen_dy = framework::EigenTensor<T, D>::From(*dy, dy_dims);
129+
auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout;
130+
eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims)
131+
.sum(reduce_dims)
132+
.reshape(eigen_dy.dimensions());
133+
}
134+
}
135+
136+
template <typename DeviceContext, typename T>
137+
class LerpKernel : public framework::OpKernel<T> {
138+
public:
139+
void Compute(const framework::ExecutionContext& ctx) const override {
140+
int rank = ctx.Output<framework::Tensor>("Out")->dims().size();
141+
PADDLE_ENFORCE_GE(
142+
rank, 1,
143+
platform::errors::InvalidArgument(
144+
"The number of dimensions for LerpOp must be "
145+
"greater than or equal to 1, but the value received is %d.",
146+
rank));
147+
PADDLE_ENFORCE_LE(
148+
rank, 6, platform::errors::InvalidArgument(
149+
"The number of dimensions for LerpOp must be "
150+
"less than or equal to 6, but the value received is %d.",
151+
rank));
152+
switch (rank) {
153+
case 1:
154+
LerpFunction<DeviceContext, T, 1>(ctx);
155+
break;
156+
case 2:
157+
LerpFunction<DeviceContext, T, 2>(ctx);
158+
break;
159+
case 3:
160+
LerpFunction<DeviceContext, T, 3>(ctx);
161+
break;
162+
case 4:
163+
LerpFunction<DeviceContext, T, 4>(ctx);
164+
break;
165+
case 5:
166+
LerpFunction<DeviceContext, T, 5>(ctx);
167+
break;
168+
case 6:
169+
LerpFunction<DeviceContext, T, 6>(ctx);
170+
break;
171+
}
172+
}
173+
};
174+
175+
template <typename DeviceContext, typename T>
176+
class LerpGradKernel : public framework::OpKernel<T> {
177+
public:
178+
void Compute(const framework::ExecutionContext& ctx) const override {
179+
int rank = ctx.Input<framework::Tensor>(framework::GradVarName("Out"))
180+
->dims()
181+
.size();
182+
PADDLE_ENFORCE_GE(
183+
rank, 1,
184+
platform::errors::InvalidArgument(
185+
"The number of dimensions for LerpGradOp must be "
186+
"greater than or equal to 1, but the value received is %d.",
187+
rank));
188+
PADDLE_ENFORCE_LE(
189+
rank, 6, platform::errors::InvalidArgument(
190+
"The number of dimensions for LerpGradOp must be "
191+
"less than or equal to 6, but the value received is %d.",
192+
rank));
193+
switch (rank) {
194+
case 1:
195+
LerpGradFunction<DeviceContext, T, 1>(ctx);
196+
break;
197+
case 2:
198+
LerpGradFunction<DeviceContext, T, 2>(ctx);
199+
break;
200+
case 3:
201+
LerpGradFunction<DeviceContext, T, 3>(ctx);
202+
break;
203+
case 4:
204+
LerpGradFunction<DeviceContext, T, 4>(ctx);
205+
break;
206+
case 5:
207+
LerpGradFunction<DeviceContext, T, 5>(ctx);
208+
break;
209+
case 6:
210+
LerpGradFunction<DeviceContext, T, 6>(ctx);
211+
break;
212+
}
213+
}
214+
};
215+
216+
} // namespace operators
217+
} // namespace paddle

0 commit comments

Comments
 (0)