Skip to content

Conversation

@arrjon
Copy link
Member

@arrjon arrjon commented Sep 16, 2025

This pull request introduces improvements to keyword argument handling in SetTransformer (please check if this fine) and adds tests for the ContinuousApproximator sampling functionality. Before calling workflow.sample with additional kwargs, e.g. method='euler' for FlowMatching, was failing with the SetTransformer.

@arrjon arrjon self-assigned this Sep 16, 2025
@arrjon arrjon marked this pull request as ready for review September 16, 2025 10:29
@codecov
Copy link

codecov bot commented Sep 16, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
bayesflow/networks/transformers/mab.py 96.77% <100.00%> (ø)
bayesflow/networks/transformers/set_transformer.py 93.75% <100.00%> (ø)

... and 3 files with indirect coverage changes

Copy link
Contributor

@stefanradev93 stefanradev93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was surprised that the existing filter in ContinuousApproximator (L525-527) is not sufficient:

summary_outputs = self.summary_network( summary_variables, **filter_kwargs(kwargs, self.summary_network.call) )

Perhaps removing the unnecessary **kwargs argument from all attention helpers (e.g., MultiHeadAttentionBlock) solves it cleanly?

Perhaps it is because the call method itself accepts kwargs, so the argument gets propagated down to the multihead attention block.

@arrjon
Copy link
Member Author

arrjon commented Sep 16, 2025

Perhaps removing the unnecessary **kwargs argument from all attention helpers (e.g., MultiHeadAttentionBlock) solves it cleanly?

I can have a look.

Perhaps it is because the call method itself accepts kwargs, so the argument gets propagated down to the multihead attention block.

That is indeed the problem.

@arrjon
Copy link
Member Author

arrjon commented Sep 16, 2025

self.attention_blocks = keras.Sequential()
this part did not accept kwargs at all, so I removed them entirely

@arrjon arrjon merged commit cede7f8 into dev Sep 23, 2025
9 checks passed
@arrjon arrjon deleted the fix_sampling_method_kwargs branch September 23, 2025 11:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants