44from keras .saving import register_keras_serializable as serializable
55
66from bayesflow .types import Shape , Tensor
7- from bayesflow .utils import logging
7+ from bayesflow .utils import logging , weighted_mean
88from bayesflow .links import OrderedQuantiles
99
1010from .scoring_rule import ScoringRule
@@ -39,7 +39,7 @@ def get_config(self):
3939 base_config = super ().get_config ()
4040 return base_config | self .config
4141
42- def get_head_shapes_from_target_shape (self , target_shape : Shape ):
42+ def get_head_shapes_from_target_shape (self , target_shape : Shape ) -> dict [ str , tuple ] :
4343 # keras.saving.load_model sometimes passes target_shape as a list, so we force a conversion
4444 target_shape = tuple (target_shape )
4545 return dict (value = (len (self .q ),) + target_shape [1 :])
@@ -49,5 +49,5 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
4949 pointwise_differance = estimates - targets [:, None , :]
5050 scores = pointwise_differance * (keras .ops .cast (pointwise_differance > 0 , float ) - self ._q [None , :, None ])
5151 scores = keras .ops .mean (scores , axis = 1 )
52- score = self . aggregate (scores , weights )
52+ score = weighted_mean (scores , weights )
5353 return score
0 commit comments