Skip to content

Commit 75f1989

Browse files
uridahsoumith
authored andcommitted
Add nn.Bilinear and tests
1 parent e221536 commit 75f1989

File tree

5 files changed

+148
-4
lines changed

5 files changed

+148
-4
lines changed

test/test_nn.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch.nn.parallel as dp
1515
import torch.nn.init as init
1616
import torch.nn.utils.rnn as rnn_utils
17+
import torch.legacy.nn as legacy
1718
from torch.nn.utils import clip_grad_norm
1819
from torch.autograd import Variable, gradcheck
1920
from torch.nn import Parameter
@@ -2048,6 +2049,35 @@ def test_triplet_margin_swap_loss(self):
20482049
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
20492050
x1, x2, x3, swap=True), (input1, input2, input3)))
20502051

2052+
def test_bilinear(self):
2053+
module = nn.Bilinear(10, 10, 8)
2054+
module2 = legacy.Bilinear(10, 10, 8)
2055+
2056+
module2.weight.copy_(module.weight.data)
2057+
module2.bias.copy_(module.bias.data)
2058+
2059+
input1 = torch.randn(4, 10)
2060+
input2 = torch.randn(4, 10)
2061+
2062+
output = module(Variable(input1), Variable(input2))
2063+
output2 = module2.forward([input1, input2])
2064+
2065+
input1_1 = Variable(input1, requires_grad=True)
2066+
input2_1 = Variable(input2, requires_grad=True)
2067+
2068+
output3 = module(input1_1, input2_1)
2069+
grad = torch.randn(*output3.size())
2070+
output3.backward(grad)
2071+
gi1 = input1_1.grad.data.clone()
2072+
gi2 = input2_1.grad.data.clone()
2073+
2074+
self.assertEqual(output.data, output2)
2075+
self.assertEqual([gi1, gi2], output3)
2076+
2077+
def forward(x1, x2):
2078+
F.bilinear(x1, x2, module.weight, module.bias)
2079+
self.assertTrue(gradcheck(forward, (input1_1, input2_1)))
2080+
20512081

20522082
class TestNNInit(TestCase):
20532083
def setUp(self):

torch/nn/_functions/linear.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,56 @@ def backward(self, grad_output):
2929
return grad_input, grad_weight, grad_bias
3030
else:
3131
return grad_input, grad_weight
32+
33+
34+
class Bilinear(Function):
35+
36+
def forward(self, input1, input2, weight, bias=None):
37+
self.save_for_backward(input1, input2, weight, bias)
38+
39+
output = input1.new(input1.size(0), weight.size(0))
40+
41+
buff = input1.new()
42+
43+
# compute output scores:
44+
for k, w in enumerate(weight):
45+
torch.mm(input1, w, out=buff)
46+
buff.mul_(input2)
47+
torch.sum(buff, 1, out=output.narrow(1, k, 1))
48+
49+
if bias is not None:
50+
output.add_(bias.expand_as(output))
51+
52+
return output
53+
54+
def backward(self, grad_output):
55+
input1, input2, weight, bias = self.saved_tensors
56+
grad_input1 = grad_input2 = grad_weight = grad_bias = None
57+
58+
buff = input1.new()
59+
60+
if self.needs_input_grad[0] or self.needs_input_grad[1]:
61+
grad_input1 = torch.mm(input2, weight[0].t())
62+
grad_input1.mul_(grad_output.narrow(1, 0, 1).expand(grad_input1.size()))
63+
grad_input2 = torch.mm(input1, weight[0])
64+
grad_input2.mul_(grad_output.narrow(1, 0, 1).expand(grad_input2.size()))
65+
66+
for k in range(1, weight.size(0)):
67+
torch.mm(input2, weight[k].t(), out=buff)
68+
buff.mul_(grad_output.narrow(1, k, 1).expand(grad_input1.size()))
69+
grad_input1.add_(buff)
70+
71+
torch.mm(input1, weight[k], out=buff)
72+
buff.mul_(grad_output.narrow(1, k, 1).expand(grad_input2.size()))
73+
grad_input2.add_(buff)
74+
75+
if self.needs_input_grad[2]:
76+
# accumulate parameter gradients:
77+
for k in range(weight.size(0)):
78+
torch.mul(input1, grad_output.narrow(1, k, 1).expand_as(input1), out=buff)
79+
grad_weight = torch.mm(buff.t(), input2)
80+
81+
if bias is not None and self.needs_input_grad[3]:
82+
grad_bias = grad_output.sum(0)
83+
84+
return grad_input1, grad_input2, grad_weight, grad_bias

