Skip to content

Commit 806fafb

Browse files
committed
More typing
1 parent 166a235 commit 806fafb

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

python/evalica/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Collection, Hashable
77
from dataclasses import dataclass
88
from types import MappingProxyType
9-
from typing import Generic, Literal, Protocol, TypeVar, runtime_checkable
9+
from typing import Generic, Literal, Protocol, TypeVar, cast, runtime_checkable
1010

1111
import numpy as np
1212
import numpy.typing as npt
@@ -979,7 +979,7 @@ def pairwise_scores(
979979
raise ScoreDimensionError(scores.ndim)
980980

981981
if solver == "naive":
982-
return pairwise_scores_naive(scores)
982+
return cast("npt.NDArray[np.float64]", pairwise_scores_naive(scores))
983983

984984
return pairwise_scores_pyo3(scores)
985985

python/evalica/naive.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, cast
3+
from typing import TYPE_CHECKING, TypeVar, cast
44

55
import numpy as np
66
import numpy.typing as npt
@@ -9,17 +9,19 @@
99

1010
if TYPE_CHECKING:
1111
from collections.abc import Collection
12-
from typing import Any
1312

13+
S = TypeVar("S", bound=npt.NBitBase)
14+
T = TypeVar("T")
1415

15-
def pairwise_scores(scores: npt.NDArray[np.number[Any]]) -> npt.NDArray[np.float64]:
16+
17+
def pairwise_scores(scores: npt.NDArray[np.number[S]]) -> npt.NDArray[np.number[S]]:
1618
if not scores.size:
17-
return np.zeros((0, 0))
19+
return np.zeros((0, 0), dtype=scores.dtype)
1820

1921
return np.nan_to_num(scores[:, np.newaxis] / (scores + scores[:, np.newaxis]))
2022

2123

22-
def _check_lengths(xs: Collection[Any], *rest: Collection[Any]) -> None:
24+
def _check_lengths(xs: Collection[T], *rest: Collection[T]) -> None:
2325
length = len(xs)
2426

2527
for collection in rest:
@@ -225,9 +227,9 @@ def eigen(
225227

226228

227229
def pagerank_matrix(
228-
matrix: npt.NDArray[np.floating[Any]],
230+
matrix: npt.NDArray[np.floating[S]],
229231
damping: float,
230-
) -> npt.NDArray[np.floating[Any]]:
232+
) -> npt.NDArray[np.floating[S]]:
231233
if not matrix.size:
232234
return np.zeros(0, dtype=matrix.dtype)
233235

0 commit comments

Comments
 (0)