|
2 | 2 | import time
|
3 | 3 | from multiprocessing import get_context
|
4 | 4 | from typing import Iterable, List, Optional, Tuple
|
| 5 | +from itertools import islice |
5 | 6 |
|
6 | 7 | import numpy as np
|
7 | 8 | import tqdm
|
@@ -79,22 +80,31 @@ def search_all(
|
79 | 80 | else:
|
80 | 81 | ctx = get_context(self.get_mp_start_method())
|
81 | 82 |
|
82 |
| -with ctx.Pool( |
83 |
| -processes=parallel, |
84 |
| -initializer=self.__class__.init_client, |
85 |
| -initargs=( |
| 83 | +def process_initializer(): |
| 84 | +"""Initialize each process before starting the search.""" |
| 85 | +self.__class__.init_client( |
86 | 86 | self.host,
|
87 | 87 | distance,
|
88 | 88 | self.connection_params,
|
89 | 89 | self.search_params,
|
90 |
| -), |
| 90 | +) |
| 91 | +self.setup_search() |
| 92 | + |
| 93 | +# Dynamically chunk the generator |
| 94 | +query_chunks = list(chunked_iterable(queries, max(1, parallel))) |
| 95 | + |
| 96 | +with ctx.Pool( |
| 97 | +processes=parallel, |
| 98 | +initializer=process_initializer, |
91 | 99 | ) as pool:
|
92 | 100 | if parallel > 10:
|
93 | 101 | time.sleep(15) # Wait for all processes to start
|
94 | 102 | start = time.perf_counter()
|
95 |
| -precisions, latencies = list( |
96 |
| -zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(queries))) |
| 103 | +results = pool.starmap( |
| 104 | +process_chunk, |
| 105 | +[(chunk, search_one) for chunk in query_chunks], |
97 | 106 | )
|
| 107 | +precisions, latencies = zip(*[result for chunk in results for result in chunk]) |
98 | 108 |
|
99 | 109 | total_time = time.perf_counter() - start
|
100 | 110 |
|
@@ -123,3 +133,15 @@ def post_search(self):
|
123 | 133 | @classmethod
|
124 | 134 | def delete_client(cls):
|
125 | 135 | pass
|
| 136 | + |
| 137 | + |
| 138 | +def chunked_iterable(iterable, size): |
| 139 | +"""Yield successive chunks of a given size from an iterable.""" |
| 140 | +it = iter(iterable) |
| 141 | +while chunk := list(islice(it, size)): |
| 142 | +yield chunk |
| 143 | + |
| 144 | + |
| 145 | +def process_chunk(chunk, search_one): |
| 146 | +"""Process a chunk of queries using the search_one function.""" |
| 147 | +return [search_one(query) for query in chunk] |
0 commit comments