2929PREFETCH_SIZE = 4096 # samples to prefetch
3030
3131
32+ def even_split_indices (split , n , num_samples ):
33+ partitions = [round (i * num_samples / n ) for i in range (n + 1 )]
34+ return [f"{ split } [{ partitions [i ]} :{ partitions [i + 1 ]} ]" for i in range (n )]
35+
36+
3237class ParserTfds (Parser ):
3338 """ Wrap Tensorflow Datasets for use in PyTorch
3439
@@ -52,7 +57,7 @@ class ParserTfds(Parser):
5257 components.
5358
5459 """
55- def __init__ (self , root , name , split = 'train' , shuffle = False , is_training = False , batch_size = None ):
60+ def __init__ (self , root , name , split = 'train' , shuffle = False , is_training = False , batch_size = None , repeats = 0 ):
5661 super ().__init__ ()
5762 self .root = root
5863 self .split = split
@@ -62,6 +67,8 @@ def __init__(self, root, name, split='train', shuffle=False, is_training=False,
6267 assert batch_size is not None ,\
6368 "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
6469 self .batch_size = batch_size
70+ self .repeats = repeats
71+ self .subsplit = None
6572
6673 self .builder = tfds .builder (name , data_dir = root )
6774 # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
@@ -95,6 +102,7 @@ def _lazy_init(self):
95102 if worker_info is not None :
96103 self .worker_info = worker_info
97104 num_workers = worker_info .num_workers
105+ global_num_workers = self .dist_num_replicas * num_workers
98106 worker_id = worker_info .id
99107
100108 # FIXME I need to spend more time figuring out the best way to distribute/split data across
@@ -114,19 +122,31 @@ def _lazy_init(self):
114122 # split = split + '[{}:]'.format(start)
115123 # else:
116124 # split = split + '[{}:{}]'.format(start, start + split_size)
117-
118- input_context = tf .distribute .InputContext (
119- num_input_pipelines = self .dist_num_replicas * num_workers ,
120- input_pipeline_id = self .dist_rank * num_workers + worker_id ,
121- num_replicas_in_sync = self .dist_num_replicas # FIXME does this have any impact?
122- )
123-
124- read_config = tfds .ReadConfig (input_context = input_context )
125- ds = self .builder .as_dataset (split = split , shuffle_files = self .shuffle , read_config = read_config )
125+ if not self .is_training and '[' not in self .split :
126+ # If not training, and split doesn't define a subsplit, manually split the dataset
127+ # for more even samples / worker
128+ self .subsplit = even_split_indices (self .split , global_num_workers , self .num_samples )[
129+ self .dist_rank * num_workers + worker_id ]
130+
131+ if self .subsplit is None :
132+ input_context = tf .distribute .InputContext (
133+ num_input_pipelines = self .dist_num_replicas * num_workers ,
134+ input_pipeline_id = self .dist_rank * num_workers + worker_id ,
135+ num_replicas_in_sync = self .dist_num_replicas # FIXME does this arg have any impact?
136+ )
137+ else :
138+ input_context = None
139+
140+ read_config = tfds .ReadConfig (
141+ shuffle_seed = 42 ,
142+ shuffle_reshuffle_each_iteration = True ,
143+ input_context = input_context )
144+ ds = self .builder .as_dataset (
145+ split = self .subsplit or self .split , shuffle_files = self .shuffle , read_config = read_config )
126146 # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
127147 ds .options ().experimental_threading .private_threadpool_size = max (1 , MAX_TP_SIZE // num_workers )
128148 ds .options ().experimental_threading .max_intra_op_parallelism = 1
129- if self .is_training :
149+ if self .is_training or self . repeats > 1 :
130150 # to prevent excessive drop_last batch behaviour w/ IterableDatasets
131151 # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
132152 ds = ds .repeat () # allow wrap around and break iteration manually
@@ -143,7 +163,7 @@ def __iter__(self):
143163 # This adds extra samples and will slightly alter validation results.
144164 # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
145165 # batches are produced (underlying tfds iter wraps around)
146- target_sample_count = math .ceil (self .num_samples / self ._num_pipelines )
166+ target_sample_count = math .ceil (max ( 1 , self . repeats ) * self .num_samples / self ._num_pipelines )
147167 if self .is_training :
148168 # round up to nearest batch_size per worker-replica
149169 target_sample_count = math .ceil (target_sample_count / self .batch_size ) * self .batch_size
@@ -160,8 +180,8 @@ def __iter__(self):
160180 if not self .is_training and self .dist_num_replicas and 0 < sample_count < target_sample_count :
161181 # Validation batch padding only done for distributed training where results are reduced across nodes.
162182 # For single process case, it won't matter if workers return different batch sizes.
163- # FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
164- assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
183+ # FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this
184+ # approach is not optimal
165185 yield img , sample ['label' ] # yield prev sample again
166186 sample_count += 1
167187
@@ -176,7 +196,7 @@ def _num_pipelines(self):
176196 def __len__ (self ):
177197 # this is just an estimate and does not factor in extra samples added to pad batches based on
178198 # complete worker & replica info (not available until init in dataloader).
179- return math .ceil (self .num_samples / self .dist_num_replicas )
199+ return math .ceil (max ( 1 , self . repeats ) * self .num_samples / self .dist_num_replicas )
180200
181201 def _filename (self , index , basename = False , absolute = False ):
182202 assert False , "Not supported" # no random access to samples
0 commit comments