@@ -259,7 +259,8 @@ def bw_hook(inc, h_module, grad_input, grad_output):
259259 self .assertEqual (counter ['forwards' ], 2 )
260260 self .assertEqual (counter ['backwards' ], 0 )
261261
262- test_bwd = module .register_backward_hook (lambda * args : bw_hook (1 , * args ))
262+ test_bwd = module .register_backward_hook (
263+ lambda * args : bw_hook (1 , * args ))
263264
264265 output = module (input )
265266 self .assertEqual (counter ['forwards' ], 3 )
@@ -816,7 +817,8 @@ def test_parallel_apply(self):
816817 inputs = ((i1 ,), (i2 ,))
817818 modules = (l1 , l2 )
818819 expected_outputs = (expected1 , expected2 )
819- outputs = dp .parallel_apply (modules , inputs )
820+
821+ outputs = dp .parallel_apply (modules , inputs , None )
820822 for out , expected in zip (outputs , expected_outputs ):
821823 self .assertEqual (out .data , expected )
822824
@@ -833,27 +835,67 @@ def test_data_parallel_noop(self):
833835 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
834836 def test_data_parallel_multiple_input (self ):
835837 class TestModule (nn .Module ):
836- def forward (self , x , y ):
837- return x + y
838-
839- m = TestModule ()
840- x = Variable (torch .randn (5 , 5 ).float ())
841- y = Variable (torch .randn (5 , 5 ).float ())
842- expected = m (x , y )
843838
844- out = dp .data_parallel (m , (x , y ), (0 , 1 ))
845- self .assertEqual (out , expected )
839+ def forward (self , var1 , var2 , float1 , var3 = None ):
840+ if var3 is None :
841+ return float1 * (var1 * var2 )
842+ else :
843+ return float1 * (var1 * var2 + var3 )
846844
847- out = dp .data_parallel (m , (x , y ), (0 ,))
848- self .assertEqual (out , expected )
845+ m = TestModule ()
846+ var1 = Variable (torch .randn (5 , 5 ).float (), requires_grad = True )
847+ var2 = Variable (torch .randn (5 , 5 ).float (), requires_grad = True )
848+ var3 = Variable (torch .randn (5 , 5 ).float (), requires_grad = False )
849+
850+ float1 = torch .randn (1 )[0 ]
851+ target = Variable (torch .randn (5 , 5 ).float ()).cuda ()
852+ crit = nn .MSELoss ()
853+
854+ expected = m (var1 , var2 , float1 )
855+ loss = expected .sum ()
856+ loss .backward ()
857+ gvar1_exp = var1 .grad .clone ()
858+ gvar2_exp = var2 .grad .clone ()
859+
860+ def local_test (out ):
861+ var1 .grad .data .fill_ (0.0 )
862+ var2 .grad .data .fill_ (0.0 )
863+ loss = out .sum ()
864+ loss .backward ()
865+ self .assertEqual (out , expected )
866+ self .assertEqual (gvar1_exp , var1 .grad )
867+ self .assertEqual (gvar2_exp , var2 .grad )
868+
869+ out = dp .data_parallel (m , (var1 , var2 , float1 ), (0 , 1 ))
870+ local_test (out )
871+
872+ out = dp .data_parallel (m , (var1 , var2 , float1 ), (0 ,))
873+ local_test (out )
874+
875+ var1 .grad .data .fill_ (0.0 )
876+ var2 .grad .data .fill_ (0.0 )
877+ expected = m (var1 , var2 , float1 , var3 = var3 )
878+ loss = expected .sum ()
879+ loss .backward ()
880+ gvar1_exp = var1 .grad .clone ()
881+ gvar2_exp = var2 .grad .clone ()
849882
850883 dpm = nn .DataParallel (TestModule ())
851- out = dpm (x , y )
852- self . assertEqual (out , expected )
884+ out = dpm (var1 , var2 , float1 , var3 = var3 )
885+ local_test (out )
853886
854887 dpm = nn .DataParallel (TestModule (), device_ids = [0 ])
855- out = dpm (x , y )
856- self .assertEqual (out , expected )
888+ out = dpm (var1 , var2 , float1 , var3 = var3 )
889+ local_test (out )
890+
891+ kwarg_wrap = {'var3' : var3 }
892+ out = dp .data_parallel (
893+ m , (var1 , var2 , float1 ), (0 , 1 ), module_kwargs = kwarg_wrap )
894+ local_test (out )
895+
896+ out = dp .data_parallel (
897+ m , (var1 , var2 , float1 ), (0 ,), module_kwargs = kwarg_wrap )
898+ local_test (out )
857899
858900 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
859901 def test_data_parallel_small_back (self ):
@@ -1426,8 +1468,10 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu):
14261468 for nonlinearity in ('tanh' , 'relu' ):
14271469 hx_val = torch .randn (num_layers , batch , hidden_size )
14281470 input_val = torch .randn (seq_length , batch , input_size )
1429- grad_output = torch .randn (seq_length , batch , hidden_size * num_directions )
1430- grad_hy = torch .randn (num_layers * num_directions , batch , hidden_size )
1471+ grad_output = torch .randn (
1472+ seq_length , batch , hidden_size * num_directions )
1473+ grad_hy = torch .randn (
1474+ num_layers * num_directions , batch , hidden_size )
14311475
14321476 rnn = nn .RNN (input_size , hidden_size , num_layers , bias = bias , nonlinearity = nonlinearity )
14331477 outputs_cpu = forward_backward (False , rnn , input_val , hx_val , grad_output , grad_hy , rnn .all_weights )
0 commit comments