Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions imblearn/over_sampling/smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class SMOTE(BaseBinarySampler):
The type of SMOTE algorithm to use one of the following options:
'regular', 'borderline1', 'borderline2', 'svm'.

svm_estimator : object, optional (default=SVC())
If `kind='svm'`, a parametrized `sklearn.svm.SVC` classifier can
be passed.

n_jobs : int, optional (default=1)
The number of threads to open if possible.

Expand Down Expand Up @@ -128,16 +132,16 @@ class SMOTE(BaseBinarySampler):

def __init__(self, ratio='auto', random_state=None, k=None, k_neighbors=5,
m=None, m_neighbors=10, out_step=0.5, kind='regular',
n_jobs=1, **kwargs):
svm_estimator=None, n_jobs=1):
super(SMOTE, self).__init__(ratio=ratio, random_state=random_state)
self.kind = kind
self.k = k
self.k_neighbors = k_neighbors
self.m = m
self.m_neighbors = m_neighbors
self.out_step = out_step
self.svm_estimator = svm_estimator
self.n_jobs = n_jobs
self.kwargs = kwargs

def _in_danger_noise(self, samples, y, kind='danger'):
"""Estimate if a set of sample are in danger or noise.
Expand Down Expand Up @@ -316,8 +320,13 @@ def _validate_estimator(self):
# in danger (near the boundary). The level of extrapolation is
# controled by the out_step.
if self.kind == 'svm':
# Store SVM object with any parameters
self.svm = SVC(random_state=self.random_state, **self.kwargs)
if self.svm_estimator is None:
# Store SVM object with any parameters
self.svm_estimator_ = SVC(random_state=self.random_state)
elif isinstance(self.svm_estimator, SVC):
self.svm_estimator_ = self.svm_estimator
else:
raise ValueError('`svm_estimator` has to be an SVC object')

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.
Expand Down Expand Up @@ -503,11 +512,11 @@ def _sample(self, X, y):
# belonging to each class.

# Fit SVM to the full data#
self.svm.fit(X, y)
self.svm_estimator_.fit(X, y)

# Find the support vectors and their corresponding indexes
support_index = self.svm.support_[y[self.svm.support_] ==
self.min_c_]
support_index = self.svm_estimator_.support_[
y[self.svm_estimator_.support_] == self.min_c_]
support_vector = X[support_index]

# First, find the nn of all the samples to identify samples
Expand Down
55 changes: 55 additions & 0 deletions imblearn/over_sampling/tests/test_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.datasets import make_classification
from sklearn.utils.estimator_checks import check_estimator
from sklearn.neighbors import NearestNeighbors
from sklearn.svm import SVC

from imblearn.over_sampling import SMOTE

Expand Down Expand Up @@ -446,3 +447,57 @@ def test_wrong_nn():
k_neighbors=nn_k)

assert_raises(ValueError, smote.fit_sample, X, Y)


def test_sample_regular_with_nn_svm():
"""Test sample function with regular SMOTE with a NN object."""

# Create the object
kind = 'svm'
nn_k = NearestNeighbors(n_neighbors=6)
svm = SVC(random_state=RND_SEED)
smote = SMOTE(random_state=RND_SEED, kind=kind,
k_neighbors=nn_k, svm_estimator=svm)

X_resampled, y_resampled = smote.fit_sample(X, Y)

X_gt = np.array([[0.11622591, -0.0317206],
[0.77481731, 0.60935141],
[1.25192108, -0.22367336],
[0.53366841, -0.30312976],
[1.52091956, -0.49283504],
[-0.28162401, -2.10400981],
[0.83680821, 1.72827342],
[0.3084254, 0.33299982],
[0.70472253, -0.73309052],
[0.28893132, -0.38761769],
[1.15514042, 0.0129463],
[0.88407872, 0.35454207],
[1.31301027, -0.92648734],
[-1.11515198, -0.93689695],
[-0.18410027, -0.45194484],
[0.9281014, 0.53085498],
[-0.14374509, 0.27370049],
[-0.41635887, -0.38299653],
[0.08711622, 0.93259929],
[1.70580611, -0.11219234],
[0.47436888, -0.2645749],
[1.07844561, -0.19435291],
[1.44015515, -1.30621303]])
y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1,
0, 0, 0, 0])
assert_array_almost_equal(X_resampled, X_gt)
assert_array_equal(y_resampled, y_gt)


def test_sample_regular_wrong_svm():
"""Test sample function with regular SMOTE with a NN object."""

# Create the object
kind = 'svm'
nn_k = NearestNeighbors(n_neighbors=6)
svm = 'rnd'
smote = SMOTE(random_state=RND_SEED, kind=kind,
k_neighbors=nn_k, svm_estimator=svm)

assert_raises(ValueError, smote.fit_sample, X, Y)
2 changes: 1 addition & 1 deletion imblearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Pipeline(pipeline.Pipeline):
>>> X_train, X_test, y_train, y_test = tts(X, y, random_state=42)
>>> pipeline.fit(X_train, y_train)
Pipeline(steps=[('smt', SMOTE(k=None, k_neighbors=5, kind='regular', m=None, m_neighbors=10, n_jobs=1,
out_step=0.5, random_state=42, ratio='auto')), ('pca', PCA(copy=True, n_components=None, whiten=False)), ('knn', KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
out_step=0.5, random_state=42, ratio='auto', svm_estimator=None)), ('pca', PCA(copy=True, n_components=None, whiten=False)), ('knn', KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=5, p=2,
weights='uniform'))])
>>> y_hat = pipeline.predict(X_test)
Expand Down
56 changes: 45 additions & 11 deletions imblearn/under_sampling/cluster_centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ class ClusterCentroids(BaseMulticlassSampler):
If None, the random number generator is the RandomState instance used
by np.random.

estimator : object, optional(default=KMeans())
Pass a `sklearn.cluster.KMeans` estimator.

n_jobs : int, optional (default=1)
The number of threads to open if possible.

**kwargs : keywords
Parameter to use for the KMeans object.

Attributes
----------
min_c_ : str or int
Expand Down Expand Up @@ -79,11 +79,47 @@ class ClusterCentroids(BaseMulticlassSampler):

"""

def __init__(self, ratio='auto', random_state=None, n_jobs=1, **kwargs):
def __init__(self, ratio='auto', random_state=None, estimator=None,
n_jobs=1):
super(ClusterCentroids, self).__init__(ratio=ratio,
random_state=random_state)
self.estimator = estimator
self.n_jobs = n_jobs
self.kwargs = kwargs

def _validate_estimator(self):
"""Private function to create the NN estimator"""

if self.estimator is None:
self.estimator_ = KMeans(random_state=self.random_state,
n_jobs=self.n_jobs)
elif isinstance(self.estimator, KMeans):
self.estimator_ = self.estimator
else:
raise ValueError('`estimator` has to be a KMeans clustering.')

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
self : object,
Return self.

"""

super(ClusterCentroids, self).fit(X, y)

self._validate_estimator()

return self

def _sample(self, X, y):
"""Resample the dataset.
Expand All @@ -105,17 +141,15 @@ def _sample(self, X, y):
The corresponding label of `X_resampled`

"""
random_state = check_random_state(self.random_state)

# Compute the number of cluster needed
if self.ratio == 'auto':
num_samples = self.stats_c_[self.min_c_]
else:
num_samples = int(self.stats_c_[self.min_c_] / self.ratio)

# Create the clustering object
kmeans = KMeans(n_clusters=num_samples, random_state=random_state)
kmeans.set_params(**self.kwargs)
# Set the number of sample for the estimator
self.estimator_.set_params(**{'n_clusters': num_samples})

# Start with the minority class
X_min = X[y == self.min_c_]
Expand All @@ -133,8 +167,8 @@ def _sample(self, X, y):
continue

# Find the centroids via k-means
kmeans.fit(X[y == key])
centroids = kmeans.cluster_centers_
self.estimator_.fit(X[y == key])
centroids = self.estimator_.cluster_centers_

# Concatenate to the minority class
X_resampled = np.concatenate((X_resampled, centroids), axis=0)
Expand Down
68 changes: 50 additions & 18 deletions imblearn/under_sampling/condensed_nearest_neighbour.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,18 @@ class CondensedNearestNeighbour(BaseMulticlassSampler):
NOTE: size_ngh is deprecated from 0.2 and will be replaced in 0.4
Use ``n_neighbors`` instead.

n_neighbors : int, optional (default=1)
Size of the neighbourhood to consider to compute the average
n_neighbors : int or object, optional (default=KNeighborsClassifier(n_neighbors=1))
If int, size of the neighbourhood to consider to compute the average
distance to the minority point samples.
If object, an object inherited from
`sklearn.neigbors.KNeighborsClassifier` should be passed.

n_seeds_S : int, optional (default=1)
Number of samples to extract in order to build the set S.

n_jobs : int, optional (default=1)
The number of threads to open if possible.

**kwargs : keywords
Parameter to use for the Neareast Neighbours object.


Attributes
----------
min_c_ : str or int
Expand Down Expand Up @@ -95,16 +93,55 @@ class CondensedNearestNeighbour(BaseMulticlassSampler):
"""

