|
3 | 3 | import torch |
4 | 4 | from metatomic.torch import ModelOutput, System |
5 | 5 | from omegaconf import OmegaConf |
| 6 | +from pydantic import ValidationError |
6 | 7 |
|
7 | | -from metatrain.pet import PET |
| 8 | +from metatrain.pet import PET, Trainer |
8 | 9 | from metatrain.pet.modules.transformer import AttentionBlock |
9 | | -from metatrain.utils.architectures import check_architecture_options |
10 | 10 | from metatrain.utils.data import DatasetInfo |
11 | 11 | from metatrain.utils.data.target_info import ( |
12 | 12 | get_energy_target_info, |
13 | 13 | get_generic_target_info, |
14 | 14 | ) |
15 | 15 | from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists |
| 16 | +from metatrain.utils.pydantic import validate_architecture_options |
16 | 17 |
|
17 | 18 | from . import DEFAULT_HYPERS, MODEL_HYPERS |
18 | 19 |
|
@@ -365,16 +366,16 @@ def test_fixed_composition_weights(): |
365 | 366 | } |
366 | 367 | } |
367 | 368 | hypers = OmegaConf.create(hypers) |
368 | | - check_architecture_options(name="pet", options=OmegaConf.to_container(hypers)) |
| 369 | + validate_architecture_options(OmegaConf.to_container(hypers), PET, Trainer) |
369 | 370 |
|
370 | 371 |
|
371 | 372 | def test_fixed_composition_weights_error(): |
372 | 373 | """Test that only input of type Dict[str, Dict[int, float]] are allowed.""" |
373 | 374 | hypers = DEFAULT_HYPERS.copy() |
374 | 375 | hypers["training"]["fixed_composition_weights"] = {"energy": {"H": 300.0}} |
375 | 376 | hypers = OmegaConf.create(hypers) |
376 | | - with pytest.raises(ValueError, match=r"'H' does not match '\^\[0-9\]\+\$'"): |
377 | | - check_architecture_options(name="pet", options=OmegaConf.to_container(hypers)) |
| 377 | + with pytest.raises(ValidationError, match=r"Input should be a valid integer"): |
| 378 | + validate_architecture_options(OmegaConf.to_container(hypers), PET, Trainer) |
378 | 379 |
|
379 | 380 |
|
380 | 381 | @pytest.mark.parametrize("per_atom", [True, False]) |
|
0 commit comments