Skip to content

Commit 6c70dd7

Browse files
authored
Fix attribute error on _NotYetLoadedTensor after loading checkpoint into quantized model with _lazy_load() (#20121)
1 parent f91349c commit 6c70dd7

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535

3636
### Fixed
3737

38-
-
38+
- Fixed an attribute error when loading a checkpoint into a quantized model using the `_lazy_load()` function ([#20121](https://github.com/Lightning-AI/lightning/pull/20121))
39+
3940

4041
-
4142

src/lightning/fabric/utilities/load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __getattr__(self, name: str) -> Any:
160160
return getattr(self.metatensor, name)
161161

162162
# materializing these is needed for quantization (see lit-gpt)
163-
if name in {"contiguous", "cuda", "half"}:
163+
if name in {"contiguous", "cuda", "half", "data"}:
164164
return getattr(self._load_tensor(), name)
165165

166166
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

tests/tests_fabric/plugins/precision/test_bitsandbytes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from lightning.fabric.connector import _Connector
2323
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
2424
from lightning.fabric.utilities.init import _materialize_meta_tensors
25+
from lightning.fabric.utilities.load import _lazy_load
2526

2627
from tests_fabric.helpers.runif import RunIf
2728

@@ -264,3 +265,10 @@ def forward(self, x):
264265
assert model.linear.weight.shape == (128, 1)
265266
# Shapes match during forward (weight is being dequantized during forward)
266267
model(torch.randn(2, 16, device=fabric.device))
268+
269+
# Test with lazy load (LitGPT uses this)
270+
# TODO: Replace `_lazy_load` with `torch.load(..., mmap=True)` in LitGPT
271+
state_dict = _lazy_load(tmp_path / "checkpoint.pt")
272+
model.load_state_dict(state_dict)
273+
assert model.linear.weight.dtype == torch.uint8
274+
assert model.linear.weight.shape == (128, 1)

0 commit comments

Comments
 (0)