Skip to content

Commit 55f3ffd

Browse files
MikeScarpwilliamFalcon
authored andcommitted
fixing bug in testing for IterableDataset (#547)
1 parent 4627887 commit 55f3ffd

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_lightning/trainer/data_loading_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def init_train_dataloader(self, model):
2424
self.get_train_dataloader = model.train_dataloader
2525

2626
# determine number of training batches
27-
if isinstance(self.get_train_dataloader(), IterableDataset):
27+
if isinstance(self.get_train_dataloader().dataset, IterableDataset):
2828
self.nb_training_batches = float('inf')
2929
else:
3030
self.nb_training_batches = len(self.get_train_dataloader())
@@ -167,7 +167,7 @@ def get_dataloaders(self, model):
167167
self.get_val_dataloaders()
168168

169169
# support IterableDataset for train data
170-
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset)
170+
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader().dataset, IterableDataset)
171171
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
172172
m = '''
173173
When using an iterableDataset for train_dataloader,

0 commit comments

Comments
 (0)