@@ -357,19 +357,31 @@ def bw_hook(module, grad_input, grad_output):
357357 self .assertEqual (input .grad .data , expected_grad )
358358
359359 def test_zero_grad (self ):
360+ i = Variable (torch .randn (2 , 5 ), requires_grad = True )
360361 module = nn .Linear (5 , 5 )
361362 for p in module .parameters ():
362363 p .requires_grad = False
363364 module .zero_grad ()
364365
365366 module .weight .requires_grad = True
366- module .weight ._grad = Variable (module .weight .data .clone ().fill_ (1 ))
367+ module .zero_grad ()
368+ self .assertIsNone (module .weight .grad ) # uninitialized grad
369+
370+ module (i ).sum ().backward ()
371+ self .assertIsNotNone (module .weight .grad )
372+ self .assertGreater (module .weight .grad .data .abs ().sum (), 0 )
367373 module .zero_grad ()
368374 self .assertEqual (module .weight .grad .data , module .weight .data .clone ().zero_ ())
369375
370376 module .bias .requires_grad = True
371- module .weight ._grad = Variable (module .weight .data .clone ().fill_ (1 ))
372- module .bias ._grad = Variable (module .bias .data .clone ().fill_ (1 ))
377+ module .zero_grad ()
378+ self .assertIsNotNone (module .weight .grad )
379+ self .assertIsNone (module .bias .grad )
380+ module (i ).sum ().backward ()
381+ self .assertIsNotNone (module .weight .grad )
382+ self .assertIsNotNone (module .bias .grad )
383+ self .assertGreater (module .weight .grad .data .abs ().sum (), 0 )
384+ self .assertGreater (module .bias .grad .data .abs ().sum (), 0 )
373385 module .zero_grad ()
374386 self .assertEqual (module .weight .grad .data , module .weight .data .clone ().zero_ ())
375387 self .assertEqual (module .bias .grad .data , module .bias .data .clone ().zero_ ())
0 commit comments