def __init__(self, return_indices=False, random_state=None,
size_ngh=None, n_neighbors=1, n_seeds_S=1, n_jobs=1,
**kwargs):
size_ngh=None, n_neighbors=None, n_seeds_S=1, n_jobs=1):
super(CondensedNearestNeighbour, self).__init__(
random_state=random_state)
self.return_indices = return_indices
self.size_ngh = size_ngh
self.n_neighbors = n_neighbors
self.n_seeds_S = n_seeds_S
self.n_jobs = n_jobs
self.kwargs = kwargs

def _validate_estimator(self):
"""Private function to create the NN estimator"""

if self.n_neighbors is None:
self.estimator_ = KNeighborsClassifier(
n_neighbors=1,
n_jobs=self.n_jobs)
elif isinstance(self.n_neighbors, int):
self.estimator_ = KNeighborsClassifier(
n_neighbors=self.n_neighbors,
n_jobs=self.n_jobs)
elif isinstance(self.n_neighbors, KNeighborsClassifier):
self.estimator_ = self.n_neighbors
else:
raise ValueError('`n_neighbors` has to be a in or an object'
' inhereited from KNeighborsClassifier.')

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
self : object,
Return self.

"""

super(CondensedNearestNeighbour, self).fit(X, y)

self._validate_estimator()

return self

def _sample(self, X, y):
"""Resample the dataset.
Expand Down Expand Up @@ -167,13 +204,8 @@ def _sample(self, X, y):
S_x = X[y == key]
S_y = y[y == key]

# Create a k-NN classifier
knn = KNeighborsClassifier(n_neighbors=self.n_neighbors,
n_jobs=self.n_jobs,
**self.kwargs)

# Fit C into the knn
knn.fit(C_x, C_y)
self.estimator_.fit(C_x, C_y)

good_classif_label = idx_maj_sample.copy()
# Check each sample in S if we keep it or drop it
Expand All @@ -184,7 +216,7 @@ def _sample(self, X, y):
continue

# Classify on S
pred_y = knn.predict(x_sam.reshape(1, -1))
pred_y = self.estimator_.predict(x_sam.reshape(1, -1))

# If the prediction do not agree with the true label
# append it in C_x
Expand All @@ -198,12 +230,12 @@ def _sample(self, X, y):
idx_maj_sample.size))

# Fit C into the knn
knn.fit(C_x, C_y)
self.estimator_.fit(C_x, C_y)

# This experimental to speed up the search
# Classify all the element in S and avoid to test the
# well classified elements
pred_S_y = knn.predict(S_x)
pred_S_y = self.estimator_.predict(S_x)
good_classif_label = np.unique(
np.append(idx_maj_sample,
np.flatnonzero(pred_S_y == S_y)))
Expand Down
Loading