@@ -36,20 +36,18 @@ class CondensedNearestNeighbour(BaseMulticlassSampler):
3636 NOTE: size_ngh is deprecated from 0.2 and will be replaced in 0.4
3737 Use ``n_neighbors`` instead.
3838
39- n_neighbors : int, optional (default=1 )
40- Size of the neighbourhood to consider to compute the average
39+ n_neighbors : int or object , optional (default=KNeighborsClassifier(n_neighbors=1) )
40+ If int, size of the neighbourhood to consider to compute the average
4141 distance to the minority point samples.
42+ If object, an object inherited from
43+ `sklearn.neigbors.KNeighborsClassifier` should be passed.
4244
4345 n_seeds_S : int, optional (default=1)
4446 Number of samples to extract in order to build the set S.
4547
4648 n_jobs : int, optional (default=1)
4749 The number of threads to open if possible.
4850
49- **kwargs : keywords
50- Parameter to use for the Neareast Neighbours object.
51-
52-
5351 Attributes
5452 ----------
5553 min_c_ : str or int
@@ -95,16 +93,55 @@ class CondensedNearestNeighbour(BaseMulticlassSampler):
9593 """
9694
9795 def __init__ (self , return_indices = False , random_state = None ,
98- size_ngh = None , n_neighbors = 1 , n_seeds_S = 1 , n_jobs = 1 ,
99- ** kwargs ):
96+ size_ngh = None , n_neighbors = None , n_seeds_S = 1 , n_jobs = 1 ):
10097 super (CondensedNearestNeighbour , self ).__init__ (
10198 random_state = random_state )
10299 self .return_indices = return_indices
103100 self .size_ngh = size_ngh
104101 self .n_neighbors = n_neighbors
105102 self .n_seeds_S = n_seeds_S
106103 self .n_jobs = n_jobs
107- self .kwargs = kwargs
104+
105+ def _validate_estimator (self ):
106+ """Private function to create the NN estimator"""
107+
108+ if self .n_neighbors is None :
109+ self .estimator_ = KNeighborsClassifier (
110+ n_neighbors = 1 ,
111+ n_jobs = self .n_jobs )
112+ elif isinstance (self .n_neighbors , int ):
113+ self .estimator_ = KNeighborsClassifier (
114+ n_neighbors = self .n_neighbors ,
115+ n_jobs = self .n_jobs )
116+ elif isinstance (self .n_neighbors , KNeighborsClassifier ):
117+ self .estimator_ = self .n_neighbors
118+ else :
119+ raise ValueError ('`n_neighbors` has to be a in or an object'
120+ ' inhereited from KNeighborsClassifier.' )
121+
122+ def fit (self , X , y ):
123+ """Find the classes statistics before to perform sampling.
124+
125+ Parameters
126+ ----------
127+ X : ndarray, shape (n_samples, n_features)
128+ Matrix containing the data which have to be sampled.
129+
130+ y : ndarray, shape (n_samples, )
131+ Corresponding label for each sample in X.
132+
133+ Returns
134+ -------
135+ self : object,
136+ Return self.
137+
138+ """
139+
140+ super (CondensedNearestNeighbour , self ).fit (X , y )
141+
142+ self ._validate_estimator ()
143+
144+ return self
108145
109146 def _sample (self , X , y ):
110147 """Resample the dataset.
@@ -167,13 +204,8 @@ def _sample(self, X, y):
167204 S_x = X [y == key ]
168205 S_y = y [y == key ]
169206
170- # Create a k-NN classifier
171- knn = KNeighborsClassifier (n_neighbors = self .n_neighbors ,
172- n_jobs = self .n_jobs ,
173- ** self .kwargs )
174-
175207 # Fit C into the knn
176- knn .fit (C_x , C_y )
208+ self . estimator_ .fit (C_x , C_y )
177209
178210 good_classif_label = idx_maj_sample .copy ()
179211 # Check each sample in S if we keep it or drop it
@@ -184,7 +216,7 @@ def _sample(self, X, y):
184216 continue
185217
186218 # Classify on S
187- pred_y = knn .predict (x_sam .reshape (1 , - 1 ))
219+ pred_y = self . estimator_ .predict (x_sam .reshape (1 , - 1 ))
188220
189221 # If the prediction do not agree with the true label
190222 # append it in C_x
@@ -198,12 +230,12 @@ def _sample(self, X, y):
198230 idx_maj_sample .size ))
199231
200232 # Fit C into the knn
201- knn .fit (C_x , C_y )
233+ self . estimator_ .fit (C_x , C_y )
202234
203235 # This experimental to speed up the search
204236 # Classify all the element in S and avoid to test the
205237 # well classified elements
206- pred_S_y = knn .predict (S_x )
238+ pred_S_y = self . estimator_ .predict (S_x )
207239 good_classif_label = np .unique (
208240 np .append (idx_maj_sample ,
209241 np .flatnonzero (pred_S_y == S_y )))
0 commit comments