Skip to content

Commit c1f5874

Browse files
NicolasHugamueller
authored andcommitted
[MRG] Add elastic net penalty to LogisticRegression (scikit-learn#11646)
* First draft on elasticnet penaly for LogisticRegression * Some basic tests * Doc update * First draft for LogisticRegressionCV. It seems to be working for binary classification and for multiclass when multi_class='ovr'. I'm having a hard time figuring out the intricacies of multi_class='multinomial'. * Changed default to None for l1_ratio. added warning message is user sets l1_ratio while penalty is not elastic-net * Some more doc * Updated example to plot elastic net sparsity * Fixed flake8 * Fixed test by not modifying attribute in fit * Fixed doc issues * WIP * Partially fixed logistic_reg_CV for multinomial. Also added some comments that are hopefully clear. Still need to fix refit=False * Fixed doc issue * WIP * Fixed test for refit=False in LogisticRegressionCV * Fixed Python 2 numpy version issue * minor doc updates * Weird doc error... * Added test to ensure that elastic net is at least as good as L1 or L2 once l1_ratio has been optimized with grid search Also addressed minor reviews * Fixed test * addressed comments * Added back ignore warning on tests * Added a functional test * Scale data in test... Now failing * elastic-net --> elasticnet * Updated doc for some attributes and checked their shape in tests * Added l1_ratio dimension to coefs_paths and scores attr * improve example + fix test * FIX incorrect lagged_update in SAGA * Add non-regression test for SAGA's bug * FIX flake8 and warning * Re fixed warning * Updated some tests * Addressed comments * more comments and added dimension to LogisticRegressionCV.n_iter_ attribute * Updated whatsnew for 0.21 * better doc shape looks * Fixed whatnew entry after merges * Added dot * Addressed comments + standardized optional default param docstrings * Addessed comments * use swapaxes instead of unsupported moveaxis (hopefully fixes tests)
1 parent f6f7e3c commit c1f5874

File tree

8 files changed

+617
-171
lines changed

8 files changed

+617
-171
lines changed

doc/modules/linear_model.rst

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ the algorithm to fit the coefficients.
338338

339339
.. _elastic_net:
340340

341-
Elastic Net
341+
Elastic-Net
342342
===========
343343
:class:`ElasticNet` is a linear regression model trained with L1 and L2 prior
344344
as regularizer. This combination allows for learning a sparse model where
@@ -390,7 +390,7 @@ the duality gap computation used for convergence control.
390390

391391
.. _multi_task_elastic_net:
392392

393-
Multi-task Elastic Net
393+
Multi-task Elastic-Net
394394
======================
395395

396396
The :class:`MultiTaskElasticNet` is an elastic-net model that estimates sparse
@@ -730,7 +730,7 @@ or the log-linear classifier. In this model, the probabilities describing the po
730730

731731
The implementation of logistic regression in scikit-learn can be accessed from
732732
class :class:`LogisticRegression`. This implementation can fit binary, One-vs-
733-
Rest, or multinomial logistic regression with optional L2 or L1
733+
Rest, or multinomial logistic regression with optional L2, L1 or Elastic-Net
734734
regularization.
735735

736736
As an optimization problem, binary class L2 penalized logistic regression
@@ -739,12 +739,22 @@ minimizes the following cost function:
739739
.. math:: \min_{w, c} \frac{1}{2}w^T w + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1) .
740740

741741
Similarly, L1 regularized logistic regression solves the following
742-
optimization problem
742+
optimization problem:
743743

744744
.. math:: \min_{w, c} \|w\|_1 + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1).
745745

746+
Elastic-Net regularization is a combination of L1 and L2, and minimizes the
747+
following cost function:
748+
749+
.. math:: \min_{w, c} \frac{1 - \rho}{2}w^T w + \rho \|w\|_1 + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1),
750+
751+
where :math:`\rho` controls the strengh of L1 regularization vs L2
752+
regularization (it corresponds to the `l1_ratio` parameter).
753+
746754
Note that, in this notation, it's assumed that the observation :math:`y_i` takes values in the set
747-
:math:`{-1, 1}` at trial :math:`i`.
755+
:math:`{-1, 1}` at trial :math:`i`. We can also see that Elastic-Net is
756+
equivalent to L1 when :math:`\rho = 1` and equivalent to L2 when
757+
:math:`\rho=0`.
748758

749759
The solvers implemented in the class :class:`LogisticRegression`
750760
are "liblinear", "newton-cg", "lbfgs", "sag" and "saga":
@@ -772,10 +782,12 @@ than other solvers for large datasets, when both the number of samples and the
772782
number of features are large.
773783

774784
The "saga" solver [7]_ is a variant of "sag" that also supports the
775-
non-smooth `penalty="l1"` option. This is therefore the solver of choice
776-
for sparse multinomial logistic regression.
785+
non-smooth `penalty="l1"`. This is therefore the solver of choice for sparse
786+
multinomial logistic regression. It is also the only solver that supports
787+
`penalty="elasticnet"`.
777788

778-
In a nutshell, the following table summarizes the penalties supported by each solver:
789+
In a nutshell, the following table summarizes the penalties supported by
790+
each solver:
779791

780792
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
781793
| | **Solvers** |
@@ -790,6 +802,8 @@ In a nutshell, the following table summarizes the penalties supported by each so
790802
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
791803
| OVR + L1 penalty | yes | no | no | no | yes |
792804
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
805+
| Elastic-Net | no | no | no | no | yes |
806+
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
793807
| **Behaviors** | |
794808
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
795809
| Penalize the intercept (bad) | yes | no | no | no | no |
@@ -799,8 +813,8 @@ In a nutshell, the following table summarizes the penalties supported by each so
799813
| Robust to unscaled datasets | yes | yes | yes | no | no |
800814
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
801815

802-
The "saga" solver is often the best choice but requires scaling. The "liblinear" solver is
803-
used by default for historical reasons.
816+
The "saga" solver is often the best choice but requires scaling. The
817+
"liblinear" solver is used by default for historical reasons.
804818

805819
For large dataset, you may also consider using :class:`SGDClassifier`
806820
with 'log' loss.
@@ -838,14 +852,11 @@ with 'log' loss.
838852
thus be used to perform feature selection, as detailed in
839853
:ref:`l1_feature_selection`.
840854

841-
:class:`LogisticRegressionCV` implements Logistic Regression with
842-
builtin cross-validation to find out the optimal C parameter.
843-
"newton-cg", "sag", "saga" and "lbfgs" solvers are found to be faster
844-
for high-dimensional dense data, due to warm-starting. For the
845-
multiclass case, if `multi_class` option is set to "ovr", an optimal C
846-
is obtained for each class and if the `multi_class` option is set to
847-
"multinomial", an optimal C is obtained by minimizing the cross-entropy
848-
loss.
855+
:class:`LogisticRegressionCV` implements Logistic Regression with built-in
856+
cross-validation support, to find the optimal `C` and `l1_ratio` parameters
857+
according to the ``scoring`` attribute. The "newton-cg", "sag", "saga" and
858+
"lbfgs" solvers are found to be faster for high-dimensional dense data, due
859+
to warm-starting (see :term:`Glossary <warm_start>`).
849860

850861
.. topic:: References:
851862

doc/tutorial/statistical_inference/supervised_learning.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ Linear models: :math:`y = X\beta + \epsilon`
183183
[ 0.30349955 -237.63931533 510.53060544 327.73698041 -814.13170937
184184
492.81458798 102.84845219 184.60648906 743.51961675 76.09517222]
185185

186+
186187
>>> # The mean square error
187188
>>> np.mean((regr.predict(diabetes_X_test) - diabetes_y_test)**2)
188189
... # doctest: +ELLIPSIS
@@ -378,7 +379,7 @@ function or **logistic** function:
378379
... multi_class='multinomial')
379380
>>> log.fit(iris_X_train, iris_y_train) # doctest: +NORMALIZE_WHITESPACE
380381
LogisticRegression(C=100000.0, class_weight=None, dual=False,
381-
fit_intercept=True, intercept_scaling=1, max_iter=100,
382+
fit_intercept=True, intercept_scaling=1, l1_ratio=None, max_iter=100,
382383
multi_class='multinomial', n_jobs=None, penalty='l2', random_state=None,
383384
solver='lbfgs', tol=0.0001, verbose=0, warm_start=False)
384385

doc/whats_new/v0.21.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ random sampling procedures.
2222
2323
- Decision trees and derived ensembles when both `max_depth` and
2424
`max_leaf_nodes` are set. |Fix|
25+
- :class:`linear_model.LogisticRegression` and
26+
:class:`linear_model.LogisticRegressionCV` with 'saga' solver. |Fix|
27+
2528

2629
Details are listed in the changelog below.
2730

@@ -146,6 +149,15 @@ Support for Python 3.4 and below has been officially dropped.
146149
affects all ensemble methods using decision trees.
147150
:pr:`12344` by :user:`Adrin Jalali <adrinjalali>`.
148151

