Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
7ab9037
Gromov-Wasserstein distance
Aug 28, 2017
0a68bf4
gromov:flake8 and other
Aug 28, 2017
3007f1d
Minor corrections suggested by @agramfort + new barycenter example + …
Aug 31, 2017
bc68cc3
minor corrections
Aug 31, 2017
986f46d
Merge branch 'master' into gromov
ncourty Aug 31, 2017
e89f09d
remove linewidth error message
Slasnista Jul 28, 2017
f469205
first proposal for OT wrappers
Slasnista Jul 28, 2017
fa36e77
small modifs according to NG proposals
Slasnista Jul 28, 2017
aa19b6a
integrate AG comments
Slasnista Jul 28, 2017
5ab5035
own BaseEstimator class written + rflamary comments addressed
Slasnista Jul 31, 2017
c7eaaf4
update SinkhornTransport class + added test for class
Slasnista Aug 1, 2017
d5c6cc1
added EMDTransport Class from NG's code + added dedicated test
Slasnista Aug 1, 2017
cd4fa72
added test for fit_transform + correction of fit_transform bug (missi…
Slasnista Aug 4, 2017
0659abe
added new class SinkhornLpl1Transport() + dedicated test
Slasnista Aug 4, 2017
2005a09
added new class SinkhornL1l2Transport() + dedicated test
Slasnista Aug 4, 2017
4e562a1
semi supervised mode supported
Slasnista Aug 4, 2017
62b40a9
correction of semi supervised mode
Slasnista Aug 4, 2017
266abb6
reformat doc strings + remove useless log / verbose parameters for emd
Slasnista Aug 4, 2017
b8672f6
out of samples by Ferradans supported for transform and inverse_trans…
Slasnista Aug 4, 2017
117cd33
added new class MappingTransport to support linear and kernel mapping…
Slasnista Aug 4, 2017
d20a067
make doc strings compliant with numpy / modif according to AG review
Slasnista Aug 23, 2017
8d19d36
out of samples transform and inverse transform by batch
Slasnista Aug 23, 2017
c8ae584
test functions for MappingTransport Class
Slasnista Aug 23, 2017
fc58f39
added deprecation warning on old classes
Slasnista Aug 23, 2017
6167f34
solving log issues to avoid errors and adding further tests
Slasnista Aug 25, 2017
181fcd3
refactoring examples according to new DA classes
Slasnista Aug 25, 2017
e1a3984
small corrections for examples
Slasnista Aug 25, 2017
4f802cf
set properly path of data
Slasnista Aug 25, 2017
e1606c1
move no da objects into utils.py
Slasnista Aug 28, 2017
f79f483
handling input arguments in fit, transform... methods + remove old ex…
Slasnista Aug 28, 2017
84e56a0
check input parameters with helper functions
Slasnista Aug 28, 2017
5964001
update readme
Slasnista Aug 28, 2017
24362ec
Gromov-Wasserstein distance
Aug 28, 2017
f8744a3
gromov:flake8 and other
Aug 28, 2017
3730779
addressed AG comments + adding random seed
Slasnista Aug 29, 2017
5a9795f
pass on examples | introduced RandomState
Slasnista Aug 29, 2017
6ae3ad7
Changes to LP solver:
toto6 Aug 29, 2017
b562927
Fix param order
toto6 Aug 29, 2017
0f7cd92
Type print
toto6 Aug 29, 2017
ceeb063
Changes:
toto6 Aug 30, 2017
8875f65
Rename for emd and emd2
toto6 Aug 30, 2017
5076131
Fix name error
toto6 Aug 30, 2017
6d60230
Move normalize function in utils.py
toto6 Aug 30, 2017
93dee55
Move norm out of fit to init for deprecated OTDA
toto6 Aug 30, 2017
8c52517
Minor corrections suggested by @agramfort + new barycenter example + …
Aug 31, 2017
4ec5b33
minor corrections
Aug 31, 2017
ab6ed1d
docstrings and naming
Sep 1, 2017
64a5d3c
docstrings and naming
Sep 1, 2017
46fc12a
solving conflicts :/
Sep 1, 2017
f12322c
add barycenters to Readme.md
Sep 1, 2017
53e1115
docstrings + naming
Sep 1, 2017
8ea74ad
docstrings + naming
Sep 1, 2017
36bf599
Corrections on Gromov
Sep 12, 2017
24784ed
Corrections on Gromov
Sep 12, 2017
84c2723
Corrections on Gromov
Sep 12, 2017
55db350
Corrections on Gromov
Sep 12, 2017
5a2ebfa
Corrections on Gromov
ncourty Sep 13, 2017
7e5df4c
Merge branch 'gromov' of https://github.com/rflamary/POT into gromov
ncourty Sep 13, 2017
c86cc4f
Merge branch 'master' into gromov
ncourty Sep 13, 2017
c7eef9d
Merge branch 'master' into gromov
ncourty Sep 13, 2017
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added new class SinkhornL1l2Transport() + dedicated test
  • Loading branch information
Slasnista authored and Nicolas Courty committed Sep 1, 2017
commit 2005a09548a6f6d42cd9aafadbb4583e4029936c
109 changes: 109 additions & 0 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,10 @@ class SinkhornLpl1Transport(BaseTransport):

Parameters
----------
reg_e : float, optional (default=1)
Entropic regularization parameter
reg_cl : float, optional (default=0.1)
Class regularization parameter
mode : string, optional (default="unsupervised")
The DA mode. If "unsupervised" no target labels are taken into account
to modify the cost matrix. If "semisupervised" the target labels
Expand All @@ -1384,6 +1388,11 @@ class SinkhornLpl1Transport(BaseTransport):
The ground metric for the Wasserstein problem
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
algorithm if no it has not converged
max_inner_iter : int, float, optional (default=200)
The number of iteration in the inner loop
verbose : int, optional (default=0)
Controls the verbosity of the optimization algorithm
log : int, optional (default=0)
Expand Down Expand Up @@ -1452,3 +1461,103 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
verbose=self.verbose, log=self.log)

