Skip to content

Improve type annotations in sklearn.metrics._regression #357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 164 additions & 30 deletions stubs/sklearn/metrics/_regression.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,84 @@ from ..utils.validation import (

__ALL__: list = ...

@overload
def mean_absolute_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["raw_values"],
) -> ndarray: ...
@overload
def mean_absolute_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
) -> ndarray | Float: ...
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
) -> float: ...
@overload
def mean_pinball_loss(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
alpha: float = 0.5,
multioutput: Literal["raw_values"],
) -> ndarray: ...
@overload
def mean_pinball_loss(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
alpha: float = 0.5,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
) -> ndarray | Float: ...
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
) -> Float: ...
@overload
def mean_absolute_percentage_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["raw_values"],
) -> ndarray: ...
@overload
def mean_absolute_percentage_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
) -> ndarray | Float: ...
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
) -> float: ...
@overload
def mean_squared_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["raw_values"],
) -> ndarray: ...
@overload
def mean_squared_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
) -> float: ...
@deprecated(
"`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_error` instead to calculate the root mean squared error."
)
@overload
def mean_squared_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
) -> ndarray | Float: ...
multioutput: Literal["raw_values"],
squared: bool,
) -> ndarray: ...
@deprecated(
"`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_error` instead to calculate the root mean squared error."
)
Expand All @@ -73,17 +121,37 @@ def mean_squared_error(
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
squared: bool,
) -> ndarray | Float: ...
) -> float: ...
@overload
def mean_squared_log_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["raw_values"],
) -> ndarray: ...
@overload
def mean_squared_log_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
) -> float: ...
@deprecated(
"`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_log_error` instead to calculate the root mean squared logarithmic error."
)
@overload
def mean_squared_log_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
) -> float | ndarray: ...
multioutput: Literal["raw_values"],
squared: bool,
) -> ndarray: ...
@deprecated(
"`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_log_error` instead to calculate the root mean squared logarithmic error."
)
Expand All @@ -93,40 +161,69 @@ def mean_squared_log_error(
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
squared: bool,
) -> float | ndarray: ...
) -> float: ...
@overload
def median_absolute_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
multioutput: Literal["raw_values"],
sample_weight: None | ArrayLike = None,
) -> ndarray | Float: ...
) -> ndarray: ...
@overload
def median_absolute_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
sample_weight: None | ArrayLike = None,
) -> Float: ...
@overload
def explained_variance_score(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["raw_values"],
force_finite: bool = True,
) -> ndarray: ...
@overload
def explained_variance_score(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["raw_values", "uniform_average", "variance_weighted"] | ArrayLike = "uniform_average",
multioutput: Literal["uniform_average", "variance_weighted"] | ArrayLike = "uniform_average",
force_finite: bool = True,
) -> float: ...
@overload
def r2_score(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["raw_values"],
force_finite: bool = True,
) -> float | ndarray: ...
) -> ndarray: ...
@overload
def r2_score(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: (Literal["raw_values", "uniform_average", "variance_weighted"] | None | ArrayLike) = "uniform_average",
multioutput: Literal["uniform_average", "variance_weighted"] | ArrayLike | None = "uniform_average",
force_finite: bool = True,
) -> ndarray | Float: ...
) -> float: ...
def max_error(y_true: ArrayLike, y_pred: ArrayLike) -> float: ...
def mean_tweedie_deviance(
y_true: ArrayLike,
y_pred: ArrayLike,
*,
sample_weight: None | ArrayLike = None,
power: Float = 0,
) -> Float: ...
) -> float: ...
def mean_poisson_deviance(y_true: ArrayLike, y_pred: ArrayLike, *, sample_weight: None | ArrayLike = None) -> Float: ...
def mean_gamma_deviance(y_true: ArrayLike, y_pred: ArrayLike, *, sample_weight: None | ArrayLike = None) -> float: ...
def d2_tweedie_score(
Expand All @@ -135,33 +232,70 @@ def d2_tweedie_score(
*,
sample_weight: None | ArrayLike = None,
power: Float = 0,
) -> float | ndarray: ...
) -> float: ...
@overload
def d2_pinball_score(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
alpha: Float = 0.5,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
) -> float | ndarray: ...
multioutput: Literal["raw_values"],
) -> ndarray: ...
@overload
def d2_pinball_score(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
alpha: Float = 0.5,
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
) -> Float: ...
@overload
def d2_absolute_error_score(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["raw_values"],
) -> ndarray: ...
@overload
def d2_absolute_error_score(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
) -> float | ndarray: ...
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
) -> Float: ...
@overload
def root_mean_squared_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
) -> float | ndarray: ...
multioutput: Literal["raw_values"],
) -> ndarray: ...
@overload
def root_mean_squared_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
) -> float: ...
@overload
def root_mean_squared_log_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: Literal["raw_values"],
) -> ndarray: ...
@overload
def root_mean_squared_log_error(
y_true: MatrixLike | ArrayLike,
y_pred: MatrixLike | ArrayLike,
*,
sample_weight: None | ArrayLike = None,
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
) -> float | ndarray: ...
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
) -> float: ...