There was an error while loading. Please reload this page.
1 parent d610b4a commit 53f25e0Copy full SHA for 53f25e0
mnist/main.py
@@ -97,9 +97,8 @@ def test():
97
if args.cuda:
98
data, target = data.cuda(), target.cuda()
99
data, target = Variable(data, volatile=True), Variable(target)
100
- current_batch_size = data.data.size()[0]
101
output = model(data)
102
- test_loss += F.nll_loss(output, target).data[0] * current_batch_size # sum up batch loss
+ test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
103
pred = output.data.max(1)[1] # get the index of the max log-probability
104
correct += pred.eq(target.data).cpu().sum()
105
0 commit comments