Skip to content

[MRG] implement InstanceHardnessCV #1125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions doc/cross_validation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
.. _cross_validation:

================
Cross validation
================

.. currentmodule:: imblearn.cross_validation


.. _instance_hardness_threshold:

The term instance hardness is used in literature to express the difficulty to
correctly classify an instance. An instance for which the predicted probability
of the true class is low, has large instance hardness. The way these
hard-to-classify instances are distributed over train and test sets in cross
validation, has significant effect on the test set performance metrics. The
`InstanceHardnessCV` splitter distributes samples with large instance hardness
equally over the folds, resulting in more robust cross validation.

We will discuss instance hardness in this document and explain how to use the
`InstanceHardnessCV` splitter.

Instance hardness and average precision
=======================================
Instance hardness is defined as 1 minus the probability of the most probable class:

.. math::

H(x) = 1 - P(\hat{y}|x)

In this equation :math:`H(x)` is the instance hardness for a sample with features
:math:`x` and :math:`P(\hat{y}|x)` the probability of predicted label :math:`\hat{y}`
given the features. If the model predicts label 0 and gives a `predict_proba` output
of [0.9, 0.1], the probability of the most probable class (0) is 0.9 and the
instance hardness is 1-0.9=0.1.

Samples with large instance hardness have significant effect on the area under
precision-recall curve, or average precision. Especially samples with label 0
with large instance hardness (so the model predicts label 1) reduce the average
precision a lot as these points affect the precision-recall curve in the left
where the area is largest; the precision is lowered in the range of low recall
and high thresholds. When doing cross validation, e.g. in case of hyperparameter
tuning or recursive feature elimination, random gathering of these points in
some folds introduce variance in CV results that deteriorates robustness of the
cross validation task. The `InstanceHardnessCV`
splitter aims to distribute the samples with large instance hardness over the
folds in order to reduce undesired variance. Note that one should use this
splitter to make model *selection* tasks robust like hyperparameter tuning and
feature selection but not for model *performance estimation* for which you also
want to know the variance of performance to be expected in production.


Create imbalanced dataset with samples with large instance hardness
===================================================================

Let’s start by creating a dataset to work with. We create a dataset with 5% class
imbalance using scikit-learn’s `make_blobs` function.

>>> import numpy as np
>>> from matplotlib import pyplot as plt
>>> from sklearn.datasets import make_blobs
>>> from imblearn.datasets import make_imbalance
>>> random_state = 10
>>> X, y = make_blobs(n_samples=[950, 50], centers=((-3, 0), (3, 0)),
... random_state=random_state)
>>> plt.scatter(X[:, 0], X[:, 1], c=y)
>>> plt.show()

.. image:: ./auto_examples/cross_validation/images/sphx_glr_plot_instance_hardness_cv_001.png
:target: ./auto_examples/cross_validation/plot_instance_hardness_cv.html
:align: center

Now we add some samples with large instance hardness

>>> X_hard, y_hard = make_blobs(n_samples=10, centers=((3, 0), (-3, 0)),
... cluster_std=1,
... random_state=random_state)
>>> X = np.vstack((X, X_hard))
>>> y = np.hstack((y, y_hard))
>>> plt.scatter(X[:, 0], X[:, 1], c=y)
>>> plt.show()

.. image:: ./auto_examples/cross_validation/images/sphx_glr_plot_instance_hardness_cv_002.png
:target: ./auto_examples/cross_validation/plot_instance_hardness_cv.html
:align: center

Assess cross validation performance variance using InstanceHardnessCV splitter
==============================================================================

Then we take a `LogisticRegressionClassifier` and assess the cross validation
performance using a `StratifiedKFold` cv splitter and the `cross_validate`
function.

>>> from sklearn.ensemble import LogisticRegressionClassifier
>>> clf = LogisticRegressionClassifier(random_state=random_state)
>>> skf_cv = StratifiedKFold(n_splits=5, shuffle=True,
... random_state=random_state)
>>> skf_result = cross_validate(clf, X, y, cv=skf_cv, scoring="average_precision")

Now, we do the same using an `InstanceHardnessCV` splitter. We use provide our
classifier to the splitter to calculate instance hardness and distribute samples
with large instance hardness equally over the folds.

>>> ih_cv = InstanceHardnessCV(estimator=clf, n_splits=5,
... random_state=random_state)
>>> ih_result = cross_validate(clf, X, y, cv=ih_cv, scoring="average_precision")

