Skip to content

Commit c70e14c

Browse files
committed
change convergence conditions
1 parent fe5065d commit c70e14c

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

numpy_ml/factorization/factors.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,18 @@ def fit(self, X, W=None, H=None, n_initializations=10, verbose=False):
140140

141141
def _fit(self, X, W, H, verbose):
142142
self._init_factor_matrices(X, W, H)
143-
prev_loss = loss = np.inf
144143
W, H = self.W, self.H
145144

146145
for i in range(self.max_iter):
147-
prev_loss = loss
148146
W = self._update_factor(X, H.T)
149147
H = self._update_factor(X.T, W).T
150148

151149
loss = self._loss(X, W @ H)
152150

153151
if verbose:
154-
print("[Iter {}] Loss: {:.6f}".format(i + 1, loss))
152+
print("[Iter {}] Loss: {:.8f}".format(i + 1, loss))
155153

156-
if (prev_loss - loss) <= self.tol:
154+
if loss <= self.tol:
157155
break
158156

159157
return W, H, loss
@@ -252,8 +250,7 @@ def _loss(self, X, Xhat):
252250

253251
def _update_H(self, X, W, H):
254252
"""Perform the fast HALS update for H"""
255-
# eps = np.finfo(float).eps
256-
eps = 1e-16
253+
eps = np.finfo(float).eps
257254
XtW = X.T @ W # dim: (M, K)
258255
WtW = W.T @ W # dim: (K, K)
259256

@@ -264,7 +261,7 @@ def _update_H(self, X, W, H):
264261

265262
def _update_W(self, X, W, H):
266263
"""Perform the fast HALS update for W"""
267-
eps = 1e-16 # np.finfo(float).eps
264+
eps = np.finfo(float).eps
268265
XHt = X @ H.T # dim: (N, K)
269266
HHt = H @ H.T # dim: (K, K)
270267

@@ -360,16 +357,14 @@ def _fit(self, X, W, H, verbose):
360357
self._init_factor_matrices(X, W, H)
361358

362359
W, H = self.W, self.H
363-
prev_loss = loss = np.inf
364360
for i in range(self.max_iter):
365-
prev_loss = loss
366361
H = self._update_H(X, W, H)
367362
W = self._update_W(X, W, H)
368363
loss = self._loss(X, W @ H)
369364

370365
if verbose:
371-
print("[Iter {}] Loss: {:.4f}".format(i + 1, loss))
366+
print("[Iter {}] Loss: {:.8f}".format(i + 1, loss))
372367

373-
if (prev_loss - loss) <= self.tol:
368+
if loss <= self.tol:
374369
break
375370
return W, H, loss

0 commit comments

Comments
 (0)