Skip to content

Commit d24f5a3

Browse files
committed
Robustify kwargs passing inference networks, add class variables
1 parent de8e1cb commit d24f5a3

File tree

6 files changed

+51
-35
lines changed

6 files changed

+51
-35
lines changed

bayesflow/approximators/approximator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def build_adapter(cls, **kwargs) -> Adapter:
2323
raise NotImplementedError
2424

2525
def build_from_data(self, data: Mapping[str, any]) -> None:
26-
self.compute_metrics(**data, stage="training")
26+
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
2727
self.built = True
2828

2929
@classmethod

bayesflow/approximators/continuous_approximator.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class ContinuousApproximator(Approximator):
3232
Additional arguments passed to the :py:class:`bayesflow.approximators.Approximator` class.
3333
"""
3434

35+
SAMPLE_KEYS = ["summary_variables", "inference_conditions"]
36+
3537
def __init__(
3638
self,
3739
*,
@@ -51,6 +53,7 @@ def build_adapter(
5153
inference_variables: Sequence[str],
5254
inference_conditions: Sequence[str] = None,
5355
summary_variables: Sequence[str] = None,
56+
standardize: bool = True,
5457
sample_weight: str = None,
5558
) -> Adapter:
5659
"""Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator.
@@ -63,9 +66,12 @@ def build_adapter(
6366
Names of the inference conditions in the data
6467
summary_variables : Sequence of str, optional
6568
Names of the summary variables in the data
69+
standardize : bool, optional
70+
Decide whether to standardize all variables, default is True
6671
sample_weight : str, optional
6772
Name of the sample weights
6873
"""
74+
6975
adapter = Adapter()
7076
adapter.to_array()
7177
adapter.convert_dtype("float64", "float32")
@@ -82,7 +88,9 @@ def build_adapter(
8288
adapter = adapter.rename(sample_weight, "sample_weight")
8389

8490
adapter.keep(["inference_variables", "inference_conditions", "summary_variables", "sample_weight"])
85-
adapter.standardize(exclude="sample_weight")
91+
92+
if standardize:
93+
adapter.standardize(exclude="sample_weight")
8694

8795
return adapter
8896

@@ -334,12 +342,18 @@ def sample(
334342
dict[str, np.ndarray]
335343
Dictionary containing generated samples with the same keys as `conditions`.
336344
"""
345+
346+
# Apply adapter transforms to raw simulated / real quantities
337347
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
338-
# at inference time, inference_variables are estimated by the networks and thus ignored in conditions
339-
conditions.pop("inference_variables", None)
348+
349+
# Ensure only keys relevant for sampling are present in the conditions dictionary
350+
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.SAMPLE_KEYS}
351+
340352
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
341353
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)}
342354
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
355+
356+
# Back-transform quantities and samples
343357
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
344358

345359
if split:

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ class ModelComparisonApproximator(Approximator):
3030
The network backbone (e.g, an MLP) that is used for model classification.
3131
The input of the classifier network is created by concatenating `classifier_variables`
3232
and (optional) output of the summary_network.
33-
summary_network: bg.networks.SummaryNetwork, optional
33+
summary_network: bf.networks.SummaryNetwork, optional
3434
The summary network used for data summarization (default is None).
3535
The input of the summary network is `summary_variables`.
3636
"""
3737

38+
SAMPLE_KEYS = ["summary_variables", "inference_conditions"]
39+
3840
def __init__(
3941
self,
4042
*,
@@ -304,9 +306,13 @@ def predict(
304306
np.ndarray
305307
Predicted posterior model probabilities given `conditions`.
306308
"""
309+
310+
# Apply adapter transforms to raw simulated / real quantities
307311
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
308-
# at inference time, model_indices are predicted by the networks and thus ignored in conditions
309-
conditions.pop("model_indices", None)
312+
313+
# Ensure only keys relevant for sampling are present in the conditions dictionary
314+
conditions = {k: v for k, v in conditions.items() if k in ModelComparisonApproximator.SAMPLE_KEYS}
315+
310316
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
311317

312318
output = self._predict(**conditions, **kwargs)

