Skip to content

Commit 2bf0b53

Browse files
committed
Remove old usage of aggregate and use weighted_mean everywhere
1 parent 9f11476 commit 2bf0b53

File tree

10 files changed

+23
-41
lines changed

10 files changed

+23
-41
lines changed

bayesflow/experimental/free_form_flow/free_form_flow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
vjp,
1313
serialize_value_or_type,
1414
deserialize_value_or_type,
15+
weighted_mean,
1516
)
1617

1718
from bayesflow.networks import InferenceNetwork
@@ -240,6 +241,6 @@ def decode(z):
240241
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
241242

242243
losses = maximum_likelihood_loss + self.beta * reconstruction_loss
243-
loss = self.aggregate(losses, sample_weight)
244+
loss = weighted_mean(losses, sample_weight)
244245

245246
return base_metrics | {"loss": loss}

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
import numpy as np
88

99
from bayesflow.types import Tensor
10-
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type, weighted_sum
10+
from bayesflow.utils import (
11+
find_network,
12+
keras_kwargs,
13+
serialize_value_or_type,
14+
deserialize_value_or_type,
15+
weighted_mean,
16+
)
1117

1218

1319
from ..inference_network import InferenceNetwork
@@ -331,6 +337,6 @@ def compute_metrics(
331337

332338
# Pseudo-huber loss, see [2], Section 3.3
333339
loss = lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber)
334-
loss = weighted_sum(loss, sample_weight)
340+
loss = weighted_mean(loss, sample_weight)
335341

336342
return base_metrics | {"loss": loss}

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
keras_kwargs,
88
serialize_value_or_type,
99
deserialize_value_or_type,
10-
weighted_sum,
10+
weighted_mean,
1111
)
1212

1313
from .actnorm import ActNorm
@@ -167,6 +167,6 @@ def compute_metrics(
167167
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
168168

169169
z, log_density = self(x, conditions=conditions, inverse=False, density=True)
170-
loss = weighted_sum(-log_density, sample_weight)
170+
loss = weighted_mean(-log_density, sample_weight)
171171

172172
return base_metrics | {"loss": loss}

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
optimal_transport,
1414
serialize_value_or_type,
1515
deserialize_value_or_type,
16-
weighted_sum,
16+
weighted_mean,
1717
)
1818
from ..inference_network import InferenceNetwork
1919

@@ -260,6 +260,6 @@ def compute_metrics(
260260
predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training")
261261

262262
loss = self.loss_fn(target_velocity, predicted_velocity)
263-
loss = weighted_sum(loss, sample_weight)
263+
loss = weighted_mean(loss, sample_weight)
264264

265265
return base_metrics | {"loss": loss}

bayesflow/scores/normed_difference_score.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from keras.saving import register_keras_serializable as serializable
33

44
from bayesflow.types import Shape, Tensor
5+
from bayesflow.utils import weighted_mean
56

67
from .scoring_rule import ScoringRule
78

@@ -55,7 +56,7 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
5556
"""
5657
estimates = estimates["value"]
5758
scores = keras.ops.absolute(estimates - targets) ** self.k
58-
score = self.aggregate(scores, weights)
59+
score = weighted_mean(scores, weights)
5960
return score
6061

6162
def get_config(self):

bayesflow/scores/parametric_distribution_score.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from keras.saving import register_keras_serializable as serializable
22

33
from bayesflow.types import Tensor
4+
from bayesflow.utils import weighted_mean
45

56
from .scoring_rule import ScoringRule
67

@@ -29,5 +30,5 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
2930
:math:`S(\hat p_\phi, \theta; k) = -\log(\hat p_\phi(\theta))`
3031
"""
3132
scores = -self.log_prob(x=targets, **estimates)
32-
score = self.aggregate(scores, weights)
33+
score = weighted_mean(scores, weights)
3334
return score

bayesflow/scores/quantile_score.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from keras.saving import register_keras_serializable as serializable
55

66
from bayesflow.types import Shape, Tensor
7-
from bayesflow.utils import logging
7+
from bayesflow.utils import logging, weighted_mean
88
from bayesflow.links import OrderedQuantiles
99

1010
from .scoring_rule import ScoringRule
@@ -39,7 +39,7 @@ def get_config(self):
3939
base_config = super().get_config()
4040
return base_config | self.config
4141

42-
def get_head_shapes_from_target_shape(self, target_shape: Shape):
42+
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, tuple]:
4343
# keras.saving.load_model sometimes passes target_shape as a list, so we force a conversion
4444
target_shape = tuple(target_shape)
4545
return dict(value=(len(self.q),) + target_shape[1:])
@@ -49,5 +49,5 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
4949
pointwise_differance = estimates - targets[:, None, :]
5050
scores = pointwise_differance * (keras.ops.cast(pointwise_differance > 0, float) - self._q[None, :, None])
5151
scores = keras.ops.mean(scores, axis=1)
52-
score = self.aggregate(scores, weights)
52+
score = weighted_mean(scores, weights)
5353
return score

bayesflow/scores/scoring_rule.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -204,30 +204,3 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor)
204204
<tf.Tensor: shape=(), dtype=float32, numpy=1.013983130455017>
205205
"""
206206
raise NotImplementedError
207-
208-
def aggregate(self, scores: Tensor, weights: Tensor = None) -> Tensor:
209-
"""
210-
Computes the mean of **scores**, optionally applying **weights**.
211-
212-
This function computes the mean value of the given scores. When weights are provided,
213-
it first multiplies the scores by the weights and then computes the mean of the result.
214-
If no weights are provided, it computes the mean of the scores.
215-
216-
Parameters
217-
----------
218-
scores : Tensor
219-
A tensor containing the scores to be aggregated.
220-
weights : Tensor, optional (default - None)
221-
A tensor of weights corresponding to each score. Must be the same shape as `scores`.
222-
If not provided, the function returns the mean of `scores`.
223-
224-
Returns
225-
-------
226-
Tensor
227-
The aggregated score computed as a weighted mean if **weights** is provided,
228-
or as the simple mean of **scores** otherwise.
229-
"""
230-
231-
if weights is not None:
232-
return keras.ops.mean(scores * weights)
233-
return keras.ops.mean(scores)

bayesflow/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
tree_concatenate,
7272
tree_stack,
7373
fill_triangular_matrix,
74-
weighted_sum,
74+
weighted_mean,
7575
)
7676
from .classification import calibration_curve, confusion_matrix
7777
from .validators import check_lengths_same

bayesflow/utils/tensor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def pad(x: Tensor, value: float | Tensor, n: int, axis: int, side: str = "both")
140140
raise TypeError(f"Invalid side type {type(side)!r}. Must be str.")
141141

142142

143-
def weighted_sum(elements: Tensor, weights: Tensor = None) -> Tensor:
143+
def weighted_mean(elements: Tensor, weights: Tensor = None) -> Tensor:
144144
"""
145145
Compute the (optionally) weighted mean of the input tensor.
146146

0 commit comments

Comments
 (0)