@@ -84,8 +84,6 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False):
8484
8585 if log :
8686 log = {}
87- log ['Cs' ] = Cs
88- log ['Ct' ] = Ct
8987 log ['Cs12' ] = Cs12
9088 log ['Cs12inv' ] = Cs12inv
9189 return A , b , log
@@ -179,23 +177,13 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None,
179177 Cs = nx .dot ((xs * ws ).T , xs ) / nx .sum (ws ) + reg * nx .eye (d , type_as = xs )
180178 Ct = nx .dot ((xt * wt ).T , xt ) / nx .sum (wt ) + reg * nx .eye (d , type_as = xt )
181179
182- Cs12 = nx .sqrtm (Cs )
183- Cs12inv = nx .inv (Cs12 )
184-
185- M0 = nx .sqrtm (dots (Cs12 , Ct , Cs12 ))
186-
187- A = dots (Cs12inv , M0 , Cs12inv )
188-
189- b = mxt - nx .dot (mxs , A )
190-
191180 if log :
192- log = {}
181+ A , b , log = bures_wasserstein_mapping ( mxs , mxt , Cs , Ct , log = log )
193182 log ['Cs' ] = Cs
194183 log ['Ct' ] = Ct
195- log ['Cs12' ] = Cs12
196- log ['Cs12inv' ] = Cs12inv
197184 return A , b , log
198185 else :
186+ A , b = bures_wasserstein_mapping (mxs , mxt , Cs , Ct )
199187 return A , b
200188
201189
@@ -251,7 +239,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
251239 Cs12 = nx .sqrtm (Cs )
252240
253241 B = nx .trace (Cs + Ct - 2 * nx .sqrtm (dots (Cs12 , Ct , Cs12 )))
254- W = nx .norm (ms - mt ) + B
242+ W = nx .sqrt ( nx . norm (ms - mt )** 2 + B )
255243 if log :
256244 log = {}
257245 log ['Cs12' ] = Cs12
@@ -334,15 +322,12 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
334322
335323 Cs = nx .dot ((xs * ws ).T , xs ) / nx .sum (ws ) + reg * nx .eye (d , type_as = xs )
336324 Ct = nx .dot ((xt * wt ).T , xt ) / nx .sum (wt ) + reg * nx .eye (d , type_as = xt )
337- Cs12 = nx .sqrtm (Cs )
338325
339- B = nx .trace (Cs + Ct - 2 * nx .sqrtm (dots (Cs12 , Ct , Cs12 )))
340- W = nx .norm (mxs - mxt ) + B
341326 if log :
342- log = {}
327+ W , log = bures_wasserstein_distance ( mxs , mxt , Cs , Ct , log = log )
343328 log ['Cs' ] = Cs
344329 log ['Ct' ] = Ct
345- log ['Cs12' ] = Cs12
346330 return W , log
347331 else :
332+ W = bures_wasserstein_distance (mxs , mxt , Cs , Ct )
348333 return W
0 commit comments