@@ -266,76 +266,67 @@ def evaluate(self, model, dataloaders, max_batches, test=False):
266266
267267 def run_evaluation (self , test = False ):
268268 # when testing make sure user defined a test step
269- can_run_test_step = False
269+ if test and not (self .is_overriden ('test_step' ) and self .is_overriden ('test_end' )):
270+ m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`.
271+ Please define and try again'''
272+ raise MisconfigurationException (m )
273+
274+ # hook
275+ model = self .get_model ()
276+ model .on_pre_performance_check ()
277+
278+ # select dataloaders
270279 if test :
271- can_run_test_step = self .is_overriden ('test_step' ) and self .is_overriden ('test_end' )
272- if not can_run_test_step :
273- m = '''You called .test() without defining a test step or test_end.
274- Please define and try again'''
275- raise MisconfigurationException (m )
276-
277- # validate only if model has validation_step defined
278- # test only if test_step or validation_step are defined
279- run_val_step = self .is_overriden ('validation_step' )
280-
281- if run_val_step or can_run_test_step :
282-
283- # hook
284- model = self .get_model ()
285- model .on_pre_performance_check ()
286-
287- # select dataloaders
288- if test :
289- dataloaders = self .get_test_dataloaders ()
290- max_batches = self .num_test_batches
291- else :
292- # val
293- dataloaders = self .get_val_dataloaders ()
294- max_batches = self .num_val_batches
295-
296- # cap max batches to 1 when using fast_dev_run
297- if self .fast_dev_run :
298- max_batches = 1
299-
300- # init validation or test progress bar
301- # main progress bar will already be closed when testing so initial position is free
302- position = 2 * self .process_position + (not test )
303- desc = 'Testing' if test else 'Validating'
304- pbar = tqdm .tqdm (desc = desc , total = max_batches , leave = test , position = position ,
305- disable = not self .show_progress_bar , dynamic_ncols = True ,
306- unit = 'batch' , file = sys .stdout )
307- setattr (self , f'{ "test" if test else "val" } _progress_bar' , pbar )
308-
309- # run evaluation
310- eval_results = self .evaluate (self .model ,
311- dataloaders ,
312- max_batches ,
313- test )
314- _ , prog_bar_metrics , log_metrics , callback_metrics , _ = self .process_output (
315- eval_results )
316-
317- # add metrics to prog bar
318- self .add_tqdm_metrics (prog_bar_metrics )
319-
320- # log metrics
321- self .log_metrics (log_metrics , {})
322-
323- # track metrics for callbacks
324- self .callback_metrics .update (callback_metrics )
325-
326- # hook
327- model .on_post_performance_check ()
328-
329- # add model specific metrics
330- tqdm_metrics = self .training_tqdm_dict
331- if not test :
332- self .main_progress_bar .set_postfix (** tqdm_metrics )
333-
334- # close progress bar
335- if test :
336- self .test_progress_bar .close ()
337- else :
338- self .val_progress_bar .close ()
280+ dataloaders = self .get_test_dataloaders ()
281+ max_batches = self .num_test_batches
282+ else :
283+ # val
284+ dataloaders = self .get_val_dataloaders ()
285+ max_batches = self .num_val_batches
286+
287+ # cap max batches to 1 when using fast_dev_run
288+ if self .fast_dev_run :
289+ max_batches = 1
290+
291+ # init validation or test progress bar
292+ # main progress bar will already be closed when testing so initial position is free
293+ position = 2 * self .process_position + (not test )
294+ desc = 'Testing' if test else 'Validating'
295+ pbar = tqdm .tqdm (desc = desc , total = max_batches , leave = test , position = position ,
296+ disable = not self .show_progress_bar , dynamic_ncols = True ,
297+ unit = 'batch' , file = sys .stdout )
298+ setattr (self , f'{ "test" if test else "val" } _progress_bar' , pbar )
299+
300+ # run evaluation
301+ eval_results = self .evaluate (self .model ,
302+ dataloaders ,
303+ max_batches ,
304+ test )
305+ _ , prog_bar_metrics , log_metrics , callback_metrics , _ = self .process_output (
306+ eval_results )
307+
308+ # add metrics to prog bar
309+ self .add_tqdm_metrics (prog_bar_metrics )
310+
311+ # log metrics
312+ self .log_metrics (log_metrics , {})
313+
314+ # track metrics for callbacks
315+ self .callback_metrics .update (callback_metrics )
316+
317+ # hook
318+ model .on_post_performance_check ()
319+
320+ # add model specific metrics
321+ tqdm_metrics = self .training_tqdm_dict
322+ if not test :
323+ self .main_progress_bar .set_postfix (** tqdm_metrics )
324+
325+ # close progress bar
326+ if test :
327+ self .test_progress_bar .close ()
328+ else :
329+ self .val_progress_bar .close ()
339330
340331 # model checkpointing
341332 if self .proc_rank == 0 and self .checkpoint_callback is not None and not test :
0 commit comments