Skip to content

Commit 756c70a

Browse files
kuynzerebwilliamFalcon
authored andcommitted
Clearer disable validation logic (#650)
* Clearer disable validation logic * fix for fast_dev_run * flake8 fix * Test check fix * update error message
1 parent 083dd6a commit 756c70a

File tree

3 files changed

+78
-77
lines changed

3 files changed

+78
-77
lines changed

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 60 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -266,76 +266,67 @@ def evaluate(self, model, dataloaders, max_batches, test=False):
266266

267267
def run_evaluation(self, test=False):
268268
# when testing make sure user defined a test step
269-
can_run_test_step = False
269+
if test and not (self.is_overriden('test_step') and self.is_overriden('test_end')):
270+
m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`.
271+
Please define and try again'''
272+
raise MisconfigurationException(m)
273+
274+
# hook
275+
model = self.get_model()
276+
model.on_pre_performance_check()
277+
278+
# select dataloaders
270279
if test:
271-
can_run_test_step = self.is_overriden('test_step') and self.is_overriden('test_end')
272-
if not can_run_test_step:
273-
m = '''You called .test() without defining a test step or test_end.
274-
Please define and try again'''
275-
raise MisconfigurationException(m)
276-
277-
# validate only if model has validation_step defined
278-
# test only if test_step or validation_step are defined
279-
run_val_step = self.is_overriden('validation_step')
280-
281-
if run_val_step or can_run_test_step:
282-
283-
# hook
284-
model = self.get_model()
285-
model.on_pre_performance_check()
286-
287-
# select dataloaders
288-
if test:
289-
dataloaders = self.get_test_dataloaders()
290-
max_batches = self.num_test_batches
291-
else:
292-
# val
293-
dataloaders = self.get_val_dataloaders()
294-
max_batches = self.num_val_batches
295-
296-
# cap max batches to 1 when using fast_dev_run
297-
if self.fast_dev_run:
298-
max_batches = 1
299-
300-
# init validation or test progress bar
301-
# main progress bar will already be closed when testing so initial position is free
302-
position = 2 * self.process_position + (not test)
303-
desc = 'Testing' if test else 'Validating'
304-
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
305-
disable=not self.show_progress_bar, dynamic_ncols=True,
306-
unit='batch', file=sys.stdout)
307-
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
308-
309-
# run evaluation
310-
eval_results = self.evaluate(self.model,
311-
dataloaders,
312-
max_batches,
313-
test)
314-
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
315-
eval_results)
316-
317-
# add metrics to prog bar
318-
self.add_tqdm_metrics(prog_bar_metrics)
319-
320-
# log metrics
321-
self.log_metrics(log_metrics, {})
322-
323-
# track metrics for callbacks
324-
self.callback_metrics.update(callback_metrics)
325-
326-
# hook
327-
model.on_post_performance_check()
328-
329-
# add model specific metrics
330-
tqdm_metrics = self.training_tqdm_dict
331-
if not test:
332-
self.main_progress_bar.set_postfix(**tqdm_metrics)
333-
334-
# close progress bar
335-
if test:
336-
self.test_progress_bar.close()
337-
else:
338-
self.val_progress_bar.close()
280+
dataloaders = self.get_test_dataloaders()
281+
max_batches = self.num_test_batches
282+
else:
283+
# val
284+
dataloaders = self.get_val_dataloaders()
285+
max_batches = self.num_val_batches
286+
287+
# cap max batches to 1 when using fast_dev_run
288+
if self.fast_dev_run:
289+
max_batches = 1
290+
291+
# init validation or test progress bar
292+
# main progress bar will already be closed when testing so initial position is free
293+
position = 2 * self.process_position + (not test)
294+
desc = 'Testing' if test else 'Validating'
295+
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
296+
disable=not self.show_progress_bar, dynamic_ncols=True,
297+
unit='batch', file=sys.stdout)
298+
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
299+
300+
# run evaluation
301+
eval_results = self.evaluate(self.model,
302+
dataloaders,
303+
max_batches,
304+
test)
305+
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
306+
eval_results)
307+
308+
# add metrics to prog bar
309+
self.add_tqdm_metrics(prog_bar_metrics)
310+
311+
# log metrics
312+
self.log_metrics(log_metrics, {})
313+
314+
# track metrics for callbacks
315+
self.callback_metrics.update(callback_metrics)
316+
317+
# hook
318+
model.on_post_performance_check()
319+
320+
# add model specific metrics
321+
tqdm_metrics = self.training_tqdm_dict
322+
if not test:
323+
self.main_progress_bar.set_postfix(**tqdm_metrics)
324+
325+
# close progress bar
326+
if test:
327+
self.test_progress_bar.close()
328+
else:
329+
self.val_progress_bar.close()
339330

340331
# model checkpointing
341332
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:

pytorch_lightning/trainer/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def __init__(
213213
# training state
214214
self.model = None
215215
self.testing = False
216+
self.disable_validation = False
216217
self.lr_schedulers = []
217218
self.optimizers = None
218219
self.global_step = 0
@@ -486,11 +487,16 @@ def run_pretrain_routine(self, model):
486487
self.run_evaluation(test=True)
487488
return
488489

490+
# check if we should run validation during training
491+
self.disable_validation = ((self.num_val_batches == 0 or
492+
not self.is_overriden('validation_step')) and
493+
not self.fast_dev_run)
494+
489495
# run tiny validation (if validation defined)
490496
# to make sure program won't crash during val
491497
ref_model.on_sanity_check_start()
492498
ref_model.on_train_start()
493-
if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0:
499+
if not self.disable_validation and self.num_sanity_val_steps > 0:
494500
# init progress bars for validation sanity check
495501
pbar = tqdm.tqdm(desc='Validation sanity check',
496502
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),

pytorch_lightning/trainer/training_loop.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def __init__(self):
184184
self.num_training_batches = None
185185
self.val_check_batch = None
186186
self.num_val_batches = None
187+
self.disable_validation = None
187188
self.fast_dev_run = None
188189
self.is_iterable_train_dataloader = None
189190
self.main_progress_bar = None
@@ -294,14 +295,16 @@ def train(self):
294295
model.current_epoch = epoch
295296
self.current_epoch = epoch
296297

297-
# val can be checked multiple times in epoch
298-
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
299-
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
300-
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
298+
total_val_batches = 0
299+
if not self.disable_validation:
300+
# val can be checked multiple times in epoch
301+
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
302+
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
303+
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
304+
total_val_batches = self.num_val_batches * val_checks_per_epoch
301305

302306
# total batches includes multiple val checks
303-
self.total_batches = (self.num_training_batches +
304-
self.num_val_batches * val_checks_per_epoch)
307+
self.total_batches = self.num_training_batches + total_val_batches
305308
self.batch_loss_value = 0 # accumulated grads
306309

307310
if self.fast_dev_run:
@@ -390,7 +393,8 @@ def run_training_epoch(self):
390393
# ---------------
391394
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
392395
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
393-
should_check_val = ((is_val_check_batch or early_stop_epoch) and can_check_epoch)
396+
should_check_val = (not self.disable_validation and can_check_epoch and
397+
(is_val_check_batch or early_stop_epoch))
394398

395399
# fast_dev_run always forces val checking after train batch
396400
if self.fast_dev_run or should_check_val:

0 commit comments

Comments
 (0)