Skip to content

Commit 9f11476

Browse files
committed
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into dev
2 parents d62e73d + c10d7d3 commit 9f11476

31 files changed

+1905
-711
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ For an in-depth exposition, check out our walkthrough notebooks below.
5858
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
5959
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
6060
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)
61+
8. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
6162

6263
More tutorials are always welcome! Please consider making a pull request if you have a cool application that you want to contribute.
6364

bayesflow/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
experimental,
88
networks,
99
simulators,
10-
workflows,
1110
utils,
11+
workflows,
12+
wrappers,
1213
)
1314

1415
from .adapters import Adapter

bayesflow/adapters/adapter.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import MutableSequence, Sequence, Mapping
1+
from collections.abc import Callable, MutableSequence, Sequence, Mapping
22

33
import numpy as np
44

@@ -24,6 +24,7 @@
2424
NumpyTransform,
2525
OneHot,
2626
Rename,
27+
SerializableCustomTransform,
2728
Sqrt,
2829
Standardize,
2930
ToArray,
@@ -274,6 +275,88 @@ def apply(
274275
self.transforms.append(transform)
275276
return self
276277

278+
def apply_serializable(
279+
self,
280+
include: str | Sequence[str] = None,
281+
*,
282+
forward: Callable[[np.ndarray, ...], np.ndarray],
283+
inverse: Callable[[np.ndarray, ...], np.ndarray],
284+
predicate: Predicate = None,
285+
exclude: str | Sequence[str] = None,
286+
**kwargs,
287+
):
288+
"""Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter.
289+
290+
Parameters
291+
----------
292+
forward : function, no lambda
293+
Registered serializable function to transform the data in the forward pass.
294+
For the adapter to be serializable, this function has to be serializable
295+
as well (see Notes). Therefore, only proper functions and no lambda
296+
functions can be used here.
297+
inverse : function, no lambda
298+
Registered serializable function to transform the data in the inverse pass.
299+
For the adapter to be serializable, this function has to be serializable
300+
as well (see Notes). Therefore, only proper functions and no lambda
301+
functions can be used here.
302+
predicate : Predicate, optional
303+
Function that indicates which variables should be transformed.
304+
include : str or Sequence of str, optional
305+
Names of variables to include in the transform.
306+
exclude : str or Sequence of str, optional
307+
Names of variables to exclude from the transform.
308+
**kwargs : dict
309+
Additional keyword arguments passed to the transform.
310+
311+
Raises
312+
------
313+
ValueError
314+
When the provided functions are not registered serializable functions.
315+
316+
Notes
317+
-----
318+
Important: The forward and inverse functions have to be registered with Keras.
319+
To do so, use the `@keras.saving.register_keras_serializable` decorator.
320+
They must also be registered (and identical) when loading the adapter
321+
at a later point in time.
322+
323+
Examples
324+
--------
325+
326+
The example below shows how to use the
327+
`keras.saving.register_keras_serializable` decorator to
328+
register functions with Keras. Note that for this simple
329+
example, one usually would use the simpler :py:meth:`apply`
330+
method.
331+
332+
>>> import keras
333+
>>>
334+
>>> @keras.saving.register_keras_serializable("custom")
335+
>>> def forward_fn(x):
336+
>>> return x**2
337+
>>>
338+
>>> @keras.saving.register_keras_serializable("custom")
339+
>>> def inverse_fn(x):
340+
>>> return x**0.5
341+
>>>
342+
>>> adapter = bf.Adapter().apply_serializable(
343+
>>> "x",
344+
>>> forward=forward_fn,
345+
>>> inverse=inverse_fn,
346+
>>> )
347+
"""
348+
transform = FilterTransform(
349+
transform_constructor=SerializableCustomTransform,
350+
predicate=predicate,
351+
include=include,
352+
exclude=exclude,
353+
forward=forward,
354+
inverse=inverse,
355+
**kwargs,
356+
)
357+
self.transforms.append(transform)
358+
return self
359+
277360
def as_set(self, keys: str | Sequence[str]):
278361
"""Append an :py:class:`~transforms.AsSet` transform to the adapter.
279362

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .one_hot import OneHot
1616
from .rename import Rename
1717
from .scale import Scale
18+
from .serializable_custom_transform import SerializableCustomTransform
1819
from .shift import Shift
1920
from .sqrt import Sqrt
2021
from .standardize import Standardize
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from collections.abc import Callable
2+
import numpy as np
3+
from keras.saving import (
4+
deserialize_keras_object as deserialize,
5+
register_keras_serializable as serializable,
6+
serialize_keras_object as serialize,
7+
get_registered_name,
8+
get_registered_object,
9+
)
10+
from .elementwise_transform import ElementwiseTransform
11+
from ...utils import filter_kwargs
12+
import inspect
13+
14+
15+
@serializable(package="bayesflow.adapters")
16+
class SerializableCustomTransform(ElementwiseTransform):
17+
"""
18+
Transforms a parameter using a pair of registered serializable forward and inverse functions.
19+
20+
Parameters
21+
----------
22+
forward : function, no lambda
23+
Registered serializable function to transform the data in the forward pass.
24+
For the adapter to be serializable, this function has to be serializable
25+
as well (see Notes). Therefore, only proper functions and no lambda
26+
functions can be used here.
27+
inverse : function, no lambda
28+
Function to transform the data in the inverse pass.
29+
For the adapter to be serializable, this function has to be serializable
30+
as well (see Notes). Therefore, only proper functions and no lambda
31+
functions can be used here.
32+
33+
Raises
34+
------
35+
ValueError
36+
When the provided functions are not registered serializable functions.
37+
38+
Notes
39+
-----
40+
Important: The forward and inverse functions have to be registered with Keras.
41+
To do so, use the `@keras.saving.register_keras_serializable` decorator.
42+
They must also be registered (and identical) when loading the adapter
43+
at a later point in time.
44+
45+
"""
46+
47+
def __init__(
48+
self,
49+
*,
50+
forward: Callable[[np.ndarray, ...], np.ndarray],
51+
inverse: Callable[[np.ndarray, ...], np.ndarray],
52+
):
53+
super().__init__()
54+
55+
self._check_serializable(forward, label="forward")
56+
self._check_serializable(inverse, label="inverse")
57+
self._forward = forward
58+
self._inverse = inverse
59+
60+
@classmethod
61+
def _check_serializable(cls, function, label=""):
62+
GENERAL_EXAMPLE_CODE = (
63+
"The example code below shows the structure of a correctly decorated function:\n\n"
64+
"```\n"
65+
"import keras\n\n"
66+
"@keras.saving.register_keras_serializable('custom')\n"
67+
f"def my_{label}(...):\n"
68+
" [your code goes here...]\n"
69+
"```\n"
70+
)
71+
if function is None:
72+
raise TypeError(
73+
f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}"
74+
)
75+
registered_name = get_registered_name(function)
76+
# check if function is a lambda function
77+
if registered_name == "<lambda>":
78+
raise ValueError(
79+
f"The provided function for '{label}' is a lambda function, "
80+
"which cannot be serialized. "
81+
"Please provide a registered serializable function by using the "
82+
"@keras.saving.register_keras_serializable decorator."
83+
f"\n{GENERAL_EXAMPLE_CODE}"
84+
)
85+
if inspect.ismethod(function):
86+
raise ValueError(
87+
f"The provided value for '{label}' is a method, not a function. "
88+
"Methods cannot be serialized separately from their classes. "
89+
"Please provide a registered serializable function instead by "
90+
"moving the functionality to a function (i.e., outside of the class) and "
91+
"using the @keras.saving.register_keras_serializable decorator."
92+
f"\n{GENERAL_EXAMPLE_CODE}"
93+
)
94+
registered_object_for_name = get_registered_object(registered_name)
95+
if registered_object_for_name is None:
96+
try:
97+
source_max_lines = 5
98+
function_source_code = inspect.getsource(function).split("\n")
99+
if len(function_source_code) > source_max_lines:
100+
function_source_code = function_source_code[:source_max_lines] + [" [...]"]
101+
102+
example_code = "For your provided function, this would look like this:\n\n"
103+
example_code += "\n".join(
104+
["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"]
105+
+ function_source_code
106+
+ ["```"]
107+
)
108+
except OSError:
109+
example_code = GENERAL_EXAMPLE_CODE
110+
raise ValueError(
111+
f"The provided function for '{label}' is not registered with Keras.\n"
112+
"Please register the function using the "
113+
"@keras.saving.register_keras_serializable decorator.\n"
114+
f"{example_code}"
115+
)
116+
if registered_object_for_name is not function:
117+
raise ValueError(
118+
f"The provided function for '{label}' does not match the function "
119+
f"registered under its name '{registered_name}'. "
120+
f"(registered function: {registered_object_for_name}, provided function: {function}). "
121+
)
122+
123+
@classmethod
124+
def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform":
125+
if get_registered_object(config["forward"]["config"], custom_objects) is None:
126+
provided_function_msg = ""
127+
if config["_forward_source_code"]:
128+
provided_function_msg = (
129+
f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```"
130+
)
131+
raise TypeError(
132+
"\n\nPLEASE READ HERE:\n"
133+
"-----------------\n"
134+
"The forward function that was provided as `forward` "
135+
"is not registered with Keras, making deserialization impossible. "
136+
f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original "
137+
"function before loading your model."
138+
f"{provided_function_msg}"
139+
)
140+
if get_registered_object(config["inverse"]["config"], custom_objects) is None:
141+
provided_function_msg = ""
142+
if config["_inverse_source_code"]:
143+
provided_function_msg = (
144+
f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```"
145+
)
146+
raise TypeError(
147+
"\n\nPLEASE READ HERE:\n"
148+
"-----------------\n"
149+
"The inverse function that was provided as `inverse` "
150+
"is not registered with Keras, making deserialization impossible. "
151+
f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original "
152+
"function before loading your model."
153+
f"{provided_function_msg}"
154+
)
155+
forward = deserialize(config["forward"], custom_objects)
156+
inverse = deserialize(config["inverse"], custom_objects)
157+
return cls(
158+
forward=forward,
159+
inverse=inverse,
160+
)
161+
162+
def get_config(self) -> dict:
163+
forward_source_code = inverse_source_code = None
164+
try:
165+
forward_source_code = inspect.getsource(self._forward)
166+
inverse_source_code = inspect.getsource(self._inverse)
167+
except OSError:
168+
pass
169+
return {
170+
"forward": serialize(self._forward),
171+
"inverse": serialize(self._inverse),
172+
"_forward_source_code": forward_source_code,
173+
"_inverse_source_code": inverse_source_code,
174+
}
175+
176+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
177+
# filter kwargs so that other transform args like batch_size, strict, ... are not passed through
178+
kwargs = filter_kwargs(kwargs, self._forward)
179+
return self._forward(data, **kwargs)
180+
181+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
182+
kwargs = filter_kwargs(kwargs, self._inverse)
183+
return self._inverse(data, **kwargs)

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def calibration_ecdf(
1919
figsize: Sequence[float] = None,
2020
label_fontsize: int = 16,
2121
legend_fontsize: int = 14,
22+
legend_location: str = "upper right",
2223
title_fontsize: int = 18,
2324
tick_fontsize: int = 12,
2425
rank_ecdf_color: str = "#132a70",
@@ -184,7 +185,7 @@ def calibration_ecdf(
184185

185186
for ax, title in zip(plot_data["axes"].flat, titles):
186187
ax.fill_between(z, L, U, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
187-
ax.legend(fontsize=legend_fontsize)
188+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
188189
ax.set_title(title, fontsize=title_fontsize)
189190

190191
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def calibration_ecdf_from_quantiles(
1919
figsize: Sequence[float] = None,
2020
label_fontsize: int = 16,
2121
legend_fontsize: int = 14,
22+
legend_location: str = "upper right",
2223
title_fontsize: int = 18,
2324
tick_fontsize: int = 12,
2425
rank_ecdf_color: str = "#132a70",
@@ -173,7 +174,7 @@ def calibration_ecdf_from_quantiles(
173174
alpha=0.2,
174175
label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands" + "\n(pointwise)",
175176
)
176-
ax.legend(fontsize=legend_fontsize)
177+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
177178
ax.set_title(title, fontsize=title_fontsize)
178179

179180
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

bayesflow/networks/transformers/mab.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,17 @@
1010
class MultiHeadAttentionBlock(keras.Layer):
1111
"""Implements the MAB block from [1] which represents learnable cross-attention.
1212
13+
In particular, it uses a so-called "Post-LN" transformer block [2] which applies
14+
layer norm following attention and following MLP. A "Pre-LN" transformer block
15+
can easily be implemented.
16+
1317
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
1418
Set transformer: A framework for attention-based permutation-invariant neural networks.
1519
In International conference on machine learning (pp. 3744-3753). PMLR.
20+
21+
[2] Xiong, R., Yang, Y., He, D., Zheng, K., Zheng, S., Xing, C., ... & Liu, T. (2020, November).
22+
On layer normalization in the transformer architecture.
23+
In International conference on machine learning (pp. 10524-10533). PMLR.
1624
"""
1725

1826
def __init__(

0 commit comments

Comments
 (0)