- Notifications
You must be signed in to change notification settings - Fork 538
Closed
Labels
Description
Describe the bug
emd2_1d errors when not using the sped-up distribution metrics, e.g. cosine, yule,
To Reproduce
Steps to reproduce the behavior:
Simple test case adapted from the 1d example code:
import numpy as np import matplotlib.pylab as pl import ot import ot.plot from ot.datasets import make_1D_gauss as gauss ############################################################################## # Generate data # ------------- #%% parameters n = 100 # nb bins # bin positions x = np.arange(n, dtype=np.float64) # Gaussian distributions a = gauss(n, m=20, s=5) # m= mean, s= std b = gauss(n, m=60, s=10) # use fast 1D solver G0 = ot.emd_1d(x, x, a, b, metric="cosine") 54 G0 = ot.emd_1d(x, x, a, b, metric="cosine") 55 56 # Equivalent to ~/miniconda3/envs/ms-gen/lib/python3.8/site-packages/ot/lp/solver_1d.py in emd_1d(x_a, x_b, a, b, metric, p, dense, log, check_marginals) 257 perm_b = nx.argsort(x_b_1d) 258 --> 259 G_sorted, indices, cost = emd_1d_sorted( 260 nx.to_numpy(a[perm_a]).astype(np.float64), 261 nx.to_numpy(b[perm_b]).astype(np.float64), ot/lp/emd_wrap.pyx in ot.lp.emd_wrap.emd_1d_sorted() AttributeError: 'float' object has no attribute 'reshape' Expected behavior
Should return a value, but instead errors (can't tell if math is yet correct)
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Linux
- Python version: 3.8.18
- How was POT installed (source,
pip,conda): pip - Build command you used (if compiling from source): pip install POT
Output of the following code snippet:
import platform; print(platform.platform()) import sys; print("Python", sys.version) import numpy; print("NumPy", numpy.__version__) import scipy; print("SciPy", scipy.__version__): import ot; print("POT", ot.__version__)Linux-5.15.0-117-generic-x86_64-with-glibc2.10
Python 3.8.18 | packaged by conda-forge | (default, Oct 10 2023, 15:44:36)
[GCC 12.3.0]
NumPy 1.24.4
SciPy 1.10.1
POT 0.9.4