Skip to content

Commit 66157c3

Browse files
fix: ruff error
1 parent 8bbc027 commit 66157c3

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

Diff for: python/evalica/__init__.py

+22-19
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Collection, Hashable
77
from dataclasses import dataclass
88
from types import MappingProxyType
9-
from typing import Generic, Literal, Protocol, TypeVar, cast, runtime_checkable, Callable
9+
from typing import Callable, Generic, Literal, Protocol, TypeVar, cast, runtime_checkable
1010

1111
import numpy as np
1212
import numpy.typing as npt
@@ -1000,21 +1000,23 @@ def pairwise_frame(scores: pd.Series[float]) -> pd.DataFrame:
10001000

10011001
@dataclass
10021002
class BootstrapConfidenceInterval:
1003-
score_method: Literal['elo', 'bradley-terry', 'newman'] # TODO: more?
1003+
"""Generate confidence interval by bootstrap."""
1004+
1005+
score_method: Literal["elo", "bradley-terry", "newman"] # TODO: more?
10041006
num_rounds: int = 100
10051007
sample_rate: float = 1.0
10061008
with_replace: bool = True
10071009

10081010
def fit(self,
10091011
df: pd.DataFrame,
1010-
left_column: str = 'left',
1011-
right_column: str = 'right',
1012-
winner_colum: str = 'winner',
1012+
left_column: str = "left",
1013+
right_column: str = "right",
1014+
winner_colum: str = "winner",
10131015
weights: Collection[float] | None = None,
10141016
win_weight: float = 1.0,
10151017
tie_weight: float = 0.5,
10161018
solver: Literal["naive", "pyo3"] = "pyo3",
1017-
**kwargs,
1019+
**kwargs: float, # TODO: change to Unpack?
10181020
) -> pd.DataFrame:
10191021
"""
10201022
Calculate confidence interval by bootstrap.
@@ -1032,15 +1034,15 @@ def fit(self,
10321034
10331035
Returns:
10341036
The dataframe with scores from all bootstrap rounds.
1035-
"""
10361037
1038+
"""
10371039
score_function = self._get_score_method()
10381040
*_, index = indexing(
10391041
xs=df[left_column],
10401042
ys=df[right_column],
10411043
)
10421044

1043-
bootstrap: list["pd.Series[float]"] = []
1045+
bootstrap: list[pd.Series[float]] = []
10441046
for r in range(self.num_rounds):
10451047
df_sample = df.sample(frac=self.sample_rate, replace=self.with_replace, random_state=r)
10461048
result_sample = score_function(
@@ -1052,31 +1054,33 @@ def fit(self,
10521054
win_weight=win_weight,
10531055
tie_weight=tie_weight,
10541056
solver=solver,
1055-
**kwargs
1057+
**kwargs,
10561058
)
10571059

10581060
bootstrap.append(result_sample.scores)
10591061

1060-
df_bootstrap = pd.DataFrame(bootstrap)
10611062
# TODO: calculate the quantiles in here?
1062-
return df_bootstrap
1063+
return pd.DataFrame(bootstrap)
10631064

1064-
def plot(self, df: pd.DataFrame):
1065-
pass
1065+
def plot(self, df: pd.DataFrame) -> None:
1066+
"""Plot confidence interval by plotly."""
1067+
raise NotImplementedError
10661068

1067-
def _get_score_method(self) -> Callable[..., Generic[T]]:
1069+
def _get_score_method(self) -> Callable[..., EloResult[T] | BradleyTerryResult[T] | NewmanResult[T]]:
10681070
score_method_map = {
1069-
'elo': elo,
1070-
'bradley-terry': bradley_terry,
1071-
'newman': newman,
1071+
"elo": elo,
1072+
"bradley-terry": bradley_terry,
1073+
"newman": newman,
10721074
}
10731075
if (score_method := score_method_map.get(self.score_method)) is None:
1074-
ValueError(f"{self.score_method=}, which is not supported!")
1076+
error_msg = f"Unsupported score method: {self.score_method}!"
1077+
raise ValueError(error_msg)
10751078
return score_method
10761079

10771080

10781081
__all__ = [
10791082
"WINNERS",
1083+
"BootstrapConfidenceInterval",
10801084
"BradleyTerryResult",
10811085
"CountingResult",
10821086
"EigenResult",
@@ -1100,5 +1104,4 @@ def _get_score_method(self) -> Callable[..., Generic[T]]:
11001104
"pagerank",
11011105
"pairwise_frame",
11021106
"pairwise_scores",
1103-
"BootstrapConfidenceInterval",
11041107
]

0 commit comments

Comments
 (0)