Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#### New features

- Added Bures Wasserstein distance in `ot.gaussian` (PR ##428)
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)
- New API for OT solver using function `ot.solve` (PR #388)
Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ API and modules
sliced
weak
factored
gaussian

.. autosummary::
:toctree: ../modules/generated/
Expand Down
4 changes: 3 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from . import weak
from . import factored
from . import solvers
from . import gaussian

# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
Expand All @@ -48,6 +49,7 @@
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve
from .gaussian import bures_wasserstein_distance

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand All @@ -63,4 +65,4 @@
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport', 'solve',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers']
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'bures_wasserstein_distance']
109 changes: 1 addition & 108 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
from .utils import list_to_array, check_params, BaseEstimator
from .unbalanced import sinkhorn_unbalanced
from .gaussian import OT_mapping_linear
from .optim import cg
from .optim import gcg

Expand Down Expand Up @@ -679,114 +680,6 @@ def df(G):
return G, L


def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
wt=None, bias=True, log=False):
r"""Return OT linear operator between samples.

The function estimates the optimal linear operator that aligns the two
empirical distributions. This is equivalent to estimating the closed
form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
:ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in
:ref:`[15] <references-OT-mapping-linear>`.

The linear operator from source to target :math:`M`

.. math::
M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}

where :

.. math::
\mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
\Sigma_s^{-1/2}

\mathbf{b} &= \mu_t - \mathbf{A} \mu_s

Parameters
----------
xs : array-like (ns,d)
samples in the source domain
xt : array-like (nt,d)
samples in the target domain
reg : float,optional
regularization added to the diagonals of covariances (>0)
ws : array-like (ns,1), optional
weights for the source samples
wt : array-like (ns,1), optional
weights for the target samples
bias: boolean, optional
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
log : bool, optional
record log if True


Returns
-------
A : (d, d) array-like
Linear operator
b : (1, d) array-like
bias
log : dict
log dictionary return only if log==True in parameters


.. _references-OT-mapping-linear:
References
----------
.. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984

.. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
Transport", 2018.


"""
xs, xt = list_to_array(xs, xt)
nx = get_backend(xs, xt)

d = xs.shape[1]

if bias:
mxs = nx.mean(xs, axis=0)[None, :]
mxt = nx.mean(xt, axis=0)[None, :]

xs = xs - mxs
xt = xt - mxt
else:
mxs = nx.zeros((1, d), type_as=xs)
mxt = nx.zeros((1, d), type_as=xs)

if ws is None:
ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]

if wt is None:
wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]

Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)

Cs12 = nx.sqrtm(Cs)
Cs_12 = nx.inv(Cs12)

M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))

A = dots(Cs_12, M0, Cs_12)

b = mxt - nx.dot(mxs, A)

if log:
log = {}
log['Cs'] = Cs
log['Ct'] = Ct
log['Cs12'] = Cs12
log['Cs_12'] = Cs_12
return A, b, log
else:
return A, b


def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,
numItermax=100, stopThr=1e-9, numInnerItermax=100000,
stopInnerThr=1e-9, log=False, verbose=False):
Expand Down
28 changes: 28 additions & 0 deletions ot/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,34 @@ def get_1D_gauss(n, m, sigma):
return make_1D_gauss(n, m, sigma)


def make_1D_samples_gauss(n, m, sigma, random_state=None):
r"""Return `n` samples drawn from 1D gaussian :math:`\mathcal{N}(m, \sigma)`

Parameters
----------
n : int
number of samples to make
m : float
mean value of the gaussian distribution
sigma : float
std of the gaussian distribution
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.

Returns
-------
X : ndarray, shape (`n`, 1)
n samples drawn from :math:`\mathcal{N}(m, \sigma)`.
"""

generator = check_random_state(random_state)
res = generator.randn(n, 1) * np.sqrt(sigma) + m
return res


def make_2D_samples_gauss(n, m, sigma, random_state=None):
r"""Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)`

Expand Down
Loading