@@ -160,15 +160,17 @@ def df(G):
160160
161161 def df (G ):
162162 return 0.5 * (gwggrad (constC , hC1 , hC2 , G , np_ ) + gwggrad (constCt , hC1t , hC2t , G , np_ ))
163- if loss_fun == 'kl_loss' :
164- armijo = True # there is no closed form line-search with KL
163+
164+ # removed since 0.9.2
165+ #if loss_fun == 'kl_loss':
166+ # armijo = True # there is no closed form line-search with KL
165167
166168 if armijo :
167169 def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
168170 return line_search_armijo (cost , G , deltaG , Mi , cost_G , nx = np_ , ** kwargs )
169171 else :
170172 def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
171- return solve_gromov_linesearch (G , deltaG , cost_G , C1 , C2 , M = 0. , reg = 1. , nx = np_ , ** kwargs )
173+ return solve_gromov_linesearch (G , deltaG , cost_G , hC1 , hC2 , M = 0. , reg = 1. , nx = np_ , ** kwargs )
172174 if log :
173175 res , log = cg (p , q , 0. , 1. , f , df , G0 , line_search , log = True , numItermax = max_iter , stopThr = tol_rel , stopThr2 = tol_abs , ** kwargs )
174176 log ['gw_dist' ] = nx .from_numpy (log ['loss' ][- 1 ], type_as = C10 )
@@ -296,9 +298,13 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
296298 if loss_fun == 'square_loss' :
297299 gC1 = 2 * C1 * nx .outer (p , p ) - 2 * nx .dot (T , nx .dot (C2 , T .T ))
298300 gC2 = 2 * C2 * nx .outer (q , q ) - 2 * nx .dot (T .T , nx .dot (C1 , T ))
299- gw = nx .set_gradients (gw , (p , q , C1 , C2 ),
300- (log_gw ['u' ] - nx .mean (log_gw ['u' ]),
301- log_gw ['v' ] - nx .mean (log_gw ['v' ]), gC1 , gC2 ))
301+ elif loss_fun == 'kl_loss' :
302+ gC1 = nx .log (C1 + 1e-15 ) * nx .outer (p , p ) - nx .dot (T , nx .dot (nx .log (C2 + 1e-15 ), T .T ))
303+ gC2 = nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
304+
305+ gw = nx .set_gradients (gw , (p , q , C1 , C2 ),
306+ (log_gw ['u' ] - nx .mean (log_gw ['u' ]),
307+ log_gw ['v' ] - nx .mean (log_gw ['v' ]), gC1 , gC2 ))
302308
303309 if log :
304310 return gw , log_gw
@@ -449,15 +455,16 @@ def df(G):
449455 def df (G ):
450456 return 0.5 * (gwggrad (constC , hC1 , hC2 , G , np_ ) + gwggrad (constCt , hC1t , hC2t , G , np_ ))
451457
452- if loss_fun == 'kl_loss' :
453- armijo = True # there is no closed form line-search with KL
458+ # removed since 0.9.2
459+ #if loss_fun == 'kl_loss':
460+ # armijo = True # there is no closed form line-search with KL
454461
455462 if armijo :
456463 def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
457464 return line_search_armijo (cost , G , deltaG , Mi , cost_G , nx = np_ , ** kwargs )
458465 else :
459466 def line_search (cost , G , deltaG , Mi , cost_G , ** kwargs ):
460- return solve_gromov_linesearch (G , deltaG , cost_G , C1 , C2 , M = (1 - alpha ) * M , reg = alpha , nx = np_ , ** kwargs )
467+ return solve_gromov_linesearch (G , deltaG , cost_G , hC1 , hC2 , M = (1 - alpha ) * M , reg = alpha , nx = np_ , ** kwargs )
461468 if log :
462469 res , log = cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , line_search , log = True , numItermax = max_iter , stopThr = tol_rel , stopThr2 = tol_abs , ** kwargs )
463470 log ['fgw_dist' ] = nx .from_numpy (log ['loss' ][- 1 ], type_as = C10 )
@@ -591,18 +598,20 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
591598 if loss_fun == 'square_loss' :
592599 gC1 = 2 * C1 * nx .outer (p , p ) - 2 * nx .dot (T , nx .dot (C2 , T .T ))
593600 gC2 = 2 * C2 * nx .outer (q , q ) - 2 * nx .dot (T .T , nx .dot (C1 , T ))
594- if isinstance (alpha , int ) or isinstance (alpha , float ):
595- fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M ),
596- (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
597- log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
598- alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ))
599- else :
600-
601- fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M , alpha ),
602- (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
603- log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
604- alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ,
605- gw_term - lin_term ))
601+ elif loss_fun == 'kl_loss' :
602+ gC1 = nx .log (C1 + 1e-15 ) * nx .outer (p , p ) - nx .dot (T , nx .dot (nx .log (C2 + 1e-15 ), T .T ))
603+ gC2 = nx .dot (T .T , nx .dot (C1 , T )) / (C2 + 1e-15 ) + nx .outer (q , q )
604+ if isinstance (alpha , int ) or isinstance (alpha , float ):
605+ fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M ),
606+ (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
607+ log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
608+ alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ))
609+ else :
610+ fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M , alpha ),
611+ (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
612+ log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
613+ alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ,
614+ gw_term - lin_term ))
606615
607616 if log :
608617 return fgw_dist , log_fgw
@@ -613,7 +622,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
613622def solve_gromov_linesearch (G , deltaG , cost_G , C1 , C2 , M , reg ,
614623 alpha_min = None , alpha_max = None , nx = None , ** kwargs ):
615624 """
616- Solve the linesearch in the FW iterations
625+ Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] <references-solve-linesearch>`.
617626
618627 Parameters
619628 ----------
@@ -625,9 +634,11 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
625634 cost_G : float
626635 Value of the cost at `G`
627636 C1 : array-like (ns,ns), optional
628- Structure matrix in the source domain.
637+ Transformed Structure matrix in the source domain.
638+ For the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix
629639 C2 : array-like (nt,nt), optional
630- Structure matrix in the target domain.
640+ Transformed Structure matrix in the source domain.
641+ For the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix
631642 M : array-like (ns,nt)
632643 Cost matrix between the features.
633644 reg : float
@@ -649,11 +660,16 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
649660
650661
651662 .. _references-solve-linesearch:
663+
652664 References
653665 ----------
654666 .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
655667 "Optimal Transport for structured data with application on graphs"
656668 International Conference on Machine Learning (ICML). 2019.
669+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
670+ "Gromov-Wasserstein averaging of kernel and distance matrices."
671+ International Conference on Machine Learning (ICML). 2016.
672+
657673 """
658674 if nx is None :
659675 G , deltaG , C1 , C2 , M = list_to_array (G , deltaG , C1 , C2 , M )
@@ -664,8 +680,8 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
664680 nx = get_backend (G , deltaG , C1 , C2 , M )
665681
666682 dot = nx .dot (nx .dot (C1 , deltaG ), C2 .T )
667- a = - 2 * reg * nx .sum (dot * deltaG )
668- b = nx .sum (M * deltaG ) - 2 * reg * (nx .sum (dot * G ) + nx .sum (nx .dot (nx .dot (C1 , G ), C2 .T ) * deltaG ))
683+ a = - reg * nx .sum (dot * deltaG )
684+ b = nx .sum (M * deltaG ) - reg * (nx .sum (dot * G ) + nx .sum (nx .dot (nx .dot (C1 , G ), C2 .T ) * deltaG ))
669685
670686 alpha = solve_1d_linesearch_quad (a , b )
671687 if alpha_min is not None or alpha_max is not None :
@@ -776,8 +792,9 @@ def gromov_barycenters(
776792 else :
777793 C = init_C
778794
779- if loss_fun == 'kl_loss' :
780- armijo = True
795+ # removed since 0.9.2
796+ #if loss_fun == 'kl_loss':
797+ # armijo = True
781798
782799 cpt = 0
783800 err = 1
@@ -960,8 +977,9 @@ def fgw_barycenters(
960977
961978 Ms = [dist (X , Ys [s ]) for s in range (len (Ys ))]
962979
963- if loss_fun == 'kl_loss' :
964- armijo = True
980+ # removed since 0.9.2
981+ #if loss_fun == 'kl_loss':
982+ # armijo = True
965983
966984 cpt = 0
967985 err_feature = 1
0 commit comments