Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9f51c14
example for log treatment in bregman.py
AdrienCorenflos May 8, 2020
a07330c
Improve doc
AdrienCorenflos Jul 14, 2020
d3292a8
Merge remote-tracking branch 'origin/master'
AdrienCorenflos Jul 14, 2020
dfa2c9d
Revert "example for log treatment in bregman.py"
AdrienCorenflos Jul 14, 2020
36377cc
Add comments by Flamary
AdrienCorenflos Jul 20, 2020
110f382
Delete repetitive description
AdrienCorenflos Jul 20, 2020
cbf6bf5
Added raw string to avoid pbs with backslashes
AdrienCorenflos Jul 20, 2020
22e7f6b
Implements sliced wasserstein
AdrienCorenflos Jul 20, 2020
7beac55
Merge branch 'master' into sliced_wasserstein
rflamary Jul 20, 2020
ba04ed6
Changed formatting of string for py3.5 support
AdrienCorenflos Jul 20, 2020
391df18
Merge remote-tracking branch 'origin/sliced_wasserstein' into sliced_…
AdrienCorenflos Jul 20, 2020
ca8364c
Docstest, expected 0.0 and not 0.
AdrienCorenflos Jul 20, 2020
2d893f2
Adressed comments by @rflamary
AdrienCorenflos Aug 4, 2020
7d9b920
No 3d plot here
AdrienCorenflos Aug 4, 2020
b68e2c2
add sliced to the docs
AdrienCorenflos Aug 4, 2020
a1309da
Merge branch 'master' into sliced_wasserstein
rflamary Aug 25, 2020
5c5c589
Merge branch 'master' into sliced_wasserstein
rflamary Aug 31, 2020
abeba45
Merge remote-tracking branch 'upstream/master' into sliced_wasserstein
AdrienCorenflos Aug 31, 2020
9a8edb5
Incorporate comments by @rflamary
AdrienCorenflos Aug 31, 2020
64fc3e1
Merge remote-tracking branch 'origin/sliced_wasserstein' into sliced_…
AdrienCorenflos Aug 31, 2020
5590a79
add link to pdf
rflamary Sep 4, 2020
1a718b2
Merge branch 'master' into sliced_wasserstein
rflamary Oct 22, 2020
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Incorporate comments by @rflamary
  • Loading branch information
AdrienCorenflos committed Aug 31, 2020
commit 9a8edb56461b8f89bd45fe20814ca9605d8d7720
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,3 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.

[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45

[32] S. Kolouri et al., Generalized Sliced Wasserstein Distances, Advances in Neural Information Processing Systems (NIPS) 33, 2019
10 changes: 5 additions & 5 deletions ot/sliced.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_random_projections(n_projections, d, seed=None):
def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):
r"""
Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance

.. math::
\mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}}

Expand Down Expand Up @@ -98,7 +99,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed
----------

.. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
.. [32] S. Kolouri et al., Generalized Sliced Wasserstein Distances, Advances in Neural Information Processing Systems (NIPS) 33, 2019
"""
from .lp import emd2_1d

Expand Down Expand Up @@ -126,19 +126,19 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed
X_t_projections = np.dot(projections, X_t.T)

if log:
projected_emd = []
projected_emd = np.empty(n_projections)
else:
projected_emd = None

res = 0.

for X_s_proj, X_t_proj in zip(X_s_projections, X_t_projections):
for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)):
emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False)
if projected_emd is not None:
projected_emd.append(emd)
projected_emd[i] = emd
res += emd

res = (res / n_projections) ** 0.5
if log:
return res, {"projections": projections.tolist(), "projected_emds": projected_emd}
return res, {"projections": projections, "projected_emds": projected_emd}
return res