Skip to content

Commit e514713

Browse files
committed
Change args in conv_transpose1d
test=develop
1 parent 3d21a3f commit e514713

File tree

6 files changed

+195
-468
lines changed

6 files changed

+195
-468
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 0 additions & 266 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
'inplace_abn',
5959
'instance_norm',
6060
'data_norm',
61-
'conv1d_transpose',
6261
'conv2d_transpose',
6362
'conv3d_transpose',
6463
'reduce_sum',
@@ -3728,271 +3727,6 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
37283727
return out
37293728

37303729

3731-
def conv1d_transpose(input,
3732-
num_filters,
3733-
output_size=None,
3734-
filter_size=None,
3735-
padding=0,
3736-
stride=1,
3737-
dilation=1,
3738-
groups=None,
3739-
param_attr=None,
3740-
bias_attr=None,
3741-
use_cudnn=True,
3742-
act=None,
3743-
name=None):
3744-
"""
3745-
:api_attr: Static Graph
3746-
3747-
The convolution1D transpose layer calculates the output based on the input,
3748-
filter, and dilation, stride, padding. Input(Input) and output(Output)
3749-
are in NCL format where N is batch size, C is the number of channels,
3750-
L is the length of the feature. The details of convolution transpose
3751-
layer, please refer to the following explanation and references
3752-
`therein <https://arxiv.org/pdf/1603.07285.pdf>`_.
3753-
If bias attribution and activation type are provided, bias is added to
3754-
the output of the convolution, and the corresponding activation function
3755-
is applied to the final result.
3756-
3757-
For each input :math:`X`, the equation is:
3758-
3759-
.. math::
3760-
3761-
Out = \sigma (W \\ast X + b)
3762-
3763-
Where:
3764-
3765-
* :math:`X`: Input value, a 3-D Tensor with NCL format.
3766-
* :math:`W`: Filter value, a 3-D Tensor with MCL format.
3767-
* :math:`\\ast`: Convolution operation.
3768-
* :math:`b`: Bias value, a 2-D Tensor with shape [M, 1].
3769-
* :math:`\\sigma`: Activation function.
3770-
* :math:`Out`: Output value, a 3-D Tensor with data format 'NCL', the shape of :math:`Out` and :math:`X` may be different.
3771-
3772-
Example:
3773-
3774-
- Input:
3775-
3776-
Input shape: :math:`(N, C_{in}, L_{in})`
3777-
3778-
Filter shape: :math:`(C_{in}, C_{out}, L_f)`
3779-
3780-
- Output:
3781-
3782-
Output shape: :math:`(N, C_{out}, L_{out})`
3783-
3784-
Where
3785-
3786-
.. math::
3787-
3788-
L^\prime_{out} &= (L_{in} - 1) * stride - pad_top - pad_bottom + dilation * (L_f - 1) + 1 \\\\
3789-
L_{out} &\in [ L^\prime_{out}, L^\prime_{out} + stride ]
3790-
3791-
Note:
3792-
The conv1d_transpose can be seen as the backward of the conv1d. For conv1d,
3793-
when stride > 1, conv1d maps multiple input shape to the same output shape,
3794-
so for conv2d_transpose, when stride > 1, input shape maps multiple output shape.
3795-
If output_size is None, :math:`L_{out} = L^\prime_{out}`;
3796-
else, the :math:`L_{out}` of the output size must between :math:`L^\prime_{out}`
3797-
and :math:`L^\prime_{out} + stride`. conv2d_transpose can compute the kernel size automatically.
3798-
3799-
Args:
3800-
input(Variable): 3-D Tensor with [N, C, L] format,
3801-
its data type is float32 or float64.
3802-
num_filters(int): The number of the filter. It is as same as the output
3803-
image channel.
3804-
output_size(int|tuple, optional): The output image size. If output size is a
3805-
tuple, it must contain one integer, (feature_length). None if use
3806-
filter_size, padding, and stride to calculate output_size.
3807-
If output_size and filter_size are specified at the same time, They
3808-
should follow the formula above. Default: None. output_size and filter_size
3809-
should not be None at the same time.
3810-
filter_size(int|tuple, optional): The filter size. If filter_size is a tuple,
3811-
it must contain one integers, (filter_size). None if
3812-
use output size to calculate filter_size. Default: None. filter_size and
3813-
output_size should not be None at the same time.
3814-
stride(int|tuple, optional): The stride size. It means the stride in transposed convolution.
3815-
If stride is a tuple, it must contain one integer, (stride_size).
3816-
Default: stride = 1.
3817-
padding(int|list|str|tuple, optional): The padding size. The padding argument effectively adds
3818-
`dilation * (kernel - 1)` amount of zero-padding on both sides of input. If `padding` is a
3819-
string, either 'VALID' or 'SAME' supported, which is the padding algorithm.
3820-
If `padding` is a tuple or list, it could be in three forms:
3821-
`[pad]` or `[pad_top, pad_bottom]`, and `[[0,0], [0,0], [pad_height_top, pad_height_bottom],
3822-
[pad_width_left, pad_width_right]]`. Default: padding = 0.
3823-
dilation(int|tuple, optional): The dilation size. It means the spacing between the kernel points.
3824-
If dilation is a tuple, it must contain one integers, (dilation_size).
3825-
Default: dilation = 1.
3826-
filter_size(int|tuple, optional): The filter size. If filter_size is a tuple,
3827-
it must contain one integers, (filter_size). None if
3828-
use output size to calculate filter_size. Default: None.
3829-
groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by
3830-
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
3831-
when group=2, the first half of the filters is only connected to the
3832-
first half of the input channels, while the second half of the
3833-
filters is only connected to the second half of the input channels.
3834-
Default: groups = 1.
3835-
param_attr (ParamAttr, optional): The parameter attribute for learnable parameters/weights
3836-
of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
3837-
will create ParamAttr as param_attr. If the Initializer of the param_attr
3838-
is not set, the parameter is initialized with Xavier. Default: None.
3839-
bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of conv2d_transpose.
3840-
If it is set to False, no bias will be added to the output units.
3841-
If it is set to None or one attribute of ParamAttr, conv2d_transpose
3842-
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
3843-
is not set, the bias is initialized zero. Default: None.
3844-
use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
3845-
library is installed. Default: True.
3846-
act (str, optional): Activation type, if it is set to None, activation is not appended.
3847-
Default: None.
3848-
name(str, optional): For detailed information, please refer
3849-
to :ref:`api_guide_Name`. Usually name is no need to set and
3850-
None by default.
3851-
3852-
Returns:
3853-
A Variable holding Tensor representing the conv1d_transpose, whose
3854-
data type is the same with input and shape is (num_batches, channels, length).
3855-
If act is None, the tensor variable storing the transposed convolution result,
3856-
and if act is not None, the tensor variable storing transposed convolution and
3857-
non-linearity activation result.
3858-
3859-
Raises:
3860-
ValueError: If the type of `use_cudnn` is not bool.
3861-
ValueError: If `padding` is a string, but not "SAME" or "VALID".
3862-
ValueError: If `padding` is a tuple, but the element corresponding to the input's batch size is not 0
3863-
or the element corresponding to the input's channel is not 0.
3864-
ValueError: If `output_size` and filter_size are None at the same time.
3865-
ShapeError: If the input is not 3-D Tensor.
3866-
ShapeError: If the input's dimension size and filter's dimension size not equal.
3867-
ShapeError: If the dimension size of input minus the size of `stride` is not 1.
3868-
ShapeError: If the number of input channels is not equal to filter's channels.
3869-
ShapeError: If the size of `output_size` is not equal to that of `stride`.
3870-
3871-
Examples:
3872-
.. code-block:: python
3873-
3874-
import paddle.fluid as fluid
3875-
data = fluid.data(name='data', shape=[None, 3, 32], dtype='float32')
3876-
conv2d_transpose = fluid.layers.conv1d_transpose(input=data, num_filters=2, filter_size=3)
3877-
"""
3878-
assert param_attr is not False, "param_attr should not be False in conv1d_transpose."
3879-
3880-
input_channel = input.shape[1]
3881-
op_type = 'conv2d_transpose'
3882-
if (input_channel == groups and num_filters == input_channel and
3883-
not use_cudnn):
3884-
op_type = 'depthwise_conv2d_transpose'
3885-
3886-
helper = LayerHelper(op_type, **locals())
3887-
if not isinstance(input, Variable):
3888-
raise TypeError("Input of conv1d_transpose must be Variable")
3889-
3890-
stride = utils.convert_to_list(stride, 1, 'stride') + [1]
3891-
dilation = utils.convert_to_list(dilation, 1, 'dilation') + [1]
3892-
3893-
if not isinstance(use_cudnn, bool):
3894-
raise ValueError("use_cudnn should be True or False")
3895-
3896-
def _update_padding(padding):
3897-
def is_list_or_tuple(ele):
3898-
if isinstance(ele, list) or isinstance(ele, tuple):
3899-
return True
3900-
return False
3901-
3902-
if is_list_or_tuple(padding) and len(padding) == 3:
3903-
if is_list_or_tuple(padding[0]):
3904-
if not (padding[0] == [0, 0] and padding[1] == [0, 0]):
3905-
raise ValueError(
3906-
"Non-zero padding(%s) in the batch or channel dimensions "
3907-
"is not supported." % str(padding))
3908-
padding = padding + [[0, 0]]
3909-
padding = padding[2:4]
3910-
padding = [ele for a_list in padding for ele in a_list]
3911-
else:
3912-
padding = padding + [0]
3913-
3914-
padding = utils.convert_to_list(padding, 4, 'padding')
3915-
else:
3916-
padding = utils.convert_to_list(padding, 1, 'padding') + [0]
3917-
padding = [padding[0], padding[0], padding[1], padding[1]]
3918-
return padding
3919-
3920-
padding_algorithm = "EXPLICIT"
3921-
if isinstance(padding, str):
3922-
padding = padding.upper()
3923-
if padding not in ["SAME", "VALID"]:
3924-
raise ValueError(
3925-
"Unknown padding: '%s'. It can only be 'SAME' or 'VALID'." %
3926-
str(padding))
3927-
if padding == "VALID":
3928-
padding_algorithm = "VALID"
3929-
padding = [0]
3930-
elif padding == "SAME":
3931-
padding_algorithm = "SAME"
3932-
padding = [0]
3933-
3934-
padding = _update_padding(padding)
3935-
3936-
input = unsqueeze(input=input, axes=[-1])
3937-
3938-
if filter_size is None:
3939-
if output_size is None:
3940-
raise ValueError("output_size must be set when filter_size is None")
3941-
if isinstance(output_size, int):
3942-
output_size = [output_size, 1]
3943-
3944-
h_in = input.shape[2]
3945-
w_in = input.shape[3]
3946-
3947-
filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + padding[0] +
3948-
padding[1] - 1) // dilation[0] + 1
3949-
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + padding[2] +
3950-
padding[3] - 1) // dilation[1] + 1
3951-
filter_size = [filter_size_h, filter_size_w]
3952-
else:
3953-
filter_size = utils.convert_to_list(
3954-
filter_size, 1, 'conv2d_transpose.filter_size') + [1]
3955-
3956-
if len(padding) == 4 and utils._is_symmetric_padding(padding, 2):
3957-
padding = [padding[0], padding[2]]
3958-
3959-
if output_size is None:
3960-
output_size = []
3961-
elif isinstance(output_size, (list, tuple, int)):
3962-
output_size = utils.convert_to_list(output_size, 1, 'output_size') + [1]
3963-
else:
3964-
raise ValueError("output_size should be int, list[int] or tuple[int]")
3965-
groups = 1 if groups is None else groups
3966-
filter_shape = [input_channel, num_filters // groups] + filter_size
3967-
3968-
img_filter = helper.create_parameter(
3969-
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
3970-
3971-
img_filter = reshape(img_filter, shape=filter_shape)
3972-
3973-
pre_bias = helper.create_variable_for_type_inference(dtype=input.dtype)
3974-
helper.append_op(
3975-
type=op_type,
3976-
inputs={'Input': [input],
3977-
'Filter': [img_filter]},
3978-
outputs={'Output': pre_bias},
3979-
attrs={
3980-
'output_size': output_size,
3981-
'strides': stride,
3982-
'paddings': padding,
3983-
'padding_algorithm': padding_algorithm,
3984-
'dilations': dilation,
3985-
'groups': groups,
3986-
'use_cudnn': use_cudnn,
3987-
'data_format': "NCHW"
3988-
})
3989-
3990-
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
3991-
out = helper.append_activation(pre_act)
3992-
out = squeeze(input=out, axes=[-1])
3993-
return out
3994-
3995-
39963730
def conv2d_transpose(input,
39973731
num_filters,
39983732
output_size=None,

python/paddle/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
from .layer.common import Linear #DEFINE_ALIAS
7979
from .layer.common import Flatten #DEFINE_ALIAS
8080
from .layer.common import UpSample #DEFINE_ALIAS
81-
from .layer.conv import Conv1DTranspose #DEFINE_ALIAS
81+
from .layer.conv import ConvTranspose1D #DEFINE_ALIAS
8282
from .layer.conv import Conv2D #DEFINE_ALIAS
8383
from .layer.conv import Conv2DTranspose #DEFINE_ALIAS
8484
from .layer.conv import Conv3D #DEFINE_ALIAS

python/paddle/nn/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
# from .common import bilinear_tensor_product #DEFINE_ALIAS
6464
from .common import assign #DEFINE_ALIAS
6565
from .common import interpolate #DEFINE_ALIAS
66-
from .conv import conv1d_transpose #DEFINE_ALIAS
66+
from .conv import conv_transpose1d #DEFINE_ALIAS
6767
from .conv import conv2d #DEFINE_ALIAS
6868
from .conv import conv2d_transpose #DEFINE_ALIAS
6969
from .conv import conv3d #DEFINE_ALIAS

0 commit comments

Comments
 (0)