torch/nn/functional.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,14 @@ def linear(input, weight, bias=None):
449449
return state(input, weight) if bias is None else state(input, weight, bias)
450450

451451

452+
def bilinear(input1, input2, weight, bias=None):
453+
state = _functions.linear.Bilinear()
454+
if bias is None:
455+
return state(input1, input2, weight)
456+
else:
457+
return state(input1, input2, weight, bias)
458+
459+
452460
def batch_norm(input, running_mean, running_var, weight=None, bias=None,
453461
training=False, momentum=0.1, eps=1e-5):
454462
f = torch._C._functions.BatchNorm(running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled)

torch/nn/modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .module import Module
2-
from .linear import Linear
2+
from .linear import Linear, Bilinear
33
from .conv import Conv1d, Conv2d, Conv3d, \
44
ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
55
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
@@ -43,5 +43,5 @@
4343
'Embedding', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell',
4444
'PixelShuffle', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'PairwiseDistance',
4545
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d',
46-
'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad2d'
46+
'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad2d', 'Bilinear',
4747
]

torch/nn/modules/linear.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torch.nn.parameter import Parameter
5-
5+
from .. import functional as F
66
from .module import Module
77

88

@@ -59,5 +59,58 @@ def __repr__(self):
5959
+ str(self.out_features) + ')'
6060

6161

62-
# TODO: Bilinear
62+
class Bilinear(Module):
63+
r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1 * A * x_2 + b`
64+
65+
Args:
66+
in_features1: size of each first input sample
67+
in_features2: size of each second input sample
68+
out_features: size of each output sample
69+
bias: If set to False, the layer will not learn an additive bias. Default: True
70+
71+
Shape:
72+
- Input: :math:`(N, in\_features1)`, :math:`(N, in\_features2)`
73+
- Output: :math:`(N, out\_features)`
74+
75+
Attributes:
76+
weight: the learnable weights of the module of shape (out_features x in_features1 x in_features2)
77+
bias: the learnable bias of the module of shape (out_features)
78+
79+
Examples::
80+
81+
>>> m = nn.Bilinear(20, 30, 40)
82+
>>> input1 = autograd.Variable(torch.randn(128, 20))
83+
>>> input1 = autograd.Variable(torch.randn(128, 30))
84+
>>> output = m(input1, input2)
85+
>>> print(output.size())
86+
"""
87+
88+
def __init__(self, in1_features, in2_features, out_features, bias=True):
89+
super(Bilinear, self).__init__()
90+
self.in1_features = in1_features
91+
self.in2_features = in2_features
92+
self.out_features = out_features
93+
self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features))
94+
95+
if bias:
96+
self.bias = Parameter(torch.Tensor(out_features))
97+
else:
98+
self.register_parameter('bias', None)
99+
self.reset_parameters()
100+
101+
def reset_parameters(self):
102+
stdv = 1. / math.sqrt(self.weight.size(1))
103+
self.weight.data.uniform_(-stdv, stdv)
104+
if self.bias is not None:
105+
self.bias.data.uniform_(-stdv, stdv)
106+
107+
def forward(self, input1, input2):
108+
return F.bilinear(input1, input2, self.weight, self.bias)
109+
110+
def __repr__(self):
111+
return self.__class__.__name__ + ' (' \
112+
+ 'in1_features=' + str(self.in1_features) \
113+
+ ', in2_features=' + str(self.in2_features) \
114+
+ ', out_features=' + str(self.out_features) + ')'
115+
63116
# TODO: PartialLinear - maybe in sparse?

0 commit comments

Comments
 (0)