@@ -36,6 +36,15 @@ def __init__(self):
3636 self .shown_warnings = None
3737 self .val_check_interval = None
3838
39+ def _percent_range_check (self , name ):
40+ value = getattr (self , name )
41+ msg = f"`{ name } ` must lie in the range [0.0, 1.0], but got { value :.3f} ."
42+ if name == "val_check_interval" :
43+ msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
44+
45+ if not 0. <= value <= 1. :
46+ raise ValueError (msg )
47+
3948 def init_train_dataloader (self , model ):
4049 """
4150 Dataloaders are provided by the model
@@ -48,6 +57,8 @@ def init_train_dataloader(self, model):
4857 if EXIST_ITER_DATASET and isinstance (self .get_train_dataloader ().dataset , IterableDataset ):
4958 self .num_training_batches = float ('inf' )
5059 else :
60+ self ._percent_range_check ('train_percent_check' )
61+
5162 self .num_training_batches = len (self .get_train_dataloader ())
5263 self .num_training_batches = int (self .num_training_batches * self .train_percent_check )
5364
@@ -56,7 +67,14 @@ def init_train_dataloader(self, model):
5667 # otherwise, it checks in [0, 1.0] % range of a training epoch
5768 if isinstance (self .val_check_interval , int ):
5869 self .val_check_batch = self .val_check_interval
70+ if self .val_check_batch > self .num_training_batches :
71+ raise ValueError (
72+ f"`val_check_interval` ({ self .val_check_interval } ) must be less than or equal "
73+ f"to the number of the training batches ({ self .num_training_batches } ). "
74+ f"If you want to disable validation set `val_percent_check` to 0.0 instead." )
5975 else :
76+ self ._percent_range_check ('val_check_interval' )
77+
6078 self .val_check_batch = int (self .num_training_batches * self .val_check_interval )
6179 self .val_check_batch = max (1 , self .val_check_batch )
6280
@@ -89,13 +107,15 @@ def init_val_dataloader(self, model):
89107 :return:
90108 """
91109 self .get_val_dataloaders = model .val_dataloader
110+ self .num_val_batches = 0
92111
93112 # determine number of validation batches
94113 # val datasets could be none, 1 or 2+
95114 if self .get_val_dataloaders () is not None :
115+ self ._percent_range_check ('val_percent_check' )
116+
96117 self .num_val_batches = sum (len (dataloader ) for dataloader in self .get_val_dataloaders ())
97118 self .num_val_batches = int (self .num_val_batches * self .val_percent_check )
98- self .num_val_batches = max (1 , self .num_val_batches )
99119
100120 on_ddp = self .use_ddp or self .use_ddp2
101121 if on_ddp and self .get_val_dataloaders () is not None :
@@ -134,10 +154,11 @@ def init_test_dataloader(self, model):
134154
135155 # determine number of test batches
136156 if self .get_test_dataloaders () is not None :
157+ self ._percent_range_check ('test_percent_check' )
158+
137159 len_sum = sum (len (dataloader ) for dataloader in self .get_test_dataloaders ())
138160 self .num_test_batches = len_sum
139161 self .num_test_batches = int (self .num_test_batches * self .test_percent_check )
140- self .num_test_batches = max (1 , self .num_test_batches )
141162
142163 on_ddp = self .use_ddp or self .use_ddp2
143164 if on_ddp and self .get_test_dataloaders () is not None :
@@ -208,6 +229,10 @@ def determine_data_use_amount(self, train_percent_check, val_percent_check,
208229 self .val_percent_check = val_percent_check
209230 self .test_percent_check = test_percent_check
210231 if overfit_pct > 0 :
232+ if overfit_pct > 1 :
233+ raise ValueError (f"`overfit_pct` must be not greater than 1.0, but got "
234+ f"{ overfit_pct :.3f} ." )
235+
211236 self .train_percent_check = overfit_pct
212237 self .val_percent_check = overfit_pct
213238 self .test_percent_check = overfit_pct
0 commit comments