3
3
Data generation mechanism
4
4
=========================
5
5
6
- This example illustrates the Geometric SMOTE data
7
- generation mechanism and the usage of its
6
+ This example illustrates the Geometric SMOTE data
7
+ generation mechanism and the usage of its
8
8
hyperparameters.
9
9
10
10
"""
16
16
import matplotlib .pyplot as plt
17
17
18
18
from sklearn .datasets import make_blobs
19
- from imblearn .over_sampling import SMOTE
20
-
21
- from gsmote import GeometricSMOTE
19
+ from imblearn .over_sampling import SMOTE , GeometricSMOTE
22
20
23
21
print (__doc__ )
24
22
@@ -47,11 +45,11 @@ def generate_imbalanced_data(
47
45
def plot_scatter (X , y , title ):
48
46
"""Function to plot some data as a scatter plot."""
49
47
plt .figure ()
50
- plt .scatter (X [y == 1 , 0 ], X [y == 1 , 1 ], label = ' Positive Class' )
51
- plt .scatter (X [y == 0 , 0 ], X [y == 0 , 1 ], label = ' Negative Class' )
48
+ plt .scatter (X [y == 1 , 0 ], X [y == 1 , 1 ], label = " Positive Class" )
49
+ plt .scatter (X [y == 0 , 0 ], X [y == 0 , 1 ], label = " Negative Class" )
52
50
plt .xlim (* XLIM )
53
51
plt .ylim (* YLIM )
54
- plt .gca ().set_aspect (' equal' , adjustable = ' box' )
52
+ plt .gca ().set_aspect (" equal" , adjustable = " box" )
55
53
plt .legend ()
56
54
plt .title (title )
57
55
@@ -66,9 +64,9 @@ def plot_hyperparameters(oversampler, X, y, param, vals, n_subplots):
66
64
for ax , val in zip (ax_arr , vals ):
67
65
oversampler .set_params (** {param : val })
68
66
X_res , y_res = oversampler .fit_resample (X , y )
69
- ax .scatter (X_res [y_res == 1 , 0 ], X_res [y_res == 1 , 1 ], label = ' Positive Class' )
70
- ax .scatter (X_res [y_res == 0 , 0 ], X_res [y_res == 0 , 1 ], label = ' Negative Class' )
71
- ax .set_title (f' { val } ' )
67
+ ax .scatter (X_res [y_res == 1 , 0 ], X_res [y_res == 1 , 1 ], label = " Positive Class" )
68
+ ax .scatter (X_res [y_res == 0 , 0 ], X_res [y_res == 0 , 1 ], label = " Negative Class" )
69
+ ax .set_title (f" { val } " )
72
70
ax .set_xlim (* XLIM )
73
71
ax .set_ylim (* YLIM )
74
72
@@ -79,8 +77,8 @@ def plot_comparison(oversamplers, X, y):
79
77
fig , ax_arr = plt .subplots (1 , 2 , figsize = (15 , 5 ))
80
78
for ax , (name , ovs ) in zip (ax_arr , oversamplers ):
81
79
X_res , y_res = ovs .fit_resample (X , y )
82
- ax .scatter (X_res [y_res == 1 , 0 ], X_res [y_res == 1 , 1 ], label = ' Positive Class' )
83
- ax .scatter (X_res [y_res == 0 , 0 ], X_res [y_res == 0 , 1 ], label = ' Negative Class' )
80
+ ax .scatter (X_res [y_res == 1 , 0 ], X_res [y_res == 1 , 1 ], label = " Positive Class" )
81
+ ax .scatter (X_res [y_res == 0 , 0 ], X_res [y_res == 0 , 1 ], label = " Negative Class" )
84
82
ax .set_title (name )
85
83
ax .set_xlim (* XLIM )
86
84
ax .set_ylim (* YLIM )
@@ -98,7 +96,7 @@ def plot_comparison(oversamplers, X, y):
98
96
X , y = generate_imbalanced_data (
99
97
200 , 2 , [(- 2.0 , 2.25 ), (1.0 , 2.0 )], 0.25 , [- 0.7 , 2.3 ], [- 0.5 , 3.1 ]
100
98
)
101
- plot_scatter (X , y , ' Imbalanced data' )
99
+ plot_scatter (X , y , " Imbalanced data" )
102
100
103
101
###############################################################################
104
102
# Geometric hyperparameters
@@ -133,13 +131,13 @@ def plot_comparison(oversamplers, X, y):
133
131
gsmote = GeometricSMOTE (
134
132
k_neighbors = 1 ,
135
133
deformation_factor = 0.0 ,
136
- selection_strategy = ' minority' ,
134
+ selection_strategy = " minority" ,
137
135
random_state = RANDOM_STATE ,
138
136
)
139
137
truncation_factors = np .array ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
140
138
n_subplots = [2 , 3 ]
141
- plot_hyperparameters (gsmote , X , y , ' truncation_factor' , truncation_factors , n_subplots )
142
- plot_hyperparameters (gsmote , X , y , ' truncation_factor' , - truncation_factors , n_subplots )
139
+ plot_hyperparameters (gsmote , X , y , " truncation_factor" , truncation_factors , n_subplots )
140
+ plot_hyperparameters (gsmote , X , y , " truncation_factor" , - truncation_factors , n_subplots )
143
141
144
142
###############################################################################
145
143
# Deformation factor
@@ -151,12 +149,12 @@ def plot_comparison(oversamplers, X, y):
151
149
gsmote = GeometricSMOTE (
152
150
k_neighbors = 1 ,
153
151
truncation_factor = 0.0 ,
154
- selection_strategy = ' minority' ,
152
+ selection_strategy = " minority" ,
155
153
random_state = RANDOM_STATE ,
156
154
)
157
155
deformation_factors = np .array ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
158
156
n_subplots = [2 , 3 ]
159
- plot_hyperparameters (gsmote , X , y , ' deformation_factor' , truncation_factors , n_subplots )
157
+ plot_hyperparameters (gsmote , X , y , " deformation_factor" , truncation_factors , n_subplots )
160
158
161
159
###############################################################################
162
160
# Selection strategy
@@ -177,10 +175,10 @@ def plot_comparison(oversamplers, X, y):
177
175
deformation_factor = 0.5 ,
178
176
random_state = RANDOM_STATE ,
179
177
)
180
- selection_strategies = np .array ([' minority' , ' majority' , ' combined' ])
178
+ selection_strategies = np .array ([" minority" , " majority" , " combined" ])
181
179
n_subplots = [1 , 3 ]
182
180
plot_hyperparameters (
183
- gsmote , X , y , ' selection_strategy' , selection_strategies , n_subplots
181
+ gsmote , X , y , " selection_strategy" , selection_strategies , n_subplots
184
182
)
185
183
186
184
###############################################################################
@@ -193,7 +191,7 @@ def plot_comparison(oversamplers, X, y):
193
191
194
192
X_new = np .vstack ([X , np .array ([2.0 , 2.0 ])])
195
193
y_new = np .hstack ([y , np .ones (1 , dtype = np .int8 )])
196
- plot_scatter (X_new , y_new , ' Imbalanced data' )
194
+ plot_scatter (X_new , y_new , " Imbalanced data" )
197
195
198
196
###############################################################################
199
197
# When the number of ``k_neighbors`` is increased, SMOTE results to the
@@ -202,11 +200,11 @@ def plot_comparison(oversamplers, X, y):
202
200
# ``majority``.
203
201
204
202
oversamplers = [
205
- (' SMOTE' , SMOTE (k_neighbors = 2 , random_state = RANDOM_STATE )),
203
+ (" SMOTE" , SMOTE (k_neighbors = 2 , random_state = RANDOM_STATE )),
206
204
(
207
- ' Geometric SMOTE' ,
205
+ " Geometric SMOTE" ,
208
206
GeometricSMOTE (
209
- k_neighbors = 2 , selection_strategy = ' combined' , random_state = RANDOM_STATE
207
+ k_neighbors = 2 , selection_strategy = " combined" , random_state = RANDOM_STATE
210
208
),
211
209
),
212
210
]
0 commit comments