- Notifications
You must be signed in to change notification settings - Fork 1.3k
[MRG] ROSE #754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[MRG] ROSE #754
Changes from all commits
Commits
Show all changes
68 commits Select commit Hold shift + click to select a range
53c7a8d Created empty test units
183c03f added ROSE empty class, modified __init__.py
0a3307b implemented ROSE, still some failed test
886694f PEP8 cleaning
07731c4 PEP8 linting
c0d7473 fixed linting errors.
013f7cc updated documentation and bibliography
2b34f47 cleaned ROSE test
8d0e99e added an exception for non binary datasets
8bbbd2e multiclass oversampling
9ac2797 removed non-binary exception
b41b06a removed unused import
d2dd6f4 minor fixes
b31a5c3 linting
d5ca24c linting
b6e95aa linting
6f7f8e1 linting
c391ec3 removed explicit pandas dataframe management
93ac868 added check_X_y() parsing
bdffda3 removed check_X_y test
cac7f0f local test 1: shrink factors
e59c4d3 test
c24f29e 1
4c83cfe 1
f3fb23b 1
9653709 1
93f7f8d 1
97269f5 1
36658bd 1
f2fd72b 1
4fc8476 1
7399233 1
5962627 1
4a202f7 1
ace2785 1
ecae868 1
335cd04 1
90c7082 1
f2c7dc0 1
216672c 1
c466b3d 1
2b154bc fixed sparse
a020e10 fixed all tests
d710f9e test added
354cb47 linting, submitted version
2d6b12a fixed test tolerance
94ecd4d pep8 fix
048090c tolerance adjustment
60777ad documentation
56949ea documentation
172496f documentation
a049e38 fixed bug in ROSE sampling strategy parsing
517c79b linting
17c9678 restored original docs
949a3be updated docstrings
72cfe44 removed whitespaces
8a30871 added more documentation
c8434c0 added math formulations to docs
5a9768c fixed revision issues
7e8031e added missing math directives
9f7823e dropped "see also" sections
9334fb2 fixed ROSE short description
b1e3b26 linting
c68d85c removed double newline in rose.py docstring
b5d05c8 trying to fix See Also missing
f45710c testing a fix for missing See Also section
9d33aa9 added just SMOTE to See Also section
73d0266 last minor typo fixes
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,202 @@ | ||
| """Class to perform over-sampling using ROSE.""" | ||
| | ||
| import numpy as np | ||
| from scipy import sparse | ||
| from sklearn.utils import check_random_state | ||
| from .base import BaseOverSampler | ||
| from ..utils._validation import _deprecate_positional_args | ||
| | ||
| | ||
| class ROSE(BaseOverSampler): | ||
| """Random Over-Sampling Examples (ROSE). | ||
| | ||
andrealorenzon marked this conversation as resolved. Show resolved Hide resolved | ||
| This object is the implementation of ROSE algorithm. | ||
| It generates new samples by a smoothed bootstrap approach, | ||
| taking a random subsample of original data and adding a | ||
| multivariate kernel density estimate :math:`f(x|Y_i)` around | ||
| them with a smoothing matrix :math:`H_j`, and finally sampling | ||
| from this distribution. A shrinking matrix can be provided, to | ||
| set the bandwidth of the gaussian kernel. | ||
| | ||
| Read more in the :ref:`User Guide <rose>`. | ||
| | ||
| Parameters | ||
| ---------- | ||
| sampling_strategy : float, str, dict or callable, default='auto' | ||
| Sampling information to resample the data set. | ||
| | ||
| - When ``float``, it corresponds to the desired ratio of the number of | ||
| samples in the minority class over the number of samples in the | ||
| majority class after resampling. Therefore, the ratio is expressed as | ||
| :math:`\\alpha_{os} = N_{rm} / N_{M}` where :math:`N_{rm}` is the | ||
| number of samples in the minority class after resampling and | ||
| :math:`N_{M}` is the number of samples in the majority class. | ||
| | ||
| .. warning:: | ||
| ``float`` is only available for **binary** classification. An | ||
| error is raised for multi-class classification. | ||
| | ||
| - When ``str``, specify the class targeted by the resampling. The | ||
| number of samples in the different classes will be equalized. | ||
| Possible choices are: | ||
| | ||
| ``'minority'``: resample only the minority class; | ||
| | ||
| ``'not minority'``: resample all classes but the minority class; | ||
| | ||
| ``'not majority'``: resample all classes but the majority class; | ||
| | ||
| ``'all'``: resample all classes; | ||
| | ||
| ``'auto'``: equivalent to ``'not majority'``. | ||
| | ||
| - When ``dict``, the keys correspond to the targeted classes. The | ||
| values correspond to the desired number of samples for each targeted | ||
| class. | ||
| | ||
| - When callable, function taking ``y`` and returns a ``dict``. The keys | ||
| correspond to the targeted classes. The values correspond to the | ||
| desired number of samples for each class. | ||
| | ||
| shrink_factors : dict, default= 1 for every class | ||
| Dict of {classes: shrinkfactors} items, applied to | ||
| the gaussian kernels. It can be used to compress/dilate the kernel. | ||
| | ||
| random_state : int, RandomState instance, default=None | ||
| Control the randomization of the algorithm. | ||
| | ||
| - If int, ``random_state`` is the seed used by the random number | ||
| generator; | ||
| - If ``RandomState`` instance, random_state is the random number | ||
| generator; | ||
| - If ``None``, the random number generator is the ``RandomState`` | ||
| instance used by ``np.random``. | ||
| | ||
| n_jobs : int, default=None | ||
| Number of CPU cores used during the cross-validation loop. | ||
| ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. | ||
| ``-1`` means using all processors. See | ||
| `Glossary <https://scikit-learn.org/stable/glossary.html#term-n-jobs>`_ | ||
| for more details. | ||
| | ||
| See Also | ||
| -------- | ||
| SMOTE : Over-sample using SMOTE. | ||
| | ||
| Notes | ||
| ----- | ||
| | ||
| References | ||
| ---------- | ||
| .. [1] N. Lunardon, G. Menardi, N.Torelli, "ROSE: A Package for Binary | ||
| Imbalanced Learning," R Journal, 6(1), 2014. | ||
| | ||
| .. [2] G Menardi, N. Torelli, "Training and assessing classification | ||
| rules with imbalanced data," Data Mining and Knowledge | ||
| Discovery, 28(1), pp.92-122, 2014. | ||
| | ||
| Examples | ||
| -------- | ||
| | ||
| >>> from imblearn.over_sampling import ROSE | ||
| >>> from sklearn.datasets import make_classification | ||
| >>> from collections import Counter | ||
| >>> r = ROSE(shrink_factors={0:1, 1:0.5, 2:0.7}) | ||
| >>> X, y = make_classification(n_classes=3, class_sep=2, | ||
| ... weights=[0.1, 0.7, 0.2], n_informative=3, n_redundant=1, flip_y=0, | ||
| ... n_features=20, n_clusters_per_class=1, n_samples=2000, random_state=10) | ||
| >>> print('Original dataset shape %s' % Counter(y)) | ||
| Original dataset shape Counter({1: 1400, 2: 400, 0: 200}) | ||
| >>> X_res, y_res = r.fit_resample(X, y) | ||
| >>> print('Resampled dataset shape %s' % Counter(y_res)) | ||
| Resampled dataset shape Counter({2: 1400, 1: 1400, 0: 1400}) | ||
| """ | ||
| | ||
| @_deprecate_positional_args | ||
| def __init__(self, *, sampling_strategy="auto", shrink_factors=None, | ||
| random_state=None, n_jobs=None): | ||
| super().__init__(sampling_strategy=sampling_strategy) | ||
| self.random_state = random_state | ||
| self.shrink_factors = shrink_factors | ||
| self.n_jobs = n_jobs | ||
| | ||
| def _make_samples(self, | ||
| X, | ||
| class_indices, | ||
| n_class_samples, | ||
| h_shrink): | ||
| """ A support function that returns artificial samples constructed | ||
| from a random subsample of the data, by adding a multiviariate | ||
| gaussian kernel and sampling from this distribution. An optional | ||
| shrink factor can be included, to compress/dilate the kernel. | ||
andrealorenzon marked this conversation as resolved. Show resolved Hide resolved | ||
| | ||
| Parameters | ||
| ---------- | ||
| X : {array-like, sparse matrix}, shape (n_samples, n_features) | ||
| Observations from which the samples will be created. | ||
| | ||
| class_indices : ndarray, shape (n_class_samples,) | ||
| The target class indices | ||
| | ||
| n_class_samples : int | ||
| The total number of samples per class to generate | ||
| | ||
| h_shrink : int | ||
| the shrink factor | ||
| | ||
| Returns | ||
| ------- | ||
| X_new : {ndarray, sparse matrix}, shape (n_samples, n_features) | ||
| Synthetically generated samples. | ||
| | ||
| y_new : ndarray, shape (n_samples,) | ||
| Target values for synthetic samples. | ||
| | ||
| """ | ||
| | ||
| number_of_features = X.shape[1] | ||
| random_state = check_random_state(self.random_state) | ||
| samples_indices = random_state.choice( | ||
| class_indices, size=n_class_samples, replace=True) | ||
| minimize_amise = (4 / ((number_of_features + 2) * len( | ||
| class_indices))) ** (1 / (number_of_features + 4)) | ||
| if sparse.issparse(X): | ||
| variances = np.diagflat( | ||
| np.std(X[class_indices, :].toarray(), axis=0, ddof=1)) | ||
| else: | ||
| variances = np.diagflat( | ||
| np.std(X[class_indices, :], axis=0, ddof=1)) | ||
| h_opt = h_shrink * minimize_amise * variances | ||
| randoms = random_state.standard_normal(size=(n_class_samples, | ||
| number_of_features)) | ||
| Xrose = np.matmul(randoms, h_opt) + X[samples_indices, :] | ||
| if sparse.issparse(X): | ||
| return sparse.csr_matrix(Xrose) | ||
| return Xrose | ||
| | ||
| def _fit_resample(self, X, y): | ||
| | ||
| X_resampled = X.copy() | ||
| y_resampled = y.copy() | ||
| | ||
| if self.shrink_factors is None: | ||
| self.shrink_factors = { | ||
| key: 1 for key in self.sampling_strategy_.keys()} | ||
| | ||
| for class_sample, n_samples in self.sampling_strategy_.items(): | ||
| class_indices = np.flatnonzero(y == class_sample) | ||
| n_class_samples = n_samples | ||
| X_new = self._make_samples(X, | ||
| class_indices, | ||
| n_samples, | ||
| self.shrink_factors[class_sample]) | ||
| y_new = np.array([class_sample] * n_class_samples) | ||
| | ||
| if sparse.issparse(X_new): | ||
| X_resampled = sparse.vstack([X_resampled, X_new]) | ||
| else: | ||
| X_resampled = np.concatenate((X_resampled, X_new)) | ||
| | ||
| y_resampled = np.hstack((y_resampled, y_new)) | ||
| | ||
| return X_resampled.astype(X.dtype), y_resampled.astype(y.dtype) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.