return self


class SinkhornL1l2Transport(BaseTransport):
"""Domain Adapatation OT method based on sinkhorn algorithm +
l1l2 class regularization.

Parameters
----------
reg_e : float, optional (default=1)
Entropic regularization parameter
reg_cl : float, optional (default=0.1)
Class regularization parameter
mode : string, optional (default="unsupervised")
The DA mode. If "unsupervised" no target labels are taken into account
to modify the cost matrix. If "semisupervised" the target labels
are taken into account to set coefficients of the pairwise distance
matrix to 0 for row and columns indices that correspond to source and
target samples which share the same labels.
mapping : string, optional (default="barycentric")
The kind of mapping to apply to transport samples from a domain into
another one.
if "barycentric" only the samples used to estimate the coupling can
be transported from a domain to another one.
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
algorithm if no it has not converged
max_inner_iter : int, float, optional (default=200)
The number of iteration in the inner loop
verbose : int, optional (default=0)
Controls the verbosity of the optimization algorithm
log : int, optional (default=0)
Controls the logs of the optimization algorithm
Attributes
----------
Coupling_ : the optimal coupling

References
----------

.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE
Transactions on Pattern Analysis and Machine Intelligence ,
vol.PP, no.99, pp.1-1
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.

"""

def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
max_iter=10, max_inner_iter=200,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans'):

self.reg_e = reg_e
self.reg_cl = reg_cl
self.mode = mode
self.max_iter = max_iter
self.max_inner_iter = max_inner_iter
self.tol = tol
self.verbose = verbose
self.log = log
self.metric = metric
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map

def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
(Xs, ys) and (Xt, yt)
Parameters
----------
Xs : array-like of shape = [n_source_samples, n_features]
The training input samples.
ys : array-like, shape = [n_source_samples]
The class labels
Xt : array-like of shape = [n_target_samples, n_features]
The training input samples.
yt : array-like, shape = [n_labeled_target_samples]
The class labels
Returns
-------
self : object
Returns self.
"""

super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)

self.Coupling_ = sinkhorn_l1l2_gl(
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.Cost,
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
verbose=self.verbose, log=self.log)

return self
50 changes: 50 additions & 0 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,56 @@ def test_sinkhorn_lpl1_transport_class():
assert_equal(transp_Xs.shape, Xs.shape)


def test_sinkhorn_l1l2_transport_class():
"""test_sinkhorn_transport
"""

ns = 150
nt = 200

Xs, ys = get_data_classif('3gauss', ns)
Xt, yt = get_data_classif('3gauss2', nt)

clf = ot.da.SinkhornL1l2Transport()

# test its computed
clf.fit(Xs=Xs, ys=ys, Xt=Xt)

# test dimensions of coupling
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))

# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)

# test transform
transp_Xs = clf.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)

Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)

# check that the oos method is not working
assert_equal(transp_Xs_new, Xs_new)

# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)

Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)

# check that the oos method is not working and returns the input data
assert_equal(transp_Xt_new, Xt_new)

# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)


def test_sinkhorn_transport_class():
"""test_sinkhorn_transport
"""
Expand Down