Skip to content

Commit 2bc41ad

Browse files
committed
rng gpu
1 parent 4a45135 commit 2bc41ad

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

test/test_gpu.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@
1414
@pytest.mark.skipif(nogpu, reason="No GPU available")
1515
def test_gpu_sinkhorn():
1616

17-
np.random.seed(0)
17+
rng = np.random.RandomState(0)
1818

1919
def describe_res(r):
2020
print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(
2121
np.min(r), np.max(r), np.mean(r), np.std(r)))
2222

2323
for n_samples in [50, 100, 500, 1000]:
2424
print(n_samples)
25-
a = np.random.rand(n_samples // 4, 100)
26-
b = np.random.rand(n_samples, 100)
25+
a = rng.rand(n_samples // 4, 100)
26+
b = rng.rand(n_samples, 100)
2727
time1 = time.time()
2828
transport = ot.da.OTDA_sinkhorn()
2929
transport.fit(a, b)
@@ -43,17 +43,18 @@ def describe_res(r):
4343

4444
@pytest.mark.skipif(nogpu, reason="No GPU available")
4545
def test_gpu_sinkhorn_lpl1():
46-
np.random.seed(0)
46+
47+
rng = np.random.RandomState(0)
4748

4849
def describe_res(r):
4950
print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}"
5051
.format(np.min(r), np.max(r), np.mean(r), np.std(r)))
5152

5253
for n_samples in [50, 100, 500]:
5354
print(n_samples)
54-
a = np.random.rand(n_samples // 4, 100)
55+
a = rng.rand(n_samples // 4, 100)
5556
labels_a = np.random.randint(10, size=(n_samples // 4))
56-
b = np.random.rand(n_samples, 100)
57+
b = rng.rand(n_samples, 100)
5758
time1 = time.time()
5859
transport = ot.da.OTDA_lpl1()
5960
transport.fit(a, labels_a, b)

0 commit comments

Comments
 (0)