Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The contributors to this library are:
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)

## Acknowledgments

Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#### New features

- Added Bures Wasserstein distance in `ot.gaussian` (PR ##428)
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)
- New API for OT solver using function `ot.solve` (PR #388)
Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ API and modules
sliced
weak
factored
gaussian

.. autosummary::
:toctree: ../modules/generated/
Expand Down
6 changes: 3 additions & 3 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ distributions. In this case there exists a close form solution given in Remark
2.29 in [15]_ and the Monge mapping is an affine function and can be
also computed from the covariances and means of the source and target
distributions. In the case when the finite sample dataset is supposed Gaussian,
we provide :any:`ot.da.OT_mapping_linear` that returns the parameters for the
we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the
Monge mapping.


Expand Down Expand Up @@ -628,7 +628,7 @@ approximate a Monge mapping from finite distributions.
First note that when the source and target distributions are supposed to be Gaussian
distributions, there exists a close form solution for the mapping and its an
affine function [14]_ of the form :math:`T(x)=Ax+b` . In this case we provide the function
:any:`ot.da.OT_mapping_linear` that returns the operator :math:`A` and vector
:any:`ot.gaussian.bures_wasserstein_mapping` that returns the operator :math:`A` and vector
:math:`b`. Note that if the number of samples is too small there is a parameter
:code:`reg` that provides a regularization for the covariance matrix estimation.

Expand All @@ -640,7 +640,7 @@ method proposed in [8]_ that estimates a continuous mapping approximating the
barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for
linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping.

.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.da.OT_mapping_linear
.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.gaussian.bures_wasserstein_mapping
:add-heading: Examples of Monge mapping estimation
:heading-level: "

Expand Down
2 changes: 1 addition & 1 deletion examples/domain-adaptation/plot_otda_linear_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
# Estimate linear mapping and transport
# -------------------------------------

Ae, be = ot.da.OT_mapping_linear(xs, xt)
Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt)

xst = xs.dot(Ae) + be

Expand Down
2 changes: 1 addition & 1 deletion examples/gromov/plot_barycenter_fgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
# -------------------------

#%% Create the barycenter
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
bary = nx.from_numpy_array(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
for i, v in enumerate(A.ravel()):
bary.add_node(i, attr_name=v)

Expand Down
3 changes: 2 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from . import weak
from . import factored
from . import solvers
from . import gaussian

# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
Expand All @@ -56,7 +57,7 @@

__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'emd2_1d', 'wasserstein_1d', 'backend',
'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
Expand Down
119 changes: 10 additions & 109 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
from .utils import list_to_array, check_params, BaseEstimator
from .utils import list_to_array, check_params, BaseEstimator, deprecated
from .unbalanced import sinkhorn_unbalanced
from .gaussian import empirical_bures_wasserstein_mapping
from .optim import cg
from .optim import gcg

Expand Down Expand Up @@ -679,112 +680,12 @@ def df(G):
return G, L


@deprecated()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is enough since teh API did not change, also you did not apss the parameter to the function so it was false

Suggested change
@deprecated()
OT_mapping_linear=deprecated(empirical_bures_wasserstein_mapping)
def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
wt=None, bias=True, log=False):
r"""Return OT linear operator between samples.

The function estimates the optimal linear operator that aligns the two
empirical distributions. This is equivalent to estimating the closed
form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
:ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in
:ref:`[15] <references-OT-mapping-linear>`.

The linear operator from source to target :math:`M`

.. math::
M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}

where :

.. math::
\mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
\Sigma_s^{-1/2}

\mathbf{b} &= \mu_t - \mathbf{A} \mu_s

Parameters
----------
xs : array-like (ns,d)
samples in the source domain
xt : array-like (nt,d)
samples in the target domain
reg : float,optional
regularization added to the diagonals of covariances (>0)
ws : array-like (ns,1), optional
weights for the source samples
wt : array-like (ns,1), optional
weights for the target samples
bias: boolean, optional
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
log : bool, optional
record log if True


Returns
-------
A : (d, d) array-like
Linear operator
b : (1, d) array-like
bias
log : dict
log dictionary return only if log==True in parameters


.. _references-OT-mapping-linear:
References
----------
.. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984

.. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
Transport", 2018.


"""
xs, xt = list_to_array(xs, xt)
nx = get_backend(xs, xt)

d = xs.shape[1]

if bias:
mxs = nx.mean(xs, axis=0)[None, :]
mxt = nx.mean(xt, axis=0)[None, :]

xs = xs - mxs
xt = xt - mxt
else:
mxs = nx.zeros((1, d), type_as=xs)
mxt = nx.zeros((1, d), type_as=xs)

if ws is None:
ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]

if wt is None:
wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]

Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)

Cs12 = nx.sqrtm(Cs)
Cs_12 = nx.inv(Cs12)

M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))

A = dots(Cs_12, M0, Cs_12)

b = mxt - nx.dot(mxs, A)

if log:
log = {}
log['Cs'] = Cs
log['Ct'] = Ct
log['Cs12'] = Cs12
log['Cs_12'] = Cs_12
return A, b, log
else:
return A, b
""" Deprecated see ot.gaussian.empirical_bures_wasserstein_mapping"""
return empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None,
wt=None, bias=True, log=False)


def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,
Expand Down Expand Up @@ -1378,10 +1279,10 @@ class label
self.mu_t = self.distribution_estimation(Xt)

# coupling estimation
returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
ws=nx.reshape(self.mu_s, (-1, 1)),
wt=nx.reshape(self.mu_t, (-1, 1)),
bias=self.bias, log=self.log)
returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg,
ws=nx.reshape(self.mu_s, (-1, 1)),
wt=nx.reshape(self.mu_t, (-1, 1)),
bias=self.bias, log=self.log)

# deal with the value of log
if self.log:
Expand Down
Loading