55# License: BSD 3 clause
66
77import numpy as np
8+ import warnings
89
910from ..base import BaseEstimator , MetaEstimatorMixin , RegressorMixin , clone
1011from ..utils import check_random_state , check_array , check_consistent_length
@@ -134,6 +135,22 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
134135
135136 lambda dy: np.sum(np.abs(dy), axis=1)
136137
138+ NOTE: residual_metric is deprecated from 0.18 and will be removed in 0.20
139+ Use ``loss`` instead.
140+
141+ loss: string, callable, optional, default "absolute_loss"
142+ String inputs, "absolute_loss" and "squared_loss" are supported which
143+ find the absolute loss and squared loss per sample
144+ respectively.
145+
146+ If ``loss`` is a callable, then it should be a function that takes
147+ two arrays as inputs, the true and predicted value and returns a 1-D
148+ array with the ``i``th value of the array corresponding to the loss
149+ on `X[i]`.
150+
151+ If the loss on a sample is greater than the ``residual_threshold``, then
152+ this sample is classified as an outlier.
153+
137154 random_state : integer or numpy.RandomState, optional
138155 The generator used to initialize the centers. If an integer is
139156 given, it fixes the seed. Defaults to the global numpy random
@@ -163,7 +180,7 @@ def __init__(self, base_estimator=None, min_samples=None,
163180 is_model_valid = None , max_trials = 100 ,
164181 stop_n_inliers = np .inf , stop_score = np .inf ,
165182 stop_probability = 0.99 , residual_metric = None ,
166- random_state = None ):
183+ loss = 'absolute_loss' , random_state = None ):
167184
168185 self .base_estimator = base_estimator
169186 self .min_samples = min_samples
@@ -176,6 +193,7 @@ def __init__(self, base_estimator=None, min_samples=None,
176193 self .stop_probability = stop_probability
177194 self .residual_metric = residual_metric
178195 self .random_state = random_state
196+ self .loss = loss
179197
180198 def fit (self , X , y , sample_weight = None ):
181199 """Fit estimator using RANSAC algorithm.
@@ -236,10 +254,33 @@ def fit(self, X, y, sample_weight=None):
236254 else :
237255 residual_threshold = self .residual_threshold
238256
239- if self .residual_metric is None :
240- residual_metric = lambda dy : np .sum (np .abs (dy ), axis = 1 )
257+ if self .residual_metric is not None :
258+ warnings .warn (
259+ "'residual_metric' will be removed in version 0.20. Use "
260+ "'loss' instead." , DeprecationWarning )
261+
262+ if self .loss == "absolute_loss" :
263+ if y .ndim == 1 :
264+ loss_function = lambda y_true , y_pred : np .abs (y_true - y_pred )
265+ else :
266+ loss_function = lambda \
267+ y_true , y_pred : np .sum (np .abs (y_true - y_pred ), axis = 1 )
268+
269+ elif self .loss == "squared_loss" :
270+ if y .ndim == 1 :
271+ loss_function = lambda y_true , y_pred : (y_true - y_pred ) ** 2
272+ else :
273+ loss_function = lambda \
274+ y_true , y_pred : np .sum ((y_true - y_pred ) ** 2 , axis = 1 )
275+
276+ elif callable (self .loss ):
277+ loss_function = self .loss
278+
241279 else :
242- residual_metric = self .residual_metric
280+ raise ValueError (
281+ "loss should be 'absolute_loss', 'squared_loss' or a callable."
282+ "Got %s. " % self .loss )
283+
243284
244285 random_state = check_random_state (self .random_state )
245286
@@ -298,10 +339,15 @@ def fit(self, X, y, sample_weight=None):
298339
299340 # residuals of all data for current random sample model
300341 y_pred = base_estimator .predict (X )
301- diff = y_pred - y
302- if diff .ndim == 1 :
303- diff = diff .reshape (- 1 , 1 )
304- residuals_subset = residual_metric (diff )
342+
343+ # XXX: Deprecation: Remove this if block in 0.20
344+ if self .residual_metric is not None :
345+ diff = y_pred - y
346+ if diff .ndim == 1 :
347+ diff = diff .reshape (- 1 , 1 )
348+ residuals_subset = self .residual_metric (diff )
349+ else :
350+ residuals_subset = loss_function (y , y_pred )
305351
306352 # classify data into inliers and outliers
307353 inlier_mask_subset = residuals_subset < residual_threshold
0 commit comments