Skip to content

Commit 467c557

Browse files
committed
FIX imports in Geometric-SMOTE examples and format code (#881)
1 parent fa3ffe5 commit 467c557

File tree

3 files changed

+54
-56
lines changed

3 files changed

+54
-56
lines changed

examples/over-sampling/plot_geometric_smote_generation_mechanism.py

+23-25
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
Data generation mechanism
44
=========================
55
6-
This example illustrates the Geometric SMOTE data
7-
generation mechanism and the usage of its
6+
This example illustrates the Geometric SMOTE data
7+
generation mechanism and the usage of its
88
hyperparameters.
99
1010
"""
@@ -16,9 +16,7 @@
1616
import matplotlib.pyplot as plt
1717

1818
from sklearn.datasets import make_blobs
19-
from imblearn.over_sampling import SMOTE
20-
21-
from gsmote import GeometricSMOTE
19+
from imblearn.over_sampling import SMOTE, GeometricSMOTE
2220

2321
print(__doc__)
2422

@@ -47,11 +45,11 @@ def generate_imbalanced_data(
4745
def plot_scatter(X, y, title):
4846
"""Function to plot some data as a scatter plot."""
4947
plt.figure()
50-
plt.scatter(X[y == 1, 0], X[y == 1, 1], label='Positive Class')
51-
plt.scatter(X[y == 0, 0], X[y == 0, 1], label='Negative Class')
48+
plt.scatter(X[y == 1, 0], X[y == 1, 1], label="Positive Class")
49+
plt.scatter(X[y == 0, 0], X[y == 0, 1], label="Negative Class")
5250
plt.xlim(*XLIM)
5351
plt.ylim(*YLIM)
54-
plt.gca().set_aspect('equal', adjustable='box')
52+
plt.gca().set_aspect("equal", adjustable="box")
5553
plt.legend()
5654
plt.title(title)
5755

@@ -66,9 +64,9 @@ def plot_hyperparameters(oversampler, X, y, param, vals, n_subplots):
6664
for ax, val in zip(ax_arr, vals):
6765
oversampler.set_params(**{param: val})
6866
X_res, y_res = oversampler.fit_resample(X, y)
69-
ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label='Positive Class')
70-
ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label='Negative Class')
71-
ax.set_title(f'{val}')
67+
ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label="Positive Class")
68+
ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label="Negative Class")
69+
ax.set_title(f"{val}")
7270
ax.set_xlim(*XLIM)
7371
ax.set_ylim(*YLIM)
7472

@@ -79,8 +77,8 @@ def plot_comparison(oversamplers, X, y):
7977
fig, ax_arr = plt.subplots(1, 2, figsize=(15, 5))
8078
for ax, (name, ovs) in zip(ax_arr, oversamplers):
8179
X_res, y_res = ovs.fit_resample(X, y)
82-
ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label='Positive Class')
83-
ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label='Negative Class')
80+
ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label="Positive Class")
81+
ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label="Negative Class")
8482
ax.set_title(name)
8583
ax.set_xlim(*XLIM)
8684
ax.set_ylim(*YLIM)
@@ -98,7 +96,7 @@ def plot_comparison(oversamplers, X, y):
9896
X, y = generate_imbalanced_data(
9997
200, 2, [(-2.0, 2.25), (1.0, 2.0)], 0.25, [-0.7, 2.3], [-0.5, 3.1]
10098
)
101-
plot_scatter(X, y, 'Imbalanced data')
99+
plot_scatter(X, y, "Imbalanced data")
102100

103101
###############################################################################
104102
# Geometric hyperparameters
@@ -133,13 +131,13 @@ def plot_comparison(oversamplers, X, y):
133131
gsmote = GeometricSMOTE(
134132
k_neighbors=1,
135133
deformation_factor=0.0,
136-
selection_strategy='minority',
134+
selection_strategy="minority",
137135
random_state=RANDOM_STATE,
138136
)
139137
truncation_factors = np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
140138
n_subplots = [2, 3]
141-
plot_hyperparameters(gsmote, X, y, 'truncation_factor', truncation_factors, n_subplots)
142-
plot_hyperparameters(gsmote, X, y, 'truncation_factor', -truncation_factors, n_subplots)
139+
plot_hyperparameters(gsmote, X, y, "truncation_factor", truncation_factors, n_subplots)
140+
plot_hyperparameters(gsmote, X, y, "truncation_factor", -truncation_factors, n_subplots)
143141

144142
###############################################################################
145143
# Deformation factor
@@ -151,12 +149,12 @@ def plot_comparison(oversamplers, X, y):
151149
gsmote = GeometricSMOTE(
152150
k_neighbors=1,
153151
truncation_factor=0.0,
154-
selection_strategy='minority',
152+
selection_strategy="minority",
155153
random_state=RANDOM_STATE,
156154
)
157155
deformation_factors = np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
158156
n_subplots = [2, 3]
159-
plot_hyperparameters(gsmote, X, y, 'deformation_factor', truncation_factors, n_subplots)
157+
plot_hyperparameters(gsmote, X, y, "deformation_factor", truncation_factors, n_subplots)
160158

161159
###############################################################################
162160
# Selection strategy
@@ -177,10 +175,10 @@ def plot_comparison(oversamplers, X, y):
177175
deformation_factor=0.5,
178176
random_state=RANDOM_STATE,
179177
)
180-
selection_strategies = np.array(['minority', 'majority', 'combined'])
178+
selection_strategies = np.array(["minority", "majority", "combined"])
181179
n_subplots = [1, 3]
182180
plot_hyperparameters(
183-
gsmote, X, y, 'selection_strategy', selection_strategies, n_subplots
181+
gsmote, X, y, "selection_strategy", selection_strategies, n_subplots
184182
)
185183

186184
###############################################################################
@@ -193,7 +191,7 @@ def plot_comparison(oversamplers, X, y):
193191

194192
X_new = np.vstack([X, np.array([2.0, 2.0])])
195193
y_new = np.hstack([y, np.ones(1, dtype=np.int8)])
196-
plot_scatter(X_new, y_new, 'Imbalanced data')
194+
plot_scatter(X_new, y_new, "Imbalanced data")
197195

198196
###############################################################################
199197
# When the number of ``k_neighbors`` is increased, SMOTE results to the
@@ -202,11 +200,11 @@ def plot_comparison(oversamplers, X, y):
202200
# ``majority``.
203201

204202
oversamplers = [
205-
('SMOTE', SMOTE(k_neighbors=2, random_state=RANDOM_STATE)),
203+
("SMOTE", SMOTE(k_neighbors=2, random_state=RANDOM_STATE)),
206204
(
207-
'Geometric SMOTE',
205+
"Geometric SMOTE",
208206
GeometricSMOTE(
209-
k_neighbors=2, selection_strategy='combined', random_state=RANDOM_STATE
207+
k_neighbors=2, selection_strategy="combined", random_state=RANDOM_STATE
210208
),
211209
),
212210
]

examples/over-sampling/plot_geometric_smote_validation_curves.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from sklearn.tree import DecisionTreeClassifier
1919
from sklearn.svm import LinearSVC
2020
from sklearn.model_selection import validation_curve
21-
from sklearn.metrics import make_scorer, cohen_kappa_score
21+
from sklearn.metrics import make_scorer
2222
from sklearn.datasets import make_classification
2323
from imblearn.pipeline import make_pipeline
2424
from imblearn.metrics import geometric_mean_score
2525

26-
from gsmote import GeometricSMOTE
26+
from imblearn.over_sampling import GeometricSMOTE
2727

2828
print(__doc__)
2929

@@ -81,12 +81,12 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
8181
plt.scatter(param_range[idx_max], test_scores_mean[idx_max])
8282
plt.title(title)
8383
plt.ylabel(scoring_name)
84-
ax.spines['top'].set_visible(False)
85-
ax.spines['right'].set_visible(False)
84+
ax.spines["top"].set_visible(False)
85+
ax.spines["right"].set_visible(False)
8686
ax.get_xaxis().tick_bottom()
8787
ax.get_yaxis().tick_left()
88-
ax.spines['left'].set_position(('outward', 10))
89-
ax.spines['bottom'].set_position(('outward', 10))
88+
ax.spines["left"].set_position(("outward", 10))
89+
ax.spines["bottom"].set_position(("outward", 10))
9090
plt.ylim([0.9, 1.0])
9191

9292

@@ -107,11 +107,11 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
107107
DecisionTreeClassifier(random_state=RANDOM_STATE),
108108
)
109109

110-
scoring_name = 'Geometric Mean Score'
110+
scoring_name = "Geometric Mean Score"
111111
validation_curve_info = generate_validation_curve_info(
112112
gsmote_gbc, X, y, range(1, 8), "geometricsmote__k_neighbors", SCORER
113113
)
114-
plot_validation_curve(validation_curve_info, scoring_name, 'K Neighbors')
114+
plot_validation_curve(validation_curve_info, scoring_name, "K Neighbors")
115115

116116
validation_curve_info = generate_validation_curve_info(
117117
gsmote_gbc,
@@ -121,7 +121,7 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
121121
"geometricsmote__truncation_factor",
122122
SCORER,
123123
)
124-
plot_validation_curve(validation_curve_info, scoring_name, 'Truncation Factor')
124+
plot_validation_curve(validation_curve_info, scoring_name, "Truncation Factor")
125125

126126
validation_curve_info = generate_validation_curve_info(
127127
gsmote_gbc,
@@ -131,17 +131,17 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
131131
"geometricsmote__deformation_factor",
132132
SCORER,
133133
)
134-
plot_validation_curve(validation_curve_info, scoring_name, 'Deformation Factor')
134+
plot_validation_curve(validation_curve_info, scoring_name, "Deformation Factor")
135135

136136
validation_curve_info = generate_validation_curve_info(
137137
gsmote_gbc,
138138
X,
139139
y,
140-
['minority', 'majority', 'combined'],
140+
["minority", "majority", "combined"],
141141
"geometricsmote__selection_strategy",
142142
SCORER,
143143
)
144-
plot_validation_curve(validation_curve_info, scoring_name, 'Selection Strategy')
144+
plot_validation_curve(validation_curve_info, scoring_name, "Selection Strategy")
145145

146146
###############################################################################
147147
# High Imbalance Ratio or low Samples to Features Ratio
@@ -158,11 +158,11 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
158158
LinearSVC(random_state=RANDOM_STATE, max_iter=1e5),
159159
)
160160

161-
scoring_name = 'Geometric Mean Score'
161+
scoring_name = "Geometric Mean Score"
162162
validation_curve_info = generate_validation_curve_info(
163163
gsmote_gbc, X, y, range(1, 8), "geometricsmote__k_neighbors", SCORER
164164
)
165-
plot_validation_curve(validation_curve_info, scoring_name, 'K Neighbors')
165+
plot_validation_curve(validation_curve_info, scoring_name, "K Neighbors")
166166

167167
validation_curve_info = generate_validation_curve_info(
168168
gsmote_gbc,
@@ -172,7 +172,7 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
172172
"geometricsmote__truncation_factor",
173173
SCORER,
174174
)
175-
plot_validation_curve(validation_curve_info, scoring_name, 'Truncation Factor')
175+
plot_validation_curve(validation_curve_info, scoring_name, "Truncation Factor")
176176

177177
validation_curve_info = generate_validation_curve_info(
178178
gsmote_gbc,
@@ -182,14 +182,14 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
182182
"geometricsmote__deformation_factor",
183183
SCORER,
184184
)
185-
plot_validation_curve(validation_curve_info, scoring_name, 'Deformation Factor')
185+
plot_validation_curve(validation_curve_info, scoring_name, "Deformation Factor")
186186

187187
validation_curve_info = generate_validation_curve_info(
188188
gsmote_gbc,
189189
X,
190190
y,
191-
['minority', 'majority', 'combined'],
191+
["minority", "majority", "combined"],
192192
"geometricsmote__selection_strategy",
193193
SCORER,
194194
)
195-
plot_validation_curve(validation_curve_info, scoring_name, 'Selection Strategy')
195+
plot_validation_curve(validation_curve_info, scoring_name, "Selection Strategy")

imblearn/over_sampling/_smote/tests/test_geometric_smote.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
@pytest.mark.parametrize(
34-
'center,surface_point',
34+
"center,surface_point",
3535
[
3636
(CENTERS[0], SURFACE_POINTS[0]),
3737
(CENTERS[1], SURFACE_POINTS[1]),
@@ -48,7 +48,7 @@ def test_make_geometric_sample_hypersphere(center, surface_point):
4848

4949

5050
@pytest.mark.parametrize(
51-
'surface_point,deformation_factor',
51+
"surface_point,deformation_factor",
5252
[
5353
(np.array([1.0, 0.0]), 0.0),
5454
(2.6 * np.array([0.0, 1.0]), 0.25),
@@ -68,7 +68,7 @@ def test_make_geometric_sample_half_hypersphere(surface_point, deformation_facto
6868

6969

7070
@pytest.mark.parametrize(
71-
'center,surface_point,truncation_factor',
71+
"center,surface_point,truncation_factor",
7272
[
7373
(center, surface_point, truncation_factor)
7474
for center, surface_point in zip(CENTERS, SURFACE_POINTS)
@@ -94,11 +94,11 @@ def test_make_geometric_sample_line_segment(center, surface_point, truncation_fa
9494
def test_gsmote_default_init():
9595
"""Test the intialization with default parameters."""
9696
gsmote = GeometricSMOTE()
97-
assert gsmote.sampling_strategy == 'auto'
97+
assert gsmote.sampling_strategy == "auto"
9898
assert gsmote.random_state is None
9999
assert gsmote.truncation_factor == 1.0
100100
assert gsmote.deformation_factor == 0.0
101-
assert gsmote.selection_strategy == 'combined'
101+
assert gsmote.selection_strategy == "combined"
102102
assert gsmote.k_neighbors == 5
103103
assert gsmote.n_jobs == 1
104104

@@ -119,12 +119,12 @@ def test_gsmote_invalid_selection_strategy():
119119
X, y = make_classification(
120120
random_state=RND_SEED, n_samples=n_samples, weights=weights
121121
)
122-
gsmote = GeometricSMOTE(random_state=RANDOM_STATE, selection_strategy='Minority')
122+
gsmote = GeometricSMOTE(random_state=RANDOM_STATE, selection_strategy="Minority")
123123
with pytest.raises(ValueError):
124124
gsmote.fit_resample(X, y)
125125

126126

127-
@pytest.mark.parametrize('selection_strategy', ['combined', 'minority', 'majority'])
127+
@pytest.mark.parametrize("selection_strategy", ["combined", "minority", "majority"])
128128
def test_gsmote_nn(selection_strategy):
129129
"""Test nearest neighbors object."""
130130
n_samples, weights = 200, [0.6, 0.4]
@@ -135,14 +135,14 @@ def test_gsmote_nn(selection_strategy):
135135
random_state=RANDOM_STATE, selection_strategy=selection_strategy
136136
)
137137
_ = gsmote.fit_resample(X, y)
138-
if selection_strategy in ('minority', 'combined'):
138+
if selection_strategy in ("minority", "combined"):
139139
assert gsmote.nns_pos_.n_neighbors == gsmote.k_neighbors + 1
140-
if selection_strategy in ('majority', 'combined'):
140+
if selection_strategy in ("majority", "combined"):
141141
assert gsmote.nn_neg_.n_neighbors == 1
142142

143143

144144
@pytest.mark.parametrize(
145-
'selection_strategy, truncation_factor, deformation_factor',
145+
"selection_strategy, truncation_factor, deformation_factor",
146146
[
147147
(selection_strategy, truncation_factor, deformation_factor)
148148
for selection_strategy in SELECTION_STRATEGY
@@ -160,7 +160,7 @@ def test_gsmote_fit_resample_binary(
160160
radius = np.sqrt(0.5) * step
161161
k_neighbors = 1
162162
gsmote = GeometricSMOTE(
163-
'auto',
163+
"auto",
164164
RANDOM_STATE,
165165
truncation_factor,
166166
deformation_factor,
@@ -174,7 +174,7 @@ def test_gsmote_fit_resample_binary(
174174

175175

176176
@pytest.mark.parametrize(
177-
'selection_strategy, truncation_factor, deformation_factor',
177+
"selection_strategy, truncation_factor, deformation_factor",
178178
[
179179
(selection_strategy, truncation_factor, deformation_factor)
180180
for selection_strategy in SELECTION_STRATEGY
@@ -196,7 +196,7 @@ def test_gsmote_fit_resample_multiclass(
196196
)
197197
k_neighbors, majority_label = 1, 0
198198
gsmote = GeometricSMOTE(
199-
'auto',
199+
"auto",
200200
RANDOM_STATE,
201201
truncation_factor,
202202
deformation_factor,

0 commit comments

Comments
 (0)