Skip to content

Commit 1213149

Browse files
adamlererapaszke
authored andcommitted
add bias option to linear; allow modules to return nested lists/tuples of tensors (pytorch#106)
* add bias option to linear; allow modules to return nested lists/tuples of tensors
1 parent 398b6f7 commit 1213149

File tree

6 files changed

+43
-27
lines changed

6 files changed

+43
-27
lines changed

test/common_nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
input_size=(4, 10),
2525
reference_fn=lambda i,p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)
2626
),
27+
dict(
28+
module_name='Linear',
29+
constructor_args=(10, 8, False),
30+
input_size=(4, 10),
31+
desc='no_bias',
32+
reference_fn=lambda i,p: torch.mm(i, p[0].t())
33+
),
2734
dict(
2835
module_name='Threshold',
2936
constructor_args=(2, 1),

test/test_nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def test_data_parallel(self):
498498
def test_parameter_dict(self):
499499
l = nn.Linear(5, 5)
500500
block = nn.Container(
501-
conv=nn.Conv2d(3, 3, 3, no_bias=True)
501+
conv=nn.Conv2d(3, 3, 3, bias=False)
502502
)
503503
net = nn.Container(
504504
linear1=l,
@@ -530,7 +530,7 @@ def test_parameter_dict(self):
530530
def test_load_parameter_dict(self):
531531
l = nn.Linear(5, 5)
532532
block = nn.Container(
533-
conv=nn.Conv2d(3, 3, 3, no_bias=True)
533+
conv=nn.Conv2d(3, 3, 3, bias=False)
534534
)
535535
net = nn.Container(
536536
linear1=l,
@@ -606,7 +606,7 @@ def add_test(test):
606606
),
607607
dict(
608608
module_name='Conv2d',
609-
constructor_args=(3, 4, (3, 3), 1, 0, None, 1, True),
609+
constructor_args=(3, 4, (3, 3), 1, 0, None, 1, False),
610610
input_size=(2, 3, 6, 6),
611611
desc='no_bias',
612612
),

torch/nn/functions/linear.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@ def backward(self, grad_output):
2424
bias = None
2525
else:
2626
input, weight, bias = tensors
27-
grad_tuple = (
28-
torch.mm(grad_output, weight) if \
29-
self.needs_input_grad[0] else None,
30-
torch.mm(grad_output.t(), input) if \
31-
self.needs_input_grad[1] else None,
32-
torch.mv(grad_output.t(), self.add_buffer) if \
33-
bias is not None and self.needs_input_grad[2] else None,
34-
)
35-
return grad_tuple
27+
28+
grad_input = grad_weight = grad_bias = None
29+
if self.needs_input_grad[0]:
30+
grad_input = torch.mm(grad_output, weight)
31+
if self.needs_input_grad[1]:
32+
grad_weight = torch.mm(grad_output.t(), input)
33+
if bias is not None and self.needs_input_grad[2]:
34+
grad_bias = torch.mv(grad_output.t(), self.add_buffer)
35+
36+
if bias is not None:
37+
return grad_input, grad_weight, grad_bias
38+
else:
39+
return grad_input, grad_weight
3640

torch/nn/modules/conv.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class Conv2d(Module):
9191
stride: the stride of the convolving kernel. Can be a single number s or a tuple (sh x sw). Default: 1
9292
padding: implicit zero padding on the input. Can be a single number s or a tuple. Default: 0
9393
dilation: If given, will do dilated (or atrous) convolutions. Can be a single number s or a tuple. Default: None
94-
no_bias: If set to true, the layer will not learn an additive bias. Default: False
94+
bias: If set to False, the layer will not learn an additive bias. Default: True
9595
Input Shape: [ * , in_channels , * , * ] : Input is minibatch x in_channels x iH x iW
9696
Output Shape:[ * , out_channels , * , * ] : Output shape is precisely minibatch x out_channels x floor((iH + 2*padH - kH) / dH + 1) x floor((iW + 2*padW - kW) / dW + 1)
9797
Members:
@@ -108,7 +108,7 @@ class Conv2d(Module):
108108
>>> output = m(input)
109109
"""
110110
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
111-
padding=0, dilation=None, groups=1, no_bias=False):
111+
padding=0, dilation=None, groups=1, bias=True):
112112
self.in_channels = in_channels
113113
self.out_channels = out_channels
114114
self.kh, self.kw = _pair(kernel_size)
@@ -121,7 +121,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
121121

122122
weight = torch.Tensor(self.out_channels, self.in_channels, self.kh,
123123
self.kw)
124-
bias = None if no_bias else torch.Tensor(self.out_channels)
124+
bias = torch.Tensor(self.out_channels) if bias else None
125125
super(Conv2d, self).__init__(
126126
weight=weight,
127127
bias=bias,
@@ -166,7 +166,7 @@ class FullConv2d(Conv2d):
166166
stride: the stride of the convolving kernel. Can be a single number or a tuple (sh x sw). Default: 1
167167
padding: implicit zero padding on the input. Can be a single number or a tuple. Default: 0
168168
output_padding: A padding of 0 or 1 pixels that should be added to the output. Can be a single number or a tuple. Default: 0
169-
no_bias: If set to true, the layer will not learn an additive bias. Default: False
169+
bias: If set to False, the layer will not learn an additive bias. Default: True
170170
Input Shape: [ * , in_channels , * , * ] : Input is minibatch x in_channels x iH x iW
171171
Output Shape:[ * , out_channels , * , * ] : Output shape is precisely minibatch x out_channels x (iH - 1) * sH - 2*padH + kH + output_paddingH x (iW - 1) * sW - 2*padW + kW
172172
Members:
@@ -181,9 +181,9 @@ class FullConv2d(Conv2d):
181181
>>> output = m(input)
182182
"""
183183
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
184-
padding=0, output_padding=0, no_bias=False):
184+
padding=0, output_padding=0, bias=True):
185185
super(FullConv2d, self).__init__(in_channels, out_channels, kernel_size,
186-
stride, padding, no_bias)
186+
stride, padding, bias)
187187
self.out_padh, self.out_padw = _pair(output_padding)
188188

