@@ -89,7 +89,9 @@ def test_sinkhorn_lpl1_transport_class(nx):
8989 # test its computed
9090 otda .fit (Xs = Xs , ys = ys , Xt = Xt )
9191 assert hasattr (otda , "cost_" )
92+ assert not np .any (np .isnan (nx .to_numpy (otda .cost_ ))), "cost is finite"
9293 assert hasattr (otda , "coupling_" )
94+ assert np .all (np .isfinite (nx .to_numpy (otda .coupling_ ))), "coupling is finite"
9395
9496 # test dimensions of coupling
9597 assert_equal (otda .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
@@ -148,7 +150,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
148150 n_semisup = nx .sum (otda_semi .cost_ )
149151
150152 # check that the cost matrix norms are indeed different
151- assert n_unsup != n_semisup , "semisupervised mode not working"
153+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
152154
153155 # check that the coupling forbids mass transport between labeled source
154156 # and labeled target samples
@@ -238,7 +240,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
238240 n_semisup = nx .sum (otda_semi .cost_ )
239241
240242 # check that the cost matrix norms are indeed different
241- assert n_unsup != n_semisup , "semisupervised mode not working"
243+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
242244
243245 # check that the coupling forbids mass transport between labeled source
244246 # and labeled target samples
@@ -331,7 +333,7 @@ def test_sinkhorn_transport_class(nx):
331333 n_semisup = nx .sum (otda_semi .cost_ )
332334
333335 # check that the cost matrix norms are indeed different
334- assert n_unsup != n_semisup , "semisupervised mode not working"
336+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
335337
336338 # check that the coupling forbids mass transport between labeled source
337339 # and labeled target samples
@@ -371,6 +373,10 @@ def test_unbalanced_sinkhorn_transport_class(nx):
371373 # test dimensions of coupling
372374 assert_equal (otda .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
373375 assert_equal (otda .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
376+ assert not np .any (np .isnan (nx .to_numpy (otda .cost_ ))), "cost is finite"
377+
378+ # test coupling
379+ assert np .all (np .isfinite (nx .to_numpy (otda .coupling_ ))), "coupling is finite"
374380
375381 # test transform
376382 transp_Xs = otda .transform (Xs = Xs )
@@ -409,19 +415,22 @@ def test_unbalanced_sinkhorn_transport_class(nx):
409415 # test unsupervised vs semi-supervised mode
410416 otda_unsup = ot .da .SinkhornTransport ()
411417 otda_unsup .fit (Xs = Xs , Xt = Xt )
418+ assert not np .any (np .isnan (nx .to_numpy (otda_unsup .cost_ ))), "cost is finite"
412419 n_unsup = nx .sum (otda_unsup .cost_ )
413420
414421 otda_semi = ot .da .SinkhornTransport ()
415422 otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
423+ assert not np .any (np .isnan (nx .to_numpy (otda_semi .cost_ ))), "cost is finite"
416424 assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
417425 n_semisup = nx .sum (otda_semi .cost_ )
418426
419427 # check that the cost matrix norms are indeed different
420- assert n_unsup != n_semisup , "semisupervised mode not working"
428+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
421429
422430 # check everything runs well with log=True
423431 otda = ot .da .SinkhornTransport (log = True )
424432 otda .fit (Xs = Xs , ys = ys , Xt = Xt )
433+ assert not np .any (np .isnan (nx .to_numpy (otda .cost_ ))), "cost is finite"
425434 assert len (otda .log_ .keys ()) != 0
426435
427436
@@ -448,7 +457,9 @@ def test_emd_transport_class(nx):
448457
449458 # test dimensions of coupling
450459 assert_equal (otda .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
460+ assert not np .any (np .isnan (nx .to_numpy (otda .cost_ ))), "cost is finite"
451461 assert_equal (otda .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
462+ assert np .all (np .isfinite (nx .to_numpy (otda .coupling_ ))), "coupling is finite"
452463
453464 # test margin constraints
454465 mu_s = unif (ns )
@@ -495,15 +506,22 @@ def test_emd_transport_class(nx):
495506 # test unsupervised vs semi-supervised mode
496507 otda_unsup = ot .da .EMDTransport ()
497508 otda_unsup .fit (Xs = Xs , ys = ys , Xt = Xt )
509+ assert_equal (otda_unsup .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
510+ assert not np .any (np .isnan (nx .to_numpy (otda_unsup .cost_ ))), "cost is finite"
511+ assert_equal (otda_unsup .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
512+ assert np .all (np .isfinite (nx .to_numpy (otda_unsup .coupling_ ))), "coupling is finite"
498513 n_unsup = nx .sum (otda_unsup .cost_ )
499514
500515 otda_semi = ot .da .EMDTransport ()
501516 otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
502517 assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
518+ assert not np .any (np .isnan (nx .to_numpy (otda_semi .cost_ ))), "cost is finite"
519+ assert_equal (otda_semi .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
520+ assert np .all (np .isfinite (nx .to_numpy (otda_semi .coupling_ ))), "coupling is finite"
503521 n_semisup = nx .sum (otda_semi .cost_ )
504522
505523 # check that the cost matrix norms are indeed different
506- assert n_unsup != n_semisup , "semisupervised mode not working"
524+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
507525
508526 # check that the coupling forbids mass transport between labeled source
509527 # and labeled target samples
0 commit comments