Skip to content

[ENH] Allow fit_resample to receive metadata routed parameters #1111

Description

@ShimantoRahman

Is your feature request related to a problem? Please describe

In cost-sensitive learning, resampling techniques are used to address the asymmetrical importance of data points. These techniques require the amount of resampling to be dependent on instance-specific parameters, such as cost weights associated with individual data points. These cost weights are usually in a cost matrix for each data point $i$:

Actual Positive ($y_i = 1$) Actual Negative ($y_i = 0$)
Predicted Positive ($\hat y_i=1$) $C_{TP_i} $ $C_{FP_i}$
Predicted Negative ($\hat y_i=0$) $C_{FN_i}$ $C_{TN_i}$

Since these cost weights are dependent on the data point, they cannot be predetermined during initialization __init__ but instead must adapt dynamically based on the input data during the fit_resample process.

The current implementation imbalanced-learn Pipeline object does not natively support passing metadata through its fit_resample method. Metadata routing, which would enable instance-dependent parameters to flow seamlessly through the pipeline, is critical for implementing cost-sensitive learning workflows.

Desired workflow (DOES NOT CURRENTLY WORK)

import numpy as np
from imblearn.pipeline import Pipeline
from sklearn import set_config
from sklearn.utils._metadata_requests import MetadataRequest, RequestMethod
from sklearn.base import BaseEstimator
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression


set_config(enable_metadata_routing=True)

class CostSensitiveSampler(BaseEstimator):

    _estimator_type = "sampler"
    __metadata_request__fit_resample = {'cost_matrix': True}

    def __init__(self, random_state=None):
        self.random_state = random_state

    def fit_resample(self, X, y, cost_matrix=None):
        # resample based on cost_matrix
        # ...
        return X, y

    def _get_metadata_request(self):
        routing = MetadataRequest(owner=self.__class__.__name__)
        routing.fit_resample.add_request(param='cost_matrix', alias=True)
        return routing

    set_fit_resample_request = RequestMethod('fit_resample', ['cost_matrix'])

X, y = make_classification()
cost_matrix = np.random.rand(X.shape[0], 2, 2)
pipeline = Pipeline([
    ('sampler', CostSensitiveSampler().set_fit_resample_request(cost_matrix=True)),
    ('model', LogisticRegression())
])
pipeline.fit(X, y, cost_matrix=cost_matrix)

Describe the solution you'd like

From what I understand from the metadata routing implementation of the Pipeline object only a couple of changes have to be made:

  1. the SIMPLE_METHODS constant found here needs to include "fit_resample":
SIMPLE_METHODS = [
            "fit",
            "partial_fit",
            "fit_resample",  # add line here
            "predict",
            "predict_proba",
            "predict_log_proba",
            "decision_function",
            "score",
            "split",
            "transform",
            "inverse_transform",
        ]

Note that this does require imbalanced-learn to redefine the classes and functions which use the SIMPLE_METHODS constant internally. These are now imported from scikit-learn if scikit-learn version 1.4 or higher is installed. These include: MetadataRequest and _MetadataRequester.
2. A method mapping from caller "fit" to callee "fit_resample" has to be added in the get_meta_data_routing(self) method found here and the filter_resample parameter of self._iter method needs be set to False:

def get_metadata_routing(self):
        """Get metadata routing of this object.

        Please check :ref:`User Guide <metadata_routing>` on how the routing
        mechanism works.

        Returns
        -------
        routing : MetadataRouter
            A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
            routing information.
        """
        router = MetadataRouter(owner=self.__class__.__name__)

        # first we add all steps except the last one
        for _, name, trans in self._iter(with_final=False, filter_passthrough=True, filter_resample=False):  # change filter_resample to False
            method_mapping = MethodMapping()
            # fit, fit_predict, and fit_transform call fit_transform if it
            # exists, or else fit and transform
            if hasattr(trans, "fit_transform"):
                (
                    method_mapping.add(caller="fit", callee="fit_transform")
                    .add(caller="fit_transform", callee="fit_transform")
                    .add(caller="fit_predict", callee="fit_transform")
                    .add(caller="fit_resample", callee="fit_transform")
                )
            else:
                (
                    method_mapping.add(caller="fit", callee="fit")
                    .add(caller="fit", callee="transform")
                    .add(caller="fit_transform", callee="fit")
                    .add(caller="fit_transform", callee="transform")
                    .add(caller="fit_predict", callee="fit")
                    .add(caller="fit_predict", callee="transform")
                    .add(caller="fit_resample", callee="fit")
                    .add(caller="fit_resample", callee="transform")
                )

            (
                method_mapping.add(caller="predict", callee="transform")
                .add(caller="predict", callee="transform")
                .add(caller="predict_proba", callee="transform")
                .add(caller="decision_function", callee="transform")
                .add(caller="predict_log_proba", callee="transform")
                .add(caller="transform", callee="transform")
                .add(caller="inverse_transform", callee="inverse_transform")
                .add(caller="score", callee="transform")
                .add(caller="fit_resample", callee="transform")
                .add(caller="fit", callee="fit_resample")  # add this line
            )

            
            router.add(method_mapping=method_mapping, **{name: trans})
        # add final estimator method mapping
        ...

Additional context

I am a PhD Researcher and used these methods for my paper and the author of a python package Empulse which has implemented samplers which require cost parameters to be passed to the fit_resample method like in the dummy example (see Empulse/Samplers). I find the whole metadata routing implementation incredibly confusing, so apologies if I made some mistakes in my reasoning.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions