@@ -253,46 +253,75 @@ def gwggrad(constC, hC1, hC2, T, nx=None):
253253 T , nx ) # [12] Prop. 2 misses a 2 factor
254254
255255
256- def update_square_loss (p , lambdas , T , Cs ):
256+ def update_square_loss (p , lambdas , T , Cs , nx = None ):
257257 r"""
258- Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
259- couplings calculated at each iteration
258+ Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S`
259+ :math:`\mathbf{T}_s` couplings calculated at each iteration of the GW
260+ barycenter problem in :ref:`[12]`:
261+
262+ .. math::
263+
264+ \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
265+
266+ Where :
267+
268+ - :math:`\mathbf{C}_s`: metric cost matrix
269+ - :math:`\mathbf{p}_s`: distribution
260270
261271 Parameters
262272 ----------
263273 p : array-like, shape (N,)
264274 Masses in the targeted barycenter.
265275 lambdas : list of float
266276 List of the `S` spaces' weights.
267- T : list of S array-like of shape (ns,N )
277+ T : list of S array-like of shape (N, ns )
268278 The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
269279 Cs : list of S array-like, shape(ns,ns)
270280 Metric cost matrices.
281+ nx : backend, optional
282+ If let to its default value None, a backend test will be conducted.
271283
272284 Returns
273285 ----------
274286 C : array-like, shape (`nt`, `nt`)
275287 Updated :math:`\mathbf{C}` matrix.
288+
289+ References
290+ ----------
291+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
292+ "Gromov-Wasserstein averaging of kernel and distance matrices."
293+ International Conference on Machine Learning (ICML). 2016.
294+
276295 """
277- T = list_to_array (* T )
278- Cs = list_to_array (* Cs )
279- p = list_to_array (p )
280- nx = get_backend (p , * T , * Cs )
296+ if nx is None :
297+ nx = get_backend (p , * T , * Cs )
281298
299+ # Correct order mistake in Equation 14 in [12]
282300 tmpsum = sum ([
283301 lambdas [s ] * nx .dot (
284- nx .dot (T [s ]. T , Cs [s ]),
285- T [s ]
302+ nx .dot (T [s ], Cs [s ]),
303+ T [s ]. T
286304 ) for s in range (len (T ))
287305 ])
288306 ppt = nx .outer (p , p )
289307
290308 return tmpsum / ppt
291309
292310
293- def update_kl_loss (p , lambdas , T , Cs ):
311+ def update_kl_loss (p , lambdas , T , Cs , nx = None ):
294312 r"""
295- Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
313+ Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S`
314+ :math:`\mathbf{T}_s` couplings calculated at each iteration of the GW
315+ barycenter problem in :ref:`[12]`:
316+
317+ .. math::
318+
319+ \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
320+
321+ Where :
322+
323+ - :math:`\mathbf{C}_s`: metric cost matrix
324+ - :math:`\mathbf{p}_s`: distribution
296325
297326
298327 Parameters
@@ -301,33 +330,41 @@ def update_kl_loss(p, lambdas, T, Cs):
301330 Weights in the targeted barycenter.
302331 lambdas : list of float
303332 List of the `S` spaces' weights
304- T : list of S array-like of shape (ns,N )
333+ T : list of S array-like of shape (N, ns )
305334 The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
306335 Cs : list of S array-like, shape(ns,ns)
307336 Metric cost matrices.
337+ nx : backend, optional
338+ If let to its default value None, a backend test will be conducted.
308339
309340 Returns
310341 ----------
311342 C : array-like, shape (`ns`, `ns`)
312343 updated :math:`\mathbf{C}` matrix
344+
345+ References
346+ ----------
347+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
348+ "Gromov-Wasserstein averaging of kernel and distance matrices."
349+ International Conference on Machine Learning (ICML). 2016.
350+
313351 """
314- Cs = list_to_array (* Cs )
315- T = list_to_array (* T )
316- p = list_to_array (p )
317- nx = get_backend (p , * T , * Cs )
352+ if nx is None :
353+ nx = get_backend (p , * T , * Cs )
318354
355+ # Correct order mistake in Equation 15 in [12]
319356 tmpsum = sum ([
320357 lambdas [s ] * nx .dot (
321- nx .dot (T [s ]. T , Cs [s ]),
322- T [s ]
358+ nx .dot (T [s ], nx . log ( nx . maximum ( Cs [s ], 1e-15 )) ),
359+ T [s ]. T
323360 ) for s in range (len (T ))
324361 ])
325362 ppt = nx .outer (p , p )
326363
327364 return nx .exp (tmpsum / ppt )
328365
329366
330- def update_feature_matrix (lambdas , Ys , Ts , p ):
367+ def update_feature_matrix (lambdas , Ys , Ts , p , nx = None ):
331368 r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
332369
333370
@@ -340,10 +377,12 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
340377 masses in the targeted barycenter
341378 lambdas : list of float
342379 List of the `S` spaces' weights
343- Ts : list of S array-like, shape (ns,N )
380+ Ts : list of S array-like, shape (N, ns )
344381 The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
345382 Ys : list of S array-like, shape (d,ns)
346383 The features.
384+ nx : backend, optional
385+ If let to its default value None, a backend test will be conducted.
347386
348387 Returns
349388 -------
@@ -357,10 +396,8 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
357396 "Optimal Transport for structured data with application on graphs"
358397 International Conference on Machine Learning (ICML). 2019.
359398 """
360- p = list_to_array (p )
361- Ts = list_to_array (* Ts )
362- Ys = list_to_array (* Ys )
363- nx = get_backend (* Ys , * Ts , p )
399+ if nx is None :
400+ nx = get_backend (* Ys , * Ts , p )
364401
365402 p = 1. / p
366403 tmpsum = sum ([
0 commit comments