bayesflow/approximators/point_approximator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,10 @@ def log_prob(
156156

157157
def _prepare_conditions(self, conditions: Mapping[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
158158
"""Adapts and converts the conditions to tensors."""
159+
159160
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
160-
conditions.pop("inference_variables", None)
161+
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.SAMPLE_KEYS}
162+
161163
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
162164

163165
def _apply_inverse_adapter_to_estimates(

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def build(self, xz_shape, conditions_shape=None):
187187
self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1])
188188
self.c_huber2 = self.c_huber**2
189189

190-
## Calculate discretization schedule in advance
190+
# Calculate discretization schedule in advance
191191
# The Jax compiler requires fixed-size arrays, so we have
192192
# to store all the discretized_times in one matrix in advance
193193
# and later only access the relevant entries.
@@ -213,34 +213,24 @@ def build(self, xz_shape, conditions_shape=None):
213213
disc = ops.convert_to_numpy(self._discretize_time(n))
214214
discretized_times[i, : len(disc)] = disc
215215
discretization_map[n] = i
216+
216217
# Finally, we convert the vectors to tensors
217218
self.discretized_times = ops.convert_to_tensor(discretized_times, dtype="float32")
218219
self.discretization_map = ops.convert_to_tensor(discretization_map)
219220

220-
def call(
221-
self,
222-
xz: Tensor,
223-
conditions: Tensor = None,
224-
inverse: bool = False,
225-
**kwargs,
226-
):
227-
if inverse:
228-
return self._inverse(xz, conditions=conditions, **kwargs)
229-
return self._forward(xz, conditions=conditions, **kwargs)
230-
231-
def _forward_train(self, x: Tensor, noise: Tensor, t: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
232-
"""Forward function for training. Calls consistency function with
233-
noisy input
234-
"""
221+
def _forward_train(
222+
self, x: Tensor, noise: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False, **kwargs
223+
) -> Tensor:
224+
"""Forward function for training. Calls consistency function with noisy input"""
235225
inp = x + t * noise
236-
return self.consistency_function(inp, t, conditions=conditions, **kwargs)
226+
return self.consistency_function(inp, t, conditions=conditions, training=training)
237227

238228
def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
239229
# Consistency Models only learn the direction from noise distribution
240230
# to target distribution, so we cannot implement this function.
241231
raise NotImplementedError("Consistency Models are not invertible")
242232

243-
def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
233+
def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor:
244234
"""Generate random draws from the approximate target distribution
245235
using the multistep sampling algorithm from [1], Algorithm 1.
246236
@@ -249,7 +239,9 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
249239
z : Tensor
250240
Samples from a standard normal distribution
251241
conditions : Tensor, optional, default: None
252-
Conditions for a approximate conditional distribution
242+
Conditions for the approximate conditional distribution
243+
training : bool, optional, default: True
244+
Whether internal layers (e.g., dropout) should behave in train or inference mode.
253245
**kwargs : dict, optional, default: {}
254246
Additional keyword arguments. Include `steps` (default: 10) to
255247
adjust the number of sampling steps.
@@ -263,15 +255,17 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
263255
x = keras.ops.copy(z) * self.max_time
264256
discretized_time = keras.ops.flip(self._discretize_time(steps), axis=-1)
265257
t = keras.ops.full((*keras.ops.shape(x)[:-1], 1), discretized_time[0], dtype=x.dtype)
266-
x = self.consistency_function(x, t, conditions=conditions)
258+
259+
x = self.consistency_function(x, t, conditions=conditions, training=training)
260+
267261
for n in range(1, steps):
268262
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)
269263
x_n = x + keras.ops.sqrt(keras.ops.square(discretized_time[n]) - self.eps**2) * noise
270264
t = keras.ops.full_like(t, discretized_time[n])
271-
x = self.consistency_function(x_n, t, conditions=conditions)
265+
x = self.consistency_function(x_n, t, conditions=conditions, training=training)
272266
return x
273267

274-
def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
268+
def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
275269
"""Compute consistency function.
276270
277271
Parameters
@@ -282,16 +276,16 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
282276
Vector of time samples in [eps, T]
283277
conditions : Tensor
284278
The conditioning vector
285-
**kwargs : dict, optional, default: {}
286-
Additional keyword arguments passed to the network.
279+
training : bool, optional, default: True
280+
Whether internal layers (e.g., dropout) should behave in train or inference mode.
287281
"""
288282

289283
if conditions is not None:
290284
xtc = ops.concatenate([x, t, conditions], axis=-1)
291285
else:
292286
xtc = ops.concatenate([x, t], axis=-1)
293287

294-
f = self.output_projector(self.subnet(xtc, **kwargs))
288+
f = self.output_projector(self.subnet(xtc, training=training))
295289

296290
# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
297291
# Thus, we can do a cross product with the time vector which is (batch_size, 1) for

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _forward(
152152
z = x
153153
log_det = keras.ops.zeros(keras.ops.shape(x)[:-1])
154154
for layer in self.invertible_layers:
155-
z, det = layer(z, conditions=conditions, inverse=False, training=training, **kwargs)
155+
z, det = layer(z, conditions=conditions, inverse=False, training=training)
156156
log_det += det
157157

158158
if density:
@@ -168,7 +168,7 @@ def _inverse(
168168
x = z
169169
log_det = keras.ops.zeros(keras.ops.shape(z)[:-1])
170170
for layer in reversed(self.invertible_layers):
171-
x, det = layer(x, conditions=conditions, inverse=True, training=training, **kwargs)
171+
x, det = layer(x, conditions=conditions, inverse=True, training=training)
172172
log_det += det
173173

174174
if density:

0 commit comments

Comments
 (0)