Skip to content

Commit e79bff3

Browse files
author
Jeff Yang
authored
feat: run evaluation for 1 epoch before training (#57)
1 parent 23c87b4 commit e79bff3

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

templates/image_classification/main.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,15 @@ def _():
164164
eval_engine.run(eval_dataloader, max_epochs=1)
165165
eval_engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), log_metrics, tag="eval")
166166

167+
# --------------------------------------------------
168+
# let's try run evaluation first as a sanity check
169+
# --------------------------------------------------
170+
171+
@train_engine.on(Events.STARTED)
172+
def _():
173+
eval_engine.run(eval_dataloader, max_epochs=1, epoch_length=2)
174+
eval_engine.state.max_epochs = None
175+
167176
# ------------------------------------------
168177
# setup if done. let's run the training
169178
# ------------------------------------------
@@ -195,7 +204,7 @@ def main():
195204

196205
if config.output_dir:
197206
now = datetime.now().strftime("%Y%m%d-%H%M%S")
198-
name = f'{config.model}-backend-{idist.backend()}-{now}'
207+
name = f"{config.model}-backend-{idist.backend()}-{now}"
199208
path = Path(config.output_dir, name)
200209
path.mkdir(parents=True, exist_ok=True)
201210
config.output_dir = path

templates/single/main.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ def _():
140140
eval_engine.run(eval_dataloader, max_epochs=1)
141141
eval_engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), log_metrics, tag="eval")
142142

143+
# --------------------------------------------------
144+
# let's try run evaluation first as a sanity check
145+
# --------------------------------------------------
146+
147+
@train_engine.on(Events.STARTED)
148+
def _():
149+
eval_engine.run(eval_dataloader, max_epochs=1, epoch_length=2)
150+
eval_engine.state.max_epochs = None
151+
143152
# ------------------------------------------
144153
# setup if done. let's run the training
145154
# ------------------------------------------
@@ -172,7 +181,7 @@ def main():
172181

173182
if config.output_dir:
174183
now = datetime.now().strftime("%Y%m%d-%H%M%S")
175-
name = f'backend-{idist.backend()}-{now}'
184+
name = f"backend-{idist.backend()}-{now}"
176185
path = Path(config.output_dir, name)
177186
path.mkdir(parents=True, exist_ok=True)
178187
config.output_dir = path

0 commit comments

Comments
 (0)