File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from multiprocessing import get_context
44
from typing import Iterable, List, Optional, Tuple
5+
from itertools import islice
56

67
import numpy as np
78
import tqdm
@@ -79,22 +80,31 @@ def search_all(
7980
else:
8081
ctx = get_context(self.get_mp_start_method())
8182

82-
with ctx.Pool(
83-
processes=parallel,
84-
initializer=self.__class__.init_client,
85-
initargs=(
83+
def process_initializer():
84+
"""Initialize each process before starting the search."""
85+
self.__class__.init_client(
8686
self.host,
8787
distance,
8888
self.connection_params,
8989
self.search_params,
90-
),
90+
)
91+
self.setup_search()
92+
93+
# Dynamically chunk the generator
94+
query_chunks = list(chunked_iterable(queries, max(1, parallel)))
95+
96+
with ctx.Pool(
97+
processes=parallel,
98+
initializer=process_initializer,
9199
) as pool:
92100
if parallel > 10:
93101
time.sleep(15) # Wait for all processes to start
94102
start = time.perf_counter()
95-
precisions, latencies = list(
96-
zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(queries)))
103+
results = pool.starmap(
104+
process_chunk,
105+
[(chunk, search_one) for chunk in query_chunks],
97106
)
107+
precisions, latencies = zip(*[result for chunk in results for result in chunk])
98108

99109
total_time = time.perf_counter() - start
100110

@@ -123,3 +133,15 @@ def post_search(self):
123133
@classmethod
124134
def delete_client(cls):
125135
pass
136+
137+
138+
def chunked_iterable(iterable, size):
139+
"""Yield successive chunks of a given size from an iterable."""
140+
it = iter(iterable)
141+
while chunk := list(islice(it, size)):
142+
yield chunk
143+
144+
145+
def process_chunk(chunk, search_one):
146+
"""Process a chunk of queries using the search_one function."""
147+
return [search_one(query) for query in chunk]

0 commit comments

Comments
 (0)