Skip to content

Commit 9d7b804

Browse files
oleksandr-pavlykogrisel
authored andcommitted
MAINT: adjustments to test_logistic.py::test_dtype_match (scikit-learn#13645)
1 parent c315bf9 commit 9d7b804

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

sklearn/linear_model/tests/test_logistic.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,32 +1311,42 @@ def test_dtype_match(solver, multi_class):
13111311
X_64 = np.array(X).astype(np.float64)
13121312
y_64 = np.array(Y1).astype(np.float64)
13131313
X_sparse_32 = sp.csr_matrix(X, dtype=np.float32)
1314+
solver_tol = 5e-4
13141315

1316+
lr_templ = LogisticRegression(
1317+
solver=solver, multi_class=multi_class,
1318+
random_state=42, tol=solver_tol, fit_intercept=True)
13151319
# Check type consistency
1316-
lr_32 = LogisticRegression(solver=solver, multi_class=multi_class,
1317-
random_state=42)
1320+
lr_32 = clone(lr_templ)
13181321
lr_32.fit(X_32, y_32)
13191322
assert_equal(lr_32.coef_.dtype, X_32.dtype)
13201323

13211324
# check consistency with sparsity
1322-
lr_32_sparse = LogisticRegression(solver=solver,
1323-
multi_class=multi_class,
1324-
random_state=42)
1325+
lr_32_sparse = clone(lr_templ)
13251326
lr_32_sparse.fit(X_sparse_32, y_32)
13261327
assert_equal(lr_32_sparse.coef_.dtype, X_sparse_32.dtype)
13271328

13281329
# Check accuracy consistency
1329-
lr_64 = LogisticRegression(solver=solver, multi_class=multi_class,
1330-
random_state=42)
1330+
lr_64 = clone(lr_templ)
13311331
lr_64.fit(X_64, y_64)
13321332
assert_equal(lr_64.coef_.dtype, X_64.dtype)
13331333

1334-
rtol = 1e-6
1334+
# solver_tol bounds the norm of the loss gradient
1335+
# dw ~= inv(H)*grad ==> |dw| ~= |inv(H)| * solver_tol, where H - hessian
1336+
#
1337+
# See https://github.com/scikit-learn/scikit-learn/pull/13645
1338+
#
1339+
# with Z = np.hstack((np.ones((3,1)), np.array(X)))
1340+
# In [8]: np.linalg.norm(np.diag([0,2,2]) + np.linalg.inv((Z.T @ Z)/4))
1341+
# Out[8]: 1.7193336918135917
1342+
1343+
# factor of 2 to get the ball diameter
1344+
atol = 2 * 1.72 * solver_tol
13351345
if os.name == 'nt' and _IS_32BIT:
13361346
# FIXME
1337-
rtol = 1e-2
1347+
atol = 1e-2
13381348

1339-
assert_allclose(lr_32.coef_, lr_64.coef_.astype(np.float32), rtol=rtol)
1349+
assert_allclose(lr_32.coef_, lr_64.coef_.astype(np.float32), atol=atol)
13401350

13411351

13421352
def test_warm_start_converge_LR():

0 commit comments

Comments
 (0)