Skip to content

Commit 7ceda9e

Browse files
authored
Merge pull request #1192 from vespa-engine/boeker/ann-tool-improvements
Make ANN parameter optimization faster
2 parents 68bcf6c + 88430a2 commit 7ceda9e

3 files changed

Lines changed: 87 additions & 43 deletions

File tree

tests/integration/test_integration_evaluation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1306,7 +1306,12 @@ def vector_to_query(vec_str: str, filter_value: int) -> dict:
13061306

13071307
print("Constructing optimizer object")
13081308
optimizer = VespaNNParameterOptimizer(
1309-
self.app, queries, 100, print_progress=True
1309+
self.app,
1310+
queries,
1311+
100,
1312+
print_progress=True,
1313+
benchmark_time_limit=1000,
1314+
recall_query_limit=10,
13101315
)
13111316

13121317
print("Running optimizer")

tests/unit/test_evaluator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3364,13 +3364,19 @@ def query_many(self, queries, max_concurrent=100, **kwargs):
33643364

33653365
app = MockVespaApp()
33663366
benchmarker = VespaQueryBenchmarker(
3367-
[{"yql": "foo"}, {"yql": "foo"}, {"yql": "foo"}], app
3367+
[{"yql": "foo"}, {"yql": "foo"}, {"yql": "foo"}],
3368+
app,
3369+
max_concurrent=10,
3370+
time_limit=11000,
33683371
)
33693372
benchmark = benchmarker.run()
3370-
self.assertEqual(len(benchmark), 3)
3371-
self.assertAlmostEqual(benchmark[0], 3000, delta=250)
3372-
self.assertAlmostEqual(benchmark[1], 3000, delta=250)
3373-
self.assertAlmostEqual(benchmark[2], 3000, delta=250)
3373+
self.assertEqual(6, len(benchmark))
3374+
self.assertAlmostEqual(4000, benchmark[0], delta=250)
3375+
self.assertAlmostEqual(2000, benchmark[1], delta=250)
3376+
self.assertAlmostEqual(4000, benchmark[2], delta=250)
3377+
self.assertAlmostEqual(2000, benchmark[3], delta=250)
3378+
self.assertAlmostEqual(4000, benchmark[4], delta=250)
3379+
self.assertAlmostEqual(2000, benchmark[5], delta=250)
33743380

33753381

33763382
class TestVespaNNParameterOptimizer(unittest.TestCase):

vespa/evaluation.py

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,15 +1864,22 @@ class VespaNNRecallEvaluator:
18641864
queries (Sequence[Mapping[str, Any]]): List of ANN queries.
18651865
hits (int): Number of hits to use. Should match the parameter targetHits in the used ANN queries.
18661866
app (Vespa): An instance of the Vespa application.
1867+
query_limit (int): Maximum number of queries to determine the recall for. Defaults to 20.
18671868
**kwargs (dict, optional): Additional HTTP request parameters. See: <https://docs.vespa.ai/en/reference/document-v1-api-reference.html#request-parameters>.
18681869
"""
18691870

18701871
def __init__(
1871-
self, queries: Sequence[Mapping[str, Any]], hits: int, app: Vespa, **kwargs
1872+
self,
1873+
queries: Sequence[Mapping[str, Any]],
1874+
hits: int,
1875+
app: Vespa,
1876+
query_limit: int = 20,
1877+
**kwargs,
18721878
):
18731879
self.queries = queries
18741880
self.hits = hits
18751881
self.app = app
1882+
self.query_limit = query_limit
18761883
self.parameters = kwargs
18771884

18781885
def _compute_recall(
@@ -1925,12 +1932,18 @@ def run(self) -> List[float]:
19251932
query_parameters_exact = dict(query_parameters, **VespaNNParameters.EXACT)
19261933

19271934
queries_with_parameters_exact = list(
1928-
map(lambda query: dict(query, **query_parameters_exact), self.queries)
1935+
map(
1936+
lambda query: dict(query, **query_parameters_exact),
1937+
self.queries[0 : self.query_limit],
1938+
)
19291939
)
19301940
responses_exact, _ = execute_queries(self.app, queries_with_parameters_exact)
19311941

19321942
queries_with_parameters = list(
1933-
map(lambda query: dict(query, **query_parameters), self.queries)
1943+
map(
1944+
lambda query: dict(query, **query_parameters),
1945+
self.queries[0 : self.query_limit],
1946+
)
19341947
)
19351948
responses, _ = execute_queries(self.app, queries_with_parameters)
19361949

@@ -1950,74 +1963,83 @@ class VespaQueryBenchmarker:
19501963
This class:
19511964
19521965
- Takes a list of queries.
1953-
- Runs the queries multiple times.
1966+
- Runs the queries for the given amount of time.
19541967
- Determines the average searchtime of these runs.
19551968
19561969
Args:
19571970
queries (Sequence[Mapping[str, Any]]): List of queries.
19581971
app (Vespa): An instance of the Vespa application.
1959-
repetitions (int, optional): Number of times to repeat the queries.
1972+
time_limit(int, optional): Time to run the benchmark for (in milliseconds).
19601973
**kwargs (dict, optional): Additional HTTP request parameters. See: <https://docs.vespa.ai/en/reference/document-v1-api-reference.html#request-parameters>.
19611974
"""
19621975

19631976
def __init__(
19641977
self,
19651978
queries: Sequence[Mapping[str, Any]],
19661979
app: Vespa,
1967-
repetitions: int = 10,
1980+
time_limit: int = 2000,
19681981
max_concurrent: int = 10,
19691982
**kwargs,
19701983
):
19711984
self.queries = queries
19721985
self.app = app
1973-
self.repetitions = repetitions
1986+
self.time_limit = time_limit
19741987
self.max_concurrent = max_concurrent
19751988
self.parameters = kwargs
19761989

1977-
def _run_benchmark(self) -> List[float]:
1978-
"""
1979-
Run all queries once and extract the searchtime.
1980-
1981-
Returns:
1982-
List[float]: List of searchtimes, corresponding to the supplied queries.
1983-
"""
1984-
queries_with_parameters = list(
1990+
self.queries_with_parameters = list(
19851991
map(
19861992
lambda query: dict(
19871993
query, **self.parameters, **{"presentation.timing": True}
19881994
),
19891995
self.queries,
19901996
)
19911997
)
1992-
_, response_times = execute_queries(
1993-
self.app, queries_with_parameters, max_concurrent=self.max_concurrent
1994-
)
1995-
return response_times
1998+
self.query_chunks = [
1999+
self.queries_with_parameters[x : x + self.max_concurrent]
2000+
for x in range(0, len(self.queries_with_parameters), self.max_concurrent)
2001+
]
19962002

1997-
def run(self) -> List[float]:
2003+
def _run_benchmark(self, time_limit) -> List[float]:
19982004
"""
1999-
Runs the benchmark (including a warm-up run not included in the result).
2005+
Run all queries once and extract the searchtime.
20002006
20012007
Returns:
20022008
List[float]: List of searchtimes, corresponding to the supplied queries.
20032009
"""
2004-
# Two warmup runs
2005-
for i in range(0, self.repetitions):
2006-
self._run_benchmark()
2010+
all_response_times = []
2011+
time_taken = 0
20072012

2008-
# Actual benchmark runs
2009-
response_times_sum = [0] * len(self.queries)
2010-
for i in range(0, self.repetitions):
2011-
response_times = self._run_benchmark()
2012-
response_times_ms = list(map(lambda x: 1000 * x, response_times))
2013-
response_times_sum = list(
2014-
map(
2015-
lambda pair: pair[0] + pair[1],
2016-
zip(response_times_sum, response_times_ms),
2017-
)
2013+
current_chunk = 0
2014+
while time_taken < time_limit:
2015+
_, response_times = execute_queries(
2016+
self.app,
2017+
self.query_chunks[current_chunk],
2018+
max_concurrent=self.max_concurrent,
20182019
)
20192020

2020-
return list(map(lambda x: x / self.repetitions, response_times_sum))
2021+
response_times_ms = list(map(lambda x: 1000 * x, response_times))
2022+
all_response_times.extend(response_times_ms)
2023+
time_taken += max(
2024+
sum(response_times_ms), 1
2025+
) # At least add something in every iteration
2026+
2027+
current_chunk = (current_chunk + 1) % len(self.query_chunks)
2028+
2029+
return all_response_times
2030+
2031+
def run(self) -> List[float]:
2032+
"""
2033+
Runs the benchmark (including a warm-up run not included in the result).
2034+
2035+
Returns:
2036+
List[float]: List of searchtimes, corresponding to the supplied queries.
2037+
"""
2038+
# Warmup run for 100ms
2039+
_ = self._run_benchmark(100)
2040+
2041+
# Actual benchmark
2042+
return self._run_benchmark(self.time_limit)
20212043

20222044

20232045
class BucketedMetricResults:
@@ -2100,6 +2122,9 @@ class VespaNNParameterOptimizer:
21002122
hits (int): Number of hits to use in recall computations. Has to match the parameter targetHits in the used ANN queries.
21012123
buckets_per_percent (int, optional): How many buckets are created for every percent point, "resolution" of the suggestions. Defaults to 2.
21022124
print_progress (bool, optional): Whether to print progress information while determining suggestions. Defaults to False.
2125+
benchmark_time_limit (int): Time in milliseconds to spend per bucket benchmark. Defaults to 5000.
2126+
recall_query_limit(int): Number of queries per bucket to compute the recall for. Defaults to 20.
2127+
max_concurrent(int): Number of queries to execute concurrently during benchmark/recall calculation. Defaults to 10.
21032128
"""
21042129

21052130
def __init__(
@@ -2109,6 +2134,8 @@ def __init__(
21092134
hits: int,
21102135
buckets_per_percent: int = 2,
21112136
print_progress: bool = False,
2137+
benchmark_time_limit: int = 5000,
2138+
recall_query_limit: int = 20,
21122139
max_concurrent: int = 10,
21132140
):
21142141
self.app = app
@@ -2120,6 +2147,8 @@ def __init__(
21202147
self.buckets = [[] for _ in range(100 * buckets_per_percent)]
21212148

21222149
self.print_progress = print_progress
2150+
self.benchmark_time_limit = benchmark_time_limit
2151+
self.recall_query_limit = recall_query_limit
21232152
self.max_concurrent = max_concurrent
21242153

21252154
def get_bucket_interval_width(self) -> float:
@@ -2438,7 +2467,11 @@ def benchmark(self, **kwargs) -> BucketedMetricResults:
24382467
)
24392468
processed_buckets += 1
24402469
benchmarker = VespaQueryBenchmarker(
2441-
bucket, self.app, max_concurrent=self.max_concurrent, **kwargs
2470+
bucket,
2471+
self.app,
2472+
time_limit=self.benchmark_time_limit,
2473+
max_concurrent=self.max_concurrent,
2474+
**kwargs,
24422475
)
24432476
response_times = benchmarker.run()
24442477
results.append(response_times)
@@ -2479,7 +2512,7 @@ def compute_average_recalls(self, **kwargs) -> BucketedMetricResults:
24792512
end="",
24802513
)
24812514
recall_evaluator = VespaNNRecallEvaluator(
2482-
bucket, self.hits, self.app, **kwargs
2515+
bucket, self.hits, self.app, self.recall_query_limit, **kwargs
24832516
)
24842517
recall_list = recall_evaluator.run()
24852518
results.append(recall_list)

0 commit comments

Comments
 (0)