|
| 1 | +""" |
| 2 | +========================= |
| 3 | +Data generation mechanism |
| 4 | +========================= |
| 5 | +
|
| 6 | +This example illustrates the Geometric SMOTE data |
| 7 | +generation mechanism and the usage of its |
| 8 | +hyperparameters. |
| 9 | +
|
| 10 | +""" |
| 11 | + |
| 12 | +# Author: Georgios Douzas <[email protected]> |
| 13 | +# Licence: MIT |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +import matplotlib.pyplot as plt |
| 17 | + |
| 18 | +from sklearn.datasets import make_blobs |
| 19 | +from imblearn.over_sampling import SMOTE |
| 20 | + |
| 21 | +from gsmote import GeometricSMOTE |
| 22 | + |
| 23 | +print(__doc__) |
| 24 | + |
| 25 | +XLIM, YLIM = [-3.0, 3.0], [0.0, 4.0] |
| 26 | +RANDOM_STATE = 5 |
| 27 | + |
| 28 | + |
| 29 | +def generate_imbalanced_data( |
| 30 | + n_maj_samples, n_min_samples, centers, cluster_std, *min_point |
| 31 | +): |
| 32 | + """Generate imbalanced data.""" |
| 33 | + X_neg, _ = make_blobs( |
| 34 | + n_samples=n_maj_samples, |
| 35 | + centers=centers, |
| 36 | + cluster_std=cluster_std, |
| 37 | + random_state=RANDOM_STATE, |
| 38 | + ) |
| 39 | + X_pos = np.array(min_point) |
| 40 | + X = np.vstack([X_neg, X_pos]) |
| 41 | + y_pos = np.zeros(X_neg.shape[0], dtype=np.int8) |
| 42 | + y_neg = np.ones(n_min_samples, dtype=np.int8) |
| 43 | + y = np.hstack([y_pos, y_neg]) |
| 44 | + return X, y |
| 45 | + |
| 46 | + |
| 47 | +def plot_scatter(X, y, title): |
| 48 | + """Function to plot some data as a scatter plot.""" |
| 49 | + plt.figure() |
| 50 | + plt.scatter(X[y == 1, 0], X[y == 1, 1], label='Positive Class') |
| 51 | + plt.scatter(X[y == 0, 0], X[y == 0, 1], label='Negative Class') |
| 52 | + plt.xlim(*XLIM) |
| 53 | + plt.ylim(*YLIM) |
| 54 | + plt.gca().set_aspect('equal', adjustable='box') |
| 55 | + plt.legend() |
| 56 | + plt.title(title) |
| 57 | + |
| 58 | + |
| 59 | +def plot_hyperparameters(oversampler, X, y, param, vals, n_subplots): |
| 60 | + """Function to plot resampled data for various |
| 61 | + values of a geometric hyperparameter.""" |
| 62 | + n_rows = n_subplots[0] |
| 63 | + fig, ax_arr = plt.subplots(*n_subplots, figsize=(15, 7 if n_rows > 1 else 3.5)) |
| 64 | + if n_rows > 1: |
| 65 | + ax_arr = [ax for axs in ax_arr for ax in axs] |
| 66 | + for ax, val in zip(ax_arr, vals): |
| 67 | + oversampler.set_params(**{param: val}) |
| 68 | + X_res, y_res = oversampler.fit_resample(X, y) |
| 69 | + ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label='Positive Class') |
| 70 | + ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label='Negative Class') |
| 71 | + ax.set_title(f'{val}') |
| 72 | + ax.set_xlim(*XLIM) |
| 73 | + ax.set_ylim(*YLIM) |
| 74 | + |
| 75 | + |
| 76 | +def plot_comparison(oversamplers, X, y): |
| 77 | + """Function to compare SMOTE and Geometric SMOTE |
| 78 | + generation of noisy samples.""" |
| 79 | + fig, ax_arr = plt.subplots(1, 2, figsize=(15, 5)) |
| 80 | + for ax, (name, ovs) in zip(ax_arr, oversamplers): |
| 81 | + X_res, y_res = ovs.fit_resample(X, y) |
| 82 | + ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label='Positive Class') |
| 83 | + ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label='Negative Class') |
| 84 | + ax.set_title(name) |
| 85 | + ax.set_xlim(*XLIM) |
| 86 | + ax.set_ylim(*YLIM) |
| 87 | + |
| 88 | + |
| 89 | +############################################################################### |
| 90 | +# Generate imbalanced data |
| 91 | +############################################################################### |
| 92 | + |
| 93 | +############################################################################### |
| 94 | +# We are generating a highly imbalanced non Gaussian data set. Only two samples |
| 95 | +# from the minority (positive) class are included to illustrate the Geometric |
| 96 | +# SMOTE data generation mechanism. |
| 97 | + |
| 98 | +X, y = generate_imbalanced_data( |
| 99 | + 200, 2, [(-2.0, 2.25), (1.0, 2.0)], 0.25, [-0.7, 2.3], [-0.5, 3.1] |
| 100 | +) |
| 101 | +plot_scatter(X, y, 'Imbalanced data') |
| 102 | + |
| 103 | +############################################################################### |
| 104 | +# Geometric hyperparameters |
| 105 | +############################################################################### |
| 106 | + |
| 107 | +############################################################################### |
| 108 | +# Similarly to SMOTE and its variations, Geometric SMOTE uses the `k_neighbors` |
| 109 | +# hyperparameter to select a random neighbor among the k nearest neighbors of a |
| 110 | +# minority class instance. On the other hand, Geometric SMOTE expands the data |
| 111 | +# generation area from the line segment of the SMOTE mechanism to a hypersphere |
| 112 | +# that can be truncated and deformed. The characteristics of the above geometric |
| 113 | +# area are determined by the hyperparameters ``truncation_factor``, |
| 114 | +# ``deformation_factor`` and ``selection_strategy``. These are called geometric |
| 115 | +# hyperparameters and allow the generation of diverse synthetic data as shown |
| 116 | +# below. |
| 117 | + |
| 118 | +############################################################################### |
| 119 | +# Truncation factor |
| 120 | +# .............................................................................. |
| 121 | +# |
| 122 | +# The hyperparameter ``truncation_factor`` determines the degree of truncation |
| 123 | +# that is applied on the initial geometric area. Selecting the values of |
| 124 | +# geometric hyperparameters as `truncation_factor=0.0`, |
| 125 | +# ``deformation_factor=0.0`` and ``selection_strategy='minority'``, the data |
| 126 | +# generation area in 2D corresponds to a circle with center as one of the two |
| 127 | +# minority class samples and radius equal to the distance between them. In the |
| 128 | +# multi-dimensional case the corresponding area is a hypersphere. When |
| 129 | +# truncation factor is increased, the hypersphere is truncated and for |
| 130 | +# ``truncation_factor=1.0`` becomes a half-hypersphere. Negative values of |
| 131 | +# ``truncation_factor`` have a similar effect but on the opposite direction. |
| 132 | + |
| 133 | +gsmote = GeometricSMOTE( |
| 134 | + k_neighbors=1, |
| 135 | + deformation_factor=0.0, |
| 136 | + selection_strategy='minority', |
| 137 | + random_state=RANDOM_STATE, |
| 138 | +) |
| 139 | +truncation_factors = np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) |
| 140 | +n_subplots = [2, 3] |
| 141 | +plot_hyperparameters(gsmote, X, y, 'truncation_factor', truncation_factors, n_subplots) |
| 142 | +plot_hyperparameters(gsmote, X, y, 'truncation_factor', -truncation_factors, n_subplots) |
| 143 | + |
| 144 | +############################################################################### |
| 145 | +# Deformation factor |
| 146 | +# .............................................................................. |
| 147 | +# |
| 148 | +# When the ``deformation_factor`` is increased, the data generation area deforms |
| 149 | +# to an ellipsis and for ``deformation_factor=1.0`` becomes a line segment. |
| 150 | + |
| 151 | +gsmote = GeometricSMOTE( |
| 152 | + k_neighbors=1, |
| 153 | + truncation_factor=0.0, |
| 154 | + selection_strategy='minority', |
| 155 | + random_state=RANDOM_STATE, |
| 156 | +) |
| 157 | +deformation_factors = np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) |
| 158 | +n_subplots = [2, 3] |
| 159 | +plot_hyperparameters(gsmote, X, y, 'deformation_factor', truncation_factors, n_subplots) |
| 160 | + |
| 161 | +############################################################################### |
| 162 | +# Selection strategy |
| 163 | +# .............................................................................. |
| 164 | +# |
| 165 | +# The hyperparameter ``selection_strategy`` determines the selection mechanism |
| 166 | +# of nearest neighbors. Initially, a minority class sample is selected randomly. |
| 167 | +# When ``selection_strategy='minority'``, a second minority class sample is |
| 168 | +# selected as one of the k nearest neighbors of it. For |
| 169 | +# ``selection_strategy='majority'``, the second sample is its nearest majority |
| 170 | +# class neighbor. Finally, for ``selection_strategy='combined'`` the two |
| 171 | +# selection mechanisms are combined and the second sample is the nearest to the |
| 172 | +# first between the two samples defined above. |
| 173 | + |
| 174 | +gsmote = GeometricSMOTE( |
| 175 | + k_neighbors=1, |
| 176 | + truncation_factor=0.0, |
| 177 | + deformation_factor=0.5, |
| 178 | + random_state=RANDOM_STATE, |
| 179 | +) |
| 180 | +selection_strategies = np.array(['minority', 'majority', 'combined']) |
| 181 | +n_subplots = [1, 3] |
| 182 | +plot_hyperparameters( |
| 183 | + gsmote, X, y, 'selection_strategy', selection_strategies, n_subplots |
| 184 | +) |
| 185 | + |
| 186 | +############################################################################### |
| 187 | +# Noisy samples |
| 188 | +############################################################################### |
| 189 | + |
| 190 | +############################################################################### |
| 191 | +# We are adding a third minority class sample to illustrate the difference |
| 192 | +# between SMOTE and Geometric SMOTE data generation mechanisms. |
| 193 | + |
| 194 | +X_new = np.vstack([X, np.array([2.0, 2.0])]) |
| 195 | +y_new = np.hstack([y, np.ones(1, dtype=np.int8)]) |
| 196 | +plot_scatter(X_new, y_new, 'Imbalanced data') |
| 197 | + |
| 198 | +############################################################################### |
| 199 | +# When the number of ``k_neighbors`` is increased, SMOTE results to the |
| 200 | +# generation of noisy samples. On the other hand, Geometric SMOTE avoids this |
| 201 | +# scenario when the ``selection_strategy`` values are either ``combined`` or |
| 202 | +# ``majority``. |
| 203 | + |
| 204 | +oversamplers = [ |
| 205 | + ('SMOTE', SMOTE(k_neighbors=2, random_state=RANDOM_STATE)), |
| 206 | + ( |
| 207 | + 'Geometric SMOTE', |
| 208 | + GeometricSMOTE( |
| 209 | + k_neighbors=2, selection_strategy='combined', random_state=RANDOM_STATE |
| 210 | + ), |
| 211 | + ), |
| 212 | +] |
| 213 | +plot_comparison(oversamplers, X_new, y_new) |
0 commit comments