There was an error while loading. Please reload this page.
1 parent 9053040 commit fb9ca4dCopy full SHA for fb9ca4d
mnist/main.py
@@ -99,8 +99,8 @@ def test():
99
data, target = Variable(data, volatile=True), Variable(target)
100
output = model(data)
101
test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
102
- pred = output.data.max(1)[1] # get the index of the max log-probability
103
- correct += pred.eq(target.data).cpu().sum()
+ pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
+ correct += pred.eq(target.data.view_as(pred)).cpu().sum()
104
105
test_loss /= len(test_loader.dataset)
106
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
0 commit comments