|
18 | 18 | import torch.nn as nn |
19 | 19 |
|
20 | 20 | from transformers import PretrainedConfig |
| 21 | +from transformers.conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping |
21 | 22 | from transformers.core_model_loading import ( |
22 | 23 | Chunk, |
23 | 24 | Concatenate, |
@@ -505,5 +506,43 @@ def __init__(self): |
505 | 506 | torch.testing.assert_close(dequantized_q, expected_q, rtol=1e-2, atol=1e-2) |
506 | 507 |
|
507 | 508 |
|
| 509 | +class TestConversionMapping(unittest.TestCase): |
| 510 | + def test_register_checkpoint_conversion_mapping(self): |
| 511 | + register_checkpoint_conversion_mapping( |
| 512 | + "foobar", |
| 513 | + [ |
| 514 | + WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"), |
| 515 | + ], |
| 516 | + ) |
| 517 | + self.assertEqual(len(get_checkpoint_conversion_mapping("foobar")), 1) |
| 518 | + |
| 519 | + def test_register_checkpoint_conversion_mapping_overwrites(self): |
| 520 | + register_checkpoint_conversion_mapping( |
| 521 | + "foobarbaz", |
| 522 | + [ |
| 523 | + WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"), |
| 524 | + ], |
| 525 | + ) |
| 526 | + with self.assertRaises(ValueError): |
| 527 | + register_checkpoint_conversion_mapping( |
| 528 | + "foobarbaz", |
| 529 | + [ |
| 530 | + WeightRenaming(".block_sparse_moe.foo", ".mlp.foo"), |
| 531 | + WeightRenaming(".block_sparse_moe.bar", ".mlp.bar"), |
| 532 | + ], |
| 533 | + ) |
| 534 | + |
| 535 | + register_checkpoint_conversion_mapping( |
| 536 | + "foobarbaz", |
| 537 | + [ |
| 538 | + WeightRenaming(".block_sparse_moe.foo", ".mlp.foo"), |
| 539 | + WeightRenaming(".block_sparse_moe.bar", ".mlp.bar"), |
| 540 | + ], |
| 541 | + overwrite=True, |
| 542 | + ) |
| 543 | + |
| 544 | + self.assertEqual(len(get_checkpoint_conversion_mapping("foobarbaz")), 2) |
| 545 | + |
| 546 | + |
508 | 547 | if __name__ == "__main__": |
509 | 548 | unittest.main() |
0 commit comments