Skip to content

Commit 47ee77a

Browse files
committed
ENH add GeometricSMOTE implementation (#881)
1 parent f1abf75 commit 47ee77a

File tree

3 files changed

+329
-0
lines changed

3 files changed

+329
-0
lines changed

imblearn/over_sampling/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ._smote import SMOTE
99
from ._smote import BorderlineSMOTE
1010
from ._smote import KMeansSMOTE
11+
from ._smote import GeometricSMOTE
1112
from ._smote import SVMSMOTE
1213
from ._smote import SMOTENC
1314
from ._smote import SMOTEN
@@ -16,6 +17,7 @@
1617
"ADASYN",
1718
"RandomOverSampler",
1819
"KMeansSMOTE",
20+
"GeometricSMOTE",
1921
"SMOTE",
2022
"BorderlineSMOTE",
2123
"SVMSMOTE",

imblearn/over_sampling/_smote/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from .cluster import KMeansSMOTE
66

7+
from .geometric import GeometricSMOTE
8+
79
from .filter import BorderlineSMOTE
810
from .filter import SVMSMOTE
911

@@ -12,6 +14,7 @@
1214
"SMOTEN",
1315
"SMOTENC",
1416
"KMeansSMOTE",
17+
"GeometricSMOTE",
1518
"BorderlineSMOTE",
1619
"SVMSMOTE",
1720
]
+324
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
"""Class to perform over-sampling using Geometric SMOTE."""
2+
3+
# Author: Georgios Douzas <[email protected]>
4+
# License: BSD 3 clause
5+
6+
import numpy as np
7+
from numpy.linalg import norm
8+
from sklearn.utils import check_random_state
9+
from imblearn.over_sampling.base import BaseOverSampler
10+
from imblearn.utils import check_neighbors_object, Substitution
11+
from imblearn.utils._docstring import _random_state_docstring
12+
13+
SELECTION_STRATEGY = ('combined', 'majority', 'minority')
14+
15+
16+
def _make_geometric_sample(
17+
center, surface_point, truncation_factor, deformation_factor, random_state
18+
):
19+
"""A support function that returns an artificial point inside
20+
the geometric region defined by the center and surface points.
21+
22+
Parameters
23+
----------
24+
center : ndarray, shape (n_features, )
25+
Center point of the geometric region.
26+
27+
surface_point : ndarray, shape (n_features, )
28+
Surface point of the geometric region.
29+
30+
truncation_factor : float, optional (default=0.0)
31+
The type of truncation. The values should be in the [-1.0, 1.0] range.
32+
33+
deformation_factor : float, optional (default=0.0)
34+
The type of geometry. The values should be in the [0.0, 1.0] range.
35+
36+
random_state : int, RandomState instance or None
37+
Control the randomization of the algorithm.
38+
39+
Returns
40+
-------
41+
point : ndarray, shape (n_features, )
42+
Synthetically generated sample.
43+
44+
"""
45+
46+
# Zero radius case
47+
if np.array_equal(center, surface_point):
48+
return center
49+
50+
# Generate a point on the surface of a unit hyper-sphere
51+
radius = norm(center - surface_point)
52+
normal_samples = random_state.normal(size=center.size)
53+
point_on_unit_sphere = normal_samples / norm(normal_samples)
54+
point = (random_state.uniform(size=1) ** (1 / center.size)) * point_on_unit_sphere
55+
56+
# Parallel unit vector
57+
parallel_unit_vector = (surface_point - center) / norm(surface_point - center)
58+
59+
# Truncation
60+
close_to_opposite_boundary = (
61+
truncation_factor > 0
62+
and np.dot(point, parallel_unit_vector) < truncation_factor - 1
63+
)
64+
close_to_boundary = (
65+
truncation_factor < 0
66+
and np.dot(point, parallel_unit_vector) > truncation_factor + 1
67+
)
68+
if close_to_opposite_boundary or close_to_boundary:
69+
point -= 2 * np.dot(point, parallel_unit_vector) * parallel_unit_vector
70+
71+
# Deformation
72+
parallel_point_position = np.dot(point, parallel_unit_vector) * parallel_unit_vector
73+
perpendicular_point_position = point - parallel_point_position
74+
point = (
75+
parallel_point_position
76+
+ (1 - deformation_factor) * perpendicular_point_position
77+
)
78+
79+
# Translation
80+
point = center + radius * point
81+
82+
return point
83+
84+
85+
@Substitution(
86+
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
87+
random_state=_random_state_docstring,
88+
)
89+
class GeometricSMOTE(BaseOverSampler):
90+
"""Class to to perform over-sampling using Geometric SMOTE.
91+
92+
This algorithm is an implementation of Geometric SMOTE, a geometrically
93+
enhanced drop-in replacement for SMOTE as presented in [1]_.
94+
95+
Read more in the :ref:`User Guide <user_guide>`.
96+
97+
Parameters
98+
----------
99+
{sampling_strategy}
100+
101+
{random_state}
102+
103+
truncation_factor : float, optional (default=0.0)
104+
The type of truncation. The values should be in the [-1.0, 1.0] range.
105+
106+
deformation_factor : float, optional (default=0.0)
107+
The type of geometry. The values should be in the [0.0, 1.0] range.
108+
109+
selection_strategy : str, optional (default='combined')
110+
The type of Geometric SMOTE algorithm with the following options:
111+
``'combined'``, ``'majority'``, ``'minority'``.
112+
113+
k_neighbors : int or object, optional (default=5)
114+
If ``int``, number of nearest neighbours to use when synthetic
115+
samples are constructed for the minority method. If object, an estimator
116+
that inherits from :class:`sklearn.neighbors.base.KNeighborsMixin` that
117+
will be used to find the k_neighbors.
118+
119+
n_jobs : int, optional (default=1)
120+
The number of threads to open if possible.
121+
122+
Notes
123+
-----
124+
See the original paper: [1]_ for more details.
125+
126+
Supports multi-class resampling. A one-vs.-rest scheme is used as
127+
originally proposed in [2]_.
128+
129+
References
130+
----------
131+
132+
.. [1] G. Douzas, F. Bacao, "Geometric SMOTE:
133+
a geometrically enhanced drop-in replacement for SMOTE",
134+
Information Sciences, vol. 501, pp. 118-135, 2019.
135+
136+
.. [2] N. V. Chawla, K. W. Bowyer, L. O. Hall, W. P. Kegelmeyer, "SMOTE:
137+
synthetic minority over-sampling technique", Journal of Artificial
138+
Intelligence Research, vol. 16, pp. 321-357, 2002.
139+
140+
Examples
141+
--------
142+
143+
>>> from collections import Counter
144+
>>> from sklearn.datasets import make_classification
145+
>>> from gsmote import GeometricSMOTE # doctest: +NORMALIZE_WHITESPACE
146+
>>> X, y = make_classification(n_classes=2, class_sep=2,
147+
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
148+
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
149+
>>> print('Original dataset shape %s' % Counter(y))
150+
Original dataset shape Counter({{1: 900, 0: 100}})
151+
>>> gsmote = GeometricSMOTE(random_state=1)
152+
>>> X_res, y_res = gsmote.fit_resample(X, y)
153+
>>> print('Resampled dataset shape %s' % Counter(y_res))
154+
Resampled dataset shape Counter({{0: 900, 1: 900}})
155+
156+
"""
157+
158+
def __init__(
159+
self,
160+
sampling_strategy='auto',
161+
random_state=None,
162+
truncation_factor=1.0,
163+
deformation_factor=0.0,
164+
selection_strategy='combined',
165+
k_neighbors=5,
166+
n_jobs=1,
167+
):
168+
super(GeometricSMOTE, self).__init__(sampling_strategy=sampling_strategy)
169+
self.random_state = random_state
170+
self.truncation_factor = truncation_factor
171+
self.deformation_factor = deformation_factor
172+
self.selection_strategy = selection_strategy
173+
self.k_neighbors = k_neighbors
174+
self.n_jobs = n_jobs
175+
176+
def _validate_estimator(self):
177+
"""Create the necessary attributes for Geometric SMOTE."""
178+
179+
# Check random state
180+
self.random_state_ = check_random_state(self.random_state)
181+
182+
# Validate strategy
183+
if self.selection_strategy not in SELECTION_STRATEGY:
184+
error_msg = (
185+
'Unknown selection_strategy for Geometric SMOTE algorithm. '
186+
'Choices are {}. Got {} instead.'
187+
)
188+
raise ValueError(
189+
error_msg.format(SELECTION_STRATEGY, self.selection_strategy)
190+
)
191+
192+
# Create nearest neighbors object for positive class
193+
if self.selection_strategy in ('minority', 'combined'):
194+
self.nns_pos_ = check_neighbors_object(
195+
'nns_positive', self.k_neighbors, additional_neighbor=1
196+
)
197+
self.nns_pos_.set_params(n_jobs=self.n_jobs)
198+
199+
# Create nearest neighbors object for negative class
200+
if self.selection_strategy in ('majority', 'combined'):
201+
self.nn_neg_ = check_neighbors_object('nn_negative', nn_object=1)
202+
self.nn_neg_.set_params(n_jobs=self.n_jobs)
203+
204+
def _make_geometric_samples(self, X, y, pos_class_label, n_samples):
205+
"""A support function that returns an artificials samples inside
206+
the geometric region defined by nearest neighbors.
207+
208+
Parameters
209+
----------
210+
X : array-like, shape (n_samples, n_features)
211+
Matrix containing the data which have to be sampled.
212+
y : array-like, shape (n_samples, )
213+
Corresponding label for each sample in X.
214+
pos_class_label : str or int
215+
The minority class (positive class) target value.
216+
n_samples : int
217+
The number of samples to generate.
218+
219+
Returns
220+
-------
221+
X_new : ndarray, shape (n_samples_new, n_features)
222+
Synthetically generated samples.
223+
y_new : ndarray, shape (n_samples_new, )
224+
Target values for synthetic samples.
225+
226+
"""
227+
228+
# Return zero new samples
229+
if n_samples == 0:
230+
return (
231+
np.array([], dtype=X.dtype).reshape(0, X.shape[1]),
232+
np.array([], dtype=y.dtype),
233+
)
234+
235+
# Select positive class samples
236+
X_pos = X[y == pos_class_label]
237+
238+
# Force minority strategy if no negative class samples are present
239+
self.selection_strategy_ = (
240+
'minority' if len(X) == len(X_pos) else self.selection_strategy
241+
)
242+
243+
# Minority or combined strategy
244+
if self.selection_strategy_ in ('minority', 'combined'):
245+
self.nns_pos_.fit(X_pos)
246+
points_pos = self.nns_pos_.kneighbors(X_pos)[1][:, 1:]
247+
samples_indices = self.random_state_.randint(
248+
low=0, high=len(points_pos.flatten()), size=n_samples
249+
)
250+
rows = np.floor_divide(samples_indices, points_pos.shape[1])
251+
cols = np.mod(samples_indices, points_pos.shape[1])
252+
253+
# Majority or combined strategy
254+
if self.selection_strategy_ in ('majority', 'combined'):
255+
X_neg = X[y != pos_class_label]
256+
self.nn_neg_.fit(X_neg)
257+
points_neg = self.nn_neg_.kneighbors(X_pos)[1]
258+
if self.selection_strategy_ == 'majority':
259+
samples_indices = self.random_state_.randint(
260+
low=0, high=len(points_neg.flatten()), size=n_samples
261+
)
262+
rows = np.floor_divide(samples_indices, points_neg.shape[1])
263+
cols = np.mod(samples_indices, points_neg.shape[1])
264+
265+
# Generate new samples
266+
X_new = np.zeros((n_samples, X.shape[1]))
267+
for ind, (row, col) in enumerate(zip(rows, cols)):
268+
269+
# Define center point
270+
center = X_pos[row]
271+
272+
# Minority strategy
273+
if self.selection_strategy_ == 'minority':
274+
surface_point = X_pos[points_pos[row, col]]
275+
276+
# Majority strategy
277+
elif self.selection_strategy_ == 'majority':
278+
surface_point = X_neg[points_neg[row, col]]
279+
280+
# Combined strategy
281+
else:
282+
surface_point_pos = X_pos[points_pos[row, col]]
283+
surface_point_neg = X_neg[points_neg[row, 0]]
284+
radius_pos = norm(center - surface_point_pos)
285+
radius_neg = norm(center - surface_point_neg)
286+
surface_point = (
287+
surface_point_neg if radius_pos > radius_neg else surface_point_pos
288+
)
289+
290+
# Append new sample
291+
X_new[ind] = _make_geometric_sample(
292+
center,
293+
surface_point,
294+
self.truncation_factor,
295+
self.deformation_factor,
296+
self.random_state_,
297+
)
298+
299+
# Create new samples for target variable
300+
y_new = np.array([pos_class_label] * len(samples_indices))
301+
302+
return X_new, y_new
303+
304+
def _fit_resample(self, X, y):
305+
306+
# Validate estimator's parameters
307+
self._validate_estimator()
308+
309+
# Copy data
310+
X_resampled, y_resampled = X.copy(), y.copy()
311+
312+
# Resample data
313+
for class_label, n_samples in self.sampling_strategy_.items():
314+
315+
# Apply gsmote mechanism
316+
X_new, y_new = self._make_geometric_samples(X, y, class_label, n_samples)
317+
318+
# Append new data
319+
X_resampled, y_resampled = (
320+
np.vstack((X_resampled, X_new)),
321+
np.hstack((y_resampled, y_new)),
322+
)
323+
324+
return X_resampled, y_resampled

0 commit comments

Comments
 (0)