152+
:mod:`sklearn.linear_model`
153+
...........................
154+
155+
- |Feature| :class:`linear_model.LogisticRegression` and
156+
:class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty,
157+
with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug <NicolasHug>`.
158+
159+
- |Fix| Fixed a bug in the 'saga' solver where the weights would not be
160+
correctly updated in some cases. :issue:`11646` by `Tom Dupre la Tour`_.
149161

150162
Multiple modules
151163
................

examples/linear_model/plot_logistic_l1_l2_sparsity.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
==============================================
55
66
Comparison of the sparsity (percentage of zero coefficients) of solutions when
7-
L1 and L2 penalty are used for different values of C. We can see that large
8-
values of C give more freedom to the model. Conversely, smaller values of C
9-
constrain the model more. In the L1 penalty case, this leads to sparser
10-
solutions.
7+
L1, L2 and Elastic-Net penalty are used for different values of C. We can see
8+
that large values of C give more freedom to the model. Conversely, smaller
9+
values of C constrain the model more. In the L1 penalty case, this leads to
10+
sparser solutions. As expected, the Elastic-Net penalty sparsity is between
11+
that of L1 and L2.
1112
1213
We classify 8x8 images of digits into two classes: 0-4 against 5-9.
1314
The visualization shows coefficients of the models for varying C.
@@ -35,45 +36,55 @@
3536
# classify small against large digits
3637
y = (y > 4).astype(np.int)
3738

39+
l1_ratio = 0.5 # L1 weight in the Elastic-Net regularization
40+
41+
fig, axes = plt.subplots(3, 3)
3842

3943
# Set regularization parameter
40-
for i, C in enumerate((1, 0.1, 0.01)):
44+
for i, (C, axes_row) in enumerate(zip((1, 0.1, 0.01), axes)):
4145
# turn down tolerance for short training time
4246
clf_l1_LR = LogisticRegression(C=C, penalty='l1', tol=0.01, solver='saga')
4347
clf_l2_LR = LogisticRegression(C=C, penalty='l2', tol=0.01, solver='saga')
48+
clf_en_LR = LogisticRegression(C=C, penalty='elasticnet', solver='saga',
49+
l1_ratio=l1_ratio, tol=0.01)
4450
clf_l1_LR.fit(X, y)
4551
clf_l2_LR.fit(X, y)
52+
clf_en_LR.fit(X, y)
4653

4754
coef_l1_LR = clf_l1_LR.coef_.ravel()
4855
coef_l2_LR = clf_l2_LR.coef_.ravel()
56+
coef_en_LR = clf_en_LR.coef_.ravel()
4957

5058
# coef_l1_LR contains zeros due to the
5159
# L1 sparsity inducing norm
5260

5361
sparsity_l1_LR = np.mean(coef_l1_LR == 0) * 100
5462
sparsity_l2_LR = np.mean(coef_l2_LR == 0) * 100
63+
sparsity_en_LR = np.mean(coef_en_LR == 0) * 100
5564

5665
print("C=%.2f" % C)
57-
print("Sparsity with L1 penalty: %.2f%%" % sparsity_l1_LR)
58-
print("score with L1 penalty: %.4f" % clf_l1_LR.score(X, y))
59-
print("Sparsity with L2 penalty: %.2f%%" % sparsity_l2_LR)
60-
print("score with L2 penalty: %.4f" % clf_l2_LR.score(X, y))
66+
print("{:<40} {:.2f}%".format("Sparsity with L1 penalty:", sparsity_l1_LR))
67+
print("{:<40} {:.2f}%".format("Sparsity with Elastic-Net penalty:",
68+
sparsity_en_LR))
69+
print("{:<40} {:.2f}%".format("Sparsity with L2 penalty:", sparsity_l2_LR))
70+
print("{:<40} {:.2f}".format("Score with L1 penalty:",
71+
clf_l1_LR.score(X, y)))
72+
print("{:<40} {:.2f}".format("Score with Elastic-Net penalty:",
73+
clf_en_LR.score(X, y)))
74+
print("{:<40} {:.2f}".format("Score with L2 penalty:",
75+
clf_l2_LR.score(X, y)))
6176

62-
l1_plot = plt.subplot(3, 2, 2 * i + 1)
63-
l2_plot = plt.subplot(3, 2, 2 * (i + 1))
6477
if i == 0:
65-
l1_plot.set_title("L1 penalty")
66-
l2_plot.set_title("L2 penalty")
67-
68-
l1_plot.imshow(np.abs(coef_l1_LR.reshape(8, 8)), interpolation='nearest',
69-
cmap='binary', vmax=1, vmin=0)
70-
l2_plot.imshow(np.abs(coef_l2_LR.reshape(8, 8)), interpolation='nearest',
71-
cmap='binary', vmax=1, vmin=0)
72-
plt.text(-8, 3, "C = %.2f" % C)
73-
74-
l1_plot.set_xticks(())
75-
l1_plot.set_yticks(())
76-
l2_plot.set_xticks(())
77-
l2_plot.set_yticks(())
78+
axes_row[0].set_title("L1 penalty")
79+
axes_row[1].set_title("Elastic-Net\nl1_ratio = %s" % l1_ratio)
80+
axes_row[2].set_title("L2 penalty")
81+
82+
for ax, coefs in zip(axes_row, [coef_l1_LR, coef_en_LR, coef_l2_LR]):
83+
ax.imshow(np.abs(coefs.reshape(8, 8)), interpolation='nearest',
84+
cmap='binary', vmax=1, vmin=0)
85+
ax.set_xticks(())
86+
ax.set_yticks(())
87+
88+
axes_row[0].set_ylabel('C = %s' % C)
7889

7990
plt.show()

0 commit comments

Comments
 (0)