Skip to content

Commit 2db667f

Browse files
committed
MNT add sparse input support and complete documentation (scikit-learn-contrib#881)
1 parent 0e80574 commit 2db667f

File tree

1 file changed

+49
-10
lines changed

1 file changed

+49
-10
lines changed

imblearn/over_sampling/_smote/geometric.py

+49-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Class to perform over-sampling using Geometric SMOTE."""
22

33
# Author: Georgios Douzas <[email protected]>
4+
# Joao Fonseca <[email protected]>
45
# License: BSD 3 clause
56

67
import numpy as np
78
from numpy.linalg import norm
9+
from scipy import sparse
810
from sklearn.utils import check_random_state
9-
from imblearn.over_sampling.base import BaseOverSampler
11+
from ..base import BaseOverSampler
1012
from imblearn.utils import check_neighbors_object, Substitution
1113
from imblearn.utils._docstring import _random_state_docstring
1214

@@ -119,6 +121,33 @@ class GeometricSMOTE(BaseOverSampler):
119121
n_jobs : int, optional (default=1)
120122
The number of threads to open if possible.
121123
124+
Attributes
125+
----------
126+
127+
sampling_strategy_ : dict
128+
Dictionary containing the information to sample the dataset. The keys
129+
corresponds to the class labels from which to sample and the values
130+
are the number of samples to sample.
131+
132+
n_features_in_ : int
133+
Number of features in the input dataset.
134+
135+
nns_pos_ : estimator object
136+
Validated k-nearest neighbours created from the `k_neighbors` parameter. It is
137+
used to find the nearest neighbors of the same class of a selected
138+
observation.
139+
140+
nn_neg_ : estimator object
141+
Validated k-nearest neighbours created from the `k_neighbors` parameter. It is
142+
used to find the nearest neighbor of the remaining classes (k=1) of a selected
143+
observation.
144+
145+
random_state_ : instance of RandomState
146+
If the `random_state` parameter is None, it is a RandomState singleton used by
147+
np.random. If `random_state` is an int, it is a RandomState instance seeded with
148+
seed. If `random_state` is already a RandomState instance, it is the same
149+
object.
150+
122151
Notes
123152
-----
124153
See the original paper: [1]_ for more details.
@@ -142,7 +171,8 @@ class GeometricSMOTE(BaseOverSampler):
142171
143172
>>> from collections import Counter
144173
>>> from sklearn.datasets import make_classification
145-
>>> from gsmote import GeometricSMOTE # doctest: +NORMALIZE_WHITESPACE
174+
>>> from imblearn.over_sampling import \
175+
GeometricSMOTE # doctest: +NORMALIZE_WHITESPACE
146176
>>> X, y = make_classification(n_classes=2, class_sep=2,
147177
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
148178
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
@@ -237,7 +267,7 @@ def _make_geometric_samples(self, X, y, pos_class_label, n_samples):
237267

238268
# Force minority strategy if no negative class samples are present
239269
self.selection_strategy_ = (
240-
'minority' if len(X) == len(X_pos) else self.selection_strategy
270+
'minority' if X.shape[0] == X_pos.shape[0] else self.selection_strategy
241271
)
242272

243273
# Minority or combined strategy
@@ -306,19 +336,28 @@ def _fit_resample(self, X, y):
306336
# Validate estimator's parameters
307337
self._validate_estimator()
308338

339+
# Ensure the input data is dense
340+
X_dense = X.toarray() if sparse.issparse(X) else X
341+
309342
# Copy data
310-
X_resampled, y_resampled = X.copy(), y.copy()
343+
X_resampled, y_resampled = [X_dense.copy()], [y.copy()]
311344

312345
# Resample data
313346
for class_label, n_samples in self.sampling_strategy_.items():
314347

315348
# Apply gsmote mechanism
316-
X_new, y_new = self._make_geometric_samples(X, y, class_label, n_samples)
317-
318-
# Append new data
319-
X_resampled, y_resampled = (
320-
np.vstack((X_resampled, X_new)),
321-
np.hstack((y_resampled, y_new)),
349+
X_new, y_new = self._make_geometric_samples(
350+
X_dense, y, class_label, n_samples
322351
)
323352

353+
X_resampled.append(X_new)
354+
y_resampled.append(y_new)
355+
356+
# Append new data
357+
if sparse.issparse(X):
358+
X_resampled = sparse.vstack(X_resampled, format=X.format)
359+
else:
360+
X_resampled = np.vstack(X_resampled).astype(X.dtype)
361+
y_resampled = np.hstack(y_resampled).astype(y.dtype)
362+
324363
return X_resampled, y_resampled

0 commit comments

Comments
 (0)