Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source-pytorch/data/iterables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ To choose a different mode, you can use the :class:`~lightning.pytorch.utilities
trainer.fit(model, combined_loader)


Currently, the ``trainer.predict`` method only supports the ``"sequential"`` mode, while ``trainer.fit`` method does not support it.
Currently, the ``trainer.fit`` method does not support the ``"sequential"`` mode.
Support for this feature is tracked in this `issue <https://github.com/Lightning-AI/lightning/issues/16830>`__.

Note that when using the ``"sequential"`` mode, you need to add an additional argument ``dataloader_idx`` to some specific hooks.
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Updated `LearningRateMonitor` to log monitored values to `trainer.callback_metrics` ([#17626](https://github.com/Lightning-AI/lightning/pull/17626))

- Added support for the `max_size_cycle|max_size|min_size` iteration modes during prediction ([#17749](https://github.com/Lightning-AI/lightning/pull/17749))

### Changed

Expand Down
55 changes: 39 additions & 16 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None:
self.epoch_batch_indices: List[List[List[int]]] = []
self.current_batch_indices: List[int] = [] # used by PredictionWriter
self.batch_progress = _Progress() # across dataloaders
self.max_batches: List[Union[int, float]] = []
# list in "sequential" mode, number otherwise
self.max_batches: Union[int, float, List[Union[int, float]]] = []

self._warning_cache = WarningCache()
self._data_source = _DataLoaderSource(None, "predict_dataloader")
Expand Down Expand Up @@ -94,7 +95,12 @@ def num_dataloaders(self) -> int:

@property
def skip(self) -> bool:
return sum(self.max_batches) == 0
return sum(self.max_batches) == 0 if isinstance(self.max_batches, list) else self.max_batches == 0

@property
def _is_sequential(self) -> bool:
assert self._combined_loader is not None
return self._combined_loader._mode == "sequential"

@_no_grad_context
def run(self) -> Optional[_PREDICT_OUTPUT]:
Expand All @@ -107,7 +113,17 @@ def run(self) -> Optional[_PREDICT_OUTPUT]:
assert data_fetcher is not None
while True:
try:
batch, batch_idx, dataloader_idx = next(data_fetcher)
if self._is_sequential:
batch, batch_idx, dataloader_idx = next(data_fetcher)
else:
batch_idx = (
data_fetcher.fetched
if isinstance(data_fetcher, _DataLoaderIterDataFetcher)
else self.batch_progress.current.ready
)
batch = next(data_fetcher)
dataloader_idx = 0

self.batch_progress.is_last_batch = data_fetcher.done
self._predict_step(batch, batch_idx, dataloader_idx)
except StopIteration:
Expand Down Expand Up @@ -139,19 +155,24 @@ def setup_data(self) -> None:
trainer_fn = TrainerFn.PREDICTING
stage = RunningStage.PREDICTING
dataloaders = []
self.max_batches = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, trainer_fn, stage, dl)
dataloaders.append(dl)

# determine number of batches
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf")
num_batches = _parse_num_batches(stage, length, trainer.limit_predict_batches)
self.max_batches.append(num_batches)
combined_loader.flattened = dataloaders
self._combined_loader = combined_loader

if self._is_sequential:
self.max_batches = []
for dl in combined_loader.flattened:
# determine number of batches
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf")
num_batches = _parse_num_batches(stage, length, trainer.limit_predict_batches)
self.max_batches.append(num_batches)
else:
has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length)
self.max_batches = len(combined_loader) if has_len_all_ranks_ else float("inf")

def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
self.batch_progress.reset_on_run()
Expand All @@ -163,13 +184,13 @@ def reset(self) -> None:
)
combined_loader = self._combined_loader
assert combined_loader is not None
if combined_loader._mode != "sequential":
raise ValueError('`trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.')

data_fetcher.setup(combined_loader)
iter(data_fetcher) # creates the iterator inside the fetcher
assert isinstance(combined_loader._iterator, _Sequential)
# set the per-dataloader limits
combined_loader._iterator.limits = self.max_batches
if isinstance(combined_loader._iterator, _Sequential):
# set the per-dataloader limits
combined_loader._iterator.limits = self.max_batches

# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
data_fetcher._start_profiler = self._on_before_fetch
Expand Down Expand Up @@ -217,7 +238,9 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None

any_on_epoch = self._store_data_for_prediction_writer(batch_idx, dataloader_idx)

step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None)
step_kwargs = self._build_kwargs(
batch, batch_idx, dataloader_idx if self._is_sequential and self.num_dataloaders > 1 else None
)

call._call_callback_hooks(trainer, "on_predict_batch_start", *step_kwargs.values())
call._call_lightning_module_hook(trainer, "on_predict_batch_start", *step_kwargs.values())
Expand Down Expand Up @@ -339,7 +362,7 @@ def _verify_dataloader_idx_requirement(self) -> None:
assert self._combined_loader is not None
_verify_dataloader_idx_requirement(
("predict_step", "on_predict_batch_start", "on_predict_batch_end"),
self._combined_loader._mode == "sequential" and self.num_dataloaders > 1,
self._is_sequential and self.num_dataloaders > 1,
RunningStage.PREDICTING,
trainer.lightning_module,
)
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,7 +1491,7 @@ def num_test_batches(self) -> Union[int, float, List[Union[int, float]]]:
return self.test_loop.max_batches

@property
def num_predict_batches(self) -> List[Union[int, float]]:
def num_predict_batches(self) -> Union[int, float, List[Union[int, float]]]:
"""The number of prediction batches that will be used during ``trainer.predict()``."""
return self.predict_loop.max_batches

Expand Down
28 changes: 28 additions & 0 deletions tests/tests_pytorch/loops/test_prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.utilities import CombinedLoader


def test_prediction_loop_stores_predictions(tmp_path):
Expand Down Expand Up @@ -286,3 +287,30 @@ def on_predict_batch_end(self, outputs, batch, batch_idx, **_):
model = IgnoringModel2()
with pytest.raises(RuntimeError, match="no `dataloader_idx` argument in `IgnoringModel2.on_predict_batch_end"):
trainer.predict(model)


@pytest.mark.parametrize(
("mode", "expected"),
[
("max_size_cycle", [{"a": 0, "b": 3}, {"a": 1, "b": 4}, {"a": 2, "b": 3}]),
("min_size", [{"a": 0, "b": 3}, {"a": 1, "b": 4}]),
("max_size", [{"a": 0, "b": 3}, {"a": 1, "b": 4}, {"a": 2, "b": None}]),
],
)
def test_prediction_loop_non_sequential_mode_supprt(tmp_path, mode, expected):
iterables = {"a": [0, 1, 2], "b": {3, 4}}
cl = CombinedLoader(iterables, mode)
seen = []

class MyModel(BoringModel):
def predict_step(self, batch, batch_idx):
seen.append(batch)

model = MyModel()
trainer = Trainer(default_root_dir=tmp_path, barebones=True)

trainer.predict(model, cl)

actual = trainer.num_predict_batches
assert actual == (2 if mode == "min_size" else 3)
assert seen == expected