diff --git a/examples/over-sampling/plot_geometric_smote_generation_mechanism.py b/examples/over-sampling/plot_geometric_smote_generation_mechanism.py new file mode 100644 index 000000000..9f13b1c88 --- /dev/null +++ b/examples/over-sampling/plot_geometric_smote_generation_mechanism.py @@ -0,0 +1,211 @@ +""" +========================= +Data generation mechanism +========================= + +This example illustrates the Geometric SMOTE data +generation mechanism and the usage of its +hyperparameters. + +""" + +# Author: Georgios Douzas +# Licence: MIT + +import numpy as np +import matplotlib.pyplot as plt + +from sklearn.datasets import make_blobs +from imblearn.over_sampling import SMOTE, GeometricSMOTE + +print(__doc__) + +XLIM, YLIM = [-3.0, 3.0], [0.0, 4.0] +RANDOM_STATE = 5 + + +def generate_imbalanced_data( + n_maj_samples, n_min_samples, centers, cluster_std, *min_point +): + """Generate imbalanced data.""" + X_neg, _ = make_blobs( + n_samples=n_maj_samples, + centers=centers, + cluster_std=cluster_std, + random_state=RANDOM_STATE, + ) + X_pos = np.array(min_point) + X = np.vstack([X_neg, X_pos]) + y_pos = np.zeros(X_neg.shape[0], dtype=np.int8) + y_neg = np.ones(n_min_samples, dtype=np.int8) + y = np.hstack([y_pos, y_neg]) + return X, y + + +def plot_scatter(X, y, title): + """Function to plot some data as a scatter plot.""" + plt.figure() + plt.scatter(X[y == 1, 0], X[y == 1, 1], label="Positive Class") + plt.scatter(X[y == 0, 0], X[y == 0, 1], label="Negative Class") + plt.xlim(*XLIM) + plt.ylim(*YLIM) + plt.gca().set_aspect("equal", adjustable="box") + plt.legend() + plt.title(title) + + +def plot_hyperparameters(oversampler, X, y, param, vals, n_subplots): + """Function to plot resampled data for various + values of a geometric hyperparameter.""" + n_rows = n_subplots[0] + fig, ax_arr = plt.subplots(*n_subplots, figsize=(15, 7 if n_rows > 1 else 3.5)) + if n_rows > 1: + ax_arr = [ax for axs in ax_arr for ax in axs] + for ax, val in zip(ax_arr, vals): + oversampler.set_params(**{param: val}) + X_res, y_res = oversampler.fit_resample(X, y) + ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label="Positive Class") + ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label="Negative Class") + ax.set_title(f"{val}") + ax.set_xlim(*XLIM) + ax.set_ylim(*YLIM) + + +def plot_comparison(oversamplers, X, y): + """Function to compare SMOTE and Geometric SMOTE + generation of noisy samples.""" + fig, ax_arr = plt.subplots(1, 2, figsize=(15, 5)) + for ax, (name, ovs) in zip(ax_arr, oversamplers): + X_res, y_res = ovs.fit_resample(X, y) + ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label="Positive Class") + ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label="Negative Class") + ax.set_title(name) + ax.set_xlim(*XLIM) + ax.set_ylim(*YLIM) + + +############################################################################### +# Generate imbalanced data +############################################################################### + +############################################################################### +# We are generating a highly imbalanced non Gaussian data set. Only two samples +# from the minority (positive) class are included to illustrate the Geometric +# SMOTE data generation mechanism. + +X, y = generate_imbalanced_data( + 200, 2, [(-2.0, 2.25), (1.0, 2.0)], 0.25, [-0.7, 2.3], [-0.5, 3.1] +) +plot_scatter(X, y, "Imbalanced data") + +############################################################################### +# Geometric hyperparameters +############################################################################### + +############################################################################### +# Similarly to SMOTE and its variations, Geometric SMOTE uses the `k_neighbors` +# hyperparameter to select a random neighbor among the k nearest neighbors of a +# minority class instance. On the other hand, Geometric SMOTE expands the data +# generation area from the line segment of the SMOTE mechanism to a hypersphere +# that can be truncated and deformed. The characteristics of the above geometric +# area are determined by the hyperparameters ``truncation_factor``, +# ``deformation_factor`` and ``selection_strategy``. These are called geometric +# hyperparameters and allow the generation of diverse synthetic data as shown +# below. + +############################################################################### +# Truncation factor +# .............................................................................. +# +# The hyperparameter ``truncation_factor`` determines the degree of truncation +# that is applied on the initial geometric area. Selecting the values of +# geometric hyperparameters as `truncation_factor=0.0`, +# ``deformation_factor=0.0`` and ``selection_strategy='minority'``, the data +# generation area in 2D corresponds to a circle with center as one of the two +# minority class samples and radius equal to the distance between them. In the +# multi-dimensional case the corresponding area is a hypersphere. When +# truncation factor is increased, the hypersphere is truncated and for +# ``truncation_factor=1.0`` becomes a half-hypersphere. Negative values of +# ``truncation_factor`` have a similar effect but on the opposite direction. + +gsmote = GeometricSMOTE( + k_neighbors=1, + deformation_factor=0.0, + selection_strategy="minority", + random_state=RANDOM_STATE, +) +truncation_factors = np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) +n_subplots = [2, 3] +plot_hyperparameters(gsmote, X, y, "truncation_factor", truncation_factors, n_subplots) +plot_hyperparameters(gsmote, X, y, "truncation_factor", -truncation_factors, n_subplots) + +############################################################################### +# Deformation factor +# .............................................................................. +# +# When the ``deformation_factor`` is increased, the data generation area deforms +# to an ellipsis and for ``deformation_factor=1.0`` becomes a line segment. + +gsmote = GeometricSMOTE( + k_neighbors=1, + truncation_factor=0.0, + selection_strategy="minority", + random_state=RANDOM_STATE, +) +deformation_factors = np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) +n_subplots = [2, 3] +plot_hyperparameters(gsmote, X, y, "deformation_factor", truncation_factors, n_subplots) + +############################################################################### +# Selection strategy +# .............................................................................. +# +# The hyperparameter ``selection_strategy`` determines the selection mechanism +# of nearest neighbors. Initially, a minority class sample is selected randomly. +# When ``selection_strategy='minority'``, a second minority class sample is +# selected as one of the k nearest neighbors of it. For +# ``selection_strategy='majority'``, the second sample is its nearest majority +# class neighbor. Finally, for ``selection_strategy='combined'`` the two +# selection mechanisms are combined and the second sample is the nearest to the +# first between the two samples defined above. + +gsmote = GeometricSMOTE( + k_neighbors=1, + truncation_factor=0.0, + deformation_factor=0.5, + random_state=RANDOM_STATE, +) +selection_strategies = np.array(["minority", "majority", "combined"]) +n_subplots = [1, 3] +plot_hyperparameters( + gsmote, X, y, "selection_strategy", selection_strategies, n_subplots +) + +############################################################################### +# Noisy samples +############################################################################### + +############################################################################### +# We are adding a third minority class sample to illustrate the difference +# between SMOTE and Geometric SMOTE data generation mechanisms. + +X_new = np.vstack([X, np.array([2.0, 2.0])]) +y_new = np.hstack([y, np.ones(1, dtype=np.int8)]) +plot_scatter(X_new, y_new, "Imbalanced data") + +############################################################################### +# When the number of ``k_neighbors`` is increased, SMOTE results to the +# generation of noisy samples. On the other hand, Geometric SMOTE avoids this +# scenario when the ``selection_strategy`` values are either ``combined`` or +# ``majority``. + +oversamplers = [ + ("SMOTE", SMOTE(k_neighbors=2, random_state=RANDOM_STATE)), + ( + "Geometric SMOTE", + GeometricSMOTE( + k_neighbors=2, selection_strategy="combined", random_state=RANDOM_STATE + ), + ), +] +plot_comparison(oversamplers, X_new, y_new) diff --git a/examples/over-sampling/plot_geometric_smote_validation_curves.py b/examples/over-sampling/plot_geometric_smote_validation_curves.py new file mode 100644 index 000000000..76a133e83 --- /dev/null +++ b/examples/over-sampling/plot_geometric_smote_validation_curves.py @@ -0,0 +1,195 @@ +""" +========================== +Plotting validation curves +========================== + +In this example the impact of the Geometric SMOTE's hyperparameters is examined. +The validation scores of a Geometric SMOTE-GBC classifier is presented for +different values of the Geometric SMOTE's hyperparameters. + +""" + +# Author: Georgios Douzas +# Licence: MIT + +import matplotlib.pyplot as plt +import numpy as np + +from sklearn.tree import DecisionTreeClassifier +from sklearn.svm import LinearSVC +from sklearn.model_selection import validation_curve +from sklearn.metrics import make_scorer +from sklearn.datasets import make_classification +from imblearn.pipeline import make_pipeline +from imblearn.metrics import geometric_mean_score + +from imblearn.over_sampling import GeometricSMOTE + +print(__doc__) + +RANDOM_STATE = 10 +SCORER = make_scorer(geometric_mean_score) + + +def generate_imbalanced_data(weights, n_samples, n_features, n_informative): + """Generate imbalanced data.""" + X, y = make_classification( + n_classes=2, + class_sep=2, + weights=weights, + n_informative=n_informative, + n_redundant=1, + flip_y=0, + n_features=n_features, + n_clusters_per_class=2, + n_samples=n_samples, + random_state=RANDOM_STATE, + ) + return X, y + + +def generate_validation_curve_info(estimator, X, y, param_range, param_name, scoring): + """Generate information for the validation curve.""" + _, test_scores = validation_curve( + estimator, + X, + y, + param_name=param_name, + param_range=param_range, + cv=3, + scoring=scoring, + n_jobs=-1, + ) + test_scores_mean = np.mean(test_scores, axis=1) + test_scores_std = np.std(test_scores, axis=1) + return test_scores_mean, test_scores_std, param_range + + +def plot_validation_curve(validation_curve_info, scoring_name, title): + """Plot the validation curve.""" + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + test_scores_mean, test_scores_std, param_range = validation_curve_info + plt.plot(param_range, test_scores_mean) + ax.fill_between( + param_range, + test_scores_mean + test_scores_std, + test_scores_mean - test_scores_std, + alpha=0.2, + ) + idx_max = np.argmax(test_scores_mean) + plt.scatter(param_range[idx_max], test_scores_mean[idx_max]) + plt.title(title) + plt.ylabel(scoring_name) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.get_xaxis().tick_bottom() + ax.get_yaxis().tick_left() + ax.spines["left"].set_position(("outward", 10)) + ax.spines["bottom"].set_position(("outward", 10)) + plt.ylim([0.9, 1.0]) + + +############################################################################### +# Low Imbalance Ratio or high Samples to Features Ratio +############################################################################### + +############################################################################### +# When :math:`\text{IR} = \frac{\text{\# majority samples}}{\text{\# minority +# samples}}` (Imbalance Ratio) is low or :math:`\text{SFR} = \frac{\text{\# +# samples}}{\text{\# features}}` (Samples to Features Ratio) is high then the +# minority selection strategy and higher absolute values of the truncation and +# deformation factors dominate as optimal hyperparameters. + +X, y = generate_imbalanced_data([0.3, 0.7], 2000, 6, 4) +gsmote_gbc = make_pipeline( + GeometricSMOTE(random_state=RANDOM_STATE), + DecisionTreeClassifier(random_state=RANDOM_STATE), +) + +scoring_name = "Geometric Mean Score" +validation_curve_info = generate_validation_curve_info( + gsmote_gbc, X, y, range(1, 8), "geometricsmote__k_neighbors", SCORER +) +plot_validation_curve(validation_curve_info, scoring_name, "K Neighbors") + +validation_curve_info = generate_validation_curve_info( + gsmote_gbc, + X, + y, + np.linspace(-1.0, 1.0, 9), + "geometricsmote__truncation_factor", + SCORER, +) +plot_validation_curve(validation_curve_info, scoring_name, "Truncation Factor") + +validation_curve_info = generate_validation_curve_info( + gsmote_gbc, + X, + y, + np.linspace(0.0, 1.0, 5), + "geometricsmote__deformation_factor", + SCORER, +) +plot_validation_curve(validation_curve_info, scoring_name, "Deformation Factor") + +validation_curve_info = generate_validation_curve_info( + gsmote_gbc, + X, + y, + ["minority", "majority", "combined"], + "geometricsmote__selection_strategy", + SCORER, +) +plot_validation_curve(validation_curve_info, scoring_name, "Selection Strategy") + +############################################################################### +# High Imbalance Ratio or low Samples to Features Ratio +############################################################################### + +############################################################################### +# When :math:`\text{IR}` is high or :math:`\text{SFR}` is low then the majority +# or combined selection strategies and lower absolute values of the truncation +# and deformation factors dominate as optimal hyperparameters. + +X, y = generate_imbalanced_data([0.1, 0.9], 2000, 400, 200) +gsmote_gbc = make_pipeline( + GeometricSMOTE(random_state=RANDOM_STATE), + LinearSVC(random_state=RANDOM_STATE, max_iter=1e5), +) + +scoring_name = "Geometric Mean Score" +validation_curve_info = generate_validation_curve_info( + gsmote_gbc, X, y, range(1, 8), "geometricsmote__k_neighbors", SCORER +) +plot_validation_curve(validation_curve_info, scoring_name, "K Neighbors") + +validation_curve_info = generate_validation_curve_info( + gsmote_gbc, + X, + y, + np.linspace(-1.0, 1.0, 9), + "geometricsmote__truncation_factor", + SCORER, +) +plot_validation_curve(validation_curve_info, scoring_name, "Truncation Factor") + +validation_curve_info = generate_validation_curve_info( + gsmote_gbc, + X, + y, + np.linspace(0.0, 1.0, 5), + "geometricsmote__deformation_factor", + SCORER, +) +plot_validation_curve(validation_curve_info, scoring_name, "Deformation Factor") + +validation_curve_info = generate_validation_curve_info( + gsmote_gbc, + X, + y, + ["minority", "majority", "combined"], + "geometricsmote__selection_strategy", + SCORER, +) +plot_validation_curve(validation_curve_info, scoring_name, "Selection Strategy") diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index a959cbb43..36504d3d4 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -8,6 +8,7 @@ from ._smote import SMOTE from ._smote import BorderlineSMOTE from ._smote import KMeansSMOTE +from ._smote import GeometricSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC from ._smote import SMOTEN @@ -16,6 +17,7 @@ "ADASYN", "RandomOverSampler", "KMeansSMOTE", + "GeometricSMOTE", "SMOTE", "BorderlineSMOTE", "SVMSMOTE", diff --git a/imblearn/over_sampling/_smote/__init__.py b/imblearn/over_sampling/_smote/__init__.py index aaf4dd348..42cd89ce8 100644 --- a/imblearn/over_sampling/_smote/__init__.py +++ b/imblearn/over_sampling/_smote/__init__.py @@ -4,6 +4,8 @@ from .cluster import KMeansSMOTE +from .geometric import GeometricSMOTE + from .filter import BorderlineSMOTE from .filter import SVMSMOTE @@ -12,6 +14,7 @@ "SMOTEN", "SMOTENC", "KMeansSMOTE", + "GeometricSMOTE", "BorderlineSMOTE", "SVMSMOTE", ] diff --git a/imblearn/over_sampling/_smote/geometric.py b/imblearn/over_sampling/_smote/geometric.py new file mode 100644 index 000000000..93f9aa134 --- /dev/null +++ b/imblearn/over_sampling/_smote/geometric.py @@ -0,0 +1,377 @@ +"""Class to perform over-sampling using Geometric SMOTE.""" + +# Author: Georgios Douzas +# Joao Fonseca +# License: BSD 3 clause + +import numpy as np +from numpy.linalg import norm +from scipy import sparse +from sklearn.utils import check_random_state +from ..base import BaseOverSampler +from imblearn.utils import check_neighbors_object, Substitution +from imblearn.utils._docstring import _random_state_docstring + +SELECTION_STRATEGY = ("combined", "majority", "minority") + + +def _make_geometric_sample( + center, surface_point, truncation_factor, deformation_factor, random_state +): + """A support function that returns an artificial point inside + the geometric region defined by the center and surface points. + + Parameters + ---------- + center : ndarray, shape (n_features, ) + Center point of the geometric region. + + surface_point : ndarray, shape (n_features, ) + Surface point of the geometric region. + + truncation_factor : float, optional (default=0.0) + The type of truncation. The values should be in the [-1.0, 1.0] range. + + deformation_factor : float, optional (default=0.0) + The type of geometry. The values should be in the [0.0, 1.0] range. + + random_state : int, RandomState instance or None + Control the randomization of the algorithm. + + Returns + ------- + point : ndarray, shape (n_features, ) + Synthetically generated sample. + + """ + + # Zero radius case + if np.array_equal(center, surface_point): + return center + + # Generate a point on the surface of a unit hyper-sphere + radius = norm(center - surface_point) + normal_samples = random_state.normal(size=center.size) + point_on_unit_sphere = normal_samples / norm(normal_samples) + point = (random_state.uniform(size=1) ** (1 / center.size)) * point_on_unit_sphere + + # Parallel unit vector + parallel_unit_vector = (surface_point - center) / norm(surface_point - center) + + # Truncation + close_to_opposite_boundary = ( + truncation_factor > 0 + and np.dot(point, parallel_unit_vector) < truncation_factor - 1 + ) + close_to_boundary = ( + truncation_factor < 0 + and np.dot(point, parallel_unit_vector) > truncation_factor + 1 + ) + if close_to_opposite_boundary or close_to_boundary: + point -= 2 * np.dot(point, parallel_unit_vector) * parallel_unit_vector + + # Deformation + parallel_point_position = np.dot(point, parallel_unit_vector) * parallel_unit_vector + perpendicular_point_position = point - parallel_point_position + point = ( + parallel_point_position + + (1 - deformation_factor) * perpendicular_point_position + ) + + # Translation + point = center + radius * point + + return point + + +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + random_state=_random_state_docstring, +) +class GeometricSMOTE(BaseOverSampler): + """Class to to perform over-sampling using Geometric SMOTE. + + This algorithm is an implementation of Geometric SMOTE, a geometrically enhanced + drop-in replacement for SMOTE as presented in [1]_. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + {sampling_strategy} + + {random_state} + + truncation_factor : float, optional (default=0.0) + The type of truncation. The values should be in the [-1.0, 1.0] range. + + deformation_factor : float, optional (default=0.0) + The type of geometry. The values should be in the [0.0, 1.0] range. + + selection_strategy : str, optional (default='combined') + The type of Geometric SMOTE algorithm with the following options: + ``'combined'``, ``'majority'``, ``'minority'``. + + k_neighbors : int or object, optional (default=5) + If ``int``, number of nearest neighbours to use when synthetic + samples are constructed for the minority method. If object, an estimator + that inherits from :class:`sklearn.neighbors.base.KNeighborsMixin` that + will be used to find the k_neighbors. + + n_jobs : int, optional (default=1) + The number of threads to open if possible. + + Attributes + ---------- + sampling_strategy_ : dict + Dictionary containing the information to sample the dataset. The keys + corresponds to the class labels from which to sample and the values + are the number of samples to sample. + + n_features_in_ : int + Number of features in the input dataset. + + nns_pos_ : estimator object + Validated k-nearest neighbours created from the `k_neighbors` parameter. It is + used to find the nearest neighbors of the same class of a selected + observation. + + nn_neg_ : estimator object + Validated k-nearest neighbours created from the `k_neighbors` parameter. It is + used to find the nearest neighbor of the remaining classes (k=1) of a selected + observation. + + random_state_ : instance of RandomState + If the `random_state` parameter is None, it is a RandomState singleton used by + np.random. If `random_state` is an int, it is a RandomState instance seeded with + seed. If `random_state` is already a RandomState instance, it is the same + object. + + See Also + -------- + SMOTE : Over-sample using SMOTE. + + SMOTEN : Over-sample using the SMOTE variant specifically for categorical + features only. + + SMOTENC : Over-sample using SMOTE for continuous and categorical features. + + SVMSMOTE : Over-sample using SVM-SMOTE variant. + + BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. + + ADASYN : Over-sample using ADASYN. + + KMeansSMOTE : Over-sample applying a clustering before to oversample using + SMOTE. + + Notes + ----- + See the original paper: [1]_ for more details. + + Supports multi-class resampling. A one-vs.-rest scheme is used as + originally proposed in [2]_. + + References + ---------- + .. [1] G. Douzas, F. Bacao, "Geometric SMOTE: + a geometrically enhanced drop-in replacement for SMOTE", + Information Sciences, vol. 501, pp. 118-135, 2019. + + .. [2] N. V. Chawla, K. W. Bowyer, L. O. Hall, W. P. Kegelmeyer, "SMOTE: + synthetic minority over-sampling technique", Journal of Artificial + Intelligence Research, vol. 16, pp. 321-357, 2002. + + Examples + -------- + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from imblearn.over_sampling import \ +GeometricSMOTE # doctest: +NORMALIZE_WHITESPACE + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape %s' % Counter(y)) + Original dataset shape Counter({{1: 900, 0: 100}}) + >>> gsmote = GeometricSMOTE(random_state=1) + >>> X_res, y_res = gsmote.fit_resample(X, y) + >>> print('Resampled dataset shape %s' % Counter(y_res)) + Resampled dataset shape Counter({{0: 900, 1: 900}}) + """ + + def __init__( + self, + sampling_strategy="auto", + random_state=None, + truncation_factor=1.0, + deformation_factor=0.0, + selection_strategy="combined", + k_neighbors=5, + n_jobs=1, + ): + super(GeometricSMOTE, self).__init__(sampling_strategy=sampling_strategy) + self.random_state = random_state + self.truncation_factor = truncation_factor + self.deformation_factor = deformation_factor + self.selection_strategy = selection_strategy + self.k_neighbors = k_neighbors + self.n_jobs = n_jobs + + def _validate_estimator(self): + """Create the necessary attributes for Geometric SMOTE.""" + + # Check random state + self.random_state_ = check_random_state(self.random_state) + + # Validate strategy + if self.selection_strategy not in SELECTION_STRATEGY: + error_msg = ( + "Unknown selection_strategy for Geometric SMOTE algorithm. " + "Choices are {}. Got {} instead." + ) + raise ValueError( + error_msg.format(SELECTION_STRATEGY, self.selection_strategy) + ) + + # Create nearest neighbors object for positive class + if self.selection_strategy in ("minority", "combined"): + self.nns_pos_ = check_neighbors_object( + "nns_positive", self.k_neighbors, additional_neighbor=1 + ) + self.nns_pos_.set_params(n_jobs=self.n_jobs) + + # Create nearest neighbors object for negative class + if self.selection_strategy in ("majority", "combined"): + self.nn_neg_ = check_neighbors_object("nn_negative", nn_object=1) + self.nn_neg_.set_params(n_jobs=self.n_jobs) + + def _make_geometric_samples(self, X, y, pos_class_label, n_samples): + """A support function that returns an artificials samples inside + the geometric region defined by nearest neighbors. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Matrix containing the data which have to be sampled. + y : array-like, shape (n_samples, ) + Corresponding label for each sample in X. + pos_class_label : str or int + The minority class (positive class) target value. + n_samples : int + The number of samples to generate. + + Returns + ------- + X_new : ndarray, shape (n_samples_new, n_features) + Synthetically generated samples. + y_new : ndarray, shape (n_samples_new, ) + Target values for synthetic samples. + + """ + + # Return zero new samples + if n_samples == 0: + return ( + np.array([], dtype=X.dtype).reshape(0, X.shape[1]), + np.array([], dtype=y.dtype), + ) + + # Select positive class samples + X_pos = X[y == pos_class_label] + + # Force minority strategy if no negative class samples are present + self.selection_strategy_ = ( + "minority" if X.shape[0] == X_pos.shape[0] else self.selection_strategy + ) + + # Minority or combined strategy + if self.selection_strategy_ in ("minority", "combined"): + self.nns_pos_.fit(X_pos) + points_pos = self.nns_pos_.kneighbors(X_pos)[1][:, 1:] + samples_indices = self.random_state_.randint( + low=0, high=len(points_pos.flatten()), size=n_samples + ) + rows = np.floor_divide(samples_indices, points_pos.shape[1]) + cols = np.mod(samples_indices, points_pos.shape[1]) + + # Majority or combined strategy + if self.selection_strategy_ in ("majority", "combined"): + X_neg = X[y != pos_class_label] + self.nn_neg_.fit(X_neg) + points_neg = self.nn_neg_.kneighbors(X_pos)[1] + if self.selection_strategy_ == "majority": + samples_indices = self.random_state_.randint( + low=0, high=len(points_neg.flatten()), size=n_samples + ) + rows = np.floor_divide(samples_indices, points_neg.shape[1]) + cols = np.mod(samples_indices, points_neg.shape[1]) + + # Generate new samples + X_new = np.zeros((n_samples, X.shape[1])) + for ind, (row, col) in enumerate(zip(rows, cols)): + + # Define center point + center = X_pos[row] + + # Minority strategy + if self.selection_strategy_ == "minority": + surface_point = X_pos[points_pos[row, col]] + + # Majority strategy + elif self.selection_strategy_ == "majority": + surface_point = X_neg[points_neg[row, col]] + + # Combined strategy + else: + surface_point_pos = X_pos[points_pos[row, col]] + surface_point_neg = X_neg[points_neg[row, 0]] + radius_pos = norm(center - surface_point_pos) + radius_neg = norm(center - surface_point_neg) + surface_point = ( + surface_point_neg if radius_pos > radius_neg else surface_point_pos + ) + + # Append new sample + X_new[ind] = _make_geometric_sample( + center, + surface_point, + self.truncation_factor, + self.deformation_factor, + self.random_state_, + ) + + # Create new samples for target variable + y_new = np.array([pos_class_label] * len(samples_indices)) + + return X_new, y_new + + def _fit_resample(self, X, y): + + # Validate estimator's parameters + self._validate_estimator() + + # Ensure the input data is dense + X_dense = X.toarray() if sparse.issparse(X) else X + + # Copy data + X_resampled, y_resampled = [X_dense.copy()], [y.copy()] + + # Resample data + for class_label, n_samples in self.sampling_strategy_.items(): + + # Apply gsmote mechanism + X_new, y_new = self._make_geometric_samples( + X_dense, y, class_label, n_samples + ) + + X_resampled.append(X_new) + y_resampled.append(y_new) + + # Append new data + if sparse.issparse(X): + X_resampled = sparse.vstack(X_resampled, format=X.format) + else: + X_resampled = np.vstack(X_resampled).astype(X.dtype) + y_resampled = np.hstack(y_resampled).astype(y.dtype) + + return X_resampled, y_resampled diff --git a/imblearn/over_sampling/_smote/tests/test_geometric_smote.py b/imblearn/over_sampling/_smote/tests/test_geometric_smote.py new file mode 100644 index 000000000..3c14bdfec --- /dev/null +++ b/imblearn/over_sampling/_smote/tests/test_geometric_smote.py @@ -0,0 +1,209 @@ +""" +Test the geometric_smote module. +""" + +from collections import Counter + +import pytest +import numpy as np +from numpy.linalg import norm +from sklearn.utils import check_random_state +from sklearn.datasets import make_classification + +from ..geometric import _make_geometric_sample, GeometricSMOTE, SELECTION_STRATEGY + +RND_SEED = 0 +RANDOM_STATE = check_random_state(RND_SEED) +CENTERS = [ + RANDOM_STATE.random_sample((2,)), + 2.6 * RANDOM_STATE.random_sample((4,)), + 3.2 * RANDOM_STATE.random_sample((10,)), + -0.5 * RANDOM_STATE.random_sample((1,)), +] +SURFACE_POINTS = [ + RANDOM_STATE.random_sample((2,)), + 5.2 * RANDOM_STATE.random_sample((4,)), + -3.5 * RANDOM_STATE.random_sample((10,)), + -10.9 * RANDOM_STATE.random_sample((1,)), +] +TRUNCATION_FACTORS = [-1.0, -0.5, 0.0, 0.5, 1.0] +DEFORMATION_FACTORS = [0.0, 0.25, 0.5, 0.75, 1.0] + + +@pytest.mark.parametrize( + "center,surface_point", + [ + (CENTERS[0], SURFACE_POINTS[0]), + (CENTERS[1], SURFACE_POINTS[1]), + (CENTERS[2], SURFACE_POINTS[2]), + (CENTERS[3], SURFACE_POINTS[3]), + ], +) +def test_make_geometric_sample_hypersphere(center, surface_point): + """Test the generation of points inside a hypersphere.""" + point = _make_geometric_sample(center, surface_point, 0.0, 0.0, RANDOM_STATE) + rel_point = point - center + rel_surface_point = surface_point - center + np.testing.assert_array_less(0.0, norm(rel_surface_point) - norm(rel_point)) + + +@pytest.mark.parametrize( + "surface_point,deformation_factor", + [ + (np.array([1.0, 0.0]), 0.0), + (2.6 * np.array([0.0, 1.0]), 0.25), + (3.2 * np.array([0.0, 1.0, 0.0, 0.0]), 0.50), + (0.5 * np.array([0.0, 0.0, 1.0]), 0.75), + (6.7 * np.array([0.0, 0.0, 1.0, 0.0, 0.0]), 1.0), + ], +) +def test_make_geometric_sample_half_hypersphere(surface_point, deformation_factor): + """Test the generation of points inside a hypersphere.""" + center = np.zeros(surface_point.shape) + point = _make_geometric_sample( + center, surface_point, 1.0, deformation_factor, RANDOM_STATE + ) + np.testing.assert_array_less(0.0, norm(surface_point) - norm(point)) + np.testing.assert_array_less(0.0, np.dot(point, surface_point)) + + +@pytest.mark.parametrize( + "center,surface_point,truncation_factor", + [ + (center, surface_point, truncation_factor) + for center, surface_point in zip(CENTERS, SURFACE_POINTS) + for truncation_factor in TRUNCATION_FACTORS + ], +) +def test_make_geometric_sample_line_segment(center, surface_point, truncation_factor): + """Test the generation of points on a line segment.""" + point = _make_geometric_sample( + center, surface_point, truncation_factor, 1.0, RANDOM_STATE + ) + rel_point = point - center + rel_surface_point = surface_point - center + dot_product = np.dot(rel_point, rel_surface_point) + norms_product = norm(rel_point) * norm(rel_surface_point) + np.testing.assert_array_less(0.0, norm(rel_surface_point) - norm(rel_point)) + dot_product = ( + np.abs(dot_product) if truncation_factor == 0.0 else (-1) * dot_product + ) + np.testing.assert_allclose(np.abs(dot_product) / norms_product, 1.0) + + +def test_gsmote_default_init(): + """Test the intialization with default parameters.""" + gsmote = GeometricSMOTE() + assert gsmote.sampling_strategy == "auto" + assert gsmote.random_state is None + assert gsmote.truncation_factor == 1.0 + assert gsmote.deformation_factor == 0.0 + assert gsmote.selection_strategy == "combined" + assert gsmote.k_neighbors == 5 + assert gsmote.n_jobs == 1 + + +def test_gsmote_fit(): + """Test fit method.""" + n_samples, weights = 200, [0.6, 0.4] + X, y = make_classification( + random_state=RND_SEED, n_samples=n_samples, weights=weights + ) + gsmote = GeometricSMOTE(random_state=RANDOM_STATE).fit(X, y) + assert gsmote.sampling_strategy_ == {1: 40} + + +def test_gsmote_invalid_selection_strategy(): + """Test invalid selection strategy.""" + n_samples, weights = 200, [0.6, 0.4] + X, y = make_classification( + random_state=RND_SEED, n_samples=n_samples, weights=weights + ) + gsmote = GeometricSMOTE(random_state=RANDOM_STATE, selection_strategy="Minority") + with pytest.raises(ValueError): + gsmote.fit_resample(X, y) + + +@pytest.mark.parametrize("selection_strategy", ["combined", "minority", "majority"]) +def test_gsmote_nn(selection_strategy): + """Test nearest neighbors object.""" + n_samples, weights = 200, [0.6, 0.4] + X, y = make_classification( + random_state=RND_SEED, n_samples=n_samples, weights=weights + ) + gsmote = GeometricSMOTE( + random_state=RANDOM_STATE, selection_strategy=selection_strategy + ) + _ = gsmote.fit_resample(X, y) + if selection_strategy in ("minority", "combined"): + assert gsmote.nns_pos_.n_neighbors == gsmote.k_neighbors + 1 + if selection_strategy in ("majority", "combined"): + assert gsmote.nn_neg_.n_neighbors == 1 + + +@pytest.mark.parametrize( + "selection_strategy, truncation_factor, deformation_factor", + [ + (selection_strategy, truncation_factor, deformation_factor) + for selection_strategy in SELECTION_STRATEGY + for truncation_factor in TRUNCATION_FACTORS + for deformation_factor in DEFORMATION_FACTORS + ], +) +def test_gsmote_fit_resample_binary( + selection_strategy, truncation_factor, deformation_factor +): + """Test fit and sample for binary class case.""" + n_maj, n_min, step, min_coor, max_coor = 12, 5, 0.5, 0.0, 8.5 + X = np.repeat(np.arange(min_coor, max_coor, step), 2).reshape(-1, 2) + y = np.concatenate([np.repeat(0, n_maj), np.repeat(1, n_min)]) + max_radius = np.sqrt(np.sum((X[-1] - X[n_maj - 1]) ** 2)) + k_neighbors = 1 + gsmote = GeometricSMOTE( + "auto", + RANDOM_STATE, + truncation_factor, + deformation_factor, + selection_strategy, + k_neighbors, + ) + X_resampled, y_resampled = gsmote.fit_resample(X, y) + assert gsmote.sampling_strategy_ == {1: (n_maj - n_min)} + assert y_resampled.sum() == n_maj + np.testing.assert_array_less(X[-1] - max_radius, X_resampled[n_maj + n_min]) + + +@pytest.mark.parametrize( + "selection_strategy, truncation_factor, deformation_factor", + [ + (selection_strategy, truncation_factor, deformation_factor) + for selection_strategy in SELECTION_STRATEGY + for truncation_factor in TRUNCATION_FACTORS + for deformation_factor in DEFORMATION_FACTORS + ], +) +def test_gsmote_fit_resample_multiclass( + selection_strategy, truncation_factor, deformation_factor +): + """Test fit and sample for multiclass case.""" + n_samples, weights = 100, [0.75, 0.15, 0.10] + X, y = make_classification( + random_state=RND_SEED, + n_samples=n_samples, + weights=weights, + n_classes=3, + n_informative=5, + ) + k_neighbors, majority_label = 1, 0 + gsmote = GeometricSMOTE( + "auto", + RANDOM_STATE, + truncation_factor, + deformation_factor, + selection_strategy, + k_neighbors, + ) + _, y_resampled = gsmote.fit_resample(X, y) + assert majority_label not in gsmote.sampling_strategy_.keys() + np.testing.assert_array_equal(np.unique(y), np.unique(y_resampled)) + assert len(set(Counter(y_resampled).values())) == 1