99#
1010# License: MIT License
1111
12- import warnings
1312import numpy as np
1413from scipy .optimize import minimize , Bounds
1514
1615from ..backend import get_backend
17- from ..utils import list_to_array , get_parameter_pair
16+ from ..utils import list_to_array , get_parameter_pair , fun_to_numpy
1817
1918
2019def _get_loss_unbalanced (a , b , c , M , reg , reg_m1 , reg_m2 , reg_div = "kl" , regm_div = "kl" ):
@@ -46,9 +45,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div
4645 Divergence used for regularization.
4746 Can take three values: 'entropy' (negative entropy), or
4847 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
49- of two calable functions returning the reg term and its derivative.
48+ of two callable functions returning the reg term and its derivative.
5049 Note that the callable functions should be able to handle Numpy arrays
51- and not tesors from the backend
50+ and not tensors from the backend
5251 regm_div: string, optional
5352 Divergence to quantify the difference between the marginals.
5453 Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
@@ -206,26 +205,27 @@ def lbfgsb_unbalanced(
206205 loss matrix
207206 reg: float
208207 regularization term >=0
209- c : array-like (dim_a, dim_b), optional (default = None)
210- Reference measure for the regularization.
211- If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
212208 reg_m: float or indexable object of length 1 or 2
213209 Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
214210 If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
215211 then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
216212 If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
217- reg_div: string, optional
213+ c : array-like (dim_a, dim_b), optional (default = None)
214+ Reference measure for the regularization.
215+ If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
216+ reg_div: string or pair of callable functions, optional (default = 'kl')
218217 Divergence used for regularization.
219218 Can take three values: 'entropy' (negative entropy), or
220219 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
221- of two calable functions returning the reg term and its derivative.
220+ of two callable functions returning the reg term and its derivative.
222221 Note that the callable functions should be able to handle Numpy arrays
223- and not tesors from the backend
224- regm_div: string, optional
222+ and not tensors from the backend, otherwise functions will be converted to Numpy
223+ leading to a computational overhead.
224+ regm_div: string, optional (default = 'kl')
225225 Divergence to quantify the difference between the marginals.
226226 Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
227- G0: array-like (dim_a, dim_b)
228- Initialization of the transport matrix
227+ G0: array-like (dim_a, dim_b), optional (default = None)
228+ Initialization of the transport matrix. None corresponds to uniform product.
229229 numItermax : int, optional
230230 Max number of iterations
231231 stopThr : float, optional
@@ -267,26 +267,14 @@ def lbfgsb_unbalanced(
267267 ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
268268 """
269269
270- # wrap the callable function to handle numpy arrays
271- if isinstance (reg_div , tuple ):
272- f0 , df0 = reg_div
273- try :
274- f0 (G0 )
275- df0 (G0 )
276- except BaseException :
277- warnings .warn (
278- "The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead"
279- )
280-
281- def f (x ):
282- return nx .to_numpy (f0 (nx .from_numpy (x , type_as = M0 )))
283-
284- def df (x ):
285- return nx .to_numpy (df0 (nx .from_numpy (x , type_as = M0 )))
286-
287- reg_div = (f , df )
270+ # test settings
271+ regm_div = regm_div .lower ()
272+ if regm_div not in ["kl" , "l2" , "tv" ]:
273+ raise ValueError (
274+ "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'" .format (regm_div )
275+ )
288276
289- else :
277+ if isinstance ( reg_div , str ) :
290278 reg_div = reg_div .lower ()
291279 if reg_div not in ["entropy" , "kl" , "l2" ]:
292280 raise ValueError (
@@ -295,16 +283,11 @@ def df(x):
295283 )
296284 )
297285
298- regm_div = regm_div .lower ()
299- if regm_div not in ["kl" , "l2" , "tv" ]:
300- raise ValueError (
301- "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'" .format (regm_div )
302- )
303-
286+ # convert all inputs to numpy arrays
304287 reg_m1 , reg_m2 = get_parameter_pair (reg_m )
305288
306289 M , a , b = list_to_array (M , a , b )
307- nx = get_backend (M , a , b )
290+ nx = get_backend (M , a , b , G0 )
308291 M0 = M
309292
310293 dim_a , dim_b = M .shape
@@ -315,10 +298,22 @@ def df(x):
315298 b = nx .ones (dim_b , type_as = M ) / dim_b
316299
317300 # convert to numpy
318- a , b , M , reg_m1 , reg_m2 , reg = nx .to_numpy (a , b , M , reg_m1 , reg_m2 , reg )
301+ if nx .__name__ == "numpy" : # remaining parameters which can be arrays
302+ reg_m1 , reg_m2 , reg = nx .to_numpy (reg_m1 , reg_m2 , reg )
303+ else :
304+ a , b , M , reg_m1 , reg_m2 , reg = nx .to_numpy (a , b , M , reg_m1 , reg_m2 , reg )
305+
319306 G0 = a [:, None ] * b [None , :] if G0 is None else nx .to_numpy (G0 )
320307 c = a [:, None ] * b [None , :] if c is None else nx .to_numpy (c )
321308
309+ # potentially convert the callable function to handle numpy arrays
310+ if isinstance (reg_div , tuple ):
311+ f0 , df0 = reg_div
312+ f = fun_to_numpy (f0 , G0 , nx , warn = True )
313+ df = fun_to_numpy (df0 , G0 , nx , warn = True )
314+
315+ reg_div = (f , df )
316+
322317 _func = _get_loss_unbalanced (a , b , c , M , reg , reg_m1 , reg_m2 , reg_div , regm_div )
323318
324319 res = minimize (
@@ -399,26 +394,27 @@ def lbfgsb_unbalanced2(
399394 loss matrix
400395 reg: float
401396 regularization term >=0
402- c : array-like (dim_a, dim_b), optional (default = None)
403- Reference measure for the regularization.
404- If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
405397 reg_m: float or indexable object of length 1 or 2
406398 Marginal relaxation term: nonnegative (including 0) but cannot be infinity.
407399 If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
408400 then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
409401 If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array.
410- reg_div: string, optional
402+ c : array-like (dim_a, dim_b), optional (default = None)
403+ Reference measure for the regularization.
404+ If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
405+ reg_div: string or pair of callable functions, optional (default = 'kl')
411406 Divergence used for regularization.
412407 Can take three values: 'entropy' (negative entropy), or
413408 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple
414- of two calable functions returning the reg term and its derivative.
409+ of two callable functions returning the reg term and its derivative.
415410 Note that the callable functions should be able to handle Numpy arrays
416- and not tesors from the backend
417- regm_div: string, optional
411+ and not tensors from the backend, otherwise functions will be converted to Numpy
412+ leading to a computational overhead.
413+ regm_div: string, optional (default = 'kl')
418414 Divergence to quantify the difference between the marginals.
419415 Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation)
420- G0: array-like (dim_a, dim_b)
421- Initialization of the transport matrix
416+ G0: array-like (dim_a, dim_b), optional (default = None)
417+ Initialization of the transport matrix. None corresponds to uniform product.
422418 returnCost: string, optional (default = "linear")
423419 If `returnCost` = "linear", then return the linear part of the unbalanced OT loss.
424420 If `returnCost` = "total", then return the total unbalanced OT loss.
0 commit comments