Skip to content

mark_forward_method does not work with ModelParallelStrategy #20710

@tonyf

Description

@tonyf

Bug description

When using the ModelParallelStrategy, methods annotated with mark_forward_method raise an exception if the function signature does not match that of the module's forward method. This fails specifically when the number of args/kwargs differ between the functions.

For calling generate here would fail in an FSDP2 setting with the error TypeError: Model.forward got an unexpected keyword argument cfg

class Model(nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return x def generate(self, x, y, cfg: int = 0.5): z_1 = self.forward(x, y) z_2 = self.foward(x, torch.zeros_like(y)) ... 

What version are you seeing the problem on?

v2.5

Error messages and logs

 │ [rank0]: │ 473 │ │ ): │ [rank0]: │ 474 │ │ │ self.callbacks.on_validation_step_start(self, batch_idx) │ [rank0]: │ 475 │ │ │ │ [rank0]: │ ❱ 476 │ │ │ result = self.validation_step(batch, batch_idx) │ [rank0]: │ 477 │ │ │ self.callbacks.on_validation_step_end(self, result, batch_idx) │ [rank0]: │ 478 │ │ │ [rank0]: │ 479 │ │ result = self.on_validation_epoch_end() │ [rank0]: │ │ [rank0]: │ /home/tony/workspace/models/models/flow_matching/stage_1_train.py:112 in validation_step │ [rank0]: │ │ [rank0]: │ 109 │ │ B, _, T, H, W = samples.shape │ [rank0]: │ 110 │ │ ct, ch, cw = self.autoencoder.compression │ [rank0]: │ 111 │ │ │ [rank0]: │ ❱ 112 │ │ samples = self.model.sample( │ [rank0]: │ 113 │ │ │ shape=(B, (T - 1) // ct + 1, H // ch, W // cw, self.autoencoder.latent_dim), │ [rank0]: │ 114 │ │ │ text=text_embeds, │ [rank0]: │ 115 │ │ │ sample_steps=self.config.sample_steps, │ [rank0]: │ │ [rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:197 in call_forward_module │ [rank0]: │ │ [rank0]: │ 194 │ │ def call_forward_module(*args: Any, **kwargs: Any) -> Any: │ [rank0]: │ 195 │ │ │ # Patch the original_module's forward, so we can redirect the arguments back │ [rank0]: │ 196 │ │ │ self._original_module.forward = wrapped_forward │ [rank0]: │ ❱ 197 │ │ │ return self.forward(*args, **kwargs) │ [rank0]: │ 198 │ │ │ [rank0]: │ 199 │ │ return call_forward_module │ [rank0]: │ 200 │ [rank0]: │ │ [rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:136 in forward │ [rank0]: │ │ [rank0]: │ 133 │ │ args, kwargs = precision.convert_input((args, kwargs)) │ [rank0]: │ 134 │ │ │ [rank0]: │ 135 │ │ with precision.forward_context(): │ [rank0]: │ ❱ 136 │ │ │ output = self._forward_module(*args, **kwargs) │ [rank0]: │ 137 │ │ │ [rank0]: │ 138 │ │ output = precision.convert_output(output) │ [rank0]: │ 139 │ [rank0]: │ │ [rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl │ [rank0]: │ │ [rank0]: │ 1736 │ │ if self._compiled_call_impl is not None: │ [rank0]: │ 1737 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │ [rank0]: │ 1738 │ │ else: │ [rank0]: │ ❱ 1739 │ │ │ return self._call_impl(*args, **kwargs) │ [rank0]: │ 1740 │ │ [rank0]: │ 1741 │ # torchrec tests the code consistency with the following code │ [rank0]: │ 1742 │ # fmt: off │ [rank0]: │ │ [rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl │ [rank0]: │ │ [rank0]: │ 1747 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │ [rank0]: │ 1748 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │ [rank0]: │ 1749 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ [rank0]: │ ❱ 1750 │ │ │ return forward_call(*args, **kwargs) │ [rank0]: │ 1751 │ │ │ [rank0]: │ 1752 │ │ result = None │ [rank0]: │ 1753 │ │ called_always_called_hooks = set() │ [rank0]: │ │ [rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:574 in _fn │ [rank0]: │ │ [rank0]: │ 571 │ │ │ ) │ [rank0]: │ 572 │ │ │ │ [rank0]: │ 573 │ │ │ try: │ [rank0]: │ ❱ 574 │ │ │ │ return fn(*args, **kwargs) │ [rank0]: │ 575 │ │ │ finally: │ [rank0]: │ 576 │ │ │ │ # Restore the dynamic layer stack depth if necessary. │ [rank0]: │ 577 │ │ │ │ torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( │ [rank0]: │ │ [rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl │ [rank0]: │ │ [rank0]: │ 1736 │ │ if self._compiled_call_impl is not None: │ [rank0]: │ 1737 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │ [rank0]: │ 1738 │ │ else: │ [rank0]: │ ❱ 1739 │ │ │ return self._call_impl(*args, **kwargs) │ [rank0]: │ 1740 │ │ [rank0]: │ 1741 │ # torchrec tests the code consistency with the following code │ [rank0]: │ 1742 │ # fmt: off │ [rank0]: │ │ [rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl │ [rank0]: │ │ [rank0]: │ 1747 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │ [rank0]: │ 1748 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │ [rank0]: │ 1749 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ [rank0]: │ ❱ 1750 │ │ │ return forward_call(*args, **kwargs) │ [rank0]: │ 1751 │ │ │ [rank0]: │ 1752 │ │ result = None │ [rank0]: │ 1753 │ │ called_always_called_hooks = set() │ [rank0]: ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ [rank0]: TypeError: Rem.forward() got an unexpected keyword argument 'shape' 

Environment

Current environment
#- PyTorch Lightning Version: 2.5.0.post #- PyTorch Version: 2.6.0+cu124 #- Python version: 3.11 #- OS: Linux #- CUDA/cuDNN version: 12.4 #- GPU models and configuration: 8xH100 #- How you installed Lightning(`conda`, `pip`, source): pip 

More info

No response

cc @justusschock @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdistributedGeneric distributed-related topicstrategy: fsdpFully Sharded Data Parallelver: 2.5.xwaiting on authorWaiting on user action, correction, or update

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions