Skip to content

Commit ac2ff4a

Browse files
committed
Merge pull request scikit-learn#5497 from MechCoder/ransac_residual
[MRG] Deprecate residual_metric and add support for loss in RANSAC
2 parents 41526cb + 761b1f7 commit ac2ff4a

File tree

3 files changed

+111
-11
lines changed

3 files changed

+111
-11
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ Enhancements
7979
- :class:`linear_model.RANSACRegressor` now supports ``sample_weights``.
8080
By `Imaculate`_.
8181

82+
- Add parameter ``loss`` to :class:`linear_model.RANSACRegressor` to measure the
83+
error on the samples for every trial. By `Manoj Kumar`_.
84+
8285
Bug fixes
8386
.........
8487

@@ -114,6 +117,8 @@ API changes summary
114117
the :mod:`model_selection` module.
115118
(`#4294 <https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_.
116119

120+
- ``residual_metric`` has been deprecated in :class:`linear_model.RANSACRegressor`.
121+
Use ``loss`` instead. By `Manoj Kumar`_.
117122

118123
.. _changes_0_17:
119124

sklearn/linear_model/ransac.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# License: BSD 3 clause
66

77
import numpy as np
8+
import warnings
89

910
from ..base import BaseEstimator, MetaEstimatorMixin, RegressorMixin, clone
1011
from ..utils import check_random_state, check_array, check_consistent_length
@@ -134,6 +135,22 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
134135
135136
lambda dy: np.sum(np.abs(dy), axis=1)
136137
138+
NOTE: residual_metric is deprecated from 0.18 and will be removed in 0.20
139+
Use ``loss`` instead.
140+
141+
loss: string, callable, optional, default "absolute_loss"
142+
String inputs, "absolute_loss" and "squared_loss" are supported which
143+
find the absolute loss and squared loss per sample
144+
respectively.
145+
146+
If ``loss`` is a callable, then it should be a function that takes
147+
two arrays as inputs, the true and predicted value and returns a 1-D
148+
array with the ``i``th value of the array corresponding to the loss
149+
on `X[i]`.
150+
151+
If the loss on a sample is greater than the ``residual_threshold``, then
152+
this sample is classified as an outlier.
153+
137154
random_state : integer or numpy.RandomState, optional
138155
The generator used to initialize the centers. If an integer is
139156
given, it fixes the seed. Defaults to the global numpy random
@@ -163,7 +180,7 @@ def __init__(self, base_estimator=None, min_samples=None,
163180
is_model_valid=None, max_trials=100,
164181
stop_n_inliers=np.inf, stop_score=np.inf,
165182
stop_probability=0.99, residual_metric=None,
166-
random_state=None):
183+
loss='absolute_loss', random_state=None):
167184

168185
self.base_estimator = base_estimator
169186
self.min_samples = min_samples
@@ -176,6 +193,7 @@ def __init__(self, base_estimator=None, min_samples=None,
176193
self.stop_probability = stop_probability
177194
self.residual_metric = residual_metric
178195
self.random_state = random_state
196+
self.loss = loss
179197

180198
def fit(self, X, y, sample_weight=None):
181199
"""Fit estimator using RANSAC algorithm.
@@ -236,10 +254,33 @@ def fit(self, X, y, sample_weight=None):
236254
else:
237255
residual_threshold = self.residual_threshold
238256

239-
if self.residual_metric is None:
240-
residual_metric = lambda dy: np.sum(np.abs(dy), axis=1)
257+
if self.residual_metric is not None:
258+
warnings.warn(
259+
"'residual_metric' will be removed in version 0.20. Use "
260+
"'loss' instead.", DeprecationWarning)
261+
262+
if self.loss == "absolute_loss":
263+
if y.ndim == 1:
264+
loss_function = lambda y_true, y_pred: np.abs(y_true - y_pred)
265+
else:
266+
loss_function = lambda \
267+
y_true, y_pred: np.sum(np.abs(y_true - y_pred), axis=1)
268+
269+
elif self.loss == "squared_loss":
270+
if y.ndim == 1:
271+
loss_function = lambda y_true, y_pred: (y_true - y_pred) ** 2
272+
else:
273+
loss_function = lambda \
274+
y_true, y_pred: np.sum((y_true - y_pred) ** 2, axis=1)
275+
276+
elif callable(self.loss):
277+
loss_function = self.loss
278+
241279
else:
242-
residual_metric = self.residual_metric
280+
raise ValueError(
281+
"loss should be 'absolute_loss', 'squared_loss' or a callable."
282+
"Got %s. " % self.loss)
283+
243284

244285
random_state = check_random_state(self.random_state)
245286

@@ -298,10 +339,15 @@ def fit(self, X, y, sample_weight=None):
298339

299340
# residuals of all data for current random sample model
300341
y_pred = base_estimator.predict(X)
301-
diff = y_pred - y
302-
if diff.ndim == 1:
303-
diff = diff.reshape(-1, 1)
304-
residuals_subset = residual_metric(diff)
342+
343+
# XXX: Deprecation: Remove this if block in 0.20
344+
if self.residual_metric is not None:
345+
diff = y_pred - y
346+
if diff.ndim == 1:
347+
diff = diff.reshape(-1, 1)
348+
residuals_subset = self.residual_metric(diff)
349+
else:
350+
residuals_subset = loss_function(y, y_pred)
305351

306352
# classify data into inliers and outliers
307353
inlier_mask_subset = residuals_subset < residual_threshold

sklearn/linear_model/tests/test_ransac.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1+
from scipy import sparse
2+
13
import numpy as np
2-
from numpy.testing import assert_equal, assert_raises, assert_array_equal,assert_array_almost_equal
3-
from sklearn.utils.testing import assert_raises_regexp, assert_almost_equal, assert_less
44
from scipy import sparse
5+
6+
from numpy.testing import assert_equal, assert_raises
7+
from numpy.testing import assert_array_almost_equal
8+
from numpy.testing import assert_array_equal
9+
510
from sklearn.utils import check_random_state
6-
from sklearn.linear_model import LinearRegression, RANSACRegressor,Lasso
11+
from sklearn.utils.testing import assert_raises_regexp
12+
from sklearn.utils.testing import assert_less
13+
from sklearn.utils.testing import assert_warns
14+
from sklearn.utils.testing import assert_almost_equal
15+
from sklearn.linear_model import LinearRegression, RANSACRegressor, Lasso
716
from sklearn.linear_model.ransac import _dynamic_max_trials
817

918

@@ -265,6 +274,7 @@ def test_ransac_multi_dimensional_targets():
265274
assert_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
266275

267276

277+
# XXX: Remove in 0.20
268278
def test_ransac_residual_metric():
269279
residual_metric1 = lambda dy: np.sum(np.abs(dy), axis=1)
270280
residual_metric2 = lambda dy: np.sum(dy ** 2, axis=1)
@@ -281,6 +291,38 @@ def test_ransac_residual_metric():
281291
residual_threshold=5, random_state=0,
282292
residual_metric=residual_metric2)
283293

294+
# multi-dimensional
295+
ransac_estimator0.fit(X, yyy)
296+
assert_warns(DeprecationWarning, ransac_estimator1.fit, X, yyy)
297+
assert_warns(DeprecationWarning, ransac_estimator2.fit, X, yyy)
298+
assert_array_almost_equal(ransac_estimator0.predict(X),
299+
ransac_estimator1.predict(X))
300+
assert_array_almost_equal(ransac_estimator0.predict(X),
301+
ransac_estimator2.predict(X))
302+
303+
# one-dimensional
304+
ransac_estimator0.fit(X, y)
305+
assert_warns(DeprecationWarning, ransac_estimator2.fit, X, y)
306+
assert_array_almost_equal(ransac_estimator0.predict(X),
307+
ransac_estimator2.predict(X))
308+
309+
def test_ransac_residual_loss():
310+
loss_multi1 = lambda y_true, y_pred: np.sum(np.abs(y_true - y_pred), axis=1)
311+
loss_multi2 = lambda y_true, y_pred: np.sum((y_true - y_pred) ** 2, axis=1)
312+
313+
loss_mono = lambda y_true, y_pred : np.abs(y_true - y_pred)
314+
yyy = np.column_stack([y, y, y])
315+
316+
base_estimator = LinearRegression()
317+
ransac_estimator0 = RANSACRegressor(base_estimator, min_samples=2,
318+
residual_threshold=5, random_state=0)
319+
ransac_estimator1 = RANSACRegressor(base_estimator, min_samples=2,
320+
residual_threshold=5, random_state=0,
321+
loss=loss_multi1)
322+
ransac_estimator2 = RANSACRegressor(base_estimator, min_samples=2,
323+
residual_threshold=5, random_state=0,
324+
loss=loss_multi2)
325+
284326
# multi-dimensional
285327
ransac_estimator0.fit(X, yyy)
286328
ransac_estimator1.fit(X, yyy)
@@ -292,9 +334,16 @@ def test_ransac_residual_metric():
292334

293335
# one-dimensional
294336
ransac_estimator0.fit(X, y)
337+
ransac_estimator2.loss = loss_mono
295338
ransac_estimator2.fit(X, y)
296339
assert_array_almost_equal(ransac_estimator0.predict(X),
297340
ransac_estimator2.predict(X))
341+
ransac_estimator3 = RANSACRegressor(base_estimator, min_samples=2,
342+
residual_threshold=5, random_state=0,
343+
loss="squared_loss")
344+
ransac_estimator3.fit(X, y)
345+
assert_array_almost_equal(ransac_estimator0.predict(X),
346+
ransac_estimator2.predict(X))
298347

299348

300349
def test_ransac_default_residual_threshold():

0 commit comments

Comments
 (0)