Skip to content

Commit fa3ffe5

Browse files
committed
MNT add sparse input support, complete documentation and format code (#881)
1 parent 0e80574 commit fa3ffe5

File tree

1 file changed

+63
-24
lines changed

1 file changed

+63
-24
lines changed

imblearn/over_sampling/_smote/geometric.py

+63-24
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""Class to perform over-sampling using Geometric SMOTE."""
22

33
# Author: Georgios Douzas <[email protected]>
4+
# Joao Fonseca <[email protected]>
45
# License: BSD 3 clause
56

67
import numpy as np
78
from numpy.linalg import norm
9+
from scipy import sparse
810
from sklearn.utils import check_random_state
9-
from imblearn.over_sampling.base import BaseOverSampler
11+
from ..base import BaseOverSampler
1012
from imblearn.utils import check_neighbors_object, Substitution
1113
from imblearn.utils._docstring import _random_state_docstring
1214

13-
SELECTION_STRATEGY = ('combined', 'majority', 'minority')
15+
SELECTION_STRATEGY = ("combined", "majority", "minority")
1416

1517

1618
def _make_geometric_sample(
@@ -119,6 +121,33 @@ class GeometricSMOTE(BaseOverSampler):
119121
n_jobs : int, optional (default=1)
120122
The number of threads to open if possible.
121123
124+
Attributes
125+
----------
126+
127+
sampling_strategy_ : dict
128+
Dictionary containing the information to sample the dataset. The keys
129+
corresponds to the class labels from which to sample and the values
130+
are the number of samples to sample.
131+
132+
n_features_in_ : int
133+
Number of features in the input dataset.
134+
135+
nns_pos_ : estimator object
136+
Validated k-nearest neighbours created from the `k_neighbors` parameter. It is
137+
used to find the nearest neighbors of the same class of a selected
138+
observation.
139+
140+
nn_neg_ : estimator object
141+
Validated k-nearest neighbours created from the `k_neighbors` parameter. It is
142+
used to find the nearest neighbor of the remaining classes (k=1) of a selected
143+
observation.
144+
145+
random_state_ : instance of RandomState
146+
If the `random_state` parameter is None, it is a RandomState singleton used by
147+
np.random. If `random_state` is an int, it is a RandomState instance seeded with
148+
seed. If `random_state` is already a RandomState instance, it is the same
149+
object.
150+
122151
Notes
123152
-----
124153
See the original paper: [1]_ for more details.
@@ -142,7 +171,8 @@ class GeometricSMOTE(BaseOverSampler):
142171
143172
>>> from collections import Counter
144173
>>> from sklearn.datasets import make_classification
145-
>>> from gsmote import GeometricSMOTE # doctest: +NORMALIZE_WHITESPACE
174+
>>> from imblearn.over_sampling import \
175+
GeometricSMOTE # doctest: +NORMALIZE_WHITESPACE
146176
>>> X, y = make_classification(n_classes=2, class_sep=2,
147177
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
148178
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
@@ -157,11 +187,11 @@ class GeometricSMOTE(BaseOverSampler):
157187

158188
def __init__(
159189
self,
160-
sampling_strategy='auto',
190+
sampling_strategy="auto",
161191
random_state=None,
162192
truncation_factor=1.0,
163193
deformation_factor=0.0,
164-
selection_strategy='combined',
194+
selection_strategy="combined",
165195
k_neighbors=5,
166196
n_jobs=1,
167197
):
@@ -182,23 +212,23 @@ def _validate_estimator(self):
182212
# Validate strategy
183213
if self.selection_strategy not in SELECTION_STRATEGY:
184214
error_msg = (
185-
'Unknown selection_strategy for Geometric SMOTE algorithm. '
186-
'Choices are {}. Got {} instead.'
215+
"Unknown selection_strategy for Geometric SMOTE algorithm. "
216+
"Choices are {}. Got {} instead."
187217
)
188218
raise ValueError(
189219
error_msg.format(SELECTION_STRATEGY, self.selection_strategy)
190220
)
191221

192222
# Create nearest neighbors object for positive class
193-
if self.selection_strategy in ('minority', 'combined'):
223+
if self.selection_strategy in ("minority", "combined"):
194224
self.nns_pos_ = check_neighbors_object(
195-
'nns_positive', self.k_neighbors, additional_neighbor=1
225+
"nns_positive", self.k_neighbors, additional_neighbor=1
196226
)
197227
self.nns_pos_.set_params(n_jobs=self.n_jobs)
198228

199229
# Create nearest neighbors object for negative class
200-
if self.selection_strategy in ('majority', 'combined'):
201-
self.nn_neg_ = check_neighbors_object('nn_negative', nn_object=1)
230+
if self.selection_strategy in ("majority", "combined"):
231+
self.nn_neg_ = check_neighbors_object("nn_negative", nn_object=1)
202232
self.nn_neg_.set_params(n_jobs=self.n_jobs)
203233

204234
def _make_geometric_samples(self, X, y, pos_class_label, n_samples):
@@ -237,11 +267,11 @@ def _make_geometric_samples(self, X, y, pos_class_label, n_samples):
237267

238268
# Force minority strategy if no negative class samples are present
239269
self.selection_strategy_ = (
240-
'minority' if len(X) == len(X_pos) else self.selection_strategy
270+
"minority" if X.shape[0] == X_pos.shape[0] else self.selection_strategy
241271
)
242272

243273
# Minority or combined strategy
244-
if self.selection_strategy_ in ('minority', 'combined'):
274+
if self.selection_strategy_ in ("minority", "combined"):
245275
self.nns_pos_.fit(X_pos)
246276
points_pos = self.nns_pos_.kneighbors(X_pos)[1][:, 1:]
247277
samples_indices = self.random_state_.randint(
@@ -251,11 +281,11 @@ def _make_geometric_samples(self, X, y, pos_class_label, n_samples):
251281
cols = np.mod(samples_indices, points_pos.shape[1])
252282

253283
# Majority or combined strategy
254-
if self.selection_strategy_ in ('majority', 'combined'):
284+
if self.selection_strategy_ in ("majority", "combined"):
255285
X_neg = X[y != pos_class_label]
256286
self.nn_neg_.fit(X_neg)
257287
points_neg = self.nn_neg_.kneighbors(X_pos)[1]
258-
if self.selection_strategy_ == 'majority':
288+
if self.selection_strategy_ == "majority":
259289
samples_indices = self.random_state_.randint(
260290
low=0, high=len(points_neg.flatten()), size=n_samples
261291
)
@@ -270,11 +300,11 @@ def _make_geometric_samples(self, X, y, pos_class_label, n_samples):
270300
center = X_pos[row]
271301

272302
# Minority strategy
273-
if self.selection_strategy_ == 'minority':
303+
if self.selection_strategy_ == "minority":
274304
surface_point = X_pos[points_pos[row, col]]
275305

276306
# Majority strategy
277-
elif self.selection_strategy_ == 'majority':
307+
elif self.selection_strategy_ == "majority":
278308
surface_point = X_neg[points_neg[row, col]]
279309

280310
# Combined strategy
@@ -306,19 +336,28 @@ def _fit_resample(self, X, y):
306336
# Validate estimator's parameters
307337
self._validate_estimator()
308338

339+
# Ensure the input data is dense
340+
X_dense = X.toarray() if sparse.issparse(X) else X
341+
309342
# Copy data
310-
X_resampled, y_resampled = X.copy(), y.copy()
343+
X_resampled, y_resampled = [X_dense.copy()], [y.copy()]
311344

312345
# Resample data
313346
for class_label, n_samples in self.sampling_strategy_.items():
314347

315348
# Apply gsmote mechanism
316-
X_new, y_new = self._make_geometric_samples(X, y, class_label, n_samples)
317-
318-
# Append new data
319-
X_resampled, y_resampled = (
320-
np.vstack((X_resampled, X_new)),
321-
np.hstack((y_resampled, y_new)),
349+
X_new, y_new = self._make_geometric_samples(
350+
X_dense, y, class_label, n_samples
322351
)
323352

353+
X_resampled.append(X_new)
354+
y_resampled.append(y_new)
355+
356+
# Append new data
357+
if sparse.issparse(X):
358+
X_resampled = sparse.vstack(X_resampled, format=X.format)
359+
else:
360+
X_resampled = np.vstack(X_resampled).astype(X.dtype)
361+
y_resampled = np.hstack(y_resampled).astype(y.dtype)
362+
324363
return X_resampled, y_resampled

0 commit comments

Comments
 (0)