Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
7ab9037
Gromov-Wasserstein distance
Aug 28, 2017
0a68bf4
gromov:flake8 and other
Aug 28, 2017
3007f1d
Minor corrections suggested by @agramfort + new barycenter example + …
Aug 31, 2017
bc68cc3
minor corrections
Aug 31, 2017
986f46d
Merge branch 'master' into gromov
ncourty Aug 31, 2017
e89f09d
remove linewidth error message
Slasnista Jul 28, 2017
f469205
first proposal for OT wrappers
Slasnista Jul 28, 2017
fa36e77
small modifs according to NG proposals
Slasnista Jul 28, 2017
aa19b6a
integrate AG comments
Slasnista Jul 28, 2017
5ab5035
own BaseEstimator class written + rflamary comments addressed
Slasnista Jul 31, 2017
c7eaaf4
update SinkhornTransport class + added test for class
Slasnista Aug 1, 2017
d5c6cc1
added EMDTransport Class from NG's code + added dedicated test
Slasnista Aug 1, 2017
cd4fa72
added test for fit_transform + correction of fit_transform bug (missi…
Slasnista Aug 4, 2017
0659abe
added new class SinkhornLpl1Transport() + dedicated test
Slasnista Aug 4, 2017
2005a09
added new class SinkhornL1l2Transport() + dedicated test
Slasnista Aug 4, 2017
4e562a1
semi supervised mode supported
Slasnista Aug 4, 2017
62b40a9
correction of semi supervised mode
Slasnista Aug 4, 2017
266abb6
reformat doc strings + remove useless log / verbose parameters for emd
Slasnista Aug 4, 2017
b8672f6
out of samples by Ferradans supported for transform and inverse_trans…
Slasnista Aug 4, 2017
117cd33
added new class MappingTransport to support linear and kernel mapping…
Slasnista Aug 4, 2017
d20a067
make doc strings compliant with numpy / modif according to AG review
Slasnista Aug 23, 2017
8d19d36
out of samples transform and inverse transform by batch
Slasnista Aug 23, 2017
c8ae584
test functions for MappingTransport Class
Slasnista Aug 23, 2017
fc58f39
added deprecation warning on old classes
Slasnista Aug 23, 2017
6167f34
solving log issues to avoid errors and adding further tests
Slasnista Aug 25, 2017
181fcd3
refactoring examples according to new DA classes
Slasnista Aug 25, 2017
e1a3984
small corrections for examples
Slasnista Aug 25, 2017
4f802cf
set properly path of data
Slasnista Aug 25, 2017
e1606c1
move no da objects into utils.py
Slasnista Aug 28, 2017
f79f483
handling input arguments in fit, transform... methods + remove old ex…
Slasnista Aug 28, 2017
84e56a0
check input parameters with helper functions
Slasnista Aug 28, 2017
5964001
update readme
Slasnista Aug 28, 2017
24362ec
Gromov-Wasserstein distance
Aug 28, 2017
f8744a3
gromov:flake8 and other
Aug 28, 2017
3730779
addressed AG comments + adding random seed
Slasnista Aug 29, 2017
5a9795f
pass on examples | introduced RandomState
Slasnista Aug 29, 2017
6ae3ad7
Changes to LP solver:
toto6 Aug 29, 2017
b562927
Fix param order
toto6 Aug 29, 2017
0f7cd92
Type print
toto6 Aug 29, 2017
ceeb063
Changes:
toto6 Aug 30, 2017
8875f65
Rename for emd and emd2
toto6 Aug 30, 2017
5076131
Fix name error
toto6 Aug 30, 2017
6d60230
Move normalize function in utils.py
toto6 Aug 30, 2017
93dee55
Move norm out of fit to init for deprecated OTDA
toto6 Aug 30, 2017
8c52517
Minor corrections suggested by @agramfort + new barycenter example + …
Aug 31, 2017
4ec5b33
minor corrections
Aug 31, 2017
ab6ed1d
docstrings and naming
Sep 1, 2017
64a5d3c
docstrings and naming
Sep 1, 2017
46fc12a
solving conflicts :/
Sep 1, 2017
f12322c
add barycenters to Readme.md
Sep 1, 2017
53e1115
docstrings + naming
Sep 1, 2017
8ea74ad
docstrings + naming
Sep 1, 2017
36bf599
Corrections on Gromov
Sep 12, 2017
24784ed
Corrections on Gromov
Sep 12, 2017
84c2723
Corrections on Gromov
Sep 12, 2017
55db350
Corrections on Gromov
Sep 12, 2017
5a2ebfa
Corrections on Gromov
ncourty Sep 13, 2017
7e5df4c
Merge branch 'gromov' of https://github.com/rflamary/POT into gromov
ncourty Sep 13, 2017
c86cc4f
Merge branch 'master' into gromov
ncourty Sep 13, 2017
c7eef9d
Merge branch 'master' into gromov
ncourty Sep 13, 2017
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
docstrings and naming
  • Loading branch information
Nicolas Courty authored and Nicolas Courty committed Sep 1, 2017
commit ab6ed1df93cd78bb7f1a54282103d4d830e68bcb
10 changes: 5 additions & 5 deletions examples/plot_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
"""

n = 30 # nb samples
n_samples = 30 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
Expand All @@ -35,9 +35,9 @@
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])


xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n, 3).dot(P) + mu_t
xt = np.random.randn(n_samples, 3).dot(P) + mu_t


"""
Expand Down Expand Up @@ -75,8 +75,8 @@
=============================================
"""

p = ot.unif(n)
q = ot.unif(n)
p = ot.unif(n_samples)
q = ot.unif(n_samples)

gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
Expand Down
20 changes: 10 additions & 10 deletions examples/plot_gromov_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def im2mat(I):
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))


carre = spi.imread('../data/carre.png').astype(np.float64) / 256
rond = spi.imread('../data/rond.png').astype(np.float64) / 256
square = spi.imread('../data/carre.png').astype(np.float64) / 256
circle = spi.imread('../data/rond.png').astype(np.float64) / 256
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256
arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you rename maybe the png files? also I see arrow = coeur. Is this a bug?


shapes = [carre, rond, triangle, fleche]
shapes = [square, circle, triangle, arrow]

S = 4
xs = [[] for i in range(S)]
Expand All @@ -118,36 +118,36 @@ def im2mat(I):
The four distributions are constructed from 4 simple images
"""
ns = [len(xs[s]) for s in range(S)]
N = 30
n_samples = 30

"""Compute all distances matrices for the four shapes"""
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
Cs = [cs / cs.max() for cs in Cs]

ps = [ot.unif(ns[s]) for s in range(S)]
p = ot.unif(N)
p = ot.unif(n_samples)


lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]

