Skip to content

Commit 3b60784

Browse files
committed
add comments
1 parent 3aca2d7 commit 3b60784

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

mnist/main.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,11 @@ def test():
9999
data, target = Variable(data, volatile=True), Variable(target)
100100
current_batch_size = data.data.size()[0]
101101
output = model(data)
102-
# F.nll_loss is averaged among each batch
103-
test_loss += F.nll_loss(output, target).data[0] * current_batch_size
102+
test_loss += F.nll_loss(output, target).data[0] * current_batch_size # sum up batch loss
104103
pred = output.data.max(1)[1] # get the index of the max log-probability
105104
correct += pred.eq(target.data).cpu().sum()
106105

107-
test_loss = test_loss # sum of loss function
106+
test_loss = test_loss # sum of loss function over all data points
108107
test_loss /= len(test_loader.dataset)
109108
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
110109
test_loss, correct, len(test_loader.dataset),

0 commit comments

Comments
 (0)