Skip to content

Commit 3e2e1d5

Browse files
authored
Add LVQ strategy argument (#9)
Also fix distance and query type arguments.
1 parent e212c05 commit 3e2e1d5

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

src/svsbench/consts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,9 @@
5454
"float16": svs.DataType.float16,
5555
"float32": svs.DataType.float32,
5656
}
57+
58+
STR_TO_LVQ_STRATEGY: Final[dict[str, svs.LVQStrategy]] = {
59+
"auto": svs.LVQStrategy.Auto,
60+
"sequential": svs.LVQStrategy.Sequential,
61+
"turbo": svs.LVQStrategy.Turbo,
62+
}

src/svsbench/loader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def create_loader(
1717
compress: bool = False,
1818
leanvec_dims: int | None = None,
1919
leanvec_alignment: int = 32,
20+
lvq_strategy: svs.LVQStrategy | None = None,
2021
) -> svs.VectorDataLoader | svs.LVQLoader | svs.LeanVecLoader:
2122
"""Create loader."""
2223
unkown_msg = f"Unknown {svs_type=}"
@@ -47,17 +48,20 @@ def create_loader(
4748
if svs_type == "lvq4x4":
4849
primary = 4
4950
residual = 4
50-
strategy = svs.LVQStrategy.Turbo
51+
default_strategy = svs.LVQStrategy.Turbo
5152
elif svs_type == "lvq4x8":
5253
primary = 4
5354
residual = 8
54-
strategy = svs.LVQStrategy.Turbo
55+
default_strategy = svs.LVQStrategy.Turbo
5556
elif svs_type == "lvq8":
5657
primary = 8
5758
residual = 0
58-
strategy = svs.LVQStrategy.Sequential
59+
default_strategy = svs.LVQStrategy.Sequential
5960
else:
6061
raise ValueError(unkown_msg)
62+
strategy = (
63+
lvq_strategy if lvq_strategy is not None else default_strategy
64+
)
6165
padding = 32
6266
if vecs_path is not None or compress:
6367
loader = svs.LVQLoader(

src/svsbench/search.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import sys
88
import time
99
from pathlib import Path
10-
from typing import Final
1110

1211
import numpy as np
1312
import svs
@@ -16,13 +15,6 @@
1615
from . import consts, utils
1716
from .loader import create_loader
1817

19-
STR_TO_STRATEGY: Final[dict[str, svs.LVQStrategy]] = {
20-
"auto": svs.LVQStrategy.Auto,
21-
"sequential": svs.LVQStrategy.Sequential,
22-
"turbo": svs.LVQStrategy.Turbo,
23-
}
24-
25-
2618
logger = logging.getLogger(__file__)
2719

2820

@@ -38,7 +30,6 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
3830
help="Query type",
3931
choices=consts.STR_TO_DATA_TYPE.keys(),
4032
default="float32",
41-
type=consts.STR_TO_DATA_TYPE.get,
4233
)
4334
parser.add_argument("--idx_dir", help="Index dir", type=Path)
4435
parser.add_argument("--data_dir", help="Data dir", type=Path)
@@ -58,11 +49,10 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
5849
type=Path,
5950
)
6051
parser.add_argument(
61-
"--strategy",
52+
"--lvq_strategy",
6253
help="LVQ strategy",
63-
choices=tuple(STR_TO_STRATEGY.keys()),
54+
choices=tuple(consts.STR_TO_LVQ_STRATEGY.keys()),
6455
default="auto",
65-
type=STR_TO_STRATEGY.get,
6656
)
6757
parser.add_argument(
6858
"--leanvec_dims", help="LeanVec dimensionality", default=-4, type=int
@@ -115,7 +105,6 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
115105
"--distance",
116106
choices=tuple(consts.STR_TO_DISTANCE.keys()),
117107
default="mip",
118-
type=consts.STR_TO_DISTANCE.get,
119108
)
120109
parser.add_argument(
121110
"--load_from_static",
@@ -151,6 +140,7 @@ def search(
151140
calibration_query_path: Path | None = None,
152141
calibration_ground_truth_path: Path | None = None,
153142
load_from_static: bool = False,
143+
lvq_strategy: svs.LVQStrategy | None = None,
154144
):
155145
logger.info({"search_args": locals()})
156146
logger.info(utils.read_system_config())
@@ -178,6 +168,7 @@ def search(
178168
compress=compress,
179169
leanvec_dims=leanvec_dims,
180170
leanvec_alignment=leanvec_alignment,
171+
lvq_strategy=lvq_strategy,
181172
)
182173

183174
if static:
@@ -337,11 +328,12 @@ def main(argv: str | None = None) -> None:
337328
prefetch_steps=args.prefetch_step,
338329
num_rep=args.num_rep,
339330
static=args.static,
340-
distance=args.distance,
341-
query_type=args.query_type,
331+
distance=consts.STR_TO_DISTANCE[args.distance],
332+
query_type=consts.STR_TO_DATA_TYPE[args.query_type],
342333
calibration_query_path=args.calibration_query_file,
343334
calibration_ground_truth_path=args.calibration_ground_truth_file,
344335
load_from_static=args.load_from_static,
336+
lvq_strategy=consts.STR_TO_LVQ_STRATEGY.get(args.lvq_strategy, None),
345337
)
346338

347339

0 commit comments

Comments
 (0)