Skip to content

Commit 05ac42c

Browse files
committed
Documenting hyperparameters in python
1 parent 4ef6a91 commit 05ac42c

File tree

10 files changed

+582
-197
lines changed

10 files changed

+582
-197
lines changed

docs/src/architectures/pet.rst

Lines changed: 68 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -25,108 +25,81 @@ This will install the PET model along with the ``metatrain`` package.
2525
Default Hyperparameters
2626
-----------------------
2727

28-
The default hyperparameters for the PET model are:
28+
The description of all the hyperparameters used in PET is provided further
29+
down this page. However, here we provide you with a yaml file containing all
30+
the default hyperparameters, which might be convenient as a starting point to
31+
create your own hyperparameter files:
2932

3033
.. literalinclude:: ../../../src/metatrain/pet/default-hypers.yaml
3134
:language: yaml
35+
:lines: 2-
3236

33-
Tuning Hyperparameters
37+
Tuning hyperparameters
3438
----------------------
3539

36-
PET offers a number of tuning knobs for flexibility across datasets:
37-
3840
The default hyperparameters above will work well in most cases, but they
39-
may not be optimal for your specific dataset. In general, the most important
40-
hyperparameters to tune are (in decreasing order of importance):
41-
42-
- ``cutoff``: This should be set to a value after which most of the interactions between
43-
atoms is expected to be negligible. A lower cutoff will lead to faster models.
44-
- ``learning_rate``: The learning rate for the neural network. This hyperparameter
45-
controls how much the weights of the network are updated at each step of the
46-
optimization. A larger learning rate will lead to faster training, but might cause
47-
instability and/or divergence.
48-
- ``batch_size``: The number of samples to use in each batch of training. This
49-
hyperparameter controls the tradeoff between training speed and memory usage. In
50-
general, larger batch sizes will lead to faster training, but might require more
51-
memory.
52-
- ``d_pet``: This hyperparameters controls width of the neural network. In general,
53-
increasing it might lead to better accuracy, especially on larger datasets, at the
54-
cost of increased training and evaluation time.
55-
- ``d_node``: The dimension of the node features. Increasing this hyperparameter
56-
might lead to better accuracy, with a relatively small increase in inference time.
57-
- ``num_gnn_layers``: The number of graph neural network layers. In general, decreasing
58-
this hyperparameter to 1 will lead to much faster models, at the expense of accuracy.
59-
Increasing it may or may not lead to better accuracy, depending on the dataset, at the
60-
cost of increased training and evaluation time.
61-
- ``num_attention_layers``: The number of attention layers in each layer of the graph
62-
neural network. Depending on the dataset, increasing this hyperparameter might lead to
63-
better accuracy, at the cost of increased training and evaluation time.
64-
- ``loss``: This section describes the loss function to be used. See the
65-
:ref:`loss-functions` for more details.
66-
- ``long_range``: In some systems and datasets, enabling long-range Coulomb interactions
67-
might be beneficial for the accuracy of the model and/or its physical correctness.
68-
See below for a breakdown of the long-range section of the model hyperparameters.
69-
70-
All Hyperparameters
71-
-------------------
72-
73-
:param name: ``pet``
74-
75-
model
76-
#####
77-
78-
:param cutoff: Cutoff radius for neighbor search
79-
:param cutoff_width: Width of the smoothing function at the cutoff
80-
:param d_pet: Dimension of the edge features
81-
:param d_head: Dimension of the attention heads
82-
:param d_node: Dimension of the node features
83-
:param d_feedforward: Dimension of the feedforward network in the attention layer
84-
:param num_heads: Attention heads per attention layer
85-
:param num_attention_layers: Number of attention layers per GNN layer
86-
:param num_gnn_layers: Number of GNN layers
87-
:param normalization: Layer normalization type. Currently available options are
88-
``RMSNorm`` or ``LayerNorm``.
89-
:param activation: Activation function. Currently available options are ``SiLU``,
90-
and ``SwiGLU``.
91-
:param transformer_type: The order in which the layer normalization and attention
92-
are applied in a transformer block. Available options are ``PreLN``
93-
(normalization before attention) and ``PostLN`` (normalization after attention).
94-
:param featurizer_type: Implementation of the featurizer of the model to use. Available
95-
options are ``residual`` (the original featurizer from the PET paper, that uses
96-
residual connections at each GNN layer for readout) and ``feedforward`` (a modern
97-
version that uses the last representation after all GNN iterations for readout).
98-
Additionally, the feedforward version uses bidirectional features flow during the
99-
message passing iterations, that favors features flowing from atom ``i`` to atom
100-
``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``.
101-
:param zbl: Use ZBL potential for short-range repulsion
102-
:param long_range: Long-range Coulomb interactions parameters:
103-
- ``enable``: Toggle for enabling long-range interactions
104-
- ``use_ewald``: Use Ewald summation. If False, P3M is used
105-
- ``smearing``: Smearing width in Fourier space
106-
- ``kspace_resolution``: Resolution of the reciprocal space grid
107-
- ``interpolation_nodes``: Number of grid points for interpolation (for PME only)
108-
109-
training
110-
########
111-
112-
:param distributed: Whether to use distributed training
113-
:param distributed_port: Port for DDP communication
114-
:param batch_size: Training batch size
115-
:param num_epochs: Number of epochs
116-
:param warmup_fraction: Fraction of training steps used for learning rate warmup
117-
:param learning_rate: Learning rate
118-
:param log_interval: Interval to log metrics
119-
:param checkpoint_interval: Interval to save checkpoints
120-
:param scale_targets: Normalize targets to unit std during training
121-
:param fixed_composition_weights: Weights for atomic contributions
122-
:param per_structure_targets: Targets to calculate per-structure losses
123-
:param log_mae: Log MAE alongside RMSE
124-
:param log_separate_blocks: Log per-block error
125-
:param grad_clip_norm: Maximum gradient norm value, by default inf (no clipping)
126-
:param loss: Loss configuration (see above)
127-
:param best_model_metric: Metric used to select best checkpoint (e.g., ``rmse_prod``)
128-
:param num_workers: Number of workers for data loading. If not provided, it is set
129-
automatically.
41+
may not be optimal for your specific dataset. There is good number of
42+
parameters to tune, both for the :ref:`model <pet_model_hypers>` and the
43+
:ref:`trainer <pet_trainer_hypers>`. Since seeing them for the first time
44+
might be overwhelming, here we provide a least of the parameters that in general
45+
are the most important (in decreasing order of importance):
46+
47+
.. autoattribute:: metatrain.pet.hypers.PETHypers.cutoff
48+
:no-index:
49+
50+
.. autoattribute:: metatrain.pet.hypers.PETTrainerHypers.learning_rate
51+
:no-index:
52+
53+
.. autoattribute:: metatrain.pet.hypers.PETTrainerHypers.batch_size
54+
:no-index:
55+
56+
.. autoattribute:: metatrain.pet.hypers.PETHypers.d_pet
57+
:no-index:
58+
59+
.. autoattribute:: metatrain.pet.hypers.PETHypers.d_node
60+
:no-index:
61+
62+
.. autoattribute:: metatrain.pet.hypers.PETHypers.num_gnn_layers
63+
:no-index:
64+
65+
.. autoattribute:: metatrain.pet.hypers.PETHypers.num_attention_layers
66+
:no-index:
67+
68+
.. autoattribute:: metatrain.pet.hypers.PETTrainerHypers.loss
69+
:no-index:
70+
71+
.. autoattribute:: metatrain.pet.hypers.PETHypers.long_range
72+
:no-index:
73+
74+
.. _pet_model_hypers:
75+
76+
Model hyperparameters
77+
------------------------
78+
79+
The parameters that go under the ``architecture.model`` section of the config file
80+
are the following:
81+
82+
.. autoclass:: metatrain.pet.hypers.PETHypers
83+
:members:
84+
:undoc-members:
85+
86+
with the long-range section being:
87+
88+
.. autoclass:: metatrain.pet.hypers.LongRangeHypers
89+
:members:
90+
:undoc-members:
91+
92+
.. _pet_trainer_hypers:
93+
94+
Trainer hyperparameters
95+
-------------------------
96+
97+
The parameters that go under the ``architecture.trainer`` section of the config file
98+
are the following:
99+
100+
.. autoclass:: metatrain.pet.hypers.PETTrainerHypers
101+
:members:
102+
:undoc-members:
130103