189189
def forward(self, input):

torch/nn/modules/linear.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class Linear(Module):
1414
Args:
1515
in_features: size of each input sample
1616
out_features: size of each output sample
17+
bias: If set to False, the layer will not learn an additive bias. Default: True
1718
Input Shape: [*, in_features] : Input can be of shape minibatch x in_features
1819
Output Shape:[*, out_features] : Output is of shape minibatch x out_features
1920
Members:
@@ -25,23 +26,27 @@ class Linear(Module):
2526
>>> output = m(input)
2627
>>> print(output.size())
2728
"""
28-
def __init__(self, in_features, out_features):
29+
def __init__(self, in_features, out_features, bias=True):
2930
self.in_features = in_features
3031
self.out_features = out_features
3132

3233
super(Linear, self).__init__(
3334
weight=torch.Tensor(out_features, in_features),
34-
bias=torch.Tensor(out_features)
35+
bias=torch.Tensor(out_features) if bias else None
3536
)
3637
self.reset_parameters()
3738

3839
def reset_parameters(self):
3940
stdv = 1./math.sqrt(self.weight.size(1))
4041
self.weight.data.uniform_(-stdv, stdv)
41-
self.bias.data.uniform_(-stdv, stdv)
42+
if self.bias is not None:
43+
self.bias.data.uniform_(-stdv, stdv)
4244

4345
def forward(self, input):
44-
return self._backend.Linear()(input, self.weight, self.bias)
46+
if self.bias is None:
47+
return self._backend.Linear()(input, self.weight)
48+
else:
49+
return self._backend.Linear()(input, self.weight, self.bias)
4550

4651

4752
# TODO: Bilinear

torch/nn/modules/module.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ def __call__(self, *input):
7373
result = self.forward(*input)
7474
for hook in self.forward_hooks.values():
7575
hook(self, input, result)
76-
if isinstance(result, tuple):
77-
fn = result[0].creator
78-
else:
79-
fn = result.creator
76+
var = result
77+
while not isinstance(var, Variable):
78+
var= var[0]
79+
creator = var.creator
8080
for key, hook in self.backward_hooks.items():
81-
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
81+
creator.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
8282
return result
8383

8484
def __getattr__(self, name):

0 commit comments

Comments
 (0)