Skip to content

Conversation

@pfebrer
Copy link
Contributor

@pfebrer pfebrer commented Oct 29, 2025

This PR is an attempt at closing #860 by defining the way in which hyperparameters should be documented in metatrain.

In the end I'm quite happy with the result because the static type checker knows which type everything is (at least the one I have in Visual Studio Code) and the sphinx documentation looks quite nice with little effort and very low mainteinance needed moving forward. For now I only implemented the changes in PET as a proof of concept, if you are happy with the design we can apply it to the other architectures. Here is a summary of the things that I did:

  • Created a new pet/hypers.py file to store the TypedDicts specifying all the model and trainer hypers. These typed dicts contain the type, default and docstring for each parameter. They don't mess with torchscript because at runtime TypedDict is simply dict. The hypers could also be stored in the model and trainer files instead.
  • Annotated the hypers argument in the trainer and model.
  • Added Generic[HypersType] to the inheritance of the model and trainer interfaces. This is needed for the static type checker to understand what self.hypers is. Just annotating the hypers argument in the PET model is not enough.
  • Created a utils/hypers.py file with some helpers, mainly the specification for strange types like the scaler and composition model weights.
  • Created a utils/_dev/gen_hypers_files.py with the code to generate default-hypers.yaml and schema-hypers.json for a given model. The usage is as simple as python -m metatrain.utils._dev.gen_hypers_file pet. It finds the model and the trainer, gets the type annotation of the hypers argument and generates the files.
  • Modified the sphinx documentation to automatically document parameters from the hypers TypedDicts. I think it looks quite nice but it is true that the appearance of the terms PETHypers and PETTrainerHypers might be confusing to users (?) I think it still easy to understand where each parameter should be used exactly given that we provide the default yaml. But let me know what you think, if you think otherwise, we can probably hack it to make it more intuitive. I quite liked the section of "most important hypers to tune" so I kept it. See here the rendered docs for PET.

By the way, the trainer's weights_decay was undocumented :)

Before merging we also have to think how to incorporate the finetuning config. If I add it as a normal hyperparameter, it will appear in the default hypers file, which is not the previous behaviour, but I think it would be fine (?).


📚 Documentation preview 📚: https://metatrain--863.org.readthedocs.build/en/863/

@pfebrer pfebrer requested a review from abmazitov as a code owner October 29, 2025 00:18
@pfebrer pfebrer force-pushed the hypers branch 2 times, most recently from b598868 to 05ac42c Compare October 29, 2025 00:56
@pfebrer
Copy link
Contributor Author

pfebrer commented Oct 29, 2025

Ok, so mypy still has some complaints (see linter CI fail). I will fix them, but the most important one is that it does not allow default values in TypedDict, which is the core of this implementation. I think we could remove this check, given that it is very convenient to use them in this case.

Copy link
Contributor

@PicoCentauri PicoCentauri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I think we could also do it for the general hyper, right?

Also, please document how to use the generation code in the dev docs. Otherwise, the next Pol will complain that devolping for metatrain is too hard because nothing is documented and very complicated.

Comment on lines 33 to 35
.. literalinclude:: ../../../src/metatrain/pet/default-hypers.yaml
:language: yaml
:lines: 2-
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be removed. As default values should be shown by the new class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to not show the line

# This file is auto-generated. Do not edit directly. 

which I thought might be confusing for a user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaaah you mean the whole file, I thought you meant the lines thing 😅 I think it is nice to have the yaml file that users can just copy, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are a big fan of gen_xxx files xD

I think this file should not be shipped but rather be called by a dev. I would put this into the developer directory in the root of the repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm makes sense, and probably this should be called from tox to follow the philosophy of metatrain. I guess I just don't like the fact that tox creates a full environment for each simple task like this 😅

