Skip to content

Commit 9cc104e

Browse files
committed
Added correct typehints for scaler and composition weights
1 parent 46eb6b9 commit 9cc104e

File tree

10 files changed

+31
-65
lines changed

10 files changed

+31
-65
lines changed

src/metatrain/pet/hypers.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55

66
from typing_extensions import TypedDict
77

8-
from metatrain.utils.hypers import (
9-
CompositionWeightsDict,
10-
ScalingWeightsDict,
11-
init_with_defaults,
12-
)
8+
from metatrain.utils.additive import FixedCompositionWeights
9+
from metatrain.utils.hypers import init_with_defaults
1310
from metatrain.utils.long_range import LongRangeHypers
1411
from metatrain.utils.loss import LossSpecification
12+
from metatrain.utils.scaler import FixedScalerWeights
1513

1614
from .modules.finetuning import FinetuneHypers
1715

@@ -117,10 +115,10 @@ class PETTrainerHypers(TypedDict):
117115
"""Interval to save checkpoints."""
118116
scale_targets: bool = True
119117
"""Normalize targets to unit std during training."""
120-
fixed_composition_weights: CompositionWeightsDict = {}
118+
fixed_composition_weights: FixedCompositionWeights = {}
121119
"""Weights for atomic contributions."""
122-
fixed_scaling_weights: ScalingWeightsDict = {}
123-
120+
fixed_scaling_weights: FixedScalerWeights = {}
121+
"""Weights for target scaling."""
124122
per_structure_targets: list[str] = []
125123
"""Targets to calculate per-structure losses."""
126124
num_workers: Optional[int] = None

src/metatrain/pet/modules/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class LoRaFinetuneHypers(TypedDict):
5454

5555
class HeadsFinetuneHypers(TypedDict):
5656
"""Hyperparameters for heads finetuning of PET models.
57-
57+
5858
Freezes all model parameters except for the prediction heads
5959
and last layers.
6060
"""

src/metatrain/pet/tests/test_functionality.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
import torch
44
from metatomic.torch import ModelOutput, System
55
from omegaconf import OmegaConf
6+
from pydantic import ValidationError
67

7-
from metatrain.pet import PET
8+
from metatrain.pet import PET, Trainer
89
from metatrain.pet.modules.transformer import AttentionBlock
9-
from metatrain.utils.architectures import check_architecture_options
1010
from metatrain.utils.data import DatasetInfo
1111
from metatrain.utils.data.target_info import (
1212
get_energy_target_info,
1313
get_generic_target_info,
1414
)
1515
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
16+
from metatrain.utils.pydantic import validate_architecture_options
1617

1718
from . import DEFAULT_HYPERS, MODEL_HYPERS
1819

@@ -365,16 +366,16 @@ def test_fixed_composition_weights():
365366
}
366367
}
367368
hypers = OmegaConf.create(hypers)
368-
check_architecture_options(name="pet", options=OmegaConf.to_container(hypers))
369+
validate_architecture_options(OmegaConf.to_container(hypers), PET, Trainer)
369370

370371

371372
def test_fixed_composition_weights_error():
372373
"""Test that only input of type Dict[str, Dict[int, float]] are allowed."""
373374
hypers = DEFAULT_HYPERS.copy()
374375
hypers["training"]["fixed_composition_weights"] = {"energy": {"H": 300.0}}
375376
hypers = OmegaConf.create(hypers)
376-
with pytest.raises(ValueError, match=r"'H' does not match '\^\[0-9\]\+\$'"):
377-
check_architecture_options(name="pet", options=OmegaConf.to_container(hypers))
377+
with pytest.raises(ValidationError, match=r"Input should be a valid integer"):
378+
validate_architecture_options(OmegaConf.to_container(hypers), PET, Trainer)
378379

379380

380381
@pytest.mark.parametrize("per_atom", [True, False])
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .composition import CompositionModel # noqa: F401
1+
from .composition import CompositionModel, FixedCompositionWeights # noqa: F401
22
from .remove import get_remove_additive_transform, remove_additive # noqa: F401
33
from .zbl import ZBL # noqa: F401

src/metatrain/utils/additive/_base_composition.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from metatomic.torch import ModelOutput, System
1313

1414

