There was an error while loading. Please reload this page.
1 parent ec7fc97 commit 083dd6aCopy full SHA for 083dd6a
README.md
@@ -128,6 +128,18 @@ To use lightning do 2 things:
128
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
129
tensorboard_logs = {'val_loss': avg_loss}
130
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
131
+
132
+ def test_step(self, batch, batch_idx):
133
+ # OPTIONAL
134
+ x, y = batch
135
+ y_hat = self.forward(x)
136
+ return {'test_loss': F.cross_entropy(y_hat, y)}
137
138
+ def test_end(self, outputs):
139
140
+ avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
141
+ tensorboard_logs = {'test_loss': avg_loss}
142
+ return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
143
144
def configure_optimizers(self):
145
# REQUIRED
0 commit comments