And yes I like pieces of code that generate files automatically if those files are meant to just be kept in sync with something else, I think (hoping to) syncronizing them manually is a bit crazy and does not work in the long run. People would probably use ChatGPT for that, for tasks as simple as this one I prefer to use something deterministic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't like many things about tox, but it is the best solution imho to have an easy way to run tests, and dev works without running weird setup scripts. If you have better ideas to implement an easy test suites that can be run with a single command and share environments, I look forward to your ideas.

I would also add a line in the beginning of the generated file that this is a generated and it should not be modified by hand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add a line in the beginning of the generated file that this is a generated and it should not be modified by hand.

Yes, this is done, see the other comment. On the JSON file I can't do it though because JSON doesn't allow comments (I think).

Regarding tox, yes we should stick to it, I'm not proposing to change that hahah But if there is a way to have a dev environment that is shared for multiple tasks that would be nice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative to having the files committed inside the repository could be to generate them when installing the project, although this might break for people running pip install -e.

@pfebrer
Copy link
Contributor Author

pfebrer commented Oct 29, 2025

Yes I was planning to document, but first the design needs to be fully decided 😅 The most important issue is what do we do with the fact that mypy doesn't like default values on TypedDict

@pfebrer
Copy link
Contributor Author

pfebrer commented Oct 29, 2025

Regarding the general hypers, I'm not sure, I'm a bit scared of touching them since I don't know exactly how they work.

Even just with PET hypers there are already problems, because for example the loss function can be inputed as a str, so str needs to be in the specification, but then it is ensured to be a dict, so in PET really it is assumed to be a dict. Because we typehint hypers: PETHypers, mypy complains saying that the loss can be a string. The obvious solution is to have InputPETHypers and SanitizedPETHypers with different type hints for the hypers that are sanitized, but I don't know if it is a nice thing to do haha

@PicoCentauri
Copy link
Contributor

Regarding the general hypers, I'm not sure, I'm a bit scared of touching them since I don't know exactly how they work.

I can show you how they work. As I am anyway forced to remove omegaconf by @HaoZeke we have to touch the logic anyway.

Yes I was planning to document, but first the design needs to be fully decided 😅 The most important issue is what do we do with the fact that mypy doesn't like default values on TypedDict

I am fine with disabling the test for the file. I think mypy has a regex for these. If not we can also disable it globally.

Even just with PET hypers there are already problems, because for example the loss function can be inputed as a str, so str needs to be in the specification, but then it is ensured to be a dict, so in PET really it is assumed to be a dict. Because we typehint hypers: PETHypers, mypy complains saying that the loss can be a string. The obvious solution is to have InputPETHypers and SanitizedPETHypers with different type hints for the hypers that are sanitized, but I don't know if it is a nice thing to do haha

Can't you define it as a Union[str, dict[Any, Any]]? mypy should accept this.

@pfebrer
Copy link
Contributor Author

pfebrer commented Oct 29, 2025

Can't you define it as a Union[str, dict[Any, Any]]? mypy should accept this.

Yes, you can (this is what is done right now), but the problem is that once it enters PET it is really just dict[Any, Any], so mypy is right in complaining.

Another solution would be something like:

def expand_loss(value: str | dict) -> dict: ... class Hypers: def __getitem__(self, key): return getattr(self, key) class PETHypers(Hypers): def __init__(self, loss: str | dict): self.loss = expand_loss(loss)

And then you document the inputs based on the signature of PETHypers. I'm not 100% sure if type checkers will understand this and/or if it would work with torchscript.

@pfebrer
Copy link
Contributor Author

pfebrer commented Oct 29, 2025

Or as @Luthaf was suggesting, using pydantic. I imagine these things should already be considered there. I don't know if it works well with torchscript though

Model,
Trainer,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, after merging the defaults with the inputs from the user, we validate using pydantic.

__config__={"extra": "forbid", "strict": True},
)

ArchitectureOptions.model_validate(options)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this new file to do validation with pydantic. For now it only has the function to validate architecture options. It grabs the TypeDicts from the annotations of the model and the trainer, creates a pydantic model on the fly and validates the options.


