Skip to content

Commit 45d8168

Browse files
allow registration of custom checkpoint conversion mappings (#42634)
* allow registration of custom checkpoint conversion mappings * add tests * chore: lint * move tests to test_core_model_loading.py * fixup --------- Co-authored-by: Arthur <arthur.zucker@gmail.com>
1 parent 8ebfd84 commit 45d8168

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

src/transformers/conversion_mapping.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,22 @@ def _build_checkpoint_conversion_mapping():
186186

187187
def get_checkpoint_conversion_mapping(model_type):
188188
global _checkpoint_conversion_mapping_cache
189-
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
189+
if _checkpoint_conversion_mapping_cache is None:
190+
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
190191
return deepcopy(_checkpoint_conversion_mapping_cache.get(model_type))
191192

192193

194+
def register_checkpoint_conversion_mapping(
195+
model_type: str, mapping: list[WeightConverter | WeightRenaming], overwrite: bool = False
196+
) -> None:
197+
global _checkpoint_conversion_mapping_cache
198+
if _checkpoint_conversion_mapping_cache is None:
199+
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
200+
if model_type in _checkpoint_conversion_mapping_cache and not overwrite:
201+
raise ValueError(f"Model type {model_type} already exists in the checkpoint conversion mapping.")
202+
_checkpoint_conversion_mapping_cache[model_type] = mapping
203+
204+
193205
# DO NOT MODIFY, KEPT FOR BC ONLY
194206
VLMS = [
195207
"aria",

tests/utils/test_core_model_loading.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn as nn
1919

2020
from transformers import PretrainedConfig
21+
from transformers.conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping
2122
from transformers.core_model_loading import (
2223
Chunk,
2324
Concatenate,
@@ -505,5 +506,43 @@ def __init__(self):
505506
torch.testing.assert_close(dequantized_q, expected_q, rtol=1e-2, atol=1e-2)
506507

507508

509+
class TestConversionMapping(unittest.TestCase):
510+
def test_register_checkpoint_conversion_mapping(self):
511+
register_checkpoint_conversion_mapping(
512+
"foobar",
513+
[
514+
WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"),
515+
],
516+
)
517+
self.assertEqual(len(get_checkpoint_conversion_mapping("foobar")), 1)
518+
519+
def test_register_checkpoint_conversion_mapping_overwrites(self):
520+
register_checkpoint_conversion_mapping(
521+
"foobarbaz",
522+
[
523+
WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"),
524+
],
525+
)
526+
with self.assertRaises(ValueError):
527+
register_checkpoint_conversion_mapping(
528+
"foobarbaz",
529+
[
530+
WeightRenaming(".block_sparse_moe.foo", ".mlp.foo"),
531+
WeightRenaming(".block_sparse_moe.bar", ".mlp.bar"),
532+
],
533+
)
534+
535+
register_checkpoint_conversion_mapping(
536+
"foobarbaz",
537+
[
538+
WeightRenaming(".block_sparse_moe.foo", ".mlp.foo"),
539+
WeightRenaming(".block_sparse_moe.bar", ".mlp.bar"),
540+
],
541+
overwrite=True,
542+
)
543+
544+
self.assertEqual(len(get_checkpoint_conversion_mapping("foobarbaz")), 2)
545+
546+
508547
if __name__ == "__main__":
509548
unittest.main()

0 commit comments

Comments
 (0)