Ct01 = [0 for i in range(2)]
for i in range(2):
Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], [
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numItermax -> max_iter?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lines too long. You don't check pep8 on examples with travis?


Ct02 = [0 for i in range(2)]
for i in range(2):
Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], [
ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)

Ct13 = [0 for i in range(2)]
for i in range(2):
Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], [
ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)

Ct23 = [0 for i in range(2)]
for i in range(2):
Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], [
ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)

"""
Expand Down
18 changes: 9 additions & 9 deletions ot/gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return(np.exp(np.divide(tmpsum, ppt)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for ()



def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein coupling between the two measured similarity matrices

Expand Down Expand Up @@ -248,7 +248,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
numItermax : int, optional
max_iter : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
Expand All @@ -274,7 +274,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
cpt = 0
err = 1

while (err > stopThr and cpt < numItermax):
while (err > stopThr and cpt < max_iter):

Tprev = T

Expand Down Expand Up @@ -307,7 +307,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
return T


def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stopThr -> tol? or stop_thr

"""
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices

Expand Down Expand Up @@ -362,10 +362,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh

if log:
gw, logv = gromov_wasserstein(
C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log)
C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
else:
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
epsilon, numItermax, stopThr, verbose, log)
epsilon, max_iter, stopThr, verbose, log)

if loss_fun == 'square_loss':
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
Expand All @@ -379,7 +379,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
return gw_dist


def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices

Expand Down Expand Up @@ -442,12 +442,12 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000

error = []

while(err > stopThr and cpt < numItermax):
while(err > stopThr and cpt < max_iter):

Cprev = C

T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
numItermax, 1e-5, verbose, log) for s in range(S)]
max_iter, 1e-5, verbose, log) for s in range(S)]

if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
Expand Down
10 changes: 5 additions & 5 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@


def test_gromov():
n = 50 # nb samples
n_samples = 50 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)

xt = [xs[n - (i + 1)] for i in range(n)]
xt = [xs[n_samples - (i + 1)] for i in range(n_samples)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for you really need the for loop here?

xt = np.array(xt)

p = ot.unif(n)
q = ot.unif(n)
p = ot.unif(n_samples)
q = ot.unif(n_samples)

C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
Expand Down