Skip to content

Commit 1da6bc5

Browse files
committed
fix equation
1 parent b9e45cc commit 1da6bc5

File tree

3 files changed

+9
-29
lines changed

3 files changed

+9
-29
lines changed

ot/da.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,12 +680,7 @@ def df(G):
680680
return G, L
681681

682682

683-
@deprecated()
684-
def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
685-
wt=None, bias=True, log=False):
686-
""" Deprecated see ot.gaussian.empirical_bures_wasserstein_mapping"""
687-
return empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None,
688-
wt=None, bias=True, log=False)
683+
OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping)
689684

690685

691686
def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,

ot/gaussian.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False):
8484

8585
if log:
8686
log = {}
87-
log['Cs'] = Cs
88-
log['Ct'] = Ct
8987
log['Cs12'] = Cs12
9088
log['Cs12inv'] = Cs12inv
9189
return A, b, log
@@ -179,23 +177,13 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None,
179177
Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
180178
Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
181179

182-
Cs12 = nx.sqrtm(Cs)
183-
Cs12inv = nx.inv(Cs12)
184-
185-
M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
186-
187-
A = dots(Cs12inv, M0, Cs12inv)
188-
189-
b = mxt - nx.dot(mxs, A)
190-
191180
if log:
192-
log = {}
181+
A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log)
193182
log['Cs'] = Cs
194183
log['Ct'] = Ct
195-
log['Cs12'] = Cs12
196-
log['Cs12inv'] = Cs12inv
197184
return A, b, log
198185
else:
186+
A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct)
199187
return A, b
200188

201189

@@ -251,7 +239,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
251239
Cs12 = nx.sqrtm(Cs)
252240

253241
B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
254-
W = nx.norm(ms - mt) + B
242+
W = nx.sqrt(nx.norm(ms - mt)**2 + B)
255243
if log:
256244
log = {}
257245
log['Cs12'] = Cs12
@@ -334,15 +322,12 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
334322

335323
Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
336324
Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
337-
Cs12 = nx.sqrtm(Cs)
338325

339-
B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
340-
W = nx.norm(mxs - mxt) + B
341326
if log:
342-
log = {}
327+
W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log)
343328
log['Cs'] = Cs
344329
log['Ct'] = Ct
345-
log['Cs12'] = Cs12
346330
return W, log
347331
else:
332+
W = bures_wasserstein_distance(mxs, mxt, Cs, Ct)
348333
return W

test/test_gaussian.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ def test_bures_wasserstein_distance(nx):
8383

8484
@pytest.mark.parametrize("bias", [True, False])
8585
def test_empirical_bures_wasserstein_distance(nx, bias):
86-
ns = 200
87-
nt = 200
86+
ns = 400
87+
nt = 400
8888

89-
rng = np.random.RandomState(2)
89+
rng = np.random.RandomState(10)
9090
Xs = rng.normal(0, 1, ns)[:, np.newaxis]
9191
Xt = rng.normal(10 * bias, 1, nt)[:, np.newaxis]
9292

0 commit comments

Comments
 (0)