Skip to content

Commit c57bad2

Browse files
committed
perf: fix multiprocessing timing measurement
- Move vector-to-bytes conversion outside timing measurements - Track actual worker start times for accurate parallel timing - Refactor worker function for compatibility with newer Python versions
1 parent 340c9e6 commit c57bad2

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

engine/base_client/search.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,12 @@ def search_all(
7575
search_one = functools.partial(self.__class__._search_one, top=top)
7676

7777
# Convert queries to a list for potential reuse
78-
queries_list = list(queries)
79-
78+
# Also, converts query vectors to bytes beforehand, preparing them for sending to client without affecting search time measurements
79+
queries_list = []
80+
for query in queries:
81+
query.vector = np.array(query.vector).astype(np.float32).tobytes()
82+
queries_list.append(query)
83+
8084
# Handle MAX_QUERIES environment variable
8185
if MAX_QUERIES > 0:
8286
queries_list = queries_list[:MAX_QUERIES]
@@ -114,12 +118,12 @@ def cycling_query_generator(queries, total_count):
114118
total_query_count = len(used_queries)
115119

116120
if parallel == 1:
117-
# Single-threaded execution
118-
start = time.perf_counter()
119-
120121
# Create a progress bar with the correct total
121122
pbar = tqdm.tqdm(total=total_query_count, desc="Processing queries", unit="queries")
122123

124+
# Single-threaded execution
125+
start = time.perf_counter()
126+
123127
# Process queries with progress updates
124128
results = []
125129
for query in used_queries:
@@ -148,42 +152,32 @@ def cycling_query_generator(queries, total_count):
148152
# For lists, we can use the chunked_iterable function
149153
query_chunks = list(chunked_iterable(used_queries, chunk_size))
150154

151-
# Function to be executed by each worker process
152-
def worker_function(chunk, result_queue):
153-
self.__class__.init_client(
154-
self.host,
155-
distance,
156-
self.connection_params,
157-
self.search_params,
158-
)
159-
self.setup_search()
160-
results = process_chunk(chunk, search_one)
161-
result_queue.put(results)
162-
163155
# Create a queue to collect results
164156
result_queue = Queue()
165157

166158
# Create worker processes
167159
processes = []
168160
for chunk in query_chunks:
169-
process = Process(target=worker_function, args=(chunk, result_queue))
161+
process = Process(target=worker_function, args=(self, distance, search_one, chunk, result_queue))
170162
processes.append(process)
171163

172-
# Start measuring time for the critical work
173-
start = time.perf_counter()
174-
175164
# Start worker processes
176165
for process in processes:
177166
process.start()
178167

179168
# Collect results from all worker processes
180169
results = []
170+
min_start_time = time.perf_counter()
181171
for _ in processes:
182-
chunk_results = result_queue.get()
172+
proc_start_time, chunk_results = result_queue.get()
183173
results.extend(chunk_results)
174+
175+
# Update min_start_time if necessary
176+
if proc_start_time < min_start_time:
177+
min_start_time = proc_start_time
184178

185179
# Stop measuring time for the critical work
186-
total_time = time.perf_counter() - start
180+
total_time = time.perf_counter() - min_start_time
187181

188182
# Wait for all worker processes to finish
189183
for process in processes:
@@ -226,13 +220,21 @@ def chunked_iterable(iterable, size):
226220
while chunk := list(islice(it, size)):
227221
yield chunk
228222

223+
# Function to be executed by each worker process
224+
def worker_function(self, distance, search_one, chunk, result_queue):
225+
self.init_client(
226+
self.host,
227+
distance,
228+
self.connection_params,
229+
self.search_params,
230+
)
231+
self.setup_search()
232+
233+
start_time = time.perf_counter()
234+
results = process_chunk(chunk, search_one)
235+
result_queue.put((start_time, results))
229236

230237
def process_chunk(chunk, search_one):
231238
"""Process a chunk of queries using the search_one function."""
232239
# No progress bar in worker processes to avoid cluttering the output
233240
return [search_one(query) for query in chunk]
234-
235-
236-
def process_chunk_wrapper(chunk, search_one):
237-
"""Wrapper to process a chunk of queries."""
238-
return process_chunk(chunk, search_one)

engine/clients/vectorsets/search.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import random
22
from typing import List, Tuple
33

4-
import numpy as np
54
from redis import Redis, RedisCluster
65

76

@@ -42,7 +41,7 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
4241
@classmethod
4342
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
4443
ef = cls.search_params["search_params"]["ef"]
45-
response = cls.client.execute_command("VSIM", "idx", "FP32", np.array(vector).astype(np.float32).tobytes(), "WITHSCORES", "COUNT", top, "EF", ef)
44+
response = cls.client.execute_command("VSIM", "idx", "FP32", vector, "WITHSCORES", "COUNT", top, "EF", ef)
4645
# decode responses
4746
# every even cell is id, every odd is the score
4847
# scores needs to be 1 - scores since on vector sets 1 is identical, 0 is opposite vector

0 commit comments

Comments
 (0)