1717
1818from scipy import linalg
1919import autograd .numpy as np
20- from pymanopt . function import Autograd
21- from pymanopt . manifolds import Stiefel
22- from pymanopt import Problem
23- from pymanopt .solvers import SteepestDescent , TrustRegions
20+
21+ import pymanopt
22+ import pymanopt . manifolds
23+ import pymanopt .optimizers
2424
2525
2626def dist (x1 , x2 ):
@@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k):
3838 ui = np .ones ((M .shape [0 ],))
3939 vi = np .ones ((M .shape [1 ],))
4040 for i in range (k ):
41- vi = w2 / (np .dot (K .T , ui ))
42- ui = w1 / (np .dot (K , vi ))
41+ vi = w2 / (np .dot (K .T , ui ) + 1e-50 )
42+ ui = w1 / (np .dot (K , vi ) + 1e-50 )
4343 G = ui .reshape ((M .shape [0 ], 1 )) * K * vi .reshape ((1 , M .shape [1 ]))
4444 return G
4545
@@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
222222 else :
223223 regmean = np .ones ((len (xc ), len (xc )))
224224
225- @Autograd
225+ manifold = pymanopt .manifolds .Stiefel (d , p )
226+
227+ @pymanopt .function .autograd (manifold )
226228 def cost (P ):
227229 # wda loss
228230 loss_b = 0
@@ -243,21 +245,21 @@ def cost(P):
243245 return loss_w / loss_b
244246
245247 # declare manifold and problem
246- manifold = Stiefel ( d , p )
247- problem = Problem (manifold = manifold , cost = cost )
248+
249+ problem = pymanopt . Problem (manifold = manifold , cost = cost )
248250
249251 # declare solver and solve
250252 if solver is None :
251- solver = SteepestDescent (maxiter = maxiter , logverbosity = verbose )
253+ solver = pymanopt . optimizers . SteepestDescent (max_iterations = maxiter , log_verbosity = verbose )
252254 elif solver in ['tr' , 'TrustRegions' ]:
253- solver = TrustRegions (maxiter = maxiter , logverbosity = verbose )
255+ solver = pymanopt . optimizers . TrustRegions (max_iterations = maxiter , log_verbosity = verbose )
254256
255- Popt = solver .solve (problem , x = P0 )
257+ Popt = solver .run (problem , initial_point = P0 )
256258
257259 def proj (X ):
258- return (X - mx .reshape ((1 , - 1 ))).dot (Popt )
260+ return (X - mx .reshape ((1 , - 1 ))).dot (Popt . point )
259261
260- return Popt , proj
262+ return Popt . point , proj
261263
262264
263265def projection_robust_wasserstein (X , Y , a , b , tau , U0 = None , reg = 0.1 , k = 2 , stopThr = 1e-3 , maxiter = 100 , verbose = 0 ):
0 commit comments