Skip to content
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- Fix same sign error for sr(F)GW conditional gradient solvers (PR #611)
- Split `test/test_gromov.py` into `test/gromov/` (PR #619)
- Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628)
- Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602)

## 0.9.3
*January 2024*
Expand Down
15 changes: 9 additions & 6 deletions ot/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
"""

# Author: Laetitia Chapel <laetitia.chapel@irisa.fr>
# License: MIT License
# Yikun Bai < yikun.bai@vanderbilt.edu >
# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>

import numpy as np
from .lp import emd
from .backend import get_backend
from .utils import list_to_array
from .backend import get_backend
from .lp import emd
import numpy as np

# License: MIT License


def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
Expand Down Expand Up @@ -581,7 +584,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
" equal than min(|a|_1, |b|_1).")

if G0 is None:
G0 = np.outer(p, q)
G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q.

dim_G_extended = (len(p) + nb_dummies, len(q) + nb_dummies)
q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies)
Expand All @@ -597,7 +600,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,

Gprev = np.copy(G0)

M = gwgrad_partial(C1, C2, G0)
M = 0.5 * gwgrad_partial(C1, C2, G0) # rescaling the gradient with 0.5 for line-search while not changing Gc
M_emd = np.zeros(dim_G_extended)
M_emd[:len(p), :len(q)] = M
M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2
Expand Down