15+
FixedCompositionWeights = dict[str, dict[int, float]]
16+
17+
1518
class BaseCompositionModel(torch.nn.Module):
1619
"""
1720
Fits a composition model for a dict of targets.
@@ -270,7 +273,7 @@ def accumulate(
270273

271274
def fit(
272275
self,
273-
fixed_weights: Optional[Dict[str, Dict[int, float]]] = None,
276+
fixed_weights: Optional[FixedCompositionWeights] = None,
274277
targets_to_fit: Optional[List[str]] = None,
275278
) -> None:
276279
"""

src/metatrain/utils/additive/composition.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from ..data import DatasetInfo, TargetInfo, unpack_batch
1818
from ..jsonschema import validate
1919
from ..transfer import batch_to
20-
from ._base_composition import BaseCompositionModel, _include_key
20+
from ._base_composition import (
21+
BaseCompositionModel,
22+
FixedCompositionWeights,
23+
_include_key,
24+
)
2125
from .remove import remove_additive
2226

2327

@@ -175,7 +179,7 @@ def train_model(
175179
additive_models: List[torch.nn.Module],
176180
batch_size: int,
177181
is_distributed: bool,
178-
fixed_weights: Optional[Dict[str, Dict[int, float]]] = None,
182+
fixed_weights: Optional[FixedCompositionWeights] = None,
179183
) -> None:
180184
"""
181185
Train the composition model on the provided training data in the ``datasets``.

src/metatrain/utils/hypers.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,46 +25,3 @@ def get_hypers_annotation(module: type) -> Type[dict]:
2525
:return: The TypedDict annotation of the 'hypers' parameter.
2626
"""
2727
return inspect.signature(module).parameters["hypers"].annotation
28-
29-
30-
# Generic dictionary types to use as annotations in hypers
31-
ScalingWeightsDict = dict
32-
# class ScalingWeightsDict(dict):
33-
# """Dictionary type for scaling weights."""
34-
35-
# @classmethod
36-
# def get_json_schema(cls) -> dict:
37-
# return {
38-
# "type": "object",
39-
# "patternProperties": {
40-
# "^.*$": {
41-
# "anyOf": [
42-
# {
43-
# "type": "object",
44-
# "propertyNames": {"pattern": "^[0-9]+$"},
45-
# "additionalProperties": {"type": "number"},
46-
# },
47-
# {"type": "number"},
48-
# ]
49-
# }
50-
# },
51-
# "additionalProperties": False,
52-
# }
53-
54-
CompositionWeightsDict = dict
55-
# class CompositionWeightsDict(dict):
56-
# """Dictionary type for composition weights."""
57-
58-
# @classmethod
59-
# def get_json_schema(cls) -> dict:
60-
# return {
61-
# "type": "object",
62-
# "patternProperties": {
63-
# "^.*$": {
64-
# "type": "object",
65-
# "propertyNames": {"pattern": "^[0-9]+$"},
66-
# "additionalProperties": {"type": "number"},
67-
# }
68-
# },
69-
# "additionalProperties": False,
70-
# }
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .remove import get_remove_scale_transform, remove_scale # noqa: F401
2-
from .scaler import Scaler # noqa: F401
2+
from .scaler import FixedScalerWeights, Scaler # noqa: F401

src/metatrain/utils/scaler/_base_scaler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from metatomic.torch import System
1212

1313

14+
FixedScalerWeights = dict[str, Union[float, dict[int, float]]]
15+
16+
1417
class BaseScaler(torch.nn.Module):
1518
"""
1619
Fits a scaler for a dict of targets. Scales are computed as the per-property (and
@@ -236,7 +239,7 @@ def accumulate(
236239

237240
def fit(
238241
self,
239-
fixed_weights: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
242+
fixed_weights: Optional[FixedScalerWeights] = None,
240243
targets_to_fit: Optional[List[str]] = None,
241244
) -> None:
242245
"""

src/metatrain/utils/scaler/scaler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ..data import DatasetInfo, TargetInfo, unpack_batch
1818
from ..jsonschema import validate
1919
from ..transfer import batch_to
20-
from ._base_scaler import BaseScaler
20+
from ._base_scaler import BaseScaler, FixedScalerWeights
2121

2222

2323
class Scaler(torch.nn.Module):
@@ -139,7 +139,7 @@ def train_model(
139139
additive_models: List[torch.nn.Module],
140140
batch_size: int,
141141
is_distributed: bool,
142-
fixed_weights: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
142+
fixed_weights: Optional[FixedScalerWeights] = None,
143143
) -> None:
144144
"""
145145
Placeholder docs.

0 commit comments

Comments
 (0)