Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix test?
  • Loading branch information
tgnassou committed Jan 13, 2023
commit f5fd6f4335cd8f6c0d6f7c01c607cf4bbbdbb4f3
4 changes: 2 additions & 2 deletions test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# License: MIT License

import numpy as np
import pytest

import ot
from ot.datasets import make_data_classif
Expand All @@ -31,7 +30,8 @@ def test_linear_mapping(nx):


def test_bures_wasserstein_distance(nx):
ms, mt, Cs, Ct = [0], [10], [[1]], [[1]]
ms, mt = np.array([0]), np.array([10])
Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32)
msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct)
Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb)

Expand Down