2020
2121@serializable ("bayesflow.networks" )
2222class FlowMatching (InferenceNetwork ):
23- """(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated
24- from [1-3].
25-
26- [1] Rectified Flow: arXiv:2209.03003
27- [2] Flow Matching: arXiv:2210.02747
28- [3] Optimal Transport Flow Matching: arXiv:2302.00482
23+ """(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas
24+ incorporated from [1-5].
25+
26+ [1] Liu et al. (2022). Flow straight and fast: Learning to generate and transfer data with rectified flow.
27+ arXiv preprint arXiv:2209.03003.
28+ [2] Lipman et al. (2022). Flow matching for generative modeling.
29+ arXiv preprint arXiv:2210.02747.
30+ [3] Tong et al. (2023). Improving and generalizing flow-based generative models with minibatch optimal transport.
31+ arXiv preprint arXiv:2302.00482.
32+ [4] Wildberger et al. (2023). Flow matching for scalable simulation-based inference.
33+ Advances in Neural Information Processing Systems, 36, 16837-16864.
34+ [5] Orsini et al. (2025). Flow matching posterior estimation for simulation-based atmospheric retrieval of
35+ exoplanets. IEEE Access.
2936 """
3037
3138 MLP_DEFAULT_CONFIG = {
@@ -59,6 +66,7 @@ def __init__(
5966 integrate_kwargs : dict [str , any ] = None ,
6067 optimal_transport_kwargs : dict [str , any ] = None ,
6168 subnet_kwargs : dict [str , any ] = None ,
69+ time_power_law_alpha : float = 0.0 ,
6270 ** kwargs ,
6371 ):
6472 """
@@ -96,6 +104,9 @@ def __init__(
96104 into a single vector or passed as separate arguments. If set to False, the subnet
97105 must accept three separate inputs: 'x' (noisy parameters), 't' (time),
98106 and optional 'conditions'. Default is True.
107+ time_power_law_alpha: float, optional
108+ Changes the distribution of sampled times during training. Time is sampled from a power law distribution
109+ p(t) ∝ t^(1/(1+α)), where α is the provided value. Default is α=0, which corresponds to uniform sampling.
99110 **kwargs
100111 Additional keyword arguments passed to the subnet and other components.
101112 """
@@ -107,6 +118,9 @@ def __init__(
107118 self .optimal_transport_kwargs = FlowMatching .OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})
108119
109120 self .loss_fn = keras .losses .get (loss_fn )
121+ self .time_power_law_alpha = float (time_power_law_alpha )
122+ if self .time_power_law_alpha <= - 1.0 :
123+ raise ValueError ("'time_power_law_alpha' must be greater than -1.0." )
110124
111125 self .seed_generator = keras .random .SeedGenerator ()
112126
@@ -164,6 +178,7 @@ def get_config(self):
164178 "integrate_kwargs" : self .integrate_kwargs ,
165179 "optimal_transport_kwargs" : self .optimal_transport_kwargs ,
166180 "concatenate_subnet_input" : self ._concatenate_subnet_input ,
181+ "time_power_law_alpha" : self .time_power_law_alpha ,
167182 # we do not need to store subnet_kwargs
168183 }
169184
@@ -307,7 +322,9 @@ def compute_metrics(
307322 # conditions must be resampled along with x1
308323 conditions = keras .ops .take (conditions , assignments , axis = 0 )
309324
310- t = keras .random .uniform ((keras .ops .shape (x0 )[0 ],), seed = self .seed_generator )
325+ u = keras .random .uniform ((keras .ops .shape (x0 )[0 ],), seed = self .seed_generator )
326+ # p(t) ∝ t^(1/(1+α)), the inverse CDF: F^(-1)(u) = u^(1+α), α=0 is uniform
327+ t = u ** (1 + self .time_power_law_alpha )
311328 t = expand_right_as (t , x0 )
312329
313330 x = t * x1 + (1 - t ) * x0
0 commit comments