Skip to content

Commit 943f2c4

Browse files
committed
Add test script to test dict output and custom instance output
1 parent 9d374b9 commit 943f2c4

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

test/test_nn.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,44 @@ def forward(self, input):
13671367
self.assertIsInstance(output[1][2], list)
13681368
self.assertIsInstance(output[1][2][0], Variable)
13691369
self.assertIsInstance(output[2], Variable)
1370+
1371+
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
1372+
def test_data_parallel_dict_and_custom_instance_output(self):
1373+
1374+
class State(object):
1375+
def __init__(self, state_list):
1376+
self.state_list = state_list
1377+
self.i = -1
1378+
def __iter__(self):
1379+
return self
1380+
def next(self):
1381+
self.i += 1
1382+
if self.i >= len(self.state_list):
1383+
self.i = -1
1384+
raise StopIteration()
1385+
return self.state_list[i]
1386+
def __next__(self):
1387+
self.next()
1388+
1389+
def fn(input):
1390+
return {'input':input}, State([input.sin(), input.cos()])
1391+
1392+
class Net(nn.Module):
1393+
def forward(self, input):
1394+
return fn(input)
1395+
1396+
1397+
i = Variable(torch.randn(2, 2).float().cuda())
1398+
gpus = range(torch.cuda.device_count())
1399+
output = dp.data_parallel(Net(), i, gpus)
1400+
expected_out = fn(i)
1401+
self.assertIsInstance(output, tuple)
1402+
self.assertEqual(len(output), 2)
1403+
self.assertIsInstance(output[0], dict)
1404+
self.assertIsInstance(output[1], State)
1405+
self.assertEqual(output[0].keys(), expected_out[0].keys())
1406+
self.assertEqual(output[0].values(), expected_out[0].values())
1407+
self.assertEqual(tuple(output[1]), tuple(expected_out[1]))
13701408

13711409
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
13721410
def test_data_parallel_nested_input(self):

0 commit comments

Comments
 (0)