33import keras
44
55from bayesflow .types import Shape , Tensor
6- from bayesflow .links import PositiveDefinite
6+ from bayesflow .links import CholeskyFactor
77from bayesflow .utils .serialization import serializable
88
99from .parametric_distribution_score import ParametricDistributionScore
1313class MultivariateNormalScore (ParametricDistributionScore ):
1414 r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))`
1515
16- Scores a predicted mean and covariance matrix with the log-score of the probability of the materialized value.
16+ Scores a predicted mean and (Cholesky factor of the) covariance matrix with the log-score of the probability
17+ of the materialized value.
1718 """
1819
19- NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("covariance " ,)
20+ NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol " ,)
2021 """
21- Marks head for covariance matrix as an exception for adapter transformations.
22+ Marks head for covariance matrix Cholesky factor as an exception for adapter transformations.
2223
2324 This variable contains names of prediction heads that should lead to a warning when the adapter is applied
2425 in inverse direction to them.
2526
2627 For more information see :py:class:`ScoringRule`.
2728 """
2829
29- TRANSFORMATION_TYPE : dict [str , str ] = {"covariance " : "both_sides_scale " }
30+ TRANSFORMATION_TYPE : dict [str , str ] = {"cov_chol " : "left_side_scale " }
3031 """
31- Marks covariance head to handle de-standardization as for covariant rank-(0,2) tensors.
32+ Marks covariance Cholesky factor head to handle de-standardization as for covariant rank-(0,2) tensors.
3233
3334 The appropriate inverse of the standardization operation is
3435
35- x_ij = x_ij' * sigma_i * sigma_j .
36+ x_ij = sigma_i * x_ij' .
3637
3738 For the mean head the default ("location_scale") is not overridden.
3839 """
@@ -41,7 +42,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
4142 super ().__init__ (links = links , ** kwargs )
4243
4344 self .dim = dim
44- self .links = links or {"covariance " : PositiveDefinite ()}
45+ self .links = links or {"cov_chol " : CholeskyFactor ()}
4546
4647 self .config = {"dim" : dim }
4748
@@ -51,14 +52,14 @@ def get_config(self):
5152
5253 def get_head_shapes_from_target_shape (self , target_shape : Shape ) -> dict [str , Shape ]:
5354 self .dim = target_shape [- 1 ]
54- return dict (mean = (self .dim ,), covariance = (self .dim , self .dim ))
55+ return dict (mean = (self .dim ,), cov_chol = (self .dim , self .dim ))
5556
56- def log_prob (self , x : Tensor , mean : Tensor , covariance : Tensor ) -> Tensor :
57+ def log_prob (self , x : Tensor , mean : Tensor , cov_chol : Tensor ) -> Tensor :
5758 """
5859 Compute the log probability density of a multivariate Gaussian distribution.
5960
6061 This function calculates the log probability density for each sample in `x` under a
61- multivariate Gaussian distribution with the given `mean` and `covariance `.
62+ multivariate Gaussian distribution with the given `mean` and `cov_chol `.
6263
6364 The computation includes the determinant of the covariance matrix, its inverse, and the quadratic
6465 form in the exponential term of the Gaussian density function.
@@ -80,6 +81,12 @@ def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
8081 given Gaussian distribution.
8182 """
8283 diff = x - mean
84+
85+ # Calculate covariance from Cholesky factors
86+ covariance = keras .ops .matmul (
87+ cov_chol ,
88+ keras .ops .swapaxes (cov_chol , - 2 , - 1 ),
89+ )
8390 precision = keras .ops .inv (covariance )
8491 log_det_covariance = keras .ops .slogdet (covariance )[1 ] # Only take the log of the determinant part
8592
@@ -91,14 +98,12 @@ def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
9198
9299 return log_prob
93100
94- def sample (self , batch_shape : Shape , mean : Tensor , covariance : Tensor ) -> Tensor :
101+ def sample (self , batch_shape : Shape , mean : Tensor , cov_chol : Tensor ) -> Tensor :
95102 """
96103 Generate samples from a multivariate Gaussian distribution.
97104
98- This function samples from a multivariate Gaussian distribution with the given `mean`
99- and `covariance` using the Cholesky decomposition method. Independent standard normal
100- samples are transformed using the Cholesky factor of the covariance matrix to generate
101- correlated samples.
105+ Independent standard normal samples are transformed using the Cholesky factor of the covariance matrix
106+ to generate correlated samples.
102107
103108 Parameters
104109 ----------
@@ -107,8 +112,8 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
107112 mean : Tensor
108113 A tensor representing the mean of the multivariate Gaussian distribution.
109114 Must have shape (batch_size, D), where D is the dimensionality of the distribution.
110- covariance : Tensor
111- A tensor representing the covariance matrix of the multivariate Gaussian distribution.
115+ cov_chol : Tensor
116+ A tensor representing a Cholesky factor of the covariance matrix of the multivariate Gaussian distribution.
112117 Must have shape (batch_size, D, D), where D is the dimensionality.
113118
114119 Returns
@@ -123,16 +128,16 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
123128 if keras .ops .shape (mean ) != (batch_size , dim ):
124129 raise ValueError (f"mean must have shape (batch_size, { dim } ), but got { keras .ops .shape (mean )} " )
125130
126- if keras .ops .shape (covariance ) != (batch_size , dim , dim ):
131+ if keras .ops .shape (cov_chol ) != (batch_size , dim , dim ):
127132 raise ValueError (
128- f"covariance must have shape (batch_size, { dim } , { dim } ), but got { keras .ops .shape (covariance )} "
133+ f"covariance Cholesky factor must have shape (batch_size, { dim } , { dim } ),"
134+ f"but got { keras .ops .shape (cov_chol )} "
129135 )
130136
131137 # Use Cholesky decomposition to generate samples
132- cholesky_factor = keras .ops .cholesky (covariance )
133138 normal_samples = keras .random .normal ((* batch_shape , dim ))
134139
135- scaled_normal = keras .ops .einsum ("ijk,ilk->ilj" , cholesky_factor , normal_samples )
140+ scaled_normal = keras .ops .einsum ("ijk,ilk->ilj" , cov_chol , normal_samples )
136141 samples = mean [:, None , :] + scaled_normal
137142
138143 return samples
0 commit comments