Skip to content
Merged
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093))


- Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164))


### Fixed

- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/fabric/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
from torch.nn import Module
from typing_extensions import Self, override

from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3


class _DeviceDtypeModuleMixin(Module):
__jit_unused_properties__: list[str] = ["device", "dtype"]

def __init__(self) -> None:
super().__init__()
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
self._device = torch.device("cpu")
# Workarounds from the original pytorch issue:
# https://github.com/pytorch/pytorch/issues/115333#issuecomment-1848449687
self._device = torch.get_default_device() if _TORCH_GREATER_EQUAL_2_3 else torch.empty(0).device

@property
def dtype(self) -> Union[str, torch.dtype]:
Expand Down
25 changes: 25 additions & 0 deletions tests/tests_fabric/utilities/test_device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch import nn as nn

from lightning.fabric.plugins.precision.utils import _DtypeContextManager
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -50,6 +51,30 @@ def test_submodules_device_and_dtype(dst_device_str, dst_type):
assert model.dtype == model.module.module.dtype == dst_type


@pytest.mark.parametrize(
"dst_device_str",
[
"cpu",
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps:0", marks=RunIf(mps=True)),
],
)
@pytest.mark.parametrize(
"dst_type",
[
torch.float,
pytest.param(torch.half, marks=RunIf(mps=False)),
pytest.param(torch.double, marks=RunIf(mps=False)),
],
)
def test_submodules_context_device_and_dtype(dst_device_str, dst_type):
dst_device = torch.device(dst_device_str)
with _DtypeContextManager(dst_type), dst_device:
model = TopModule()
assert model.device == dst_device
assert model.dtype == dst_type


@pytest.mark.parametrize(
"device",
[
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,22 @@ def test_init_module_context(monkeypatch):
strategy.tensor_init_context.reset_mock()


@pytest.mark.parametrize(
("target_device", "accelerator", "devices"),
[
("cpu", "cpu", "auto"),
pytest.param("cuda:0", "gpu", [0], marks=RunIf(min_cuda_gpus=1)),
pytest.param("cuda:1", "gpu", [1], marks=RunIf(min_cuda_gpus=2)),
],
)
def test_init_module_device_type(target_device, accelerator, devices):
"""Test that the strategy returns the context manager for initializing the module."""
trainer = Trainer(accelerator=accelerator, devices=devices)
with trainer.init_module():
model = BoringModel()
assert model.device == torch.device(target_device)


def test_expand_home_trainer():
"""Test that the dirpath gets expanded if it contains `~`."""
home_root = Path.home()
Expand Down
Loading