When we plot the test scores for both cv splitters, we see that the variance using
the `InstanceHardnessCV` splitter is lower than for the `StratifiedKFold` splitter.

>>> plt.boxplot([skf_result['test_score'], ih_result['test_score']],
... tick_labels=["StratifiedKFold", "InstanceHardnessCV"],
... vert=False)
>>> plt.xlabel('Average precision')
>>> plt.tight_layout()

.. image:: ./auto_examples/cross_validation/images/sphx_glr_plot_instance_hardness_cv_003.png
:target: ./auto_examples/cross_validation/plot_instance_hardness_cv.html
:align: center
23 changes: 23 additions & 0 deletions doc/references/cross_validation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. _under_sampling_ref:

Cross validation methods
======================

.. automodule:: imblearn.cross_validation
:no-members:
:no-inherited-members:

CV splitters
--------------------

.. automodule:: imblearn.cross_validation._cross_validation
:no-members:
:no-inherited-members:

.. currentmodule:: imblearn.cross_validation

.. autosummary::
:toctree: generated/
:template: class.rst

InstanceHardnessCV
1 change: 1 addition & 0 deletions doc/references/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ This is the full API documentation of the `imbalanced-learn` toolbox.
miscellaneous
pipeline
metrics
cross_validation
datasets
utils
1 change: 1 addition & 0 deletions doc/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ User Guide
ensemble.rst
miscellaneous.rst
metrics.rst
cross_validation.rst
common_pitfalls.rst
Dataset loading utilities <datasets/index.rst>
developers_utils.rst
Expand Down
6 changes: 6 additions & 0 deletions examples/cross_validation/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _cross_validation_examples:

Example using cross validation classes
======================================

Cross validation classes to be used for classification problems with imbalanced class distributions
82 changes: 82 additions & 0 deletions examples/cross_validation/plot_instance_hardness_cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
====================================================
Distribute hard-to-classify datapoints over CV folds
====================================================

'Instance hardness' refers to the difficulty to classify an instance. The way
hard-to-classify instances are distributed over train and test sets has
significant effect on the test set performance metrics. In this example we
show how to deal with this problem. We are making the comparison with normal
StratifiedKFold cv splitting.
"""

# Authors: Frits Hermans, https://fritshermans.github.io
# License: MIT

# %%
print(__doc__)

# %% [markdown]
# Create an imbalanced dataset with instance hardness
# ---------------------------------------------------
#
# We will create an imbalanced dataset with using scikit-learn's `make_blobs`
# function and set the imbalancedness to 5%; only 5% of the labels is positive.


import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs

X, y = make_blobs(n_samples=[950,50], centers=((-3, 0), (3, 0)), random_state=10)
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()

# %%
# To introduce instance hardness in our dataset, we add some hard to classify samples:
X_hard, y_hard = make_blobs(n_samples=10, centers=((3, 0), (-3, 0)),
cluster_std=1,
random_state=10)
X = np.vstack((X, X_hard))
y = np.hstack((y, y_hard))
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()

# %% [markdown]
# Compare cross validation scores using StratifiedKFold and InstanceHardnessCV
# ----------------------------------------------------------------------------
#
# We calculate cross validation scores using `cross_validate` and a
# `LogisticRegression` classifier. We compare the results using a
# `StratifiedKFold` cv splitter and an `InstanceHardnessCV` splitter.
# As we are dealing with an imbalanced classification problem, we
# use `average_precision` for scoring.

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold, cross_validate

from imblearn.cross_validation import InstanceHardnessCV

# %%
clf = LogisticRegression(random_state=10)

# %%
skf_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=10)
skf_result = cross_validate(clf, X, y, cv=skf_cv, scoring="average_precision")

# %%
ih_cv = InstanceHardnessCV(estimator=clf, n_splits=5, random_state=10)
ih_result = cross_validate(clf, X, y, cv=ih_cv, scoring="average_precision")

# %%
# The boxplot below shows that the `InstanceHardnessCV` splitter results
# in less variation of average precision than `StratifiedKFold` splitter.
# When doing hyperparameter tuning or feature selection using a wrapper
# method (like `RFECV`) this will give more stable results.

# %%
plt.boxplot([skf_result['test_score'], ih_result['test_score']],
tick_labels=["StratifiedKFold", "InstanceHardnessCV"], vert=False)
plt.xlabel('Average precision')
plt.tight_layout()
plt.show()
3 changes: 3 additions & 0 deletions imblearn/cross_validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._cross_validation import InstanceHardnessCV

__all__ = ["InstanceHardnessCV"]
111 changes: 111 additions & 0 deletions imblearn/cross_validation/_cross_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import LeaveOneGroupOut, cross_val_predict


class InstanceHardnessCV:
"""Instance-hardness CV splitter

