1212
1313from sklearn .base import ClassifierMixin , clone
1414from sklearn .ensemble import RandomForestClassifier
15+ from sklearn .ensemble ._base import _set_random_states
1516from sklearn .model_selection import StratifiedKFold
17+ from sklearn .model_selection import cross_val_predict
18+ from sklearn .utils import check_random_state
1619from sklearn .utils import _safe_indexing
1720
1821from ..base import BaseUnderSampler
@@ -108,7 +111,7 @@ def __init__(
108111 self .cv = cv
109112 self .n_jobs = n_jobs
110113
111- def _validate_estimator (self ):
114+ def _validate_estimator (self , random_state ):
112115 """Private function to create the classifier"""
113116
114117 if (
@@ -117,6 +120,8 @@ def _validate_estimator(self):
117120 and hasattr (self .estimator , "predict_proba" )
118121 ):
119122 self .estimator_ = clone (self .estimator )
123+ _set_random_states (self .estimator_ , random_state )
124+
120125 elif self .estimator is None :
121126 self .estimator_ = RandomForestClassifier (
122127 n_estimators = 100 ,
@@ -131,22 +136,18 @@ def _validate_estimator(self):
131136 )
132137
133138 def _fit_resample (self , X , y ):
134- self ._validate_estimator ()
139+ random_state = check_random_state (self .random_state )
140+ self ._validate_estimator (random_state )
135141
136142 target_stats = Counter (y )
137- skf = StratifiedKFold (n_splits = self .cv , shuffle = False ).split (X , y )
138- probabilities = np .zeros (y .shape [0 ], dtype = float )
139-
140- for train_index , test_index in skf :
141- X_train = _safe_indexing (X , train_index )
142- X_test = _safe_indexing (X , test_index )
143- y_train = _safe_indexing (y , train_index )
144- y_test = _safe_indexing (y , test_index )
145-
146- self .estimator_ .fit (X_train , y_train )
147-
148- probs = self .estimator_ .predict_proba (X_test )
149- probabilities [test_index ] = probs [range (len (y_test )), y_test ]
143+ skf = StratifiedKFold (
144+ n_splits = self .cv , shuffle = True , random_state = random_state ,
145+ )
146+ probabilities = cross_val_predict (
147+ self .estimator_ , X , y , cv = skf , n_jobs = self .n_jobs ,
148+ method = 'predict_proba'
149+ )
150+ probabilities = probabilities [range (len (y )), y ]
150151
151152 idx_under = np .empty ((0 ,), dtype = int )
152153
0 commit comments