23 files changed

+796
-37
lines changed
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,10 @@ def project(self) -> str:
274274
project_not_found_exception_str = (
275275
"Unable to find your project. Please provide a project ID by:"
276276
"\n- Passing a constructor argument"
277-
"\n- Using aiplatform.init()"
277+
"\n- Using vertexai.init()"
278278
"\n- Setting project using 'gcloud config set project my-project'"
279279
"\n- Setting a GCP environment variable"
280+
"\n- To create a Google Cloud project, please follow guidance at https://developers.google.com/workspace/guides/create-project"
280281
)
281282

282283
try:
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ message MatchRequest {
4242
// The list of restricts.
4343
repeated Namespace restricts = 4;
4444

45+
//The list of numeric restricts.
46+
repeated NumericNamespace numeric_restricts = 11;
47+
4548
// Crowding is a constraint on a neigr list produced by nearest neigr
4649
// search requiring that no more than some value k' of the k neigrs
4750
// returned have the same value of crowding_attribute.
@@ -88,6 +91,9 @@ message Embedding {
8891
// The list of restricts.
8992
repeated Namespace restricts = 3;
9093

94+
// The list of numeric restricts.
95+
repeated NumericNamespace numeric_restricts = 5;
96+
9197
// The attribute value used for crowding. The maximum number of neigrs
9298
// to return per crowding attribute value
9399
// (per_crowding_attribute_num_neigrs) is configured per-query.
@@ -175,6 +181,7 @@ message BatchMatchResponse {
175181

176182
// Namespace specifies the rules for determining the datapoints that are
177183
// eligible for each matching query, overall query is an AND across namespaces.
184+
// This uses categorical tokens.
178185
message Namespace {
179186
// The string name of the namespace that this proto is specifying,
180187
// such as "color", "shape", "geo", or "tags".
@@ -192,4 +199,53 @@ message Namespace {
192199
// query will match datapoints that are red or blue, but if those points are
193200
// also purple, then they will be excluded even if they are red/blue.
194201
repeated string deny_tokens = 3;
195-
}
202+
}
203+
204+
// NumericNamespace specifies the rules for determining the datapoints that are
205+
// eligible for each matching query, overall query is an AND across namespaces.
206+
// This uses numeric comparisons.
207+
message NumericNamespace {
208+
209+
// The string name of the namespace that this proto is specifying,
210+
// such as "size" or "cost".
211+
string name = 1;
212+
213+
// The type of Value must be consistent for all datapoints with a given
214+
// namespace name. This is verified at runtime.
215+
oneof Value {
216+
// Represents 64 bit integer.
217+
int64 value_int = 2;
218+
// Represents 32 bit float.
219+
float value_float = 3;
220+
// Represents 64 bit float.
221+
double value_double = 4;
222+
}
223+
224+
// Which comparison operator to use. Should be specified for queries only;
225+
// specifying this for a datapoint is an error.
226+
//
227+
// Datapoints for which Operator is true relative to the query's Value
228+
// field will be allowlisted.
229+
enum Operator {
230+
// Default value of the enum.
231+
OPERATOR_UNSPECIFIED = 0;
232+
233+
// Datapoints are eligible iff their value is < the query's.
234+
LESS = 1;
235+
236+
// Datapoints are eligible iff their value is <= the query's.
237+
LESS_EQUAL = 2;
238+
239+
// Datapoints are eligible iff their value is == the query's.
240+
EQUAL = 3;
241+
242+
// Datapoints are eligible iff their value is >= the query's.
243+
GREATER_EQUAL = 4;
244+
245+
// Datapoints are eligible iff their value is > the query's.
246+
GREATER = 5;
247+
}
248+
249+
// Which comparison operator to use.
250+
Operator op = 5;
251+
}
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def __init__(
216216
)
217217
self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name)
218218

219+
self._public_match_client = None
219220
if self.public_endpoint_domain_name:
220221
self._public_match_client = self._instantiate_public_match_client()
221222

@@ -518,6 +519,36 @@ def _instantiate_public_match_client(
518519
api_path_override=self.public_endpoint_domain_name,
519520
)
520521

522+
def _instantiate_private_match_service_stub(
523+
self,
524+
deployed_index_id: str,
525+
) -> match_service_pb2_grpc.MatchServiceStub:
526+
"""Helper method to instantiate private match service stub.
527+
Args:
528+
deployed_index_id (str):
529+
Required. The user specified ID of the
530+
DeployedIndex.
531+
Returns:
532+
stub (match_service_pb2_grpc.MatchServiceStub):
533+
Initialized match service stub.
534+
"""
535+
# Find the deployed index by id
536+
deployed_indexes = [
537+
deployed_index
538+
for deployed_index in self.deployed_indexes
539+
if deployed_index.id == deployed_index_id
540+
]
541+
542+
if not deployed_indexes:
543+
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")
544+
545+
# Retrieve server ip from deployed index
546+
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address
547+
548+
# Set up channel and stub
549+
channel = grpc.insecure_channel("{}:10000".format(server_ip))
550+
return match_service_pb2_grpc.MatchServiceStub(channel)
551+
521552
@property
522553
def public_endpoint_domain_name(self) -> Optional[str]:
523554
"""Public endpoint DNS name."""
@@ -1233,7 +1264,8 @@ def read_index_datapoints(
12331264
deployed_index_id: str,
12341265
ids: List[str] = [],
12351266
) -> List[gca_index_v1beta1.IndexDatapoint]:
1236-
"""Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint.
1267+
"""Reads the datapoints/vectors of the given IDs on the specified
1268+
deployed index which is deployed to public or private endpoint.
12371269
12381270
```
12391271
Example Usage:
@@ -1252,9 +1284,25 @@ def read_index_datapoints(
12521284
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
12531285
"""
12541286
if not self._public_match_client:
1255-
raise ValueError(
1256-
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
1287+
# Call private match service stub with BatchGetEmbeddings request
1288+
response = self._batch_get_embeddings(
1289+
deployed_index_id=deployed_index_id, ids=ids
12571290
)
1291+
return [
1292+
gca_index_v1beta1.IndexDatapoint(
1293+
datapoint_id=embedding.id,
1294+
feature_vector=embedding.float_val,
1295+
restricts=gca_index_v1beta1.IndexDatapoint.Restriction(
1296+
namespace=embedding.restricts.name,
1297+
allow_list=embedding.restricts.allow_tokens,
1298+
),
1299+
deny_list=embedding.restricts.deny_tokens,
1300+
crowding_attributes=gca_index_v1beta1.CrowdingEmbedding(
1301+
str(embedding.crowding_tag)
1302+
),
1303+
)
1304+
for embedding in response.embeddings
1305+
]
12581306

12591307
# Create the ReadIndexDatapoints request
12601308
read_index_datapoints_request = (
@@ -1273,6 +1321,38 @@ def read_index_datapoints(
12731321
# Wrap the results and return
12741322
return response.datapoints
12751323

1324+
def _batch_get_embeddings(
1325+
self,
1326+
*,
1327+
deployed_index_id: str,
1328+
ids: List[str] = [],
1329+
) -> List[List[match_service_pb2.Embedding]]:
1330+
"""
1331+
Reads the datapoints/vectors of the given IDs on the specified index
1332+
which is deployed to private endpoint.
1333+
1334+
Args:
1335+
deployed_index_id (str):
1336+
Required. The ID of the DeployedIndex to match the queries against.
1337+
ids (List[str]):
1338+
Required. IDs of the datapoints to be searched for.
1339+
Returns:
1340+
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
1341+
"""
1342+
stub = self._instantiate_private_match_service_stub(
1343+
deployed_index_id=deployed_index_id
1344+
)
1345+
1346+
# Create the batch get embeddings request
1347+
batch_request = match_service_pb2.BatchGetEmbeddingsRequest()
1348+
batch_request.deployed_index_id = deployed_index_id
1349+
1350+
for id in ids:
1351+
batch_request.id.append(id)
1352+
response = stub.BatchGetEmbeddings(batch_request)
1353+
1354+
return response.embeddings
1355+
12761356
def match(
12771357
self,
12781358
deployed_index_id: str,
@@ -1310,23 +1390,9 @@ def match(
13101390
Returns:
13111391
List[List[MatchNeigr]] - A list of nearest neigrs for each query.
13121392
"""
1313-
1314-
# Find the deployed index by id
1315-
deployed_indexes = [
1316-
deployed_index
1317-
for deployed_index in self.deployed_indexes
1318-
if deployed_index.id == deployed_index_id
1319-
]
1320-
1321-
if not deployed_indexes:
1322-
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")
1323-
1324-
# Retrieve server ip from deployed index
1325-
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address
1326-
1327-
# Set up channel and stub
1328-
channel = grpc.insecure_channel("{}:10000".format(server_ip))
1329-
stub = match_service_pb2_grpc.MatchServiceStub(channel)
1393+
stub = self._instantiate_private_match_service_stub(
1394+
deployed_index_id=deployed_index_id
1395+
)
13301396

13311397
# Create the batch match request
13321398
batch_request = match_service_pb2.BatchMatchRequest()
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,18 @@ def __init__(self, address: Optional[str]) -> None:
111111
" failed to start Head node properly because custom service account isn't supported.",
112112
)
113113
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
114+
cluster = _gapic_utils.persistent_resource_to_cluster(
115+
persistent_resource=self.response
116+
)
117+
if cluster is None:
118+
raise ValueError(
119+
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
120+
)
121+
local_ray_verion = _validation_utils.get_local_ray_version()
122+
if cluster.ray_version != local_ray_verion:
123+
raise ValueError(
124+
f"[Ray on Vertex AI]: Local runtime has Ray version {local_ray_verion}, but the cluster runtime has {cluster.ray_version}. Please ensure that the Ray versions match."
125+
)
114126
super().__init__(address)
115127

116128
def connect(self) -> _VertexRayClientContext:
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import google.auth
1919
import google.auth.transport.requests
2020
import logging
21+
import ray
2122
import re
2223

2324
from google.cloud.aiplatform import initializer
@@ -68,6 +69,13 @@ def maybe_reconstruct_resource_name(address) -> str:
6869
return address
6970

7071

72+
def get_local_ray_version():
73+
ray_version = ray.__version__.split(".")
74+
if len(ray_version) == 3:
75+
ray_version = ray_version[:2]
76+
return "_".join(ray_version)
77+
78+
7179
def get_image_uri(ray_version, python_version, enable_cuda):
7280
"""Image uri for a given ray version and python version."""
7381
if ray_version not in ["2_4"]:
Original file line numberDiff line numberDiff line change
@@ -1204,3 +1204,52 @@ def mock_autolog():
12041204
with .object(aiplatform, "autolog") as mock_autolog_method:
12051205
mock_autolog_method.return_value = None
12061206
yield mock_autolog_method
1207+
1208+
1209+
"""
1210+
----------------------------------------------------------------------------
1211+
Vector Search Fixtures
1212+
----------------------------------------------------------------------------
1213+
"""
1214+
1215+
1216+
@pytest.fixture
1217+
def mock_index():
1218+
mock = MagicMock(aiplatform.MatchingEngineIndex)
1219+
yield mock
1220+
1221+
1222+
@pytest.fixture
1223+
def mock_index_endpoint():
1224+
mock = MagicMock(aiplatform.MatchingEngineIndexEndpoint)
1225+
yield mock
1226+
1227+
1228+
@pytest.fixture
1229+
def mock_index_init(mock_index):
1230+
with .object(aiplatform, "MatchingEngineIndex") as mock:
1231+
mock.return_value = mock_index
1232+
yield mock
1233+
1234+
1235+
@pytest.fixture
1236+
def mock_index_upsert_datapoints(mock_index):
1237+
with .object(mock_index, "upsert_datapoints") as mock_upsert:
1238+
mock_upsert.return_value = None
1239+
yield mock_upsert
1240+
1241+
1242+
@pytest.fixture
1243+
def mock_index_endpoint_init(mock_index_endpoint):
1244+
with .object(aiplatform, "MatchingEngineIndexEndpoint") as mock:
1245+
mock.return_value = mock_index_endpoint
1246+
yield mock
1247+
1248+
1249+
@pytest.fixture
1250+
def mock_index_endpoint_find_neigrs(mock_index_endpoint):
1251+
with .object(
1252+
mock_index_endpoint, "find_neigrs"
1253+
) as mock_find_neigrs:
1254+
mock_find_neigrs.return_value = None
1255+
yield mock_find_neigrs
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,12 @@
338338
)
339339
TENSORBOARD_EXPERIMENT_NAME = "my-tensorboard-experiment"
340340
TENSORBOARD_PLUGIN_PROFILE_NAME = "profile"
341+
342+
# Vector Search
343+
VECTOR_SEARCH_INDEX = "123"
344+
VECTOR_SERACH_INDEX_DATAPOINTS = [
345+
{"datapoint_id": "datapoint_id_1", "feature_vector": [0.1]}
346+
]
347+
VECTOR_SEARCH_INDEX_ENDPOINT = "456"
348+
VECTOR_SEARCH_DEPLOYED_INDEX_ID = "789"
349+
VECTOR_SERACH_INDEX_QUERIES = [[0.1]]

0 commit comments

Comments
 (0)