7
7
import sys
8
8
import time
9
9
from pathlib import Path
10
- from typing import Final
11
10
12
11
import numpy as np
13
12
import svs
16
15
from . import consts , utils
17
16
from .loader import create_loader
18
17
19
- STR_TO_STRATEGY : Final [dict [str , svs .LVQStrategy ]] = {
20
- "auto" : svs .LVQStrategy .Auto ,
21
- "sequential" : svs .LVQStrategy .Sequential ,
22
- "turbo" : svs .LVQStrategy .Turbo ,
23
- }
24
-
25
-
26
18
logger = logging .getLogger (__file__ )
27
19
28
20
@@ -38,7 +30,6 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
38
30
help = "Query type" ,
39
31
choices = consts .STR_TO_DATA_TYPE .keys (),
40
32
default = "float32" ,
41
- type = consts .STR_TO_DATA_TYPE .get ,
42
33
)
43
34
parser .add_argument ("--idx_dir" , help = "Index dir" , type = Path )
44
35
parser .add_argument ("--data_dir" , help = "Data dir" , type = Path )
@@ -58,11 +49,10 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
58
49
type = Path ,
59
50
)
60
51
parser .add_argument (
61
- "--strategy " ,
52
+ "--lvq_strategy " ,
62
53
help = "LVQ strategy" ,
63
- choices = tuple (STR_TO_STRATEGY .keys ()),
54
+ choices = tuple (consts . STR_TO_LVQ_STRATEGY .keys ()),
64
55
default = "auto" ,
65
- type = STR_TO_STRATEGY .get ,
66
56
)
67
57
parser .add_argument (
68
58
"--leanvec_dims" , help = "LeanVec dimensionality" , default = - 4 , type = int
@@ -115,7 +105,6 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
115
105
"--distance" ,
116
106
choices = tuple (consts .STR_TO_DISTANCE .keys ()),
117
107
default = "mip" ,
118
- type = consts .STR_TO_DISTANCE .get ,
119
108
)
120
109
parser .add_argument (
121
110
"--load_from_static" ,
@@ -151,6 +140,7 @@ def search(
151
140
calibration_query_path : Path | None = None ,
152
141
calibration_ground_truth_path : Path | None = None ,
153
142
load_from_static : bool = False ,
143
+ lvq_strategy : svs .LVQStrategy | None = None ,
154
144
):
155
145
logger .info ({"search_args" : locals ()})
156
146
logger .info (utils .read_system_config ())
@@ -178,6 +168,7 @@ def search(
178
168
compress = compress ,
179
169
leanvec_dims = leanvec_dims ,
180
170
leanvec_alignment = leanvec_alignment ,
171
+ lvq_strategy = lvq_strategy ,
181
172
)
182
173
183
174
if static :
@@ -337,11 +328,12 @@ def main(argv: str | None = None) -> None:
337
328
prefetch_steps = args .prefetch_step ,
338
329
num_rep = args .num_rep ,
339
330
static = args .static ,
340
- distance = args .distance ,
341
- query_type = args .query_type ,
331
+ distance = consts . STR_TO_DISTANCE [ args .distance ] ,
332
+ query_type = consts . STR_TO_DATA_TYPE [ args .query_type ] ,
342
333
calibration_query_path = args .calibration_query_file ,
343
334
calibration_ground_truth_path = args .calibration_ground_truth_file ,
344
335
load_from_static = args .load_from_static ,
336
+ lvq_strategy = consts .STR_TO_LVQ_STRATEGY .get (args .lvq_strategy , None ),
345
337
)
346
338
347
339
0 commit comments