|
1 | 1 | import math |
2 | 2 | import random |
| 3 | +import string |
3 | 4 | import unittest |
4 | 5 | import itertools |
5 | 6 | import contextlib |
@@ -2106,6 +2107,47 @@ def _create_random_nd_tensor(self, dims, size_min, size_max, as_variable): |
2106 | 2107 | def _random_float(self, a, b): |
2107 | 2108 | return (b - a) * random.random() + a |
2108 | 2109 |
|
| 2110 | + def test_calculate_gain_linear(self): |
| 2111 | + for fn in ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose2d', 'conv_transpose2d', 'conv_transpose3d']: |
| 2112 | + gain = init.calculate_gain(fn) |
| 2113 | + self.assertEqual(gain, 1) |
| 2114 | + |
| 2115 | + def test_calculate_gain_nonlinear(self): |
| 2116 | + for fn in ['sigmoid', 'tanh', 'relu', 'leaky_relu']: |
| 2117 | + gain = init.calculate_gain(fn) |
| 2118 | + if fn == 'sigmoid': |
| 2119 | + self.assertEqual(gain, 1) |
| 2120 | + elif fn == 'tanh': # 5 / 3 |
| 2121 | + self.assertEqual(gain, 1.6666666666666667) |
| 2122 | + elif fn == 'relu': # sqrt(2) |
| 2123 | + self.assertEqual(gain, 1.4142135623730951) |
| 2124 | + elif fn == 'leaky_relu': # sqrt(2 / 1 + slope^2)) |
| 2125 | + self.assertEqual(gain, 1.4141428569978354) |
| 2126 | + |
| 2127 | + def test_calculate_gain_leaky_relu(self): |
| 2128 | + for param in [None, 0, 0.01, 10]: |
| 2129 | + gain = init.calculate_gain('leaky_relu', param) |
| 2130 | + if param is None: # Default slope is 0.01 |
| 2131 | + self.assertEqual(gain, 1.4141428569978354) |
| 2132 | + elif param == 0: # No slope = same gain as normal ReLU |
| 2133 | + self.assertEqual(gain, 1.4142135623730951) |
| 2134 | + elif param == 0.01: |
| 2135 | + self.assertEqual(gain, 1.4141428569978354) |
| 2136 | + elif param == 10: |
| 2137 | + self.assertEqual(gain, 0.14071950894605836) |
| 2138 | + |
| 2139 | + def test_calculate_gain_leaky_relu_only_accepts_numbers(self): |
| 2140 | + for param in [True, [1], {'a': 'b'}]: |
| 2141 | + with self.assertRaises(ValueError): |
| 2142 | + init.calculate_gain('leaky_relu', param) |
| 2143 | + |
| 2144 | + def test_calculate_gain_only_accepts_valid_nonlinearities(self): |
| 2145 | + for n in [2, 5, 25]: |
| 2146 | + # Generate random strings of lengths that definitely aren't supported |
| 2147 | + random_string = ''.join([random.choice(string.ascii_lowercase) for i in range(n)]) |
| 2148 | + with self.assertRaises(ValueError): |
| 2149 | + init.calculate_gain(random_string) |
| 2150 | + |
2109 | 2151 | @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") |
2110 | 2152 | def test_uniform(self): |
2111 | 2153 | for as_variable in [True, False]: |
@@ -2138,6 +2180,79 @@ def test_constant(self): |
2138 | 2180 |
|
2139 | 2181 | self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) |
2140 | 2182 |
|
| 2183 | + def test_eye(self): |
| 2184 | + for as_variable in [True, False]: |
| 2185 | + input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5, as_variable=as_variable) |
| 2186 | + init.eye(input_tensor) |
| 2187 | + if as_variable: |
| 2188 | + input_tensor = input_tensor.data |
| 2189 | + |
| 2190 | + # Check every single element |
| 2191 | + for i in range(input_tensor.size(0)): |
| 2192 | + for j in range(input_tensor.size(1)): |
| 2193 | + if i == j: |
| 2194 | + assert input_tensor[i][j] == 1 |
| 2195 | + else: |
| 2196 | + assert input_tensor[i][j] == 0 |
| 2197 | + |
| 2198 | + def test_eye_only_works_on_2d_inputs(self): |
| 2199 | + for as_variable in [True, False]: |
| 2200 | + for dims in [1, 3]: |
| 2201 | + with self.assertRaises(ValueError): |
| 2202 | + tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3, as_variable=as_variable) |
| 2203 | + init.eye(tensor) |
| 2204 | + |
| 2205 | + def test_dirac_properties(self): |
| 2206 | + for as_variable in [True, False]: |
| 2207 | + for dims in [3, 4, 5]: |
| 2208 | + input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5, as_variable=as_variable) |
| 2209 | + init.dirac(input_tensor) |
| 2210 | + if as_variable: |
| 2211 | + input_tensor = input_tensor.data |
| 2212 | + |
| 2213 | + c_out, c_in = input_tensor.size(0), input_tensor.size(1) |
| 2214 | + min_d = min(c_out, c_in) |
| 2215 | + # Check number of nonzeros is equivalent to smallest dim |
| 2216 | + assert torch.nonzero(input_tensor).size(0) == min_d |
| 2217 | + # Check sum of values (can have precision issues, hence assertEqual) is also equivalent |
| 2218 | + self.assertEqual(input_tensor.sum(), min_d) |
| 2219 | + |
| 2220 | + def test_dirac_identity(self): |
| 2221 | + batch, in_c, out_c, size, kernel_size = 8, 3, 4, 5, 3 |
| 2222 | + # Test 1D |
| 2223 | + input_var = Variable(torch.randn(batch, in_c, size)) |
| 2224 | + filter_var = Variable(torch.zeros(out_c, in_c, kernel_size)) |
| 2225 | + init.dirac(filter_var) |
| 2226 | + output_var = F.conv1d(input_var, filter_var) |
| 2227 | + input_tensor, output_tensor = input_var.data, output_var.data # Variables do not support nonzero |
| 2228 | + self.assertEqual(input_tensor[:, :, 1:-1], output_tensor[:, :in_c, :]) # Assert in_c outputs are preserved |
| 2229 | + assert torch.nonzero(output_tensor[:, in_c:, :]).numel() == 0 # Assert extra outputs are 0 |
| 2230 | + |
| 2231 | + # Test 2D |
| 2232 | + input_var = Variable(torch.randn(batch, in_c, size, size)) |
| 2233 | + filter_var = Variable(torch.zeros(out_c, in_c, kernel_size, kernel_size)) |
| 2234 | + init.dirac(filter_var) |
| 2235 | + output_var = F.conv2d(input_var, filter_var) |
| 2236 | + input_tensor, output_tensor = input_var.data, output_var.data |
| 2237 | + self.assertEqual(input_tensor[:, :, 1:-1, 1:-1], output_tensor[:, :in_c, :, :]) |
| 2238 | + assert torch.nonzero(output_tensor[:, in_c:, :, :]).numel() == 0 |
| 2239 | + |
| 2240 | + # Test 3D |
| 2241 | + input_var = Variable(torch.randn(batch, in_c, size, size, size)) |
| 2242 | + filter_var = Variable(torch.zeros(out_c, in_c, kernel_size, kernel_size, kernel_size)) |
| 2243 | + init.dirac(filter_var) |
| 2244 | + output_var = F.conv3d(input_var, filter_var) |
| 2245 | + input_tensor, output_tensor = input_var.data, output_var.data |
| 2246 | + self.assertEqual(input_tensor[:, :, 1:-1, 1:-1, 1:-1], output_tensor[:, :in_c, :, :]) |
| 2247 | + assert torch.nonzero(output_tensor[:, in_c:, :, :, :]).numel() == 0 |
| 2248 | + |
| 2249 | + def test_dirac_only_works_on_3_4_5d_inputs(self): |
| 2250 | + for as_variable in [True, False]: |
| 2251 | + for dims in [1, 2, 6]: |
| 2252 | + with self.assertRaises(ValueError): |
| 2253 | + tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3, as_variable=as_variable) |
| 2254 | + init.dirac(tensor) |
| 2255 | + |
2141 | 2256 | def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self): |
2142 | 2257 | for as_variable in [True, False]: |
2143 | 2258 | for dims in [0, 1]: |
|
0 commit comments