Skip to content

Commit 31455e5

Browse files
[MRG] Fix order problems in (F)GW barycenters and utils (#576)
* fix local merge * correct gw.utils file * fix gw barycenter functions * fix tests * complete coverage * add PR to release * corrections Remi
1 parent ef6c3c1 commit 31455e5

File tree

5 files changed

+115
-63
lines changed

5 files changed

+115
-63
lines changed

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572)
2727
- Create `ot/bregman/`repository (Issue #567, PR #569)
2828
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)
29+
- Fix (fused) gromov-wasserstein barycenter solvers to support `kl_loss`(PR #576)
2930

3031

3132
## 0.9.1
@@ -602,4 +603,4 @@ It provides the following solvers:
602603
* Optimal transport for domain adaptation with group lasso regularization
603604
* Conditional gradient and Generalized conditional gradient for regularized OT.
604605

605-
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
606+
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.

ot/gromov/_bregman.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -457,17 +457,17 @@ def entropic_gromov_barycenters(
457457
Cprev = C
458458
if warmstartT:
459459
T = [entropic_gromov_wasserstein(
460-
Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, T[s],
460+
C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, T[s],
461461
max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
462462
else:
463463
T = [entropic_gromov_wasserstein(
464-
Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None,
464+
C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, None,
465465
max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
466466

467467
if loss_fun == 'square_loss':
468-
C = update_square_loss(p, lambdas, T, Cs)
468+
C = update_square_loss(p, lambdas, T, Cs, nx)
469469
elif loss_fun == 'kl_loss':
470-
C = update_kl_loss(p, lambdas, T, Cs)
470+
C = update_kl_loss(p, lambdas, T, Cs, nx)
471471

472472
if cpt % 10 == 0:
473473
# we can speed up the process by checking for the error only all
@@ -962,9 +962,9 @@ def entropic_fused_gromov_barycenters(
962962
Y = init_Y
963963

964964
if warmstartT:
965-
T = [nx.outer(p_, p) for p_ in ps]
965+
T = [None] * S
966966

967-
Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
967+
Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]
968968

969969
cpt = 0
970970
err = 1
@@ -984,23 +984,22 @@ def entropic_fused_gromov_barycenters(
984984

985985
if warmstartT:
986986
T = [entropic_fused_gromov_wasserstein(
987-
Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
987+
Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha,
988988
T[s], max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
989989

990990
else:
991991
T = [entropic_fused_gromov_wasserstein(
992-
Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
992+
Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha,
993993
None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
994994

995995
if loss_fun == 'square_loss':
996-
C = update_square_loss(p, lambdas, T, Cs)
996+
C = update_square_loss(p, lambdas, T, Cs, nx)
997997
elif loss_fun == 'kl_loss':
998-
C = update_kl_loss(p, lambdas, T, Cs)
998+
C = update_kl_loss(p, lambdas, T, Cs, nx)
999999

10001000
Ys_temp = [y.T for y in Ys]
1001-
T_temp = [Ts.T for Ts in T]
1002-
Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p).T
1003-
Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
1001+
Y = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
1002+
Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]
10041003

10051004
if cpt % 10 == 0:
10061005
# we can speed up the process by checking for the error only all

ot/gromov/_gw.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -830,16 +830,18 @@ def gromov_barycenters(
830830
Cprev = C
831831

832832
if warmstartT:
833-
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s],
834-
max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
833+
T = [gromov_wasserstein(
834+
C, Cs[s], p, ps[s], loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s],
835+
max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
835836
else:
836-
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=None,
837-
max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
837+
T = [gromov_wasserstein(
838+
C, Cs[s], p, ps[s], loss_fun, symmetric=symmetric, armijo=armijo, G0=None,
839+
max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
838840
if loss_fun == 'square_loss':
839-
C = update_square_loss(p, lambdas, T, Cs)
841+
C = update_square_loss(p, lambdas, T, Cs, nx)
840842

841843
elif loss_fun == 'kl_loss':
842-
C = update_kl_loss(p, lambdas, T, Cs)
844+
C = update_kl_loss(p, lambdas, T, Cs, nx)
843845

844846
if cpt % 10 == 0:
845847
# we can speed up the process by checking for the error only all
@@ -898,14 +900,14 @@ def fgw_barycenters(
898900
If let to its default value None, uniform weights are taken.
899901
alpha : float, optional
900902
Alpha parameter for the fgw distance.
901-
fixed_structure : bool
902-
Whether to fix the structure of the barycenter during the updates
903-
fixed_features : bool
903+
fixed_structure : bool, optional
904+
Whether to fix the structure of the barycenter during the updates.
905+
fixed_features : bool, optional
904906
Whether to fix the feature of the barycenter during the updates
905907
p : array-like, shape (N,), optional
906908
Weights in the targeted barycenter.
907909
If let to its default value None, uniform distribution is taken.
908-
loss_fun : str
910+
loss_fun : str, optional
909911
Loss function used for the solver either 'square_loss' or 'kl_loss'
910912
symmetric : bool, optional
911913
Either structures are to be assumed symmetric or not. Default value is True.
@@ -1024,19 +1026,18 @@ def fgw_barycenters(
10241026
T = [fused_gromov_wasserstein(
10251027
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
10261028
G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
1027-
# T is N,ns
1029+
10281030
if not fixed_features:
10291031
Ys_temp = [y.T for y in Ys]
1030-
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
1032+
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
10311033
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
10321034

10331035
if not fixed_structure:
1034-
T_temp = [t.T for t in T]
10351036
if loss_fun == 'square_loss':
1036-
C = update_square_loss(p, lambdas, T_temp, Cs)
1037+
C = update_square_loss(p, lambdas, T, Cs, nx)
10371038

10381039
elif loss_fun == 'kl_loss':
1039-
C = update_kl_loss(p, lambdas, T_temp, Cs)
1040+
C = update_kl_loss(p, lambdas, T, Cs, nx)
10401041

10411042
err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
10421043
err_structure = nx.norm(C - Cprev)

ot/gromov/_utils.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -253,46 +253,75 @@ def gwggrad(constC, hC1, hC2, T, nx=None):
253253
T, nx) # [12] Prop. 2 misses a 2 factor
254254

255255

256-
def update_square_loss(p, lambdas, T, Cs):
256+
def update_square_loss(p, lambdas, T, Cs, nx=None):
257257
r"""
258-
Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
259-
couplings calculated at each iteration
258+
Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S`
259+
:math:`\mathbf{T}_s` couplings calculated at each iteration of the GW
260+
barycenter problem in :ref:`[12]`:
261+
262+
.. math::
263+
264+
\mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
265+
266+
Where :
267+
268+
- :math:`\mathbf{C}_s`: metric cost matrix
269+
- :math:`\mathbf{p}_s`: distribution
260270
261271
Parameters
262272
----------
263273
p : array-like, shape (N,)
264274
Masses in the targeted barycenter.
265275
lambdas : list of float
266276
List of the `S` spaces' weights.
267-
T : list of S array-like of shape (ns,N)
277+
T : list of S array-like of shape (N, ns)
268278
The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
269279
Cs : list of S array-like, shape(ns,ns)
270280
Metric cost matrices.
281+
nx : backend, optional
282+
If let to its default value None, a backend test will be conducted.
271283
272284
Returns
273285
----------
274286
C : array-like, shape (`nt`, `nt`)
275287
Updated :math:`\mathbf{C}` matrix.
288+
289+
References
290+
----------
291+
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
292+
"Gromov-Wasserstein averaging of kernel and distance matrices."
293+
International Conference on Machine Learning (ICML). 2016.
294+
276295
"""
277-
T = list_to_array(*T)
278-
Cs = list_to_array(*Cs)
279-
p = list_to_array(p)
280-
nx = get_backend(p, *T, *Cs)
296+
if nx is None:
297+
nx = get_backend(p, *T, *Cs)
281298

299+
# Correct order mistake in Equation 14 in [12]
282300
tmpsum = sum([
283301
lambdas[s] * nx.dot(
284-
nx.dot(T[s].T, Cs[s]),
285-
T[s]
302+
nx.dot(T[s], Cs[s]),
303+
T[s].T
286304
) for s in range(len(T))
287305
])
288306
ppt = nx.outer(p, p)
289307

290308
return tmpsum / ppt
291309

292310

293-
def update_kl_loss(p, lambdas, T, Cs):
311+
def update_kl_loss(p, lambdas, T, Cs, nx=None):
294312
r"""
295-
Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
313+
Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S`
314+
:math:`\mathbf{T}_s` couplings calculated at each iteration of the GW
315+
barycenter problem in :ref:`[12]`:
316+
317+
.. math::
318+
319+
\mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
320+
321+
Where :
322+
323+
- :math:`\mathbf{C}_s`: metric cost matrix
324+
- :math:`\mathbf{p}_s`: distribution
296325
297326
298327
Parameters
@@ -301,33 +330,41 @@ def update_kl_loss(p, lambdas, T, Cs):
301330
Weights in the targeted barycenter.
302331
lambdas : list of float
303332
List of the `S` spaces' weights
304-
T : list of S array-like of shape (ns,N)
333+
T : list of S array-like of shape (N, ns)
305334
The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
306335
Cs : list of S array-like, shape(ns,ns)
307336
Metric cost matrices.
337+
nx : backend, optional
338+
If let to its default value None, a backend test will be conducted.
308339
309340
Returns
310341
----------
311342
C : array-like, shape (`ns`, `ns`)
312343
updated :math:`\mathbf{C}` matrix
344+
345+
References
346+
----------
347+
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
348+
"Gromov-Wasserstein averaging of kernel and distance matrices."
349+
International Conference on Machine Learning (ICML). 2016.
350+
313351
"""
314-
Cs = list_to_array(*Cs)
315-
T = list_to_array(*T)
316-
p = list_to_array(p)
317-
nx = get_backend(p, *T, *Cs)
352+
if nx is None:
353+
nx = get_backend(p, *T, *Cs)
318354

355+
# Correct order mistake in Equation 15 in [12]
319356
tmpsum = sum([
320357
lambdas[s] * nx.dot(
321-
nx.dot(T[s].T, Cs[s]),
322-
T[s]
358+
nx.dot(T[s], nx.log(nx.maximum(Cs[s], 1e-15))),
359+
T[s].T
323360
) for s in range(len(T))
324361
])
325362
ppt = nx.outer(p, p)
326363

327364
return nx.exp(tmpsum / ppt)
328365

329366

330-
def update_feature_matrix(lambdas, Ys, Ts, p):
367+
def update_feature_matrix(lambdas, Ys, Ts, p, nx=None):
331368
r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
332369
333370
@@ -340,10 +377,12 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
340377
masses in the targeted barycenter
341378
lambdas : list of float
342379
List of the `S` spaces' weights
343-
Ts : list of S array-like, shape (ns,N)
380+
Ts : list of S array-like, shape (N, ns)
344381
The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
345382
Ys : list of S array-like, shape (d,ns)
346383
The features.
384+
nx : backend, optional
385+
If let to its default value None, a backend test will be conducted.
347386
348387
Returns
349388
-------
@@ -357,10 +396,8 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
357396
"Optimal Transport for structured data with application on graphs"
358397
International Conference on Machine Learning (ICML). 2019.
359398
"""
360-
p = list_to_array(p)
361-
Ts = list_to_array(*Ts)
362-
Ys = list_to_array(*Ys)
363-
nx = get_backend(*Ys, *Ts, p)
399+
if nx is None:
400+
nx = get_backend(*Ys, *Ts, p)
364401

365402
p = 1. / p
366403
tmpsum = sum([

test/test_gromov.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,20 +1441,29 @@ def test_fgw_barycenter(nx):
14411441
p = ot.unif(n_samples)
14421442

14431443
ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)
1444-
1445-
Xb, Cb = ot.gromov.fgw_barycenters(
1446-
n_samples, [ysb, ytb], [C1b, C2b], None, [.5, .5], 0.5, fixed_structure=False,
1447-
fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345
1444+
lambdas = [.5, .5]
1445+
Csb = [C1b, C2b]
1446+
Ysb = [ysb, ytb]
1447+
Xb, Cb, logb = ot.gromov.fgw_barycenters(
1448+
n_samples, Ysb, Csb, None, lambdas, 0.5, fixed_structure=False,
1449+
fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
1450+
random_state=12345, log=True
14481451
)
1452+
# test correspondance with utils function
1453+
recovered_Cb = ot.gromov.update_square_loss(pb, lambdas, logb['Ts_iter'][-1], Csb)
1454+
recovered_Xb = ot.gromov.update_feature_matrix(lambdas, [y.T for y in Ysb], logb['Ts_iter'][-1], pb).T
1455+
1456+
np.testing.assert_allclose(Cb, recovered_Cb)
1457+
np.testing.assert_allclose(Xb, recovered_Xb)
14491458

14501459
xalea = rng.randn(n_samples, 2)
14511460
init_C = ot.dist(xalea, xalea)
14521461
init_C /= init_C.max()
14531462
init_Cb = nx.from_numpy(init_C)
14541463

1455-
with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_structure=True`and `init_C=None`
1464+
with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None`
14561465
Xb, Cb = ot.gromov.fgw_barycenters(
1457-
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
1466+
n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None,
14581467
alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False,
14591468
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
14601469
)
@@ -1471,7 +1480,7 @@ def test_fgw_barycenter(nx):
14711480
init_X = rng.randn(n_samples, ys.shape[1])
14721481
init_Xb = nx.from_numpy(init_X)
14731482

1474-
with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_features=True`and `init_X=None`
1483+
with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None`
14751484
Xb, Cb, logb = ot.gromov.fgw_barycenters(
14761485
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
14771486
fixed_structure=False, fixed_features=True, init_X=None,
@@ -1490,14 +1499,19 @@ def test_fgw_barycenter(nx):
14901499
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
14911500

14921501
# add test with 'kl_loss'
1493-
X, C = ot.gromov.fgw_barycenters(
1502+
X, C, log = ot.gromov.fgw_barycenters(
14941503
n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5,
14951504
fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss',
1496-
max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True, random_state=12345
1505+
max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True,
1506+
random_state=12345, log=True
14971507
)
14981508
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
14991509
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
15001510

1511+
# test correspondance with utils function
1512+
recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['Ts_iter'][-1], [C1, C2])
1513+
np.testing.assert_allclose(C, recovered_C)
1514+
15011515

15021516
def test_gromov_wasserstein_linear_unmixing(nx):
15031517
n = 4

0 commit comments

Comments
 (0)