|
20 | 20 | import re
|
21 | 21 | import shutil
|
22 | 22 | import tempfile
|
| 23 | +import requests |
23 | 24 | from typing import (
|
24 | 25 | Any,
|
25 | 26 | Dict,
|
|
35 | 36 | from google.api_core import operation
|
36 | 37 | from google.api_core import exceptions as api_exceptions
|
37 | 38 | from google.auth import credentials as auth_credentials
|
| 39 | +from google.auth.transport import requests as google_auth_requests |
38 | 40 |
|
39 | 41 | from google.cloud import aiplatform
|
40 | 42 | from google.cloud.aiplatform import base
|
| 43 | +from google.cloud.aiplatform import constants |
41 | 44 | from google.cloud.aiplatform import explain
|
42 | 45 | from google.cloud.aiplatform import initializer
|
43 | 46 | from google.cloud.aiplatform import jobs
|
|
69 | 72 | _DEFAULT_MACHINE_TYPE = "n1-standard-2"
|
70 | 73 | _DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"
|
71 | 74 | _SUCCESSFUL_HTTP_RESPONSE = 300
|
| 75 | +_RAW_PREDICT_DEPLOYED_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id" |
| 76 | +_RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model" |
72 | 77 |
|
73 | 78 | _LOGGER = base.Logger(__name__)
|
74 | 79 |
|
@@ -200,6 +205,8 @@ def __init__(
|
200 | 205 | location=self.location,
|
201 | 206 | credentials=credentials,
|
202 | 207 | )
|
| 208 | +self.authorized_session = None |
| 209 | +self.raw_predict_request_url = None |
203 | 210 |
|
204 | 211 | def _skipped_getter_call(self) -> bool:
|
205 | 212 | """Check if GAPIC resource was populated by call to get/list API methods
|
@@ -1389,16 +1396,15 @@ def update(
|
1389 | 1396 | """Updates an endpoint.
|
1390 | 1397 |
|
1391 | 1398 | Example usage:
|
1392 |
| -
|
1393 |
| -my_endpoint = my_endpoint.update( |
1394 |
| -display_name='my-updated-endpoint', |
1395 |
| -description='my updated description', |
1396 |
| -labels={'key': 'value'}, |
1397 |
| -traffic_split={ |
1398 |
| -'123456': 20, |
1399 |
| -'234567': 80, |
1400 |
| -}, |
1401 |
| -) |
| 1399 | +my_endpoint = my_endpoint.update( |
| 1400 | +display_name='my-updated-endpoint', |
| 1401 | +description='my updated description', |
| 1402 | +labels={'key': 'value'}, |
| 1403 | +traffic_split={ |
| 1404 | +'123456': 20, |
| 1405 | +'234567': 80, |
| 1406 | +}, |
| 1407 | +) |
1402 | 1408 |
|
1403 | 1409 | Args:
|
1404 | 1410 | display_name (str):
|
@@ -1481,6 +1487,7 @@ def predict(
|
1481 | 1487 | instances: List,
|
1482 | 1488 | parameters: Optional[Dict] = None,
|
1483 | 1489 | timeout: Optional[float] = None,
|
| 1490 | +use_raw_predict: Optional[bool] = False, |
1484 | 1491 | ) -> Prediction:
|
1485 | 1492 | """Make a prediction against this Endpoint.
|
1486 | 1493 |
|
@@ -1505,29 +1512,80 @@ def predict(
|
1505 | 1512 | [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
|
1506 | 1513 | ``parameters_schema_uri``.
|
1507 | 1514 | timeout (float): Optional. The timeout for this request in seconds.
|
| 1515 | +use_raw_predict (bool): |
| 1516 | +Optional. Default value is False. If set to True, the underlying prediction call will be made |
| 1517 | +against Endpoint.raw_predict(). Note that model version information will |
| 1518 | +not be available in the prediciton response using raw_predict. |
1508 | 1519 |
|
1509 | 1520 | Returns:
|
1510 | 1521 | prediction (aiplatform.Prediction):
|
1511 | 1522 | Prediction with returned predictions and Model ID.
|
1512 | 1523 | """
|
1513 | 1524 | self.wait()
|
| 1525 | +if use_raw_predict: |
| 1526 | +raw_predict_response = self.raw_predict( |
| 1527 | +body=json.dumps({"instances": instances, "parameters": parameters}), |
| 1528 | +headers={"Content-Type": "application/json"}, |
| 1529 | +) |
| 1530 | +json_response = json.loads(raw_predict_response.text) |
| 1531 | +return Prediction( |
| 1532 | +predictions=json_response["predictions"], |
| 1533 | +deployed_model_id=raw_predict_response.headers[ |
| 1534 | +_RAW_PREDICT_DEPLOYED_MODEL_ID_KEY |
| 1535 | +], |
| 1536 | +model_resource_name=raw_predict_response.headers[ |
| 1537 | +_RAW_PREDICT_MODEL_RESOURCE_KEY |
| 1538 | +], |
| 1539 | +) |
| 1540 | +else: |
| 1541 | +prediction_response = self._prediction_client.predict( |
| 1542 | +endpoint=self._gca_resource.name, |
| 1543 | +instances=instances, |
| 1544 | +parameters=parameters, |
| 1545 | +timeout=timeout, |
| 1546 | +) |
1514 | 1547 |
|
1515 |
| -prediction_response = self._prediction_client.predict( |
1516 |
| -endpoint=self._gca_resource.name, |
1517 |
| -instances=instances, |
1518 |
| -parameters=parameters, |
1519 |
| -timeout=timeout, |
1520 |
| -) |
| 1548 | +return Prediction( |
| 1549 | +predictions=[ |
| 1550 | +json_format.MessageToDict(item) |
| 1551 | +for item in prediction_response.predictions.pb |
| 1552 | +], |
| 1553 | +deployed_model_id=prediction_response.deployed_model_id, |
| 1554 | +model_version_id=prediction_response.model_version_id, |
| 1555 | +model_resource_name=prediction_response.model, |
| 1556 | +) |
1521 | 1557 |
|
1522 |
| -return Prediction( |
1523 |
| -predictions=[ |
1524 |
| -json_format.MessageToDict(item) |
1525 |
| -for item in prediction_response.predictions.pb |
1526 |
| -], |
1527 |
| -deployed_model_id=prediction_response.deployed_model_id, |
1528 |
| -model_version_id=prediction_response.model_version_id, |
1529 |
| -model_resource_name=prediction_response.model, |
1530 |
| -) |
| 1558 | +def raw_predict( |
| 1559 | +self, body: bytes, headers: Dict[str, str] |
| 1560 | +) -> requests.models.Response: |
| 1561 | +"""Makes a prediction request using arbitrary headers. |
| 1562 | +
|
| 1563 | +Example usage: |
| 1564 | +my_endpoint = aiplatform.Endpoint(ENDPOINT_ID) |
| 1565 | +response = my_endpoint.raw_predict( |
| 1566 | +body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}' |
| 1567 | +headers = {'Content-Type':'application/json'} |
| 1568 | +) |
| 1569 | +status_code = response.status_code |
| 1570 | +results = json.dumps(response.text) |
| 1571 | +
|
| 1572 | +Args: |
| 1573 | +body (bytes): |
| 1574 | +The body of the prediction request in bytes. This must not exceed 1.5 mb per request. |
| 1575 | +headers (Dict[str, str]): |
| 1576 | +The header of the request as a dictionary. There are no restrictions on the header. |
| 1577 | +
|
| 1578 | +Returns: |
| 1579 | +A requests.models.Response object containing the status code and prediction results. |
| 1580 | +""" |
| 1581 | +if not self.authorized_session: |
| 1582 | +self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES |
| 1583 | +self.authorized_session = google_auth_requests.AuthorizedSession( |
| 1584 | +self.credentials |
| 1585 | +) |
| 1586 | +self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict" |
| 1587 | + |
| 1588 | +return self.authorized_session.post(self.raw_predict_request_url, body, headers) |
1531 | 1589 |
|
1532 | 1590 | def explain(
|
1533 | 1591 | self,
|
@@ -2004,7 +2062,7 @@ def _http_request(
|
2004 | 2062 | def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction:
|
2005 | 2063 | """Make a prediction against this PrivateEndpoint using a HTTP request.
|
2006 | 2064 | This method must be called within the network the PrivateEndpoint is peered to.
|
2007 |
| -The predict() call will fail otherwise. To check, use `PrivateEndpoint.network`. |
| 2065 | +Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`. |
2008 | 2066 |
|
2009 | 2067 | Example usage:
|
2010 | 2068 | response = my_private_endpoint.predict(instances=[...])
|
@@ -2062,6 +2120,39 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
|
2062 | 2120 | deployed_model_id=self._gca_resource.deployed_models[0].id,
|
2063 | 2121 | )
|
2064 | 2122 |
|
| 2123 | +def raw_predict( |
| 2124 | +self, body: bytes, headers: Dict[str, str] |
| 2125 | +) -> requests.models.Response: |
| 2126 | +"""Make a prediction request using arbitrary headers. |
| 2127 | +This method must be called within the network the PrivateEndpoint is peered to. |
| 2128 | +Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`. |
| 2129 | +
|
| 2130 | +Example usage: |
| 2131 | +my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID) |
| 2132 | +response = my_endpoint.raw_predict( |
| 2133 | +body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}' |
| 2134 | +headers = {'Content-Type':'application/json'} |
| 2135 | +) |
| 2136 | +status_code = response.status_code |
| 2137 | +results = json.dumps(response.text) |
| 2138 | +
|
| 2139 | +Args: |
| 2140 | +body (bytes): |
| 2141 | +The body of the prediction request in bytes. This must not exceed 1.5 mb per request. |
| 2142 | +headers (Dict[str, str]): |
| 2143 | +The header of the request as a dictionary. There are no restrictions on the header. |
| 2144 | +
|
| 2145 | +Returns: |
| 2146 | +A requests.models.Response object containing the status code and prediction results. |
| 2147 | +""" |
| 2148 | +self.wait() |
| 2149 | +return self._http_request( |
| 2150 | +method="POST", |
| 2151 | +url=self.predict_http_uri, |
| 2152 | +body=body, |
| 2153 | +headers=headers, |
| 2154 | +) |
| 2155 | + |
2065 | 2156 | def explain(self):
|
2066 | 2157 | raise NotImplementedError(
|
2067 | 2158 | f"{self.__class__.__name__} class does not support 'explain' as of now."
|
|
0 commit comments