File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
pytorch_lightning/trainer Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -24,7 +24,7 @@ def init_train_dataloader(self, model):
2424 self .get_train_dataloader = model .train_dataloader
2525
2626 # determine number of training batches
27- if isinstance (self .get_train_dataloader (), IterableDataset ):
27+ if isinstance (self .get_train_dataloader (). dataset , IterableDataset ):
2828 self .nb_training_batches = float ('inf' )
2929 else :
3030 self .nb_training_batches = len (self .get_train_dataloader ())
@@ -167,7 +167,7 @@ def get_dataloaders(self, model):
167167 self .get_val_dataloaders ()
168168
169169 # support IterableDataset for train data
170- self .is_iterable_train_dataloader = isinstance (self .get_train_dataloader (), IterableDataset )
170+ self .is_iterable_train_dataloader = isinstance (self .get_train_dataloader (). dataset , IterableDataset )
171171 if self .is_iterable_train_dataloader and not isinstance (self .val_check_interval , int ):
172172 m = '''
173173 When using an iterableDataset for train_dataloader,
You can’t perform that action at this time.
0 commit comments