Skip to content

Commit a649c6a

Browse files
committed
Test hypers in architecture tests
1 parent 9e7a287 commit a649c6a

File tree

6 files changed

+44
-16
lines changed

6 files changed

+44
-16
lines changed

src/metatrain/deprecated/nanopet/tests/test_functionality.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
from . import DEFAULT_HYPERS, MODEL_HYPERS
1818

19+
def test_valid_defaults():
20+
"""Tests that the default hypers pass the architecture options check."""
21+
hypers = OmegaConf.create(DEFAULT_HYPERS)
22+
check_architecture_options(name="deprecated.nanopet", options=OmegaConf.to_container(hypers))
1923

2024
def test_nanopet_padding():
2125
"""Tests that the model predicts the same energy independently of the

src/metatrain/experimental/flashmd/tests/test_functionality.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,21 @@
22
import torch
33
from metatensor.torch import Labels, TensorBlock, TensorMap
44
from metatomic.torch import ModelOutput, System
5+
from omegaconf import OmegaConf
56

67
from metatrain.experimental.flashmd.model import FlashMD
7-
from metatrain.utils.architectures import get_default_hypers
8+
from metatrain.utils.architectures import get_default_hypers, check_architecture_options
89
from metatrain.utils.data import DatasetInfo
910
from metatrain.utils.data.target_info import TargetInfo
1011
from metatrain.utils.neighbor_lists import (
1112
get_requested_neighbor_lists,
1213
get_system_with_neighbor_lists,
1314
)
1415

16+
def test_valid_defaults():
17+
"""Tests that the default hypers pass the architecture options check."""
18+
hypers = OmegaConf.create(get_default_hypers("experimental.flashmd"))
19+
check_architecture_options(name="experimental.flashmd", options=OmegaConf.to_container(hypers))
1520

1621
@pytest.mark.filterwarnings("ignore:custom data:UserWarning")
1722
def test_forward():
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from . import DEFAULT_HYPERS
2+
from metatrain.utils.architectures import check_architecture_options
3+
from omegaconf import OmegaConf
4+
5+
def test_valid_defaults():
6+
"""Tests that the default hypers pass the architecture options check."""
7+
hypers = OmegaConf.create(DEFAULT_HYPERS)
8+
check_architecture_options(name="gap", options=OmegaConf.to_container(hypers))

src/metatrain/pet/tests/test_functionality.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
from . import DEFAULT_HYPERS, MODEL_HYPERS
1919

20+
def test_valid_defaults():
21+
"""Tests that the default hypers pass the architecture options check."""
22+
hypers = OmegaConf.create(DEFAULT_HYPERS)
23+
check_architecture_options(name="pet", options=OmegaConf.to_container(hypers))
2024

2125
def test_prediction():
2226
"""Tests the basic functionality of the forward pass of the model."""

src/metatrain/soap_bpnn/tests/test_functionality.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121

2222
from . import DEFAULT_HYPERS, MODEL_HYPERS
2323

24+
def test_valid_defaults():
25+
"""Tests that the default hypers pass the architecture options check."""
26+
hypers = OmegaConf.create(DEFAULT_HYPERS)
27+
check_architecture_options(name="soap_bpnn", options=OmegaConf.to_container(hypers))
2428

2529
def test_prediction_subset_elements():
2630
"""Tests that the model can predict on a subset of the elements it was trained

tests/utils/test_architectures.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@ def test_find_all_architectures():
3535
def test_get_architecture_path():
3636
assert get_architecture_path("soap_bpnn") == PACKAGE_ROOT / "soap_bpnn"
3737

38-
39-
@pytest.mark.parametrize("name", find_all_architectures())
40-
def test_get_default_hypers(name):
41-
"""Test that architecture hypers for all arches can be loaded."""
42-
if name == "llpr":
43-
# Skip this architecture as it is not a valid architecture but a wrapper
44-
return
38+
def test_get_default_hypers():
39+
"""Test that architecture default hypers can be loaded.
40+
41+
We use soap_bpnn as the test architecture to see if the function works.
42+
Other architectures might have dependencies, and therefore loading their
43+
default hypers could fail. The loading of default hypers should be
44+
tested in the tests of each architecture.
45+
"""
46+
name = "soap_bpnn"
4547
default_hypers = get_default_hypers(name)
4648
assert type(default_hypers) is dict
4749
assert default_hypers["name"] == name
@@ -79,7 +81,6 @@ def test_check_architecture_name_deprecated():
7981
[
8082
PACKAGE_ROOT / "soap_bpnn",
8183
PACKAGE_ROOT / "soap_bpnn" / "__init__.py",
82-
PACKAGE_ROOT / "soap_bpnn" / "default-hypers.yaml",
8384
],
8485
)
8586
def test_get_architecture_name(path_type, path):
@@ -99,13 +100,15 @@ def test_get_architecture_name_err_no_such_arch():
99100
with pytest.raises(ValueError, match=match):
100101
get_architecture_name(path)
101102

102-
103-
@pytest.mark.parametrize("name", find_all_architectures())
104-
def test_check_valid_default_architecture_options(name):
105-
"""Test that all default hypers are according to the provided schema."""
106-
if name == "llpr":
107-
# Skip this architecture as it is not a valid architecture but a wrapper
108-
return
103+
def test_check_valid_default_architecture_options():
104+
"""Test that validating architecture options works.
105+
106+
We use soap_bpnn as the test architecture to see if the function works.
107+
Other architectures might have dependencies, and therefore loading their
108+
default hypers could fail. The loading of default hypers should be
109+
tested in the tests of each architecture.
110+
"""
111+
name = "soap_bpnn"
109112
options = get_default_hypers(name)
110113
check_architecture_options(name=name, options=options)
111114

0 commit comments

Comments
 (0)