File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -99,12 +99,11 @@ def test():
9999 data , target  =  Variable (data , volatile = True ), Variable (target )
100100 current_batch_size  =  data .data .size ()[0 ]
101101 output  =  model (data )
102-  # F.nll_loss is averaged among each batch 
103-  test_loss  +=  F .nll_loss (output , target ).data [0 ] *  current_batch_size 
102+  test_loss  +=  F .nll_loss (output , target ).data [0 ] *  current_batch_size  # sum up batch loss 
104103 pred  =  output .data .max (1 )[1 ] # get the index of the max log-probability 
105104 correct  +=  pred .eq (target .data ).cpu ().sum ()
106105
107-  test_loss  =  test_loss  # sum of loss function 
106+  test_loss  =  test_loss  # sum of loss function over all data points  
108107 test_loss  /=  len (test_loader .dataset )
109108 print ('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
110109 test_loss , correct , len (test_loader .dataset ),
                         You can’t perform that action at this time. 
           
                  
0 commit comments