Skip to content

Commit e5ce20a

Browse files
committed
Test hypers in architecture tests
1 parent 9e7a287 commit e5ce20a

File tree

6 files changed

+59
-14
lines changed

6 files changed

+59
-14
lines changed

src/metatrain/deprecated/nanopet/tests/test_functionality.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
from . import DEFAULT_HYPERS, MODEL_HYPERS
1818

1919

20+
def test_valid_defaults():
21+
"""Tests that the default hypers pass the architecture options check."""
22+
hypers = OmegaConf.create(DEFAULT_HYPERS)
23+
check_architecture_options(
24+
name="deprecated.nanopet", options=OmegaConf.to_container(hypers)
25+
)
26+
27+
2028
def test_nanopet_padding():
2129
"""Tests that the model predicts the same energy independently of the
2230
padding size."""

src/metatrain/experimental/flashmd/tests/test_functionality.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import torch
33
from metatensor.torch import Labels, TensorBlock, TensorMap
44
from metatomic.torch import ModelOutput, System
5+
from omegaconf import OmegaConf
56

67
from metatrain.experimental.flashmd.model import FlashMD
7-
from metatrain.utils.architectures import get_default_hypers
8+
from metatrain.utils.architectures import check_architecture_options, get_default_hypers
89
from metatrain.utils.data import DatasetInfo
910
from metatrain.utils.data.target_info import TargetInfo
1011
from metatrain.utils.neighbor_lists import (
@@ -13,6 +14,14 @@
1314
)
1415

1516

17+
def test_valid_defaults():
18+
"""Tests that the default hypers pass the architecture options check."""
19+
hypers = OmegaConf.create(get_default_hypers("experimental.flashmd"))
20+
check_architecture_options(
21+
name="experimental.flashmd", options=OmegaConf.to_container(hypers)
22+
)
23+
24+
1625
@pytest.mark.filterwarnings("ignore:custom data:UserWarning")
1726
def test_forward():
1827
"Run a forward pass of FlashMD on two small systems and verify the output shapes."
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from omegaconf import OmegaConf
2+
3+
from metatrain.utils.architectures import check_architecture_options
4+
5+
from . import DEFAULT_HYPERS
6+
7+
8+
def test_valid_defaults():
9+
"""Tests that the default hypers pass the architecture options check."""
10+
hypers = OmegaConf.create(DEFAULT_HYPERS)
11+
check_architecture_options(name="gap", options=OmegaConf.to_container(hypers))

src/metatrain/pet/tests/test_functionality.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
from . import DEFAULT_HYPERS, MODEL_HYPERS
1919

2020

21+
def test_valid_defaults():
22+
"""Tests that the default hypers pass the architecture options check."""
23+
hypers = OmegaConf.create(DEFAULT_HYPERS)
24+
check_architecture_options(name="pet", options=OmegaConf.to_container(hypers))
25+
26+
2127
def test_prediction():
2228
"""Tests the basic functionality of the forward pass of the model."""
2329
dataset_info = DatasetInfo(

src/metatrain/soap_bpnn/tests/test_functionality.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
from . import DEFAULT_HYPERS, MODEL_HYPERS
2323

2424

25+
def test_valid_defaults():
26+
"""Tests that the default hypers pass the architecture options check."""
27+
hypers = OmegaConf.create(DEFAULT_HYPERS)
28+
check_architecture_options(name="soap_bpnn", options=OmegaConf.to_container(hypers))
29+
30+
2531
def test_prediction_subset_elements():
2632
"""Tests that the model can predict on a subset of the elements it was trained
2733
on."""

tests/utils/test_architectures.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@ def test_get_architecture_path():
3636
assert get_architecture_path("soap_bpnn") == PACKAGE_ROOT / "soap_bpnn"
3737

3838

39-
@pytest.mark.parametrize("name", find_all_architectures())
40-
def test_get_default_hypers(name):
41-
"""Test that architecture hypers for all arches can be loaded."""
42-
if name == "llpr":
43-
# Skip this architecture as it is not a valid architecture but a wrapper
44-
return
39+
def test_get_default_hypers():
40+
"""Test that architecture default hypers can be loaded.
41+
42+
We use soap_bpnn as the test architecture to see if the function works.
43+
Other architectures might have dependencies, and therefore loading their
44+
default hypers could fail. The loading of default hypers should be
45+
tested in the tests of each architecture.
46+
"""
47+
name = "soap_bpnn"
4548
default_hypers = get_default_hypers(name)
4649
assert type(default_hypers) is dict
4750
assert default_hypers["name"] == name
@@ -79,7 +82,6 @@ def test_check_architecture_name_deprecated():
7982
[
8083
PACKAGE_ROOT / "soap_bpnn",
8184
PACKAGE_ROOT / "soap_bpnn" / "__init__.py",
82-
PACKAGE_ROOT / "soap_bpnn" / "default-hypers.yaml",
8385
],
8486
)
8587
def test_get_architecture_name(path_type, path):
@@ -100,12 +102,15 @@ def test_get_architecture_name_err_no_such_arch():
100102
get_architecture_name(path)
101103

102104

103-
@pytest.mark.parametrize("name", find_all_architectures())
104-
def test_check_valid_default_architecture_options(name):
105-
"""Test that all default hypers are according to the provided schema."""
106-
if name == "llpr":
107-
# Skip this architecture as it is not a valid architecture but a wrapper
108-
return
105+
def test_check_valid_default_architecture_options():
106+
"""Test that validating architecture options works.
107+
108+
We use soap_bpnn as the test architecture to see if the function works.
109+
Other architectures might have dependencies, and therefore loading their
110+
default hypers could fail. The loading of default hypers should be
111+
tested in the tests of each architecture.
112+
"""
113+
name = "soap_bpnn"
109114
options = get_default_hypers(name)
110115
check_architecture_options(name=name, options=options)
111116

0 commit comments

Comments
 (0)