CV splitter that distributes samples with large instance hardness equally
over the folds

Read more in the :ref:`User Guide <instance_hardness_threshold>`.

Parameters
----------
estimator : estimator object
Classifier to be used to estimate instance hardness of the samples.
This classifier should implement `predict_proba`.

n_splits : int, default=5
Number of folds. Must be at least 2.

random_state : int, RandomState instance, default=None
Determines random_state for reproducible results across multiple calls.

Examples
--------
>>> from imblearn.cross_validation import InstanceHardnessCV
>>> from sklearn.datasets import make_classification
>>> from sklearn.model_selection import cross_validate
>>> from sklearn.linear_model import LogisticRegression
>>> X, y = make_classification(weights=[0.9, 0.1], class_sep=2,
... n_informative=3, n_redundant=1, flip_y=0.05, n_samples=1000, random_state=10)
>>> estimator = LogisticRegression(random_state=10)
>>> ih_cv = InstanceHardnessCV(estimator=estimator, n_splits=5,random_state=10)
>>> cv_result = cross_validate(estimator, X, y, cv=ih_cv)
>>> print(f"Standard deviation of test_scores: {cv_result['test_score'].std():.3f}")
Standard deviation of test_scores: 0.004
"""

def __init__(self, estimator, n_splits=5, random_state=None):
self.n_splits = n_splits
self.estimator = estimator
self.random_state = random_state

def split(self, X, y, groups=None):
"""
Generate indices to split data into training and test set.

Parameters
----------
X: array-like of shape (n_samples, n_features)
Training data, where n_samples is the number of samples and
n_features is the number of features.

y: array-like of shape (n_samples,)
The target variable.

groups: object
Always ignored, exists for compatibility.

Yields
------

train: ndarray
The training set indices for that split.

test: ndarray
The testing set indices for that split.

"""
if self.estimator is not None:
self.estimator_ = self.estimator
else:
self.estimator_ = RandomForestClassifier(
n_jobs=-1, class_weight="balanced", random_state=self.random_state
)
probas = cross_val_predict(
self.estimator_, X, y, cv=self.n_splits, method="predict_proba"
)
# by sorting first on y then on proba rows are ordered by instance hardness
# within the group having the same label
sorted_indices = np.lexsort((probas[:, 1], y))
groups = np.zeros(len(X), dtype=int)
groups[sorted_indices] = np.arange(len(X)) % self.n_splits
cv = LeaveOneGroupOut()
for train_index, test_index in cv.split(X, y, groups):
yield train_index, test_index

def get_n_splits(self, X=None, y=None, groups=None):
"""
Returns the number of splitting iterations in the cross-validator.

Parameters
----------
X: object
Always ignored, exists for compatibility.

y: object
Always ignored, exists for compatibility.

groups: object
Always ignored, exists for compatibility.

Returns
-------
n_splits: int
Returns the number of splitting iterations in the cross-validator.

"""
return self.n_splits
Empty file.
32 changes: 32 additions & 0 deletions imblearn/cross_validation/tests/test_instance_hardness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_validate
from sklearn.utils._testing import assert_array_equal

from imblearn.cross_validation import InstanceHardnessCV

X, y = make_classification(
weights=[0.9, 0.1],
class_sep=2,
n_informative=3,
n_redundant=1,
flip_y=0.05,
n_samples=1000,
random_state=10,
)


def test_instancehardness_cv():
clf = LogisticRegression(random_state=10)
ih_cv = InstanceHardnessCV(estimator=clf, random_state=10)
cv_result = cross_validate(clf, X, y, cv=ih_cv)
assert_array_equal(cv_result['test_score'], [0.975, 0.965, 0.96, 0.955, 0.965])


@pytest.mark.parametrize("n_splits", [2, 3, 4])
def test_instancehardness_cv_n_splits(n_splits):
clf = LogisticRegression(random_state=10)
ih_cv = InstanceHardnessCV(estimator=clf, n_splits=n_splits, random_state=10)
assert ih_cv.get_n_splits() == n_splits