@@ -120,7 +120,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
120120
121121 nx = get_backend (a , b , M )
122122
123- if nx .sum (a ) > 1 or nx .sum (b ) > 1 :
123+ if nx .sum (a ) > 1 + 1e-15 or nx .sum (b ) > 1 + 1e-15 : # 1e-15 for numerical errors
124124 raise ValueError ("Problem infeasible. Check that a and b are in the "
125125 "simplex" )
126126
@@ -270,36 +270,43 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
270270
271271 nx = get_backend (a , b , M )
272272
273+ dim_a , dim_b = M .shape
274+ if len (a ) == 0 :
275+ a = nx .ones (dim_a , type_as = a ) / dim_a
276+ if len (b ) == 0 :
277+ b = nx .ones (dim_b , type_as = b ) / dim_b
278+
273279 if m is None :
274280 return partial_wasserstein_lagrange (a , b , M , log = log , ** kwargs )
275281 elif m < 0 :
276282 raise ValueError ("Problem infeasible. Parameter m should be greater"
277283 " than 0." )
278- elif m > nx .min (( nx .sum (a ), nx .sum (b ))):
284+ elif m > nx .min (nx . stack (( nx .sum (a ), nx .sum (b ) ))):
279285 raise ValueError ("Problem infeasible. Parameter m should lower or"
280286 " equal than min(|a|_1, |b|_1)." )
281287
282- a0 , b0 , M0 = a , b , M
283- # convert to humpy
284- a , b , M = nx .to_numpy (a , b , M )
285-
286- b_extended = np .append (b , [(np .sum (a ) - m ) / nb_dummies ] * nb_dummies )
287- a_extended = np .append (a , [(np .sum (b ) - m ) / nb_dummies ] * nb_dummies )
288- M_extended = np .zeros ((len (a_extended ), len (b_extended )))
289- M_extended [- nb_dummies :, - nb_dummies :] = np .max (M ) * 2
290- M_extended [:len (a ), :len (b )] = M
288+ b_extension = nx .ones (nb_dummies , type_as = b ) * (nx .sum (a ) - m ) / nb_dummies
289+ b_extended = nx .concatenate ((b , b_extension ))
290+ a_extension = nx .ones (nb_dummies , type_as = a ) * (nx .sum (b ) - m ) / nb_dummies
291+ a_extended = nx .concatenate ((a , a_extension ))
292+ M_extension = nx .ones ((nb_dummies , nb_dummies ), type_as = M ) * nx .max (M ) * 2
293+ M_extended = nx .concatenate (
294+ (nx .concatenate ((M , nx .zeros ((M .shape [0 ], M_extension .shape [1 ]))), axis = 1 ),
295+ nx .concatenate ((nx .zeros ((M_extension .shape [0 ], M .shape [1 ])), M_extension ), axis = 1 )),
296+ axis = 0
297+ )
291298
292299 gamma , log_emd = emd (a_extended , b_extended , M_extended , log = True ,
293300 ** kwargs )
294301
295- gamma = nx . from_numpy ( gamma [:len (a ), :len (b )], type_as = M )
302+ gamma = gamma [:len (a ), :len (b )]
296303
297304 if log_emd ['warning' ] is not None :
298305 raise ValueError ("Error in the EMD resolution: try to increase the"
299306 " number of dummy points" )
300- log_emd ['partial_w_dist' ] = nx .sum (M0 * gamma )
301- log_emd ['u' ] = nx . from_numpy ( log_emd ['u' ][:len (a )], type_as = a0 )
302- log_emd ['v' ] = nx . from_numpy ( log_emd ['v' ][:len (b )], type_as = b0 )
307+ log_emd ['partial_w_dist' ] = nx .sum (M * gamma )
308+ log_emd ['u' ] = log_emd ['u' ][:len (a )]
309+ log_emd ['v' ] = log_emd ['v' ][:len (b )]
303310
304311 if log :
305312 return gamma , log_emd
@@ -389,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
389396 NeurIPS.
390397 """
391398
399+ a , b , M = list_to_array (a , b , M )
400+
401+ nx = get_backend (a , b , M )
402+
392403 partial_gw , log_w = partial_wasserstein (a , b , M , m , nb_dummies , log = True ,
393404 ** kwargs )
394405 log_w ['T' ] = partial_gw
395406
396407 if log :
397- return np .sum (partial_gw * M ), log_w
408+ return nx .sum (partial_gw * M ), log_w
398409 else :
399- return np .sum (partial_gw * M )
410+ return nx .sum (partial_gw * M )
400411
401412
402413def gwgrad_partial (C1 , C2 , T ):
@@ -838,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
838849 ot.partial.partial_wasserstein: exact Partial Wasserstein
839850 """
840851
841- a = np . asarray (a , dtype = np . float64 )
842- b = np . asarray ( b , dtype = np . float64 )
843- M = np . asarray ( M , dtype = np . float64 )
852+ a , b , M = list_to_array (a , b , M )
853+
854+ nx = get_backend ( a , b , M )
844855
845856 dim_a , dim_b = M .shape
846- dx = np .ones (dim_a , dtype = np . float64 )
847- dy = np .ones (dim_b , dtype = np . float64 )
857+ dx = nx .ones (dim_a , type_as = a )
858+ dy = nx .ones (dim_b , type_as = b )
848859
849860 if len (a ) == 0 :
850- a = np .ones (dim_a , dtype = np . float64 ) / dim_a
861+ a = nx .ones (dim_a , type_as = a ) / dim_a
851862 if len (b ) == 0 :
852- b = np .ones (dim_b , dtype = np . float64 ) / dim_b
863+ b = nx .ones (dim_b , type_as = b ) / dim_b
853864
854865 if m is None :
855- m = np .min (( np .sum (a ), np .sum (b ))) * 1.0
866+ m = nx .min (nx . stack (( nx .sum (a ), nx .sum (b ) ))) * 1.0
856867 if m < 0 :
857868 raise ValueError ("Problem infeasible. Parameter m should be greater"
858869 " than 0." )
859- if m > np .min (( np .sum (a ), np .sum (b ))):
870+ if m > nx .min (nx . stack (( nx .sum (a ), nx .sum (b ) ))):
860871 raise ValueError ("Problem infeasible. Parameter m should lower or"
861872 " equal than min(|a|_1, |b|_1)." )
862873
863874 log_e = {'err' : []}
864875
865- # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute
866- K = np .empty (M .shape , dtype = M .dtype )
867- np .divide (M , - reg , out = K )
868- np .exp (K , out = K )
869- np .multiply (K , m / np .sum (K ), out = K )
876+ if type (a ) == type (b ) == type (M ) == np .ndarray :
877+ # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
878+ K = np .empty (M .shape , dtype = M .dtype )
879+ np .divide (M , - reg , out = K )
880+ np .exp (K , out = K )
881+ np .multiply (K , m / np .sum (K ), out = K )
882+ else :
883+ K = nx .exp (- M / reg )
884+ K = K * m / nx .sum (K )
870885
871886 err , cpt = 1 , 0
872- q1 = np .ones (K .shape )
873- q2 = np .ones (K .shape )
874- q3 = np .ones (K .shape )
887+ q1 = nx .ones (K .shape , type_as = K )
888+ q2 = nx .ones (K .shape , type_as = K )
889+ q3 = nx .ones (K .shape , type_as = K )
875890
876891 while (err > stopThr and cpt < numItermax ):
877892 Kprev = K
878893 K = K * q1
879- K1 = np .dot (np .diag (np .minimum (a / np .sum (K , axis = 1 ), dx )), K )
894+ K1 = nx .dot (nx .diag (nx .minimum (a / nx .sum (K , axis = 1 ), dx )), K )
880895 q1 = q1 * Kprev / K1
881896 K1prev = K1
882897 K1 = K1 * q2
883- K2 = np .dot (K1 , np .diag (np .minimum (b / np .sum (K1 , axis = 0 ), dy )))
898+ K2 = nx .dot (K1 , nx .diag (nx .minimum (b / nx .sum (K1 , axis = 0 ), dy )))
884899 q2 = q2 * K1prev / K2
885900 K2prev = K2
886901 K2 = K2 * q3
887- K = K2 * (m / np .sum (K2 ))
902+ K = K2 * (m / nx .sum (K2 ))
888903 q3 = q3 * K2prev / K
889904
890- if np .any (np .isnan (K )) or np .any (np .isinf (K )):
905+ if nx .any (nx .isnan (K )) or nx .any (nx .isinf (K )):
891906 print ('Warning: numerical errors at iteration' , cpt )
892907 break
893908 if cpt % 10 == 0 :
894- err = np . linalg .norm (Kprev - K )
909+ err = nx .norm (Kprev - K )
895910 if log :
896911 log_e ['err' ].append (err )
897912 if verbose :
@@ -901,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
901916 print ('{:5d}|{:8e}|' .format (cpt , err ))
902917
903918 cpt = cpt + 1
904- log_e ['partial_w_dist' ] = np .sum (M * K )
919+ log_e ['partial_w_dist' ] = nx .sum (M * K )
905920 if log :
906921 return K , log_e
907922 else :
0 commit comments