# Create a loss function:
loss_hypers = self.hypers["loss"]
assert not isinstance(loss_hypers, str) # For mypy type checking
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to just add this check here for now instead of having the complexity of defining a type for raw inputs and a type for the sanitized hypers. Since loss is the only hyper that is sanitized in this way it should be fine for now, and we could think of more automated solutions if the problem comes up in other hypers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth to get @ppegolo and @jwa7 opinions here as well, since they worked on the loss!

Copy link
Contributor Author

@pfebrer pfebrer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following what we discussed today, I incorporated pydantic validation, removing the need of jsonschema (for the moment just in the architecture hypers as a proof of concept). Defining directly the pydantic models became really problematic because torchscript does not like them, so I went with the approach proposed by @Luthaf: we define hypers as TypedDicts and we use these definitions in pydantic to do validation. I think it is quite a simple approach that avoids the complications of using the pydantic models directly. I guess it has some limitations in what you can do, but it is possible that we never want to do very complex validation, let's see.

Let me know if you think the approach is sane. If so, I will finish the implementation.

@pfebrer
Copy link
Contributor Author

pfebrer commented Oct 31, 2025

Turns out typing the scaler and composition weights for pydantic validation was much easier/transparent than the previous json schemas :)

I'm almost done with PET, I would need some validation/feedback before I move to doing the full thing 😅 I know the PR touches a lot of files, but most of it is just adding types

Copy link
Member

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes looks fine overall, although we need some explanations on what's going on in the documentation for "adding a new architecture"

Comment on lines 1 to 3
# mypy: disable-error-code=misc
# We ignore misc errors in this file because TypedDict
# with default values is not allowed by mypy.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we disable this check with a mypy option instead? This is not very nice for new model authors otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the problem here is that there is no specific error code for this error, and therefore we would have to disable all misc errors as far as I understand 😪 But maybe there is a way to filter by message? Do you know (I couldn't find any way of doing it)?



