@@ -459,6 +459,12 @@ def test_entropic_proximal_gromov(nx):
459459
460460 C1b , C2b , pb , qb , G0b = nx .from_numpy (C1 , C2 , p , q , G0 )
461461
462+ with pytest .raises (ValueError ):
463+ loss_fun = 'weird_loss_fun'
464+ G , log = ot .gromov .entropic_gromov_wasserstein (
465+ C1 , C2 , None , q , loss_fun , symmetric = None , G0 = G0 ,
466+ epsilon = 1e-1 , max_iter = 50 , solver = 'PPA' , verbose = True , log = True , numItermax = 1 )
467+
462468 G , log = ot .gromov .entropic_gromov_wasserstein (
463469 C1 , C2 , None , q , 'square_loss' , symmetric = None , G0 = G0 ,
464470 epsilon = 1e-1 , max_iter = 50 , solver = 'PPA' , verbose = True , log = True , numItermax = 1 )
@@ -606,6 +612,12 @@ def test_entropic_fgw(nx):
606612
607613 Mb , C1b , C2b , pb , qb , G0b = nx .from_numpy (M , C1 , C2 , p , q , G0 )
608614
615+ with pytest .raises (ValueError ):
616+ loss_fun = 'weird_loss_fun'
617+ G , log = ot .gromov .entropic_fused_gromov_wasserstein (
618+ M , C1 , C2 , None , None , loss_fun , symmetric = None , G0 = G0 ,
619+ epsilon = 1e-1 , max_iter = 10 , verbose = True , log = True )
620+
609621 G , log = ot .gromov .entropic_fused_gromov_wasserstein (
610622 M , C1 , C2 , None , None , 'square_loss' , symmetric = None , G0 = G0 ,
611623 epsilon = 1e-1 , max_iter = 10 , verbose = True , log = True )
@@ -812,20 +824,28 @@ def test_entropic_fgw_barycenter(nx):
812824 C2 = ot .dist (Xt )
813825 p1 = ot .unif (ns )
814826 p2 = ot .unif (nt )
815- n_samples = 2
827+ n_samples = 3
816828 p = ot .unif (n_samples )
817829
818830 ysb , ytb , C1b , C2b , p1b , p2b , pb = nx .from_numpy (ys , yt , C1 , C2 , p1 , p2 , p )
819831
832+ with pytest .raises (ValueError ):
833+ loss_fun = 'weird_loss_fun'
834+ X , C , log = ot .gromov .entropic_fused_gromov_barycenters (
835+ n_samples , [ys , yt ], [C1 , C2 ], None , p , [.5 , .5 ], loss_fun , 0.1 ,
836+ max_iter = 10 , tol = 1e-3 , verbose = True , warmstartT = True , random_state = 42 ,
837+ solver = 'PPA' , numItermax = 10 , log = True
838+ )
839+
820840 X , C , log = ot .gromov .entropic_fused_gromov_barycenters (
821841 n_samples , [ys , yt ], [C1 , C2 ], None , p , [.5 , .5 ], 'square_loss' , 0.1 ,
822842 max_iter = 10 , tol = 1e-3 , verbose = True , warmstartT = True , random_state = 42 ,
823- solver = 'PPA' , numItermax = 1 , log = True
843+ solver = 'PPA' , numItermax = 10 , log = True
824844 )
825845 Xb , Cb = ot .gromov .entropic_fused_gromov_barycenters (
826846 n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], None , [.5 , .5 ], 'square_loss' , 0.1 ,
827847 max_iter = 10 , tol = 1e-3 , verbose = False , warmstartT = True , random_state = 42 ,
828- solver = 'PPA' , numItermax = 1 , log = False )
848+ solver = 'PPA' , numItermax = 10 , log = False )
829849 Xb , Cb = nx .to_numpy (Xb , Cb )
830850
831851 np .testing .assert_allclose (C , Cb , atol = 1e-06 )
@@ -1052,6 +1072,13 @@ def test_gromov_entropic_barycenter(nx):
10521072
10531073 C1b , C2b , p1b , p2b , pb = nx .from_numpy (C1 , C2 , p1 , p2 , p )
10541074
1075+ with pytest .raises (ValueError ):
1076+ loss_fun = 'weird_loss_fun'
1077+ Cb = ot .gromov .entropic_gromov_barycenters (
1078+ n_samples , [C1 , C2 ], None , p , [.5 , .5 ], loss_fun , 1e-3 ,
1079+ max_iter = 10 , tol = 1e-3 , verbose = True , warmstartT = True , random_state = 42
1080+ )
1081+
10551082 Cb = ot .gromov .entropic_gromov_barycenters (
10561083 n_samples , [C1 , C2 ], None , p , [.5 , .5 ], 'square_loss' , 1e-3 ,
10571084 max_iter = 10 , tol = 1e-3 , verbose = True , warmstartT = True , random_state = 42
0 commit comments