Skip to content
Prev Previous commit
Next Next commit
fixed linting errors.
  • Loading branch information
Andrea Lorenzon committed Aug 26, 2020
commit c0d74739e47010e4445d1547b697a5ec3ad8a2eb
16 changes: 10 additions & 6 deletions imblearn/over_sampling/_rose.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@ class ROSE(BaseOverSampler):

"""Oversample using Random OverSampling Examples (ROSE) algorithm.

The algorithm generates new samples by a smoothed bootstrap approach.
The generation of new examples corresponds to the generation of data from
the kernel density estimate of f(x|Y_i), with a smoothing matrix H_j.
A shrinking matrix can be provided, to set the bandwidth of the gaussian
kernel.

Read more in the :ref:`User Guide <rose>`.

Parameters
----------
{sampling_strategy}
Expand Down Expand Up @@ -74,9 +81,6 @@ def _make_samples(self,
Target values for synthetic samples.

"""

# pdb.set_trace()

p = X.shape[1]

random_state = check_random_state(self.random_state)
Expand All @@ -94,12 +98,13 @@ def _make_samples(self,

def _fit_resample(self, X, y):

#random_state = check_random_state(self.random_state)
X_resampled = np.empty((0, X.shape[1]), dtype=X.dtype)
y_resampled = np.empty((0), dtype=X.dtype)

if self.shrink_factors is None:
self.shrink_factors = {key: 1 for key in self.sampling_strategy_.keys()}
self.shrink_factors = {
key: 1 for key in self.sampling_strategy_.keys()
}

for class_sample, n_samples in self.sampling_strategy_.items():
class_indices = np.flatnonzero(y == class_sample)
Expand All @@ -118,4 +123,3 @@ def _fit_resample(self, X, y):
y_resampled = np.hstack((y_resampled, y_new))

return X_resampled.astype(X.dtype), y_resampled.astype(y.dtype)

75 changes: 29 additions & 46 deletions imblearn/over_sampling/tests/test_rose.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,36 @@
from imblearn.over_sampling import ROSE


def test_testunit():
"""Check test unit"""
return True


def test_random_state():
"""Check randomState()"""
assert np.random.RandomState(42)


def test_instance():
"""Check ROSE instantiation"""
rose = ROSE()
assert id(rose)


RND_SEED = 0
X = np.array(
[
[0.11622591, -0.0317206],
[0.77481731, 0.60935141],
[1.25192108, -0.22367336],
[0.53366841, -0.30312976],
[1.52091956, -0.49283504],
[-0.28162401, -2.10400981],
[0.83680821, 1.72827342],
[0.3084254, 0.33299982],
[0.70472253, -0.73309052],
[0.28893132, -0.38761769],
[1.15514042, 0.0129463],
[0.88407872, 0.35454207],
[1.31301027, -0.92648734],
[-1.11515198, -0.93689695],
[-0.18410027, -0.45194484],
[0.9281014, 0.53085498],
[-0.14374509, 0.27370049],
[-0.41635887, -0.38299653],
[0.08711622, 0.93259929],
[1.70580611, -0.11219234],
]
)
Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
R_TOL = 1e-4


def test_rose():

"""Check ROSE use"""
X_res, y_res = ROSE().fit_resample(X, Y)

RND_SEED = 0
X = np.array(
[
[0.11622591, -0.0317206],
[0.77481731, 0.60935141],
[1.25192108, -0.22367336],
[0.53366841, -0.30312976],
[1.52091956, -0.49283504],
[-0.28162401, -2.10400981],
[0.83680821, 1.72827342],
[0.3084254, 0.33299982],
[0.70472253, -0.73309052],
[0.28893132, -0.38761769],
[1.15514042, 0.0129463],
[0.88407872, 0.35454207],
[1.31301027, -0.92648734],
[-1.11515198, -0.93689695],
[-0.18410027, -0.45194484],
[0.9281014, 0.53085498],
[-0.14374509, 0.27370049],
[-0.41635887, -0.38299653],
[0.08711622, 0.93259929],
[1.70580611, -0.11219234],
]
)
Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
X_res, y_res = ROSE(random_state=RND_SEED).fit_resample(X, Y)
assert np.unique(Y.all()) == np.unique(y_res.all())
assert X_res.shape[1] == X.shape[1]