File tree

2 files changed

+62
-66
lines changed

2 files changed

+62
-66
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,59 @@
11
from abc import ABC
2-
from typing import List, Type
2+
import importlib
3+
from typing import Dict, List, Type
34

45
from engine.base_client.client import (
56
BaseClient,
67
BaseConfigurator,
78
BaseSearcher,
89
BaseUploader,
910
)
10-
from engine.clients.elasticsearch import (
11-
ElasticConfigurator,
12-
ElasticSearcher,
13-
ElasticUploader,
14-
)
15-
from engine.clients.milvus import MilvusConfigurator, MilvusSearcher, MilvusUploader
16-
from engine.clients.opensearch import (
17-
OpenSearchConfigurator,
18-
OpenSearchSearcher,
19-
OpenSearchUploader,
20-
)
21-
from engine.clients.pgvector import (
22-
PgVectorConfigurator,
23-
PgVectorSearcher,
24-
PgVectorUploader,
25-
)
26-
from engine.clients.qdrant import QdrantConfigurator, QdrantSearcher, QdrantUploader
27-
from engine.clients.redis import RedisConfigurator, RedisSearcher, RedisUploader
28-
from engine.clients.weaviate import (
29-
WeaviateConfigurator,
30-
WeaviateSearcher,
31-
WeaviateUploader,
32-
)
3311

34-
from engine.clients.vectorsets import (
35-
RedisVsetConfigurator,
36-
RedisVsetSearcher,
37-
RedisVsetUploader,
38-
)
12+
# Dictionary to store dynamically imported client classes
13+
_engine_classes = {}
14+
15+
def _import_engine_classes(engine_name: str) -> Dict[str, Type]:
16+
"""
17+
Dynamically import client classes for a specific engine.
18+
19+
Args:
20+
engine_name: The name of the engine (e.g., 'redis', 'qdrant')
21+
22+
Returns:
23+
Dictionary with configurator, uploader, and searcher classes
24+
"""
25+
if engine_name in _engine_classes:
26+
return _engine_classes[engine_name]
3927

40-
ENGINE_CONFIGURATORS = {
41-
"qdrant": QdrantConfigurator,
42-
"weaviate": WeaviateConfigurator,
43-
"milvus": MilvusConfigurator,
44-
"elasticsearch": ElasticConfigurator,
45-
"opensearch": OpenSearchConfigurator,
46-
"redis": RedisConfigurator,
47-
"pgvector": PgVectorConfigurator,
48-
"vectorsets": RedisVsetConfigurator,
49-
}
50-
51-
ENGINE_UPLOADERS = {
52-
"qdrant": QdrantUploader,
53-
"weaviate": WeaviateUploader,
54-
"milvus": MilvusUploader,
55-
"elasticsearch": ElasticUploader,
56-
"opensearch": OpenSearchUploader,
57-
"redis": RedisUploader,
58-
"pgvector": PgVectorUploader,
59-
"vectorsets": RedisVsetUploader,
60-
}
61-
62-
ENGINE_SEARCHERS = {
63-
"qdrant": QdrantSearcher,
64-
"weaviate": WeaviateSearcher,
65-
"milvus": MilvusSearcher,
66-
"elasticsearch": ElasticSearcher,
67-
"opensearch": OpenSearchSearcher,
68-
"redis": RedisSearcher,
69-
"pgvector": PgVectorSearcher,
70-
"vectorsets": RedisVsetSearcher,
71-
}
28+
# Handle special case for vectorsets which uses redis prefix
29+
if engine_name == "vectorsets":
30+
module_name = f"engine.clients.vectorsets"
31+
class_prefix = "RedisVset"
32+
else:
33+
module_name = f"engine.clients.{engine_name}"
34+
# Convert first letter to uppercase for class name
35+
class_prefix = engine_name[0].upper() + engine_name[1:]
36+
37+
try:
38+
module = importlib.import_module(module_name)
39+
configurator_class = getattr(module, f"{class_prefix}Configurator")
40+
uploader_class = getattr(module, f"{class_prefix}Uploader")
41+
searcher_class = getattr(module, f"{class_prefix}Searcher")
42+
43+
_engine_classes[engine_name] = {
44+
"configurator": configurator_class,
45+
"uploader": uploader_class,
46+
"searcher": searcher_class
47+
}
48+
49+
return _engine_classes[engine_name]
50+
except (ImportError, AttributeError) as e:
51+
raise ImportError(f"Failed to import classes for engine '{engine_name}': {e}")
52+
53+
# Empty dictionaries that will be populated on demand
54+
ENGINE_CONFIGURATORS = {}
55+
ENGINE_UPLOADERS = {}
56+
ENGINE_SEARCHERS = {}
7257

7358

7459
class ClientFactory(ABC):
@@ -78,7 +63,17 @@ def __init__(self, host):
7863

7964
def _create_configurator(self, experiment) -> BaseConfigurator:
8065
self.engine = experiment["engine"]
81-
engine_configurator_class = ENGINE_CONFIGURATORS[experiment["engine"]]
66+
engine_name = experiment["engine"]
67+
68+
# Dynamically import engine classes if not already imported
69+
if engine_name not in _engine_classes:
70+
_import_engine_classes(engine_name)
71+
# Add to the global dictionaries for compatibility
72+
ENGINE_CONFIGURATORS[engine_name] = _engine_classes[engine_name]["configurator"]
73+
ENGINE_UPLOADERS[engine_name] = _engine_classes[engine_name]["uploader"]
74+
ENGINE_SEARCHERS[engine_name] = _engine_classes[engine_name]["searcher"]
75+
76+
engine_configurator_class = _engine_classes[engine_name]["configurator"]
8277
engine_configurator = engine_configurator_class(
8378
self.host,
8479
collection_params={**experiment.get("collection_params", {})},
@@ -87,7 +82,8 @@ def _create_configurator(self, experiment) -> BaseConfigurator:
8782
return engine_configurator
8883

8984
def _create_uploader(self, experiment) -> BaseUploader:
90-
engine_uploader_class = ENGINE_UPLOADERS[experiment["engine"]]
85+
engine_name = experiment["engine"]
86+
engine_uploader_class = _engine_classes[engine_name]["uploader"]
9187
engine_uploader = engine_uploader_class(
9288
self.host,
9389
connection_params={**experiment.get("connection_params", {})},
@@ -96,9 +92,8 @@ def _create_uploader(self, experiment) -> BaseUploader:
9692
return engine_uploader
9793

9894
def _create_searchers(self, experiment) -> List[BaseSearcher]:
99-
engine_searcher_class: Type[BaseSearcher] = ENGINE_SEARCHERS[
100-
experiment["engine"]
101-
]
95+
engine_name = experiment["engine"]
96+
engine_searcher_class: Type[BaseSearcher] = _engine_classes[engine_name]["searcher"]
10297

10398
engine_searchers = [
10499
engine_searcher_class(
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def run(
3939
for name, config in all_engines.items()
4040
if any(fnmatch.fnmatch(name, engine) for engine in engines)
4141
}
42+
4243
selected_datasets = {
4344
name: config
4445
for name, config in all_datasets.items()

0 commit comments

Comments
 (0)