class PET(ModelInterface):
class PET(ModelInterface[PETHypers]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remind me why we need to parametrize ModelInterface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so that the type checker knows what self.hypers is


# Create a loss function:
loss_hypers = self.hypers["loss"]
assert not isinstance(loss_hypers, str) # For mypy type checking
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth to get @ppegolo and @jwa7 opinions here as well, since they worked on the loss!

:return: The TypedDict annotation of the 'hypers' parameter.
"""
return inspect.signature(module).parameters["hypers"].annotation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried this is too much magic that would confuse new contributors. Maybe we could use a more explicit __hyper_class__ class attribute (like we do for all other model and trainer metadata)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I see, we could do that. But then we would have two sources of truth (annotation and __hypers_cls__) at the risk of them going out of sync for some reason. I feel like python annotations are already very standard nowadays, so I don't think that the fact that they are used feels more magic than a double underscore attribute. Also when thinking about the "Add a new architecture" docs, I feel like they will be cleaner if we just say: type hint your class and that's It. But if my arguments don't convince you I'm willing to add the __hypers_cls__

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the multiple sources of truth, we can enforce that the annotation and __hypers_cls__ are the same in ModelInterface.__init__(), so I'm not too worried about it.

IMO this is more about what's the easiest to teach/understand, and my main worry with the annotation is that it is a bit too much magic for my taste. Although if others are happy with it I can live with it as well!

@frostedoyster and @PicoCentauri, do you have thoughts about how a new architecture contributor would indicate which class should be used for hyper validation/generation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm writing the "Add a new architecture" section and indeed seems simpler to explain first with __hypers_cls__, because then you don't need typing and one can create a valid architecture with:

class Hypers: alpha = 1 mode = "strict" ... class MyModel(ModelInterface): __hypers_cls__ = Hypers def __init__(hypers: dict, dataset_info: DatasetInfo): ....

which would be the equivalent of the previous "just providing default-hypers.yaml", without causing failures for mypy. On a later section I can explain how to type and document so that the architecture is ready to be stable. Although given how easy it is to document in python, I would at least require that hypers have a docstring to be considered worthy of being experimental. Do you agree with this last sentence?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative to having the files committed inside the repository could be to generate them when installing the project, although this might break for people running pip install -e.

@pfebrer
Copy link
Contributor Author

pfebrer commented Nov 3, 2025

An alternative to having the files committed inside the repository could be to generate them when installing the project, although this might break for people running pip install -e.

I don't know why I can't answer this comment directly, I'll answer here. We don't actually need them anymore, since the defaults could be generated on the fly from the hypers specification. I just generated them here to respect traditions, and because I'm not sure that they are useless since they are a very convenient way to look at all the defaults at once, or even use the file as a template. Maybe if we generate the file for the online docs is enough though. So for me it's either we keep them in the repo or they disappear completely (even after the installation). What would you vote for?

@Luthaf
Copy link
Member

Luthaf commented Nov 4, 2025

I'd be fine removing it from the repo, especially if we can dynamically create them for the documentation.

@pfebrer
Copy link
Contributor Author

pfebrer commented Nov 10, 2025

The general tests that check get_default_hypers:

@pytest.mark.parametrize("name", find_all_architectures())
def test_get_default_hypers(name):
"""Test that architecture hypers for all arches can be loaded."""
if name == "llpr":
# Skip this architecture as it is not a valid architecture but a wrapper
return
default_hypers = get_default_hypers(name)
assert type(default_hypers) is dict
assert default_hypers["name"] == name

and

@pytest.mark.parametrize("name", find_all_architectures())
def test_check_valid_default_architecture_options(name):
"""Test that all default hypers are according to the provided schema."""
if name == "llpr":
# Skip this architecture as it is not a valid architecture but a wrapper
return
options = get_default_hypers(name)
check_architecture_options(name=name, options=options)

can no longer be generic tests because to get the default hypers we now need to import the architecture (and therefore need the architecture dependencies). I have moved them to each architecture's tests (while keeping a test for soap_bpnn on the generic tests)

@pfebrer
Copy link
Contributor Author

pfebrer commented Nov 11, 2025

This is ready! (if the tests agree)

@pfebrer
Copy link
Contributor Author

pfebrer commented Nov 11, 2025

When using this in MACE I realized that to generate the documentation one needs to import the architecture and therefore we need all the dependencies, which is not sustainable in the long run if architectures keep coming in.

We have been discussing with Philip and we are moving to imposing a bit more where the hypers should live (what is the exact path), which will allow us to generate documentation without having to import the architectures.

module.__model__.__hypers_cls__ = documentation.ModelHypers
module.__trainer__.__hypers_cls__ = documentation.TrainerHypers

return module
Copy link
Contributor Author

@pfebrer pfebrer Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made 2 modifications to import_architecture:

  • Raise the "dependencies not installed error" only when really that is the case (using ModuleNotFoundError instead of ImportError).
  • Search for the hypers classes in documentation.py and set them as __hypers_cls__ for the model and trainer. The reason I did this is because in the future if we allow external architectures, we probably don't want to force them to have a documentation.py file, and they will come with a __hypers_cls__ themselves.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise the "dependencies not installed error" only when really that is the case (using ModuleNotFoundError instead of ImportError).

I would keep using ImportError, code might fail to import for reasons other than ModuleNotFound. But I'm happy to tweak the error message to say "not installed or with a broken installation".

@pfebrer
Copy link
Contributor Author

pfebrer commented Nov 12, 2025

@Luthaf could you check that you don't hate the new approach that we agreed with Philip? Just a quick look here: https://metatrain--863.org.readthedocs.build/en/863/dev-docs/new-architecture.html#documentation-documentation-py will be enough

@pfebrer
Copy link
Contributor Author

pfebrer commented Nov 12, 2025

Ready to merge if/when you approve 👍

@Luthaf
Copy link
Member

Luthaf commented Nov 12, 2025

@Luthaf could you check that you don't hate the new approach that we agreed with Philip?

This looks good, I like that we have one single file for both hypers and general architecture documentation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants