Skip to content

Conversation

@arrjon
Copy link
Member

@arrjon arrjon commented Sep 6, 2025

This pull request updates the DiagonalNormal distribution implementation to support multi-dimensional (non-vector) distributions instead of being limited to vectors. The main changes involve replacing the single dim attribute with a tuple dims, and updating all relevant logic to handle arbitrary shapes.

This is necessary if we pass not a vector but something else to inference networks like diffusion, flow matching etc.

@arrjon arrjon self-assigned this Sep 6, 2025
@arrjon arrjon marked this pull request as draft September 6, 2025 09:56
@codecov
Copy link

codecov bot commented Sep 7, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
bayesflow/approximators/continuous_approximator.py 91.87% <100.00%> (+0.12%) ⬆️
bayesflow/distributions/diagonal_normal.py 95.34% <100.00%> (ø)
@arrjon arrjon marked this pull request as ready for review September 7, 2025 16:12
@arrjon arrjon requested a review from stefanradev93 September 8, 2025 15:14
Copy link
Contributor

@stefanradev93 stefanradev93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Jonas, thanks for the PR. See individual comments.

@arrjon
Copy link
Member Author

arrjon commented Sep 9, 2025

I changed the batch shape now to

target_dim = self.inference_network.base_distribution.dims batch_shape = keras.ops.shape(inference_conditions)[: -len(target_dim)] 

that should cover now more use cases as the shape is inferred from the conditions, but the base shape is also considered.

@arrjon arrjon merged commit bc2bda8 into dev Sep 10, 2025
9 checks passed
@arrjon arrjon deleted the normal_distribution_dimension branch September 10, 2025 08:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants