Skip to content

Commit cd2d093

Browse files
Add New Flow Matching Schedules (bayesflow-org#565)
* add fm schedule * add fm schedule * add comments * expose time_power_law_alpha * Improve doc [skip ci] --------- Co-authored-by: stefanradev93 <stefan.radev93@gmail.com>
1 parent 076fdc8 commit cd2d093

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,19 @@
2020

2121
@serializable("bayesflow.networks")
2222
class FlowMatching(InferenceNetwork):
23-
"""(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated
24-
from [1-3].
25-
26-
[1] Rectified Flow: arXiv:2209.03003
27-
[2] Flow Matching: arXiv:2210.02747
28-
[3] Optimal Transport Flow Matching: arXiv:2302.00482
23+
"""(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas
24+
incorporated from [1-5].
25+
26+
[1] Liu et al. (2022). Flow straight and fast: Learning to generate and transfer data with rectified flow.
27+
arXiv preprint arXiv:2209.03003.
28+
[2] Lipman et al. (2022). Flow matching for generative modeling.
29+
arXiv preprint arXiv:2210.02747.
30+
[3] Tong et al. (2023). Improving and generalizing flow-based generative models with minibatch optimal transport.
31+
arXiv preprint arXiv:2302.00482.
32+
[4] Wildberger et al. (2023). Flow matching for scalable simulation-based inference.
33+
Advances in Neural Information Processing Systems, 36, 16837-16864.
34+
[5] Orsini et al. (2025). Flow matching posterior estimation for simulation-based atmospheric retrieval of
35+
exoplanets. IEEE Access.
2936
"""
3037

3138
MLP_DEFAULT_CONFIG = {
@@ -59,6 +66,7 @@ def __init__(
5966
integrate_kwargs: dict[str, any] = None,
6067
optimal_transport_kwargs: dict[str, any] = None,
6168
subnet_kwargs: dict[str, any] = None,
69+
time_power_law_alpha: float = 0.0,
6270
**kwargs,
6371
):
6472
"""
@@ -96,6 +104,9 @@ def __init__(
96104
into a single vector or passed as separate arguments. If set to False, the subnet
97105
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
98106
and optional 'conditions'. Default is True.
107+
time_power_law_alpha: float, optional
108+
Changes the distribution of sampled times during training. Time is sampled from a power law distribution
109+
p(t) ∝ t^(1/(1+α)), where α is the provided value. Default is α=0, which corresponds to uniform sampling.
99110
**kwargs
100111
Additional keyword arguments passed to the subnet and other components.
101112
"""
@@ -107,6 +118,9 @@ def __init__(
107118
self.optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})
108119

109120
self.loss_fn = keras.losses.get(loss_fn)
121+
self.time_power_law_alpha = float(time_power_law_alpha)
122+
if self.time_power_law_alpha <= -1.0:
123+
raise ValueError("'time_power_law_alpha' must be greater than -1.0.")
110124

111125
self.seed_generator = keras.random.SeedGenerator()
112126

@@ -164,6 +178,7 @@ def get_config(self):
164178
"integrate_kwargs": self.integrate_kwargs,
165179
"optimal_transport_kwargs": self.optimal_transport_kwargs,
166180
"concatenate_subnet_input": self._concatenate_subnet_input,
181+
"time_power_law_alpha": self.time_power_law_alpha,
167182
# we do not need to store subnet_kwargs
168183
}
169184

@@ -307,7 +322,9 @@ def compute_metrics(
307322
# conditions must be resampled along with x1
308323
conditions = keras.ops.take(conditions, assignments, axis=0)
309324

310-
t = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
325+
u = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
326+
# p(t) ∝ t^(1/(1+α)), the inverse CDF: F^(-1)(u) = u^(1+α), α=0 is uniform
327+
t = u ** (1 + self.time_power_law_alpha)
311328
t = expand_right_as(t, x0)
312329

313330
x = t * x1 + (1 - t) * x0

0 commit comments

Comments
 (0)