131104
References
132105
----------

src/metatrain/pet/default-hypers.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# This file is auto-generated. Do not edit directly.
12
architecture:
23
name: pet
34
model:
@@ -27,7 +28,7 @@ architecture:
2728
batch_size: 16
2829
num_epochs: 1000
2930
warmup_fraction: 0.01
30-
learning_rate: 1e-4
31+
learning_rate: 0.0001
3132
weight_decay: null
3233
log_interval: 1
3334
checkpoint_interval: 100

src/metatrain/pet/hypers.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from typing import Literal, Optional, TypedDict
2+
3+
from metatrain.utils.hypers import (
4+
CompositionWeightsDict,
5+
LossDict,
6+
ScalingWeightsDict,
7+
init_with_defaults,
8+
)
9+
10+
11+
class LongRangeHypers(TypedDict):
12+
"""In some systems and datasets, enabling long-range Coulomb interactions
13+
might be beneficial for the accuracy of the model and/or
14+
its physical correctness."""
15+
16+
enable: bool = False
17+
"""Toggle for enabling long-range interactions"""
18+
use_ewald: bool = False
19+
"""Use Ewald summation. If False, P3M is used"""
20+
smearing: float = 1.4
21+
"""Smearing width in Fourier space"""
22+
kspace_resolution: float = 1.33
23+
"""Resolution of the reciprocal space grid"""
24+
interpolation_nodes: int = 5
25+
"""Number of grid points for interpolation (for PME only)"""
26+
27+
28+
class PETHypers(TypedDict):
29+
"""Hyperparameters for the PET model."""
30+
31+
cutoff: float = 4.5
32+
"""Cutoff radius for neighbor search.
33+
34+
This should be set to a value after which most of the interactions
35+
between atoms is expected to be negligible. A lower cutoff will lead
36+
to faster models.
37+
"""
38+
cutoff_width: float = 0.2
39+
"""Width of the smoothing function at the cutoff"""
40+
d_pet: int = 128
41+
"""Dimension of the edge features.
42+
43+
This hyperparameters controls width of the neural network. In general,
44+
increasing it might lead to better accuracy, especially on larger datasets, at the
45+
cost of increased training and evaluation time.
46+
"""
47+
d_head: int = 128
48+
"""Dimension of the attention heads."""
49+
d_node: int = 256
50+
"""Dimension of the node features.
51+
52+
Increasing this hyperparameter might lead to better accuracy,
53+
with a relatively small increase in inference time.
54+
"""
55+
d_feedforward: int = 256
56+
"""Dimension of the feedforward network in the attention layer."""
57+
num_heads: int = 8
58+
"""Attention heads per attention layer."""
59+
num_attention_layers: int = 2
60+
"""The number of attention layers in each layer of the graph
61+
neural network. Depending on the dataset, increasing this hyperparameter might
62+
lead to better accuracy, at the cost of increased training and evaluation time.
63+
"""
64+
num_gnn_layers: int = 2
65+
"""The number of graph neural network layers.
66+
67+
In general, decreasing this hyperparameter to 1 will lead to much faster models,
68+
at the expense of accuracy. Increasing it may or may not lead to better accuracy,
69+
depending on the dataset, at the cost of increased training and evaluation time.
70+
"""
71+
normalization: Literal["RMSNorm", "LayerNorm"] = "RMSNorm"
72+
"""Layer normalization type."""
73+
activation: Literal["SiLU", "SwiGLU"] = "SwiGLU"
74+
"""Activation function."""
75+
transformer_type: Literal["PreLN", "PostLN"] = "PreLN"
76+
"""The order in which the layer normalization and attention
77+
are applied in a transformer block. Available options are ``PreLN``
78+
(normalization before attention) and ``PostLN`` (normalization after attention)."""
79+
featurizer_type: Literal["residual", "feedforward"] = "feedforward"
80+
"""Implementation of the featurizer of the model to use. Available
81+
options are ``residual`` (the original featurizer from the PET paper, that uses
82+
residual connections at each GNN layer for readout) and ``feedforward`` (a modern
83+
version that uses the last representation after all GNN iterations for readout).
84+
Additionally, the feedforward version uses bidirectional features flow during the
85+
message passing iterations, that favors features flowing from atom ``i`` to atom
86+
``j`` to be not equal to the features flowing from atom ``j`` to atom ``i``."""
87+
zbl: bool = False
88+
"""Use ZBL potential for short-range repulsion"""
89+
long_range: LongRangeHypers = init_with_defaults(LongRangeHypers)
90+
"""Long-range Coulomb interactions parameters."""
91+
92+
93+
class PETTrainerHypers(TypedDict):
94+
"""Hyperparameters for training PET models."""
95+
96+
distributed: bool = False
97+
"""Whether to use distributed training"""
98+
distributed_port: int = 39591
99+
"""Port for DDP communication"""
100+
batch_size: int = 16
101+
"""The number of samples to use in each batch of training. This
102+
hyperparameter controls the tradeoff between training speed and memory usage. In
103+
general, larger batch sizes will lead to faster training, but might require more
104+
memory."""
105+
num_epochs: int = 1000
106+
"""Number of epochs."""
107+
warmup_fraction: float = 0.01
108+
"""Fraction of training steps used for learning rate warmup."""
109+
learning_rate: float = 1e-4
110+
"""Learning rate."""
111+
weight_decay: Optional[float] = None
112+
113+
log_interval: int = 1
114+
"""Interval to log metrics."""
115+
checkpoint_interval: int = 100
116+
"""Interval to save checkpoints."""
117+
scale_targets: bool = True
118+
"""Normalize targets to unit std during training."""
119+
fixed_composition_weights: CompositionWeightsDict = {}
120+
"""Weights for atomic contributions."""
121+
fixed_scaling_weights: ScalingWeightsDict = {}
122+
123+
per_structure_targets: list[str] = []
124+
"""Targets to calculate per-structure losses."""
125+
num_workers: Optional[int] = None
126+
"""Number of workers for data loading. If not provided, it is set
127+
automatically."""
128+
log_mae: bool = True
129+
"""Log MAE alongside RMSE"""
130+
log_separate_blocks: bool = False
131+
"""Log per-block error."""
132+
best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "mae_prod"
133+
"""Metric used to select best checkpoint (e.g., ``rmse_prod``)"""
134+
grad_clip_norm: float = 1.0
135+
"""Maximum gradient norm value, by default inf (no clipping)"""
136+
loss: str | LossDict = "mse"
137+
"""This section describes the loss function to be used. See the
138+
:ref:`loss-functions` for more details."""

src/metatrain/pet/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import typing
23
import warnings
34
from math import prod
45
from typing import Any, Dict, List, Literal, Optional, Tuple
@@ -26,16 +27,17 @@
2627
from metatrain.utils.sum_over_atoms import sum_over_atoms
2728

2829
from . import checkpoints
30+
from .hypers import PETHypers
2931
from .modules.finetuning import apply_finetuning_strategy
3032
from .modules.structures import systems_to_batch
3133
from .modules.transformer import CartesianTransformer
3234
from .modules.utilities import cutoff_func
3335

3436

35-
AVAILABLE_FEATURIZERS = ["feedforward", "residual"]
37+
AVAILABLE_FEATURIZERS = typing.get_args(PETHypers.__annotations__["featurizer_type"])
3638

3739

38-
class PET(ModelInterface):
40+
class PET(ModelInterface[PETHypers]):
3941
"""
4042
Metatrain-native implementation of the PET architecture.
4143
@@ -56,7 +58,7 @@ class PET(ModelInterface):
5658
component_labels: Dict[str, List[List[Labels]]]
5759
NUM_FEATURE_TYPES: int = 2 # node + edge features
5860

59-
def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None:
61+
def __init__(self, hypers: PETHypers, dataset_info: DatasetInfo) -> None:
6062
super().__init__(hypers, dataset_info, self.__default_metadata__)
6163

6264
# Cache frequently accessed hyperparameters

0 commit comments

Comments
 (0)