Skip to content

Commit 848ace4

Browse files
committed
WIP typing
1 parent 0342c23 commit 848ace4

20 files changed

+325
-114
lines changed

.pre-commit-config.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
default_language_version:
2+
python: python3.12
13
repos:
24
- repo: https://github.com/pre-commit/pre-commit-hooks
35
rev: v5.0.0
@@ -13,11 +15,17 @@ repos:
1315
- id: no-commit-to-branch
1416
args: ["--branch=main"]
1517
- repo: https://github.com/astral-sh/ruff-pre-commit
16-
rev: v0.8.6
18+
rev: v0.9.1
1719
hooks:
1820
- id: ruff
1921
args: ["--fix"]
2022
- id: ruff-format
23+
- repo: https://github.com/pre-commit/mirrors-mypy
24+
rev: v1.14.1
25+
hooks:
26+
- id: mypy
27+
additional_dependencies:
28+
- pynndescent
2129
- repo: https://github.com/pre-commit/mirrors-prettier
2230
rev: v4.0.0-alpha.8
2331
hooks:

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This file only contains a selection of the most common options. For a full
44
# list see the documentation:
55
# https://www.sphinx-doc.org/en/master/usage/configuration.html
6+
from __future__ import annotations
67

78
import os
89

examples/rnn_dbscan_big.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,26 @@
77
88
"""
99

10+
from __future__ import annotations
11+
12+
from typing import TYPE_CHECKING
13+
1014
import numpy as np
1115
from joblib import Memory
1216
from sklearn import metrics
1317
from sklearn.datasets import fetch_openml
1418

1519
from sklearn_ann.cluster.rnn_dbscan import simple_rnn_dbscan_pipeline
1620

21+
if TYPE_CHECKING:
22+
from typing import Any
23+
24+
from sklearn.utils import Bunch
25+
1726

1827
# #############################################################################
1928
# Generate sample data
20-
def fetch_mnist():
29+
def fetch_mnist() -> Bunch:
2130
print("Downloading mnist_784")
2231
mnist = fetch_openml("mnist_784")
2332
return mnist.data / 255, mnist.target
@@ -28,7 +37,9 @@ def fetch_mnist():
2837
X, y = memory.cache(fetch_mnist)()
2938

3039

31-
def run_rnn_dbscan(neighbor_transformer, n_neighbors, **kwargs):
40+
def run_rnn_dbscan(
41+
neighbor_transformer: object, n_neighbors: int, **kwargs: Any
42+
) -> None:
3243
# #############################################################################
3344
# Compute RnnDBSCAN
3445

examples/rnn_dbscan_simple.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Mostly copypasted from sklearn's DBSCAN example.
99
1010
"""
11+
from __future__ import annotations
1112

1213
import numpy as np
1314
from sklearn import metrics

pyproject.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ select = [
7474
"PTH", # Pathlib
7575
"RUF", # Ruff’s own rules
7676
"T20", # print statements
77+
"TC", # type checking
7778
]
7879
ignore = [
7980
# Don’t complain about “confusables”
@@ -84,6 +85,10 @@ ignore = [
8485
"tests/*.py" = ["T20"]
8586
[tool.ruff.lint.isort]
8687
known-first-party = ["sklearn_ann"]
88+
required-imports = ["from __future__ import annotations"]
89+
[tool.ruff.lint.flake8-type-checking]
90+
exempt-modules = []
91+
strict = true
8792

8893
[tool.hatch.envs.docs]
8994
installer = "uv"
@@ -97,6 +102,14 @@ features = ["tests", "annlibs"]
97102
[tool.hatch.build.targets.wheel]
98103
packages = ["src/sklearn_ann"]
99104

105+
[tool.mypy]
106+
python_version = "3.11"
107+
mypy_path = ["src", "tests"]
108+
strict = true
109+
explicit_package_bases = true # pytest doesn’t do __init__.py
110+
no_implicit_optional = true
111+
disallow_untyped_decorators = false # e.g. pytest.mark.parametrize
112+
100113
[build-system]
101114
requires = ["hatchling", "hatch-vcs", "hatch-fancy-pypi-readme"]
102115
build-backend = "hatchling.build"

src/sklearn_ann/cluster/rnn_dbscan.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
1+
from __future__ import annotations
2+
13
from collections import deque
2-
from typing import cast
4+
from typing import TYPE_CHECKING, cast
35

46
import numpy as np
7+
from scipy.sparse import csr_matrix
58
from sklearn.base import BaseEstimator, ClusterMixin
69
from sklearn.neighbors import KNeighborsTransformer
710
from sklearn.utils import Tags
811
from sklearn.utils.validation import validate_data
912

1013
from ..utils import get_sparse_row
1114

15+
if TYPE_CHECKING:
16+
from collections.abc import Iterator
17+
from typing import Literal, Self
18+
19+
from numpy.typing import NDArray
20+
from sklearn.pipeline import Pipeline
21+
22+
1223
UNCLASSIFIED = -2
1324
NOISE = -1
1425

@@ -37,7 +48,9 @@ def join(it1, it2):
3748
cur_it2 = next(it2, None)
3849

3950

40-
def neighborhood(is_core, knns, rev_knns, idx):
51+
def neighborhood(
52+
is_core: NDArray[np.bool_], knns: csr_matrix, rev_knns: csr_matrix, idx: int
53+
) -> Iterator[tuple[int, float]]:
4154
# TODO: Make this inner bit faster
4255
knn_it = get_sparse_row(knns, idx)
4356
rev_core_knn_it = (
@@ -52,10 +65,12 @@ def neighborhood(is_core, knns, rev_knns, idx):
5265
)
5366

5467

55-
def rnn_dbscan_inner(is_core, knns, rev_knns, labels):
68+
def rnn_dbscan_inner(
69+
is_core: NDArray[np.bool_], knns: csr_matrix, rev_knns: csr_matrix, labels
70+
) -> list[float]:
5671
cluster = 0
57-
cur_dens = 0
58-
dens = []
72+
cur_dens = 0.0
73+
dens: list[float] = []
5974
for x_idx in range(len(labels)):
6075
if labels[x_idx] == UNCLASSIFIED:
6176
# Expand cluster
@@ -81,7 +96,7 @@ def rnn_dbscan_inner(is_core, knns, rev_knns, labels):
8196
elif labels[z_idx] == NOISE:
8297
labels[z_idx] = cluster
8398
dens.append(cur_dens)
84-
cur_dens = 0
99+
cur_dens = 0.0
85100
cluster += 1
86101
else:
87102
labels[x_idx] = NOISE
@@ -138,15 +153,20 @@ class RnnDBSCAN(ClusterMixin, BaseEstimator):
138153
"""
139154

140155
def __init__(
141-
self, n_neighbors=5, *, input_guarantee="none", n_jobs=None, keep_knns=False
142-
):
156+
self,
157+
n_neighbors: int = 5,
158+
*,
159+
input_guarantee: Literal["none", "kneighbors"] = "none",
160+
n_jobs: int | None = None,
161+
keep_knns: bool = False,
162+
) -> None:
143163
self.n_neighbors = n_neighbors
144164
self.input_guarantee = input_guarantee
145165
self.n_jobs = n_jobs
146166
self.keep_knns = keep_knns
147167

148-
def fit(self, X, y=None):
149-
X = validate_data(self, X, accept_sparse="csr")
168+
def fit(self, X: NDArray[np.float64] | csr_matrix, y: None = None) -> Self:
169+
X = cast(csr_matrix, validate_data(self, X, accept_sparse="csr"))
150170
if self.input_guarantee == "none":
151171
algorithm = KNeighborsTransformer(n_neighbors=self.n_neighbors)
152172
X = algorithm.fit_transform(X)
@@ -157,7 +177,7 @@ def fit(self, X, y=None):
157177
"Expected input_guarantee to be one of 'none', 'kneighbors'"
158178
)
159179

160-
XT = X.transpose().tocsr(copy=True)
180+
XT = cast(csr_matrix, X.transpose().tocsr(copy=True))
161181
if self.keep_knns:
162182
self.knns_ = X
163183
self.rev_knns_ = XT
@@ -176,11 +196,11 @@ def fit(self, X, y=None):
176196

177197
return self
178198

179-
def fit_predict(self, X, y=None):
199+
def fit_predict(self, X, y=None) -> NDArray[np.int32]:
180200
self.fit(X, y=y)
181201
return self.labels_
182202

183-
def drop_knns(self):
203+
def drop_knns(self) -> None:
184204
del self.knns_
185205
del self.rev_knns_
186206

@@ -191,22 +211,28 @@ def __sklearn_tags__(self) -> Tags:
191211

192212

193213
def simple_rnn_dbscan_pipeline(
194-
neighbor_transformer, n_neighbors, n_jobs=None, keep_knns=None, **kwargs
195-
):
214+
neighbor_transformer: object,
215+
n_neighbors: int,
216+
*,
217+
n_jobs: int | None = None,
218+
keep_knns: bool = False,
219+
input_guarantee: Literal["none", "kneighbors"] = "none",
220+
) -> Pipeline:
196221
"""
197222
Create a simple pipeline comprising a transformer and RnnDBSCAN.
198223
199224
Parameters
200225
----------
201-
neighbor_transformer : class implementing KNeighborsTransformer interface
202-
n_neighbors:
226+
neighbor_transformer
227+
class implementing KNeighborsTransformer interface
228+
n_neighbors
203229
Passed to neighbor_transformer and RnnDBSCAN
204-
n_jobs:
230+
n_jobs
205231
Passed to neighbor_transformer and RnnDBSCAN
206-
keep_knns:
232+
keep_knns
233+
Passed to RnnDBSCAN
234+
input_guarantee
207235
Passed to RnnDBSCAN
208-
kwargs:
209-
Passed to neighbor_transformer
210236
"""
211237
from sklearn.pipeline import make_pipeline
212238

src/sklearn_ann/kneighbors/annoy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import annoy
24
import numpy as np
35
from scipy.sparse import csr_matrix

src/sklearn_ann/kneighbors/faiss.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import math
4+
from typing import TYPE_CHECKING, TypedDict
45

56
import faiss
67
import numpy as np
@@ -13,10 +14,22 @@
1314

1415
from ..utils import TransformerChecksMixin, postprocess_knn_csr
1516

17+
if TYPE_CHECKING:
18+
from typing import Self
19+
20+
from numpy.typing import ArrayLike, NDArray
21+
22+
23+
class MetricInfo(TypedDict):
24+
metric: int
25+
normalize: bool
26+
negate: bool
27+
28+
1629
L2_INFO = {"metric": faiss.METRIC_L2, "sqrt": True}
1730

1831

19-
METRIC_MAP = {
32+
METRIC_MAP: dict[str, MetricInfo] = {
2033
"cosine": {
2134
"metric": faiss.METRIC_INNER_PRODUCT,
2235
"normalize": True,
@@ -34,7 +47,12 @@
3447
}
3548

3649

37-
def mk_faiss_index(feats, inner_metric, index_key="", nprobe=128) -> faiss.Index:
50+
def mk_faiss_index(
51+
feats: NDArray[np.float32],
52+
inner_metric: int,
53+
index_key: str = "",
54+
nprobe: int = 128,
55+
) -> faiss.Index:
3856
size, dim = feats.shape
3957
if not index_key:
4058
if inner_metric == faiss.METRIC_INNER_PRODUCT:
@@ -64,15 +82,15 @@ def mk_faiss_index(feats, inner_metric, index_key="", nprobe=128) -> faiss.Index
6482
class FAISSTransformer(TransformerChecksMixin, TransformerMixin, BaseEstimator):
6583
def __init__(
6684
self,
67-
n_neighbors=5,
85+
n_neighbors: int = 5,
6886
*,
69-
metric="euclidean",
70-
index_key="",
71-
n_probe=128,
72-
n_jobs=-1,
73-
include_fwd=True,
74-
include_rev=False,
75-
):
87+
metric: str = "euclidean",
88+
index_key: str = "",
89+
n_probe: int = 128,
90+
n_jobs: int = -1,
91+
include_fwd: bool = True,
92+
include_rev: bool = False,
93+
) -> None:
7694
self.n_neighbors = n_neighbors
7795
self.metric = metric
7896
self.index_key = index_key
@@ -82,10 +100,10 @@ def __init__(
82100
self.include_rev = include_rev
83101

84102
@property
85-
def _metric_info(self):
103+
def _metric_info(self) -> MetricInfo:
86104
return METRIC_MAP[self.metric]
87105

88-
def fit(self, X, y=None):
106+
def fit(self, X: ArrayLike, y: None = None) -> Self:
89107
normalize = self._metric_info.get("normalize", False)
90108
X = validate_data(self, X, dtype=np.float32, copy=normalize)
91109
self.n_samples_fit_ = X.shape[0]
@@ -100,14 +118,14 @@ def fit(self, X, y=None):
100118
self.faiss_ = mk_faiss_index(X, inner_metric, self.index_key, self.n_probe)
101119
return self
102120

103-
def transform(self, X):
121+
def transform(self, X: NDArray[np.number]) -> csr_matrix:
104122
normalize = self._metric_info.get("normalize", False)
105123
X = self._transform_checks(X, "faiss_", dtype=np.float32, copy=normalize)
106124
if normalize:
107125
normalize_L2(X)
108126
return self._transform(X)
109127

110-
def _transform(self, X):
128+
def _transform(self, X: NDArray[np.float32]) -> csr_matrix:
111129
n_samples_transform = self.n_samples_fit_ if X is None else X.shape[0]
112130
n_neighbors = self.n_neighbors + 1
113131
if X is None:
@@ -156,7 +174,7 @@ def _transform(self, X):
156174
mat, include_fwd=self.include_fwd, include_rev=self.include_rev
157175
)
158176

159-
def fit_transform(self, X, y=None):
177+
def fit_transform(self, X: ArrayLike, y: None = None) -> csr_matrix:
160178
return self.fit(X, y=y)._transform(X=None)
161179

162180
def __sklearn_tags__(self) -> Tags:

0 commit comments

Comments
 (0)