Skip to content

Commit cbb9f08

Browse files
Kaixhinsoumith
authored andcommitted
Add new init methods gain, eye and dirac (pytorch#1172)
1 parent f75ab85 commit cbb9f08

File tree

3 files changed

+267
-51
lines changed

3 files changed

+267
-51
lines changed

docs/source/nn.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,9 +855,12 @@ torch.nn.init
855855
=============
856856

857857
.. currentmodule:: torch.nn.init
858+
.. autofunction:: calculate_gain
858859
.. autofunction:: uniform
859860
.. autofunction:: normal
860861
.. autofunction:: constant
862+
.. autofunction:: eye
863+
.. autofunction:: dirac
861864
.. autofunction:: xavier_uniform
862865
.. autofunction:: xavier_normal
863866
.. autofunction:: kaiming_uniform

test/test_nn.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import random
3+
import string
34
import unittest
45
import itertools
56
import contextlib
@@ -2106,6 +2107,47 @@ def _create_random_nd_tensor(self, dims, size_min, size_max, as_variable):
21062107
def _random_float(self, a, b):
21072108
return (b - a) * random.random() + a
21082109

2110+
def test_calculate_gain_linear(self):
2111+
for fn in ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose2d', 'conv_transpose2d', 'conv_transpose3d']:
2112+
gain = init.calculate_gain(fn)
2113+
self.assertEqual(gain, 1)
2114+
2115+
def test_calculate_gain_nonlinear(self):
2116+
for fn in ['sigmoid', 'tanh', 'relu', 'leaky_relu']:
2117+
gain = init.calculate_gain(fn)
2118+
if fn == 'sigmoid':
2119+
self.assertEqual(gain, 1)
2120+
elif fn == 'tanh': # 5 / 3
2121+
self.assertEqual(gain, 1.6666666666666667)
2122+
elif fn == 'relu': # sqrt(2)
2123+
self.assertEqual(gain, 1.4142135623730951)
2124+
elif fn == 'leaky_relu': # sqrt(2 / 1 + slope^2))
2125+
self.assertEqual(gain, 1.4141428569978354)
2126+
2127+
def test_calculate_gain_leaky_relu(self):
2128+
for param in [None, 0, 0.01, 10]:
2129+
gain = init.calculate_gain('leaky_relu', param)
2130+
if param is None: # Default slope is 0.01
2131+
self.assertEqual(gain, 1.4141428569978354)
2132+
elif param == 0: # No slope = same gain as normal ReLU
2133+
self.assertEqual(gain, 1.4142135623730951)
2134+
elif param == 0.01:
2135+
self.assertEqual(gain, 1.4141428569978354)
2136+
elif param == 10:
2137+
self.assertEqual(gain, 0.14071950894605836)
2138+
2139+
def test_calculate_gain_leaky_relu_only_accepts_numbers(self):
2140+
for param in [True, [1], {'a': 'b'}]:
2141+
with self.assertRaises(ValueError):
2142+
init.calculate_gain('leaky_relu', param)
2143+
2144+
def test_calculate_gain_only_accepts_valid_nonlinearities(self):
2145+
for n in [2, 5, 25]:
2146+
# Generate random strings of lengths that definitely aren't supported
2147+
random_string = ''.join([random.choice(string.ascii_lowercase) for i in range(n)])
2148+
with self.assertRaises(ValueError):
2149+
init.calculate_gain(random_string)
2150+
21092151
@unittest.skipIf(not TEST_SCIPY, "Scipy not found.")
21102152
def test_uniform(self):
21112153
for as_variable in [True, False]:
@@ -2138,6 +2180,79 @@ def test_constant(self):
21382180

21392181
self.assertEqual(input_tensor, input_tensor.clone().fill_(val))
21402182

2183+
def test_eye(self):
2184+
for as_variable in [True, False]:
2185+
input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5, as_variable=as_variable)
2186+
init.eye(input_tensor)
2187+
if as_variable:
2188+
input_tensor = input_tensor.data
2189+
2190+
# Check every single element
2191+
for i in range(input_tensor.size(0)):
2192+
for j in range(input_tensor.size(1)):
2193+
if i == j:
2194+
assert input_tensor[i][j] == 1
2195+
else:
2196+
assert input_tensor[i][j] == 0
2197+
2198+
def test_eye_only_works_on_2d_inputs(self):
2199+
for as_variable in [True, False]:
2200+
for dims in [1, 3]:
2201+
with self.assertRaises(ValueError):
2202+
tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3, as_variable=as_variable)
2203+
init.eye(tensor)
2204+
2205+
def test_dirac_properties(self):
2206+
for as_variable in [True, False]:
2207+
for dims in [3, 4, 5]:
2208+
input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5, as_variable=as_variable)
2209+
init.dirac(input_tensor)
2210+
if as_variable:
2211+
input_tensor = input_tensor.data
2212+
2213+
c_out, c_in = input_tensor.size(0), input_tensor.size(1)
2214+
min_d = min(c_out, c_in)
2215+
# Check number of nonzeros is equivalent to smallest dim
2216+
assert torch.nonzero(input_tensor).size(0) == min_d
2217+
# Check sum of values (can have precision issues, hence assertEqual) is also equivalent
2218+
self.assertEqual(input_tensor.sum(), min_d)
2219+
2220+
def test_dirac_identity(self):
2221+
batch, in_c, out_c, size, kernel_size = 8, 3, 4, 5, 3
2222+
# Test 1D
2223+
input_var = Variable(torch.randn(batch, in_c, size))
2224+
filter_var = Variable(torch.zeros(out_c, in_c, kernel_size))
2225+
init.dirac(filter_var)
2226+
output_var = F.conv1d(input_var, filter_var)
2227+
input_tensor, output_tensor = input_var.data, output_var.data # Variables do not support nonzero
2228+
self.assertEqual(input_tensor[:, :, 1:-1], output_tensor[:, :in_c, :]) # Assert in_c outputs are preserved
2229+
assert torch.nonzero(output_tensor[:, in_c:, :]).numel() == 0 # Assert extra outputs are 0
2230+
2231+
# Test 2D
2232+
input_var = Variable(torch.randn(batch, in_c, size, size))
2233+
filter_var = Variable(torch.zeros(out_c, in_c, kernel_size, kernel_size))
2234+
init.dirac(filter_var)
2235+
output_var = F.conv2d(input_var, filter_var)
2236+
input_tensor, output_tensor = input_var.data, output_var.data
2237+
self.assertEqual(input_tensor[:, :, 1:-1, 1:-1], output_tensor[:, :in_c, :, :])
2238+
assert torch.nonzero(output_tensor[:, in_c:, :, :]).numel() == 0
2239+
2240+
# Test 3D
2241+
input_var = Variable(torch.randn(batch, in_c, size, size, size))
2242+
filter_var = Variable(torch.zeros(out_c, in_c, kernel_size, kernel_size, kernel_size))
2243+
init.dirac(filter_var)
2244+
output_var = F.conv3d(input_var, filter_var)
2245+
input_tensor, output_tensor = input_var.data, output_var.data
2246+
self.assertEqual(input_tensor[:, :, 1:-1, 1:-1, 1:-1], output_tensor[:, :in_c, :, :])
2247+
assert torch.nonzero(output_tensor[:, in_c:, :, :, :]).numel() == 0
2248+
2249+
def test_dirac_only_works_on_3_4_5d_inputs(self):
2250+
for as_variable in [True, False]:
2251+
for dims in [1, 2, 6]:
2252+
with self.assertRaises(ValueError):
2253+
tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3, as_variable=as_variable)
2254+
init.dirac(tensor)
2255+
21412256
def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self):
21422257
for as_variable in [True, False]:
21432258
for dims in [0, 1]:

0 commit comments

Comments
 (0)