@@ -1367,6 +1367,44 @@ def forward(self, input):
13671367 self .assertIsInstance (output [1 ][2 ], list )
13681368 self .assertIsInstance (output [1 ][2 ][0 ], Variable )
13691369 self .assertIsInstance (output [2 ], Variable )
1370+
1371+ @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
1372+ def test_data_parallel_dict_and_custom_instance_output (self ):
1373+
1374+ class State (object ):
1375+ def __init__ (self , state_list ):
1376+ self .state_list = state_list
1377+ self .i = - 1
1378+ def __iter__ (self ):
1379+ return self
1380+ def next (self ):
1381+ self .i += 1
1382+ if self .i >= len (self .state_list ):
1383+ self .i = - 1
1384+ raise StopIteration ()
1385+ return self .state_list [i ]
1386+ def __next__ (self ):
1387+ self .next ()
1388+
1389+ def fn (input ):
1390+ return {'input' :input }, State ([input .sin (), input .cos ()])
1391+
1392+ class Net (nn .Module ):
1393+ def forward (self , input ):
1394+ return fn (input )
1395+
1396+
1397+ i = Variable (torch .randn (2 , 2 ).float ().cuda ())
1398+ gpus = range (torch .cuda .device_count ())
1399+ output = dp .data_parallel (Net (), i , gpus )
1400+ expected_out = fn (i )
1401+ self .assertIsInstance (output , tuple )
1402+ self .assertEqual (len (output ), 2 )
1403+ self .assertIsInstance (output [0 ], dict )
1404+ self .assertIsInstance (output [1 ], State )
1405+ self .assertEqual (output [0 ].keys (), expected_out [0 ].keys ())
1406+ self .assertEqual (output [0 ].values (), expected_out [0 ].values ())
1407+ self .assertEqual (tuple (output [1 ]), tuple (expected_out [1 ]))
13701408
13711409 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
13721410 def test_data_parallel_nested_input (self ):
0 commit comments