Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] load_classifier_results() does not work for multiclass time series #330

Closed
Mithrillion opened this issue Feb 3, 2025 · 2 comments · Fixed by #326
Closed

[BUG] load_classifier_results() does not work for multiclass time series #330

Mithrillion opened this issue Feb 3, 2025 · 2 comments · Fixed by #326
Labels
bug Something isn't working

Comments

@Mithrillion
Copy link

Describe the bug

When loading an experiment file using load_classifier_results(), if the time series has multiple classes, sklearn will fail to compute precision / recall scores because it cannot compute binary classification store on multiclass label series.

This section of the code is causing issues:

if self.sensitivity is None or overwrite:
    self.sensitivity = recall_score(
        self.class_labels,
        self.predictions,
        average="binary",
        pos_label=self._minority_class,
        zero_division=0.0,
    )
if self.specificity is None or overwrite:
    self.specificity = recall_score(
        self.class_labels,
        self.predictions,
        average="binary",
        pos_label=self._majority_class,
        zero_division=0.0,
    )

It might be necessary to explicitly convert the class labels, such as:

self.sensitivity = recall_score(
    self.class_labels == self._minority_class,
    self.predictions == self._minority_class,
    average="binary",
    zero_division=0.0,
)

Steps/Code to reproduce the bug

This is a sklearn API issue, which can be reproduced with the example below:

from sklearn.metrics import recall_score

recall_score(
    [0, 2, 1, 2, 1, 2, 1, 1, 0],
    [1, 1, 1, 1, 2, 2, 2, 2, 0],
    average="binary",
    pos_label=2,
    zero_division=0.0,
)

Expected results

Converting a multiclass classification problem to binary based on pos_label.

Actual results

(directly reproduced on recall_score() itself instead of the full load_classifier_results() function)

ValueError                                Traceback (most recent call last)
Cell In[2], line 1
----> 1 recall_score(
      2     [0, 2, 1, 2, 1, 2, 1, 1, 0],
      3     [1, 1, 1, 1, 2, 2, 2, 2, 0],
      4     average="binary",
      5     pos_label=2,
      6     zero_division=0.0,
      7 )

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:213, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    207 try:
    208     with config_context(
    209         skip_parameter_validation=(
    210             prefer_skip_nested_validation or global_skip_validation
    211         )
    212     ):
--> 213         return func(*args, **kwargs)
    214 except InvalidParameterError as e:
    215     # When the function is just a wrapper around an estimator, we allow
    216     # the function to delegate validation to the estimator, but we replace
    217     # the name of the estimator by the name of the function in the error
    218     # message to avoid confusion.
    219     msg = re.sub(
    220         r"parameter of \w+ must be",
    221         f"parameter of {func.__qualname__} must be",
    222         str(e),
    223     )

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/sklearn/metrics/_classification.py:2385, in recall_score(y_true, y_pred, labels, pos_label, average, sample_weight, zero_division)
   2217 @validate_params(
   2218     {
   2219         "y_true": ["array-like", "sparse matrix"],
   (...)
   2244     zero_division="warn",
   2245 ):
   2246     """Compute the recall.
   2247 
   2248     The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
   (...)
   2383     array([1. , 1. , 0.5])
   2384     """
-> 2385     _, r, _, _ = precision_recall_fscore_support(
   2386         y_true,
   2387         y_pred,
   2388         labels=labels,
   2389         pos_label=pos_label,
   2390         average=average,
   2391         warn_for=("recall",),
   2392         sample_weight=sample_weight,
   2393         zero_division=zero_division,
   2394     )
   2395     return r

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:186, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    184 global_skip_validation = get_config()["skip_parameter_validation"]
    185 if global_skip_validation:
--> 186     return func(*args, **kwargs)
    188 func_sig = signature(func)
    190 # Map *args/**kwargs to the function signature

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1789, in precision_recall_fscore_support(y_true, y_pred, beta, labels, pos_label, average, warn_for, sample_weight, zero_division)
   1626 """Compute precision, recall, F-measure and support for each class.
   1627 
   1628 The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
   (...)
   1786  array([2, 2, 2]))
   1787 """
   1788 _check_zero_division(zero_division)
-> 1789 labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label)
   1791 # Calculate tp_sum, pred_sum, true_sum ###
   1792 samplewise = average == "samples"

File /mnt/Nova/Envs/ml_dev/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1578, in _check_set_wise_labels(y_true, y_pred, average, labels, pos_label)
   1576         if y_type == "multiclass":
   1577             average_options.remove("samples")
-> 1578         raise ValueError(
   1579             "Target is %s but average='binary'. Please "
   1580             "choose another average setting, one of %r." % (y_type, average_options)
   1581         )
   1582 elif pos_label not in (None, 1):
   1583     warnings.warn(
   1584         "Note that pos_label (set to %r) is ignored when "
   1585         "average != 'binary' (got %r). You may use "
   (...)
   1588         UserWarning,
   1589     )

ValueError: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].
@Mithrillion Mithrillion added the bug Something isn't working label Feb 3, 2025
@MatthewMiddlehurst
Copy link
Member

ah sorry this was a recent change, I have this fixed in a branch. will push that now

@MatthewMiddlehurst
Copy link
Member

this should be fixed on main now using a different method of generating the stat for non-binary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants