- Notifications
You must be signed in to change notification settings - Fork 534
Gromov-Wasserstein distance #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7ab9037
0a68bf4
3007f1d
bc68cc3
986f46d
e89f09d
f469205
fa36e77
aa19b6a
5ab5035
c7eaaf4
d5c6cc1
cd4fa72
0659abe
2005a09
4e562a1
62b40a9
266abb6
b8672f6
117cd33
d20a067
8d19d36
c8ae584
fc58f39
6167f34
181fcd3
e1a3984
4f802cf
e1606c1
f79f483
84e56a0
5964001
24362ec
f8744a3
3730779
5a9795f
6ae3ad7
b562927
0f7cd92
ceeb063
8875f65
5076131
6d60230
93dee55
8c52517
4ec5b33
ab6ed1d
64a5d3c
46fc12a
f12322c
53e1115
8ea74ad
36bf599
24784ed
84c2723
55db350
5a2ebfa
7e5df4c
c86cc4f
c7eef9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
| @@ -16,7 +16,7 @@ It provides the following solvers: | |
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. | ||
* Joint OT matrix and mapping estimation [8]. | ||
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt). | ||
| ||
* Gromov-Wasserstein distances [12] | ||
| ||
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. | ||
| ||
| @@ -182,3 +182,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t | |
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816. | ||
| ||
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063. | ||
| ||
[12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
==================== | ||
Gromov-Wasserstein example | ||
==================== | ||
| ||
| ||
This example is designed to show how to use the Gromov-Wassertsein distance | ||
computation in POT. | ||
| ||
| ||
""" | ||
| ||
# Author: Erwan Vautier <erwan.vautier@gmail.com> | ||
# Nicolas Courty <ncourty@irisa.fr> | ||
# | ||
# License: MIT License | ||
| ||
import scipy as sp | ||
import numpy as np | ||
| ||
import ot | ||
import matplotlib.pylab as pl | ||
| ||
from mpl_toolkits.mplot3d import Axes3D | ||
| ||
| ||
| ||
""" | ||
Sample two Gaussian distributions (2D and 3D) | ||
==================== | ||
| ||
| ||
The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. For | ||
demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces. | ||
| ||
""" | ||
n=30 # nb samples | ||
| ||
mu_s=np.array([0,0]) | ||
cov_s=np.array([[1,0],[0,1]]) | ||
| ||
mu_t=np.array([4,4,4]) | ||
cov_t=np.array([[1,0,0],[0,1,0],[0,0,1]]) | ||
| ||
| ||
| ||
xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) | ||
P=sp.linalg.sqrtm(cov_t) | ||
xt= np.random.randn(n,3).dot(P)+mu_t | ||
| ||
| ||
| ||
""" | ||
Plotting the distributions | ||
==================== | ||
""" | ||
fig=pl.figure() | ||
ax1=fig.add_subplot(121) | ||
ax1.plot(xs[:,0],xs[:,1],'+b',label='Source samples') | ||
ax2=fig.add_subplot(122,projection='3d') | ||
ax2.scatter(xt[:,0],xt[:,1],xt[:,2],color='r') | ||
pl.show() | ||
| ||
| ||
""" | ||
Compute distance kernels, normalize them and then display | ||
==================== | ||
""" | ||
| ||
C1=sp.spatial.distance.cdist(xs,xs) | ||
C2=sp.spatial.distance.cdist(xt,xt) | ||
| ||
C1/=C1.max() | ||
C2/=C2.max() | ||
| ||
pl.figure() | ||
pl.subplot(121) | ||
pl.imshow(C1) | ||
pl.subplot(122) | ||
pl.imshow(C2) | ||
pl.show() | ||
| ||
""" | ||
Compute Gromov-Wasserstein plans and distance | ||
==================== | ||
""" | ||
| ||
p=ot.unif(n) | ||
q=ot.unif(n) | ||
| ||
gw=ot.gromov_wasserstein(C1,C2,p,q,'square_loss',epsilon=5e-4) | ||
gw_dist=ot.gromov_wasserstein2(C1,C2,p,q,'square_loss',epsilon=5e-4) | ||
| ||
print('Gromov-Wasserstein distances between the distribution: '+str(gw_dist)) | ||
| ||
pl.figure() | ||
pl.imshow(gw,cmap='jet') | ||
pl.colorbar() | ||
pl.show() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and barycenters