Skip to content

Commit dea7d7d

Browse files
【2.0 API】Add conv1d_transpose API (#26356)
1 parent 7bd7b18 commit dea7d7d

File tree

6 files changed

+674
-1
lines changed

6 files changed

+674
-1
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import paddle
17+
from paddle import fluid, nn
18+
import paddle.fluid.dygraph as dg
19+
import paddle.nn.functional as F
20+
import paddle.fluid.initializer as I
21+
import unittest
22+
23+
24+
class ConvTranspose1dTestCase(unittest.TestCase):
25+
def __init__(self,
26+
methodName='runTest',
27+
batch_size=4,
28+
spartial_shape=16,
29+
in_channels=6,
30+
out_channels=8,
31+
filter_size=3,
32+
output_size=None,
33+
padding=0,
34+
output_padding=0,
35+
stride=1,
36+
dilation=1,
37+
groups=1,
38+
no_bias=False,
39+
data_format="NCL",
40+
dtype="float32"):
41+
super(ConvTranspose1dTestCase, self).__init__(methodName)
42+
self.batch_size = batch_size
43+
self.in_channels = in_channels
44+
self.out_channels = out_channels
45+
self.spartial_shape = spartial_shape
46+
self.filter_size = filter_size
47+
self.output_size = output_size
48+
49+
self.padding = padding
50+
self.output_padding = output_padding
51+
self.stride = stride
52+
self.dilation = dilation
53+
self.groups = groups
54+
self.no_bias = no_bias
55+
self.data_format = data_format
56+
self.dtype = dtype
57+
58+
def setUp(self):
59+
60+
self.channel_last = False if self.data_format == "NCL" else True
61+
input_shape = (self.batch_size, self.in_channels,
62+
self.spartial_shape) if not self.channel_last else (
63+
self.batch_size,
64+
self.spartial_shape,
65+
self.in_channels, )
66+
self.input = np.random.randn(*input_shape).astype(self.dtype)
67+
68+
if isinstance(self.filter_size, int):
69+
filter_size = [self.filter_size]
70+
else:
71+
filter_size = self.filter_size
72+
self.weight_shape = weight_shape = (self.in_channels, self.out_channels
73+
// self.groups) + tuple(filter_size)
74+
self.weight = np.random.uniform(
75+
-1, 1, size=weight_shape).astype(self.dtype)
76+
if not self.no_bias:
77+
self.bias = np.random.uniform(
78+
-1, 1, size=(self.out_channels, )).astype(self.dtype)
79+
else:
80+
self.bias = None
81+
82+
def functional(self, place):
83+
main = fluid.Program()
84+
start = fluid.Program()
85+
with fluid.unique_name.guard():
86+
with fluid.program_guard(main, start):
87+
input_shape = (-1, self.in_channels,
88+
-1) if not self.channel_last else (
89+
-1, -1, self.in_channels)
90+
x_var = fluid.data("input", input_shape, dtype=self.dtype)
91+
w_var = fluid.data(
92+
"weight", self.weight_shape, dtype=self.dtype)
93+
b_var = fluid.data(
94+
"bias", (self.out_channels, ), dtype=self.dtype)
95+
y_var = F.conv_transpose1d(
96+
x_var,
97+
w_var,
98+
None if self.no_bias else b_var,
99+
output_size=self.output_size,
100+
padding=self.padding,
101+
output_padding=self.output_padding,
102+
stride=self.stride,
103+
dilation=self.dilation,
104+
groups=self.groups,
105+
data_format=self.data_format)
106+
feed_dict = {"input": self.input, "weight": self.weight}
107+
if self.bias is not None:
108+
feed_dict["bias"] = self.bias
109+
exe = fluid.Executor(place)
110+
exe.run(start)
111+
y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var])
112+
return y_np
113+
114+
def paddle_nn_layer(self):
115+
x_var = paddle.to_tensor(self.input)
116+
conv = nn.ConvTranspose1d(
117+
self.in_channels,
118+
self.out_channels,
119+
self.filter_size,
120+
padding=self.padding,
121+
output_padding=self.output_padding,
122+
stride=self.stride,
123+
dilation=self.dilation,
124+
groups=self.groups,
125+
data_format=self.data_format)
126+
conv.weight.set_value(self.weight)
127+
if not self.no_bias:
128+
conv.bias.set_value(self.bias)
129+
y_var = conv(x_var, output_size=self.output_size)
130+
y_np = y_var.numpy()
131+
return y_np
132+
133+
def _test_equivalence(self, place):
134+
result1 = self.functional(place)
135+
with dg.guard(place):
136+
result2 = self.paddle_nn_layer()
137+
np.testing.assert_array_almost_equal(result1, result2)
138+
139+
def runTest(self):
140+
place = fluid.CPUPlace()
141+
self._test_equivalence(place)
142+
143+
if fluid.core.is_compiled_with_cuda():
144+
place = fluid.CUDAPlace(0)
145+
self._test_equivalence(place)
146+
147+
148+
class ConvTranspose1dErrorTestCase(ConvTranspose1dTestCase):
149+
def runTest(self):
150+
place = fluid.CPUPlace()
151+
with dg.guard(place):
152+
with self.assertRaises(ValueError):
153+
self.paddle_nn_layer()
154+
155+
156+
def add_cases(suite):
157+
suite.addTest(ConvTranspose1dTestCase(methodName='runTest'))
158+
suite.addTest(
159+
ConvTranspose1dTestCase(
160+
methodName='runTest', stride=[2], no_bias=True, dilation=2))
161+
suite.addTest(
162+
ConvTranspose1dTestCase(
163+
methodName='runTest',
164+
filter_size=(3),
165+
output_size=[36],
166+
stride=[2],
167+
dilation=2))
168+
suite.addTest(
169+
ConvTranspose1dTestCase(
170+
methodName='runTest', stride=2, dilation=(2)))
171+
suite.addTest(
172+
ConvTranspose1dTestCase(
173+
methodName='runTest', padding="valid"))
174+
suite.addTest(
175+
ConvTranspose1dTestCase(
176+
methodName='runTest', padding='valid'))
177+
suite.addTest(
178+
ConvTranspose1dTestCase(
179+
methodName='runTest', filter_size=1, padding=3))
180+
suite.addTest(ConvTranspose1dTestCase(methodName='runTest', padding=[2]))
181+
suite.addTest(
182+
ConvTranspose1dTestCase(
183+
methodName='runTest', data_format="NLC"))
184+
suite.addTest(
185+
ConvTranspose1dTestCase(
186+
methodName='runTest', groups=2, padding="valid"))
187+
suite.addTest(
188+
ConvTranspose1dTestCase(
189+
methodName='runTest',
190+
out_channels=6,
191+
in_channels=3,
192+
groups=3,
193+
padding="valid"))
194+
suite.addTest(
195+
ConvTranspose1dTestCase(
196+
methodName='runTest',
197+
data_format="NLC",
198+
spartial_shape=16,
199+
output_size=18))
200+
suite.addTest(
201+
ConvTranspose1dTestCase(
202+
methodName='runTest', data_format="NLC", stride=3,
203+
output_padding=2))
204+
205+
206+
def add_error_cases(suite):
207+
suite.addTest(
208+
ConvTranspose1dErrorTestCase(
209+
methodName='runTest', data_format="not_valid"))
210+
suite.addTest(
211+
ConvTranspose1dErrorTestCase(
212+
methodName='runTest', in_channels=5, groups=2))
213+
suite.addTest(
214+
ConvTranspose1dErrorTestCase(
215+
methodName='runTest', stride=2, output_padding=3))
216+
suite.addTest(
217+
ConvTranspose1dErrorTestCase(
218+
methodName='runTest', output_size="not_valid"))
219+
220+
221+
def load_tests(loader, standard_tests, pattern):
222+
suite = unittest.TestSuite()
223+
add_cases(suite)
224+
add_error_cases(suite)
225+
return suite
226+
227+
228+
if __name__ == '__main__':
229+
unittest.main()

python/paddle/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
from .layer.conv import Conv1d #DEFINE_ALIAS
9898
from .layer.conv import Conv2d #DEFINE_ALIAS
9999
from .layer.conv import Conv3d #DEFINE_ALIAS
100+
from .layer.conv import ConvTranspose1d #DEFINE_ALIAS
100101
from .layer.conv import ConvTranspose2d #DEFINE_ALIAS
101102
from .layer.conv import ConvTranspose3d #DEFINE_ALIAS
102103
# from .layer.conv import TreeConv #DEFINE_ALIAS

python/paddle/nn/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from .common import assign #DEFINE_ALIAS
7272
from .common import interpolate #DEFINE_ALIAS
7373
from .conv import conv1d #DEFINE_ALIAS
74+
from .conv import conv_transpose1d #DEFINE_ALIAS
7475
from .conv import conv2d #DEFINE_ALIAS
7576
from .conv import conv_transpose2d #DEFINE_ALIAS
7677
from .conv import conv3d #DEFINE_ALIAS

0 commit comments

Comments
 (0)