- Notifications
You must be signed in to change notification settings - Fork 534
Gromov-Wasserstein distance #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7ab9037
0a68bf4
3007f1d
bc68cc3
986f46d
e89f09d
f469205
fa36e77
aa19b6a
5ab5035
c7eaaf4
d5c6cc1
cd4fa72
0659abe
2005a09
4e562a1
62b40a9
266abb6
b8672f6
117cd33
d20a067
8d19d36
c8ae584
fc58f39
6167f34
181fcd3
e1a3984
4f802cf
e1606c1
f79f483
84e56a0
5964001
24362ec
f8744a3
3730779
5a9795f
6ae3ad7
b562927
0f7cd92
ceeb063
8875f65
5076131
6d60230
93dee55
8c52517
4ec5b33
ab6ed1d
64a5d3c
46fc12a
f12322c
53e1115
8ea74ad
36bf599
24784ed
84c2723
55db350
5a2ebfa
7e5df4c
c86cc4f
c7eef9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
| @@ -91,12 +91,12 @@ def im2mat(I): | |
return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) | ||
| ||
| ||
carre = spi.imread('../data/carre.png').astype(np.float64) / 256 | ||
rond = spi.imread('../data/rond.png').astype(np.float64) / 256 | ||
square = spi.imread('../data/carre.png').astype(np.float64) / 256 | ||
circle = spi.imread('../data/rond.png').astype(np.float64) / 256 | ||
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256 | ||
fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256 | ||
arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256 | ||
| ||
shapes = [carre, rond, triangle, fleche] | ||
shapes = [square, circle, triangle, arrow] | ||
| ||
S = 4 | ||
xs = [[] for i in range(S)] | ||
| @@ -118,36 +118,36 @@ def im2mat(I): | |
The four distributions are constructed from 4 simple images | ||
""" | ||
ns = [len(xs[s]) for s in range(S)] | ||
N = 30 | ||
n_samples = 30 | ||
| ||
"""Compute all distances matrices for the four shapes""" | ||
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)] | ||
Cs = [cs / cs.max() for cs in Cs] | ||
| ||
ps = [ot.unif(ns[s]) for s in range(S)] | ||
p = ot.unif(N) | ||
p = ot.unif(n_samples) | ||
| ||
| ||
lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]] | ||
| ||
Ct01 = [0 for i in range(2)] | ||
for i in range(2): | ||
Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [ | ||
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], [ | ||
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) | ||
| ||
| ||
Ct02 = [0 for i in range(2)] | ||
for i in range(2): | ||
Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [ | ||
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], [ | ||
ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) | ||
| ||
Ct13 = [0 for i in range(2)] | ||
for i in range(2): | ||
Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [ | ||
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], [ | ||
ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) | ||
| ||
Ct23 = [0 for i in range(2)] | ||
for i in range(2): | ||
Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [ | ||
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], [ | ||
ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) | ||
| ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
| @@ -208,7 +208,7 @@ def update_kl_loss(p, lambdas, T, Cs): | |
return(np.exp(np.divide(tmpsum, ppt))) | ||
| ||
| ||
| ||
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False): | ||
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False): | ||
""" | ||
Returns the gromov-wasserstein coupling between the two measured similarity matrices | ||
| ||
| @@ -248,7 +248,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr | |
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss' | ||
epsilon : float | ||
Regularization term >0 | ||
numItermax : int, optional | ||
max_iter : int, optional | ||
Max number of iterations | ||
stopThr : float, optional | ||
Stop threshold on error (>0) | ||
| @@ -274,7 +274,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr | |
cpt = 0 | ||
err = 1 | ||
| ||
while (err > stopThr and cpt < numItermax): | ||
while (err > stopThr and cpt < max_iter): | ||
| ||
Tprev = T | ||
| ||
| @@ -307,7 +307,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr | |
return T | ||
| ||
| ||
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False): | ||
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False): | ||
| ||
""" | ||
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices | ||
| ||
| @@ -362,10 +362,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh | |
| ||
if log: | ||
gw, logv = gromov_wasserstein( | ||
C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log) | ||
C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log) | ||
else: | ||
gw = gromov_wasserstein(C1, C2, p, q, loss_fun, | ||
epsilon, numItermax, stopThr, verbose, log) | ||
epsilon, max_iter, stopThr, verbose, log) | ||
| ||
if loss_fun == 'square_loss': | ||
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw)) | ||
| @@ -379,7 +379,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh | |
return gw_dist | ||
| ||
| ||
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False): | ||
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False): | ||
""" | ||
Returns the gromov-wasserstein barycenters of S measured similarity matrices | ||
| ||
| @@ -442,12 +442,12 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000 | |
| ||
error = [] | ||
| ||
while(err > stopThr and cpt < numItermax): | ||
while(err > stopThr and cpt < max_iter): | ||
| ||
Cprev = C | ||
| ||
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, | ||
numItermax, 1e-5, verbose, log) for s in range(S)] | ||
max_iter, 1e-5, verbose, log) for s in range(S)] | ||
| ||
if loss_fun == 'square_loss': | ||
C = update_square_loss(p, lambdas, T, Cs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
| @@ -10,18 +10,18 @@ | |
| ||
| ||
def test_gromov(): | ||
n = 50 # nb samples | ||
n_samples = 50 # nb samples | ||
| ||
mu_s = np.array([0, 0]) | ||
cov_s = np.array([[1, 0], [0, 1]]) | ||
| ||
xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s) | ||
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s) | ||
| ||
xt = [xs[n - (i + 1)] for i in range(n)] | ||
xt = [xs[n_samples - (i + 1)] for i in range(n_samples)] | ||
| ||
xt = np.array(xt) | ||
| ||
p = ot.unif(n) | ||
q = ot.unif(n) | ||
p = ot.unif(n_samples) | ||
q = ot.unif(n_samples) | ||
| ||
C1 = ot.dist(xs, xs) | ||
C2 = ot.dist(xt, xt) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you rename maybe the png files? also I see arrow = coeur. Is this a bug?