Skip to content

Commit 083dd6a

Browse files
Frédéric Branchaud-CharronwilliamFalcon
authored andcommitted
Update Readme so that .test will work. (#659)
When one follows the Readme, the example will fail once we call `trainer.test()` because the methods are not overridden. Fixes #428
1 parent ec7fc97 commit 083dd6a

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,18 @@ To use lightning do 2 things:
128128
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
129129
tensorboard_logs = {'val_loss': avg_loss}
130130
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
131+
132+
def test_step(self, batch, batch_idx):
133+
# OPTIONAL
134+
x, y = batch
135+
y_hat = self.forward(x)
136+
return {'test_loss': F.cross_entropy(y_hat, y)}
137+
138+
def test_end(self, outputs):
139+
# OPTIONAL
140+
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
141+
tensorboard_logs = {'test_loss': avg_loss}
142+
return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
131143

132144
def configure_optimizers(self):
133145
# REQUIRED

0 commit comments

Comments
 (0)