Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
7ab9037
Gromov-Wasserstein distance
Aug 28, 2017
0a68bf4
gromov:flake8 and other
Aug 28, 2017
3007f1d
Minor corrections suggested by @agramfort + new barycenter example + …
Aug 31, 2017
bc68cc3
minor corrections
Aug 31, 2017
986f46d
Merge branch 'master' into gromov
ncourty Aug 31, 2017
e89f09d
remove linewidth error message
Slasnista Jul 28, 2017
f469205
first proposal for OT wrappers
Slasnista Jul 28, 2017
fa36e77
small modifs according to NG proposals
Slasnista Jul 28, 2017
aa19b6a
integrate AG comments
Slasnista Jul 28, 2017
5ab5035
own BaseEstimator class written + rflamary comments addressed
Slasnista Jul 31, 2017
c7eaaf4
update SinkhornTransport class + added test for class
Slasnista Aug 1, 2017
d5c6cc1
added EMDTransport Class from NG's code + added dedicated test
Slasnista Aug 1, 2017
cd4fa72
added test for fit_transform + correction of fit_transform bug (missi…
Slasnista Aug 4, 2017
0659abe
added new class SinkhornLpl1Transport() + dedicated test
Slasnista Aug 4, 2017
2005a09
added new class SinkhornL1l2Transport() + dedicated test
Slasnista Aug 4, 2017
4e562a1
semi supervised mode supported
Slasnista Aug 4, 2017
62b40a9
correction of semi supervised mode
Slasnista Aug 4, 2017
266abb6
reformat doc strings + remove useless log / verbose parameters for emd
Slasnista Aug 4, 2017
b8672f6
out of samples by Ferradans supported for transform and inverse_trans…
Slasnista Aug 4, 2017
117cd33
added new class MappingTransport to support linear and kernel mapping…
Slasnista Aug 4, 2017
d20a067
make doc strings compliant with numpy / modif according to AG review
Slasnista Aug 23, 2017
8d19d36
out of samples transform and inverse transform by batch
Slasnista Aug 23, 2017
c8ae584
test functions for MappingTransport Class
Slasnista Aug 23, 2017
fc58f39
added deprecation warning on old classes
Slasnista Aug 23, 2017
6167f34
solving log issues to avoid errors and adding further tests
Slasnista Aug 25, 2017
181fcd3
refactoring examples according to new DA classes
Slasnista Aug 25, 2017
e1a3984
small corrections for examples
Slasnista Aug 25, 2017
4f802cf
set properly path of data
Slasnista Aug 25, 2017
e1606c1
move no da objects into utils.py
Slasnista Aug 28, 2017
f79f483
handling input arguments in fit, transform... methods + remove old ex…
Slasnista Aug 28, 2017
84e56a0
check input parameters with helper functions
Slasnista Aug 28, 2017
5964001
update readme
Slasnista Aug 28, 2017
24362ec
Gromov-Wasserstein distance
Aug 28, 2017
f8744a3
gromov:flake8 and other
Aug 28, 2017
3730779
addressed AG comments + adding random seed
Slasnista Aug 29, 2017
5a9795f
pass on examples | introduced RandomState
Slasnista Aug 29, 2017
6ae3ad7
Changes to LP solver:
toto6 Aug 29, 2017
b562927
Fix param order
toto6 Aug 29, 2017
0f7cd92
Type print
toto6 Aug 29, 2017
ceeb063
Changes:
toto6 Aug 30, 2017
8875f65
Rename for emd and emd2
toto6 Aug 30, 2017
5076131
Fix name error
toto6 Aug 30, 2017
6d60230
Move normalize function in utils.py
toto6 Aug 30, 2017
93dee55
Move norm out of fit to init for deprecated OTDA
toto6 Aug 30, 2017
8c52517
Minor corrections suggested by @agramfort + new barycenter example + …
Aug 31, 2017
4ec5b33
minor corrections
Aug 31, 2017
ab6ed1d
docstrings and naming
Sep 1, 2017
64a5d3c
docstrings and naming
Sep 1, 2017
46fc12a
solving conflicts :/
Sep 1, 2017
f12322c
add barycenters to Readme.md
Sep 1, 2017
53e1115
docstrings + naming
Sep 1, 2017
8ea74ad
docstrings + naming
Sep 1, 2017
36bf599
Corrections on Gromov
Sep 12, 2017
24784ed
Corrections on Gromov
Sep 12, 2017
84c2723
Corrections on Gromov
Sep 12, 2017
55db350
Corrections on Gromov
Sep 12, 2017
5a2ebfa
Corrections on Gromov
ncourty Sep 13, 2017
7e5df4c
Merge branch 'gromov' of https://github.com/rflamary/POT into gromov
ncourty Sep 13, 2017
c86cc4f
Merge branch 'master' into gromov
ncourty Sep 13, 2017
c7eef9d
Merge branch 'master' into gromov
ncourty Sep 13, 2017
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
solving log issues to avoid errors and adding further tests
  • Loading branch information
Slasnista authored and Nicolas Courty committed Sep 1, 2017
commit 6167f34a721886d4b9038a8b1746a2c8c81132ce
57 changes: 42 additions & 15 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,10 @@ class SinkhornTransport(BaseTransport):

Attributes
----------
coupling_ : the optimal coupling
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
log_ : dictionary
The dictionary of log, empty dic if parameter log is not True

References
----------
Expand Down Expand Up @@ -1367,11 +1370,18 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)

# coupling estimation
self.coupling_ = sinkhorn(
returned_ = sinkhorn(
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
numItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)

# deal with the value of log
if self.log:
self.coupling_, self.log_ = returned_
else:
self.coupling_ = returned_
self.log_ = dict()

return self


Expand Down Expand Up @@ -1400,7 +1410,8 @@ class EMDTransport(BaseTransport):

Attributes
----------
coupling_ : the optimal coupling
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling

References
----------
Expand Down Expand Up @@ -1475,15 +1486,14 @@ class SinkhornLpl1Transport(BaseTransport):
The number of iteration in the inner loop
verbose : int, optional (default=0)
Controls the verbosity of the optimization algorithm
log : int, optional (default=0)
Controls the logs of the optimization algorithm
limit_max: float, optional (defaul=np.infty)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost

Attributes
----------
coupling_ : the optimal coupling
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling

References
----------
Expand All @@ -1500,7 +1510,7 @@ class SinkhornLpl1Transport(BaseTransport):

def __init__(self, reg_e=1., reg_cl=0.1,
max_iter=10, max_inner_iter=200,
tol=10e-9, verbose=False, log=False,
tol=10e-9, verbose=False,
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
Expand All @@ -1511,7 +1521,6 @@ def __init__(self, reg_e=1., reg_cl=0.1,
self.max_inner_iter = max_inner_iter
self.tol = tol
self.verbose = verbose
self.log = log
self.metric = metric
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
Expand Down Expand Up @@ -1544,7 +1553,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
verbose=self.verbose, log=self.log)
verbose=self.verbose)

return self

Expand Down Expand Up @@ -1584,7 +1593,10 @@ class SinkhornL1l2Transport(BaseTransport):

Attributes
----------
coupling_ : the optimal coupling
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
log_ : dictionary
The dictionary of log, empty dic if parameter log is not True

References
----------
Expand Down Expand Up @@ -1641,12 +1653,19 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):

super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)

self.coupling_ = sinkhorn_l1l2_gl(
returned_ = sinkhorn_l1l2_gl(
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
verbose=self.verbose, log=self.log)

# deal with the value of log
if self.log:
self.coupling_, self.log_ = returned_
else:
self.coupling_ = returned_
self.log_ = dict()

return self


Expand Down Expand Up @@ -1683,14 +1702,15 @@ class MappingTransport(BaseEstimator):

Attributes
----------
coupling_ : array-like, shape (n_source_samples, n_features)
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
mapping_ : array-like, shape (n_features (+ 1), n_features)
(if bias) for kernel == linear
The associated mapping

array-like, shape (n_source_samples (+ 1), n_features)
(if bias) for kernel == gaussian
log_ : dictionary
The dictionary of log, empty dic if parameter log is not True

References
----------
Expand Down Expand Up @@ -1745,19 +1765,26 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
self.Xt = Xt

if self.kernel == "linear":
self.coupling_, self.mapping_ = joint_OT_mapping_linear(
returned_ = joint_OT_mapping_linear(
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
verbose=self.verbose, verbose2=self.verbose2,
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
stopThr=self.tol, stopInnerThr=self.inner_tol, log=self.log)

elif self.kernel == "gaussian":
self.coupling_, self.mapping_ = joint_OT_mapping_kernel(
returned_ = joint_OT_mapping_kernel(
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
sigma=self.sigma, verbose=self.verbose, verbose2=self.verbose,
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
stopInnerThr=self.inner_tol, stopThr=self.tol, log=self.log)

# deal with the value of log
if self.log:
self.coupling_, self.mapping_, self.log_ = returned_
else:
self.coupling_, self.mapping_ = returned_
self.log_ = dict()

return self

def transform(self, Xs):
Expand Down
39 changes: 33 additions & 6 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def test_sinkhorn_lpl1_transport_class():

# test its computed
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(clf, "cost_")
assert hasattr(clf, "coupling_")

# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
Expand Down Expand Up @@ -89,6 +91,9 @@ def test_sinkhorn_l1l2_transport_class():

# test its computed
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(clf, "cost_")
assert hasattr(clf, "coupling_")
assert hasattr(clf, "log_")

# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
Expand Down Expand Up @@ -137,6 +142,11 @@ def test_sinkhorn_l1l2_transport_class():

assert n_unsup != n_semisup, "semisupervised mode not working"

# check everything runs well with log=True
clf = ot.da.SinkhornL1l2Transport(log=True)
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
assert len(clf.log_.keys()) != 0


def test_sinkhorn_transport_class():
"""test_sinkhorn_transport
Expand All @@ -152,6 +162,9 @@ def test_sinkhorn_transport_class():

# test its computed
clf.fit(Xs=Xs, Xt=Xt)
assert hasattr(clf, "cost_")
assert hasattr(clf, "coupling_")
assert hasattr(clf, "log_")

# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
Expand Down Expand Up @@ -200,6 +213,11 @@ def test_sinkhorn_transport_class():

assert n_unsup != n_semisup, "semisupervised mode not working"

# check everything runs well with log=True
clf = ot.da.SinkhornTransport(log=True)
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
assert len(clf.log_.keys()) != 0


def test_emd_transport_class():
"""test_sinkhorn_transport
Expand All @@ -215,6 +233,8 @@ def test_emd_transport_class():

# test its computed
clf.fit(Xs=Xs, Xt=Xt)
assert hasattr(clf, "cost_")
assert hasattr(clf, "coupling_")

# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
Expand Down Expand Up @@ -282,6 +302,9 @@ def test_mapping_transport_class():
# check computation and dimensions if bias == False
clf = ot.da.MappingTransport(kernel="linear", bias=False)
clf.fit(Xs=Xs, Xt=Xt)
assert hasattr(clf, "coupling_")
assert hasattr(clf, "mapping_")
assert hasattr(clf, "log_")

assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert_equal(clf.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
Expand Down Expand Up @@ -369,6 +392,11 @@ def test_mapping_transport_class():
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)

# check everything runs well with log=True
clf = ot.da.MappingTransport(kernel="gaussian", log=True)
clf.fit(Xs=Xs, Xt=Xt)
assert len(clf.log_.keys()) != 0


def test_otda():

Expand Down Expand Up @@ -434,9 +462,8 @@ def test_otda():

# if __name__ == "__main__":

# test_otda()
# test_sinkhorn_transport_class()
# test_emd_transport_class()
# test_sinkhorn_l1l2_transport_class()
# test_sinkhorn_lpl1_transport_class()
# test_mapping_transport_class()
# test_sinkhorn_transport_class()
# test_emd_transport_class()
# test_sinkhorn_l1l2_transport_class()
# test_sinkhorn_lpl1_transport_class()
# test_mapping_transport_class()