Skip to content

Commit 68c6c92

Browse files
stefanradev93LarsKuevpratzhan-ol
authored
[WIP] Move standardization into approximators and make adapter stateless. (#486)
* Add standardization to continuous approximator and test * Fix init bugs, adapt tnotebooks * Add training flag to build_from_data * Fix inference conditions check * Fix tests * Remove unnecessary init calls * Add deprecation warning * Refactor compute metrics and add standardization to model comp * Fix standardization in cont approx * Fix sample keys -> condition keys * amazing keras fix * moving_mean and moving_std still not loading [WIP] * remove hacky approximator serialization test * fix building of models in tests * Fix standardization * Add standardizatrion to model comp and let it use inheritance * make assert_models/layers_equal more thorough * [no ci] use map_shape_structure to convert shapes to arrays This automatically takes care of nested structures. * Extend Standardization to support nested inputs (#501) * extend Standardization to nested inputs By using `keras.tree.flatten` und `keras.tree.pack_sequence_as`, we can support arbitrary nested structures. A `flatten_shape` function is introduced, analogous to `map_shape_structure`, for use in the build function. * keep tree utils in submodule * Streamline call * Fix typehint --------- Co-authored-by: stefanradev93 <stefan.radev93@gmail.com> * Update moments before transform and update test * Update notebooks * Refactor and simplify due to standardize * Add comment for fetching the dict's first item, deprecate logits arg and fix typehint * add missing import in test * Refactor preparation of data for networks and new point_appr.log_prob * ContinuousApproximator._prepare_data unifies all preparation in sample, log_prob and estimate for both ContinuousApproximator and PointApproximator * PointApproximator now overrides log_prob * Add class attributes to inform proper standardization * Implement stable moving mean and std * Adapt and fix tests * minor adaptations to moving average (update time, init) We should put the update before the standardization, to use the maximum amount of information available. We can then also initialize the moving M^2 with zero, as it will be filled immediately. The special case of M^2 = 0 is not problematic, as no variance automatically indicates that all entries are equal, and we can set them to zero (see my comment). I added another test case to cover that case, and added a test for the standard deviation to the existing test. * increase tolerance of allclose tests * [no ci] set trainable to False explicitly in ModelComparisonApproximator * point estimate of covariance compatible with standardization * properly set values to zero if std is zero Cases for inf and -inf were missing * fix sample post-processing in point approximator * activate tests for multivariate normal score * [no ci] undo prev commit: MVN test still not stable, was hidden by std of 0 * specify explicit build functions for approximators * set std for untrained standardization layer to one An untrained layer thereby does not modify the input. * [no ci] reformulate zero std case * approximator builds: add guards against building networks twice * [no ci] add comparison with loaded approx to workflow test * Cleanup and address building standardization layers when None specified * Cleanup and address building standardization layers when None specified 2 * Add default case for std transform and add transformation to doc. * adapt handling of the special case M^2=0 * [no ci] minor fix in concatenate_valid_shapes * [no ci] extend test suite for approximators * fixes for standardize=None case * skip unstable MVN score case * Better transformation types * Add test for both_sides_scale inverse standardization * Add test for left_side_scale inverse standardization * Remove flaky test failing due to sampling error * Fix input dtypes in inverse standardization transformation_type tests * Use concatenate_valid in _sample * Replace PositiveDefinite link with CholeskyFactor This finally makes the MVN score sampling test stable for the jax backend, for which the keras.ops.cholesky operation is numerically unstable. The score's sample method avoids calling keras.ops.cholesky to resolve the issue. Instead the estimation head returns the Cholesky factor directly rather than the covariance matrix (as it used to be). * Reintroduce test sampling with MVN score * Address TODOs and adapt docstrings and workflow * Adapt notebooks * Fix in model comparison * Update readme and add point estimation nb --------- Co-authored-by: LarsKue <lars@kuehmichel.de> Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: han-ol <g@hans.olischlaeger.com> Co-authored-by: Hans Olischläger <106988117+han-ol@users.noreply.github.com>
1 parent 735969c commit 68c6c92

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3560
-2681
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ Many examples from [Bayesian Cognitive Modeling: A Practical Course](https://bay
130130
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
131131
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
132132
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
133-
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation_and_Expert_Stats.ipynb)
133+
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation.ipynb)
134134
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
135135
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
136136
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)

bayesflow/adapters/transforms/standardize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Sequence
2+
import warnings
23

34
import numpy as np
45

@@ -69,6 +70,14 @@ def __init__(
6970
):
7071
super().__init__()
7172

73+
if mean is None or std is None:
74+
warnings.warn(
75+
"Dynamic standardization is deprecated and will be removed in later versions."
76+
"Instead, use the standardize argument of the approximator / workflow instance or provide "
77+
"fixed mean and std arguments. You may incur some redundant computations if you keep this transform.",
78+
DeprecationWarning,
79+
)
80+
7281
self.mean = mean
7382
self.std = std
7483

bayesflow/adapters/transforms/transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __repr__(self):
2222

2323
@classmethod
2424
def from_config(cls, config: dict, custom_objects=None):
25+
# noinspection PyArgumentList
2526
return cls(**deserialize(config, custom_objects=custom_objects))
2627

2728
def get_config(self) -> dict:

bayesflow/approximators/approximator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,16 @@
1111

1212

1313
class Approximator(BackendApproximator):
14-
def build(self, data_shapes: any) -> None:
15-
mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes)
16-
self.build_from_data(mock_data)
14+
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
15+
raise NotImplementedError
1716

1817
@classmethod
1918
def build_adapter(cls, **kwargs) -> Adapter:
2019
# implemented by each respective architecture
2120
raise NotImplementedError
2221

23-
def build_from_data(self, data: dict[str, any]) -> None:
24-
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
25-
self.built = True
22+
def build_from_data(self, adapted_data: dict[str, any]) -> None:
23+
raise NotImplementedError
2624

2725
@classmethod
2826
def build_dataset(
@@ -61,6 +59,9 @@ def build_dataset(
6159
max_queue_size=max_queue_size,
6260
)
6361

62+
def call(self, *args, **kwargs):
63+
return self.compute_metrics(*args, **kwargs)
64+
6465
def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
6566
"""
6667
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
@@ -132,6 +133,7 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
132133
logging.info("Building on a test batch.")
133134
mock_data = dataset[0]
134135
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
135-
self.build_from_data(mock_data)
136+
mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data)
137+
self.build(mock_data_shapes)
136138

137139
return super().fit(dataset=dataset, **kwargs)

0 commit comments

Comments
 (0)