File tree

7 files changed

+229
-34
lines changed

7 files changed

+229
-34
lines changed
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _are_futures_done(self) -> bool:
241241
return self.__latest_future is None
242242

243243
def wait(self):
244-
"""Helper method to that blocks until all futures are complete."""
244+
"""Helper method that blocks until all futures are complete."""
245245
future = self.__latest_future
246246
if future:
247247
futures.wait([future], return_when=futures.FIRST_EXCEPTION)
@@ -974,7 +974,11 @@ def _sync_object_with_future_result(
974974
"_gca_resource",
975975
"credentials",
976976
]
977-
optional_sync_attributes = ["_prediction_client"]
977+
optional_sync_attributes = [
978+
"_prediction_client",
979+
"_authorized_session",
980+
"_raw_predict_request_url",
981+
]
978982

979983
for attribute in sync_attributes:
980984
setattr(self, attribute, getattr(result, attribute))
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,6 @@
9292
# that is being used for usage metrics tracking purposes.
9393
# For more details on go/oneplatform-api-analytics
9494
USER_AGENT_SDK_COMMAND = ""
95+
96+
# Needed for Endpoint.raw_predict
97+
DEFAULT_AUTHED_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import re
2121
import shutil
2222
import tempfile
23+
import requests
2324
from typing import (
2425
Any,
2526
Dict,
@@ -35,9 +36,11 @@
3536
from google.api_core import operation
3637
from google.api_core import exceptions as api_exceptions
3738
from google.auth import credentials as auth_credentials
39+
from google.auth.transport import requests as google_auth_requests
3840

3941
from google.cloud import aiplatform
4042
from google.cloud.aiplatform import base
43+
from google.cloud.aiplatform import constants
4144
from google.cloud.aiplatform import explain
4245
from google.cloud.aiplatform import initializer
4346
from google.cloud.aiplatform import jobs
@@ -69,6 +72,8 @@
6972
_DEFAULT_MACHINE_TYPE = "n1-standard-2"
7073
_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"
7174
_SUCCESSFUL_HTTP_RESPONSE = 300
75+
_RAW_PREDICT_DEPLOYED_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id"
76+
_RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model"
7277

7378
_LOGGER = base.Logger(__name__)
7479

@@ -200,6 +205,8 @@ def __init__(
200205
location=self.location,
201206
credentials=credentials,
202207
)
208+
self.authorized_session = None
209+
self.raw_predict_request_url = None
203210

204211
def _skipped_getter_call(self) -> bool:
205212
"""Check if GAPIC resource was populated by call to get/list API methods
@@ -1389,16 +1396,15 @@ def update(
13891396
"""Updates an endpoint.
13901397
13911398
Example usage:
1392-
1393-
my_endpoint = my_endpoint.update(
1394-
display_name='my-updated-endpoint',
1395-
description='my updated description',
1396-
labels={'key': 'value'},
1397-
traffic_split={
1398-
'123456': 20,
1399-
'234567': 80,
1400-
},
1401-
)
1399+
my_endpoint = my_endpoint.update(
1400+
display_name='my-updated-endpoint',
1401+
description='my updated description',
1402+
labels={'key': 'value'},
1403+
traffic_split={
1404+
'123456': 20,
1405+
'234567': 80,
1406+
},
1407+
)
14021408
14031409
Args:
14041410
display_name (str):
@@ -1481,6 +1487,7 @@ def predict(
14811487
instances: List,
14821488
parameters: Optional[Dict] = None,
14831489
timeout: Optional[float] = None,
1490+
use_raw_predict: Optional[bool] = False,
14841491
) -> Prediction:
14851492
"""Make a prediction against this Endpoint.
14861493
@@ -1505,29 +1512,80 @@ def predict(
15051512
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
15061513
``parameters_schema_uri``.
15071514
timeout (float): Optional. The timeout for this request in seconds.
1515+
use_raw_predict (bool):
1516+
Optional. Default value is False. If set to True, the underlying prediction call will be made
1517+
against Endpoint.raw_predict(). Note that model version information will
1518+
not be available in the prediciton response using raw_predict.
15081519
15091520
Returns:
15101521
prediction (aiplatform.Prediction):
15111522
Prediction with returned predictions and Model ID.
15121523
"""
15131524
self.wait()
1525+
if use_raw_predict:
1526+
raw_predict_response = self.raw_predict(
1527+
body=json.dumps({"instances": instances, "parameters": parameters}),
1528+
headers={"Content-Type": "application/json"},
1529+
)
1530+
json_response = json.loads(raw_predict_response.text)
1531+
return Prediction(
1532+
predictions=json_response["predictions"],
1533+
deployed_model_id=raw_predict_response.headers[
1534+
_RAW_PREDICT_DEPLOYED_MODEL_ID_KEY
1535+
],
1536+
model_resource_name=raw_predict_response.headers[
1537+
_RAW_PREDICT_MODEL_RESOURCE_KEY
1538+
],
1539+
)
1540+
else:
1541+
prediction_response = self._prediction_client.predict(
1542+
endpoint=self._gca_resource.name,
1543+
instances=instances,
1544+
parameters=parameters,
1545+
timeout=timeout,
1546+
)
15141547

1515-
prediction_response = self._prediction_client.predict(
1516-
endpoint=self._gca_resource.name,
1517-
instances=instances,
1518-
parameters=parameters,
1519-
timeout=timeout,
1520-
)
1548+
return Prediction(
1549+
predictions=[
1550+
json_format.MessageToDict(item)
1551+
for item in prediction_response.predictions.pb
1552+
],
1553+
deployed_model_id=prediction_response.deployed_model_id,
1554+
model_version_id=prediction_response.model_version_id,
1555+
model_resource_name=prediction_response.model,
1556+
)
15211557

1522-
return Prediction(
1523-
predictions=[
1524-
json_format.MessageToDict(item)
1525-
for item in prediction_response.predictions.pb
1526-
],
1527-
deployed_model_id=prediction_response.deployed_model_id,
1528-
model_version_id=prediction_response.model_version_id,
1529-
model_resource_name=prediction_response.model,
1530-
)
1558+
def raw_predict(
1559+
self, body: bytes, headers: Dict[str, str]
1560+
) -> requests.models.Response:
1561+
"""Makes a prediction request using arbitrary headers.
1562+
1563+
Example usage:
1564+
my_endpoint = aiplatform.Endpoint(ENDPOINT_ID)
1565+
response = my_endpoint.raw_predict(
1566+
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
1567+
headers = {'Content-Type':'application/json'}
1568+
)
1569+
status_code = response.status_code
1570+
results = json.dumps(response.text)
1571+
1572+
Args:
1573+
body (bytes):
1574+
The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
1575+
headers (Dict[str, str]):
1576+
The header of the request as a dictionary. There are no restrictions on the header.
1577+
1578+
Returns:
1579+
A requests.models.Response object containing the status code and prediction results.
1580+
"""
1581+
if not self.authorized_session:
1582+
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
1583+
self.authorized_session = google_auth_requests.AuthorizedSession(
1584+
self.credentials
1585+
)
1586+
self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict"
1587+
1588+
return self.authorized_session.post(self.raw_predict_request_url, body, headers)
15311589

15321590
def explain(
15331591
self,
@@ -2004,7 +2062,7 @@ def _http_request(
20042062
def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction:
20052063
"""Make a prediction against this PrivateEndpoint using a HTTP request.
20062064
This method must be called within the network the PrivateEndpoint is peered to.
2007-
The predict() call will fail otherwise. To check, use `PrivateEndpoint.network`.
2065+
Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`.
20082066
20092067
Example usage:
20102068
response = my_private_endpoint.predict(instances=[...])
@@ -2062,6 +2120,39 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
20622120
deployed_model_id=self._gca_resource.deployed_models[0].id,
20632121
)
20642122

2123+
def raw_predict(
2124+
self, body: bytes, headers: Dict[str, str]
2125+
) -> requests.models.Response:
2126+
"""Make a prediction request using arbitrary headers.
2127+
This method must be called within the network the PrivateEndpoint is peered to.
2128+
Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`.
2129+
2130+
Example usage:
2131+
my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID)
2132+
response = my_endpoint.raw_predict(
2133+
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
2134+
headers = {'Content-Type':'application/json'}
2135+
)
2136+
status_code = response.status_code
2137+
results = json.dumps(response.text)
2138+
2139+
Args:
2140+
body (bytes):
2141+
The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
2142+
headers (Dict[str, str]):
2143+
The header of the request as a dictionary. There are no restrictions on the header.
2144+
2145+
Returns:
2146+
A requests.models.Response object containing the status code and prediction results.
2147+
"""
2148+
self.wait()
2149+
return self._http_request(
2150+
method="POST",
2151+
url=self.predict_http_uri,
2152+
body=body,
2153+
headers=headers,
2154+
)
2155+
20652156
def explain(self):
20662157
raise NotImplementedError(
20672158
f"{self.__class__.__name__} class does not support 'explain' as of now."
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@
8080
"uvicorn >= 0.16.0",
8181
]
8282

83-
private_endpoints_extra_require = [
84-
"urllib3 >=1.21.1, <1.27",
85-
]
83+
endpoint_extra_require = ["requests >= 2.28.1"]
84+
85+
private_endpoints_extra_require = ["urllib3 >=1.21.1, <1.27", "requests >= 2.28.1"]
8686
full_extra_require = list(
8787
set(
8888
tensorboard_extra_require
@@ -92,6 +92,7 @@
9292
+ featurestore_extra_require
9393
+ pipelines_extra_require
9494
+ datasets_extra_require
95+
+ endpoint_extra_require
9596
+ vizier_extra_require
9697
+ prediction_extra_require
9798
+ private_endpoints_extra_require
@@ -136,6 +137,7 @@
136137
"google-cloud-resource-manager >= 1.3.3, < 3.0.0dev",
137138
),
138139
extras_require={
140+
"endpoint": endpoint_extra_require,
139141
"full": full_extra_require,
140142
"metadata": metadata_extra_require,
141143
"tensorboard": tensorboard_extra_require,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2022 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import json
19+
20+
from google.cloud import aiplatform
21+
22+
from tests.system.aiplatform import e2e_base
23+
24+
_PERMANENT_IRIS_ENDPOINT_ID = "4966625964059525120"
25+
_PREDICTION_INSTANCE = {
26+
"petal_length": "3.0",
27+
"petal_width": "3.0",
28+
"sepal_length": "3.0",
29+
"sepal_width": "3.0",
30+
}
31+
32+
33+
class TestModelInteractions(e2e_base.TestEndToEnd):
34+
_temp_prefix = ""
35+
endpoint = aiplatform.Endpoint(_PERMANENT_IRIS_ENDPOINT_ID)
36+
37+
def test_prediction(self):
38+
# test basic predict
39+
prediction_response = self.endpoint.predict(instances=[_PREDICTION_INSTANCE])
40+
assert len(prediction_response.predictions) == 1
41+
42+
# test predict(use_raw_predict = True)
43+
prediction_with_raw_predict = self.endpoint.predict(
44+
instances=[_PREDICTION_INSTANCE], use_raw_predict=True
45+
)
46+
assert (
47+
prediction_with_raw_predict.deployed_model_id
48+
== prediction_response.deployed_model_id
49+
)
50+
assert (
51+
prediction_with_raw_predict.model_resource_name
52+
== prediction_response.model_resource_name
53+
)
54+
55+
# test raw_predict
56+
raw_prediction_response = self.endpoint.raw_predict(
57+
json.dumps({"instances": [_PREDICTION_INSTANCE]}),
58+
{"Content-Type": "application/json"},
59+
)
60+
assert raw_prediction_response.status_code == 200
61+
assert len(json.loads(raw_prediction_response.text)) == 1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
@pytest.mark.usefixtures("delete_staging_bucket", "tear_down_resources")
32-
class TestModel(e2e_base.TestEndToEnd):
32+
class TestModelUploadAndUpdate(e2e_base.TestEndToEnd):
3333

3434
_temp_prefix = "temp_vertex_sdk_e2e_model_upload_test"
3535

@@ -65,9 +65,8 @@ def test_upload_and_deploy_xgboost_model(self, shared_state):
6565
# See https://.com/googleapis/python-aiplatform/issues/773
6666
endpoint = model.deploy(machine_type="n1-standard-2")
6767
shared_state["resources"].append(endpoint)
68-
predict_response = endpoint.predict(instances=[[0, 0, 0]])
69-
assert len(predict_response.predictions) == 1
7068

69+
# test model update
7170
model = model.update(
7271
display_name="new_name",
7372
description="new_description",

0 commit comments

Comments
 (0)