2020import re
2121import shutil
2222import tempfile
23+ import requests
2324from typing import (
2425 Any ,
2526 Dict ,
3536from google .api_core import operation
3637from google .api_core import exceptions as api_exceptions
3738from google .auth import credentials as auth_credentials
39+ from google .auth .transport import requests as google_auth_requests
3840
3941from google .cloud import aiplatform
4042from google .cloud .aiplatform import base
43+ from google .cloud .aiplatform import constants
4144from google .cloud .aiplatform import explain
4245from google .cloud .aiplatform import initializer
4346from google .cloud .aiplatform import jobs
6972_DEFAULT_MACHINE_TYPE = "n1-standard-2"
7073_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"
7174_SUCCESSFUL_HTTP_RESPONSE = 300
75+ _RAW_PREDICT_DEPLOYED_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id"
76+ _RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model"
7277
7378_LOGGER = base .Logger (__name__ )
7479
@@ -200,6 +205,8 @@ def __init__(
200205 location = self .location ,
201206 credentials = credentials ,
202207 )
208+ self .authorized_session = None
209+ self .raw_predict_request_url = None
203210
204211 def _skipped_getter_call (self ) -> bool :
205212 """Check if GAPIC resource was populated by call to get/list API methods
@@ -1389,16 +1396,15 @@ def update(
13891396 """Updates an endpoint.
13901397
13911398 Example usage:
1392-
1393- my_endpoint = my_endpoint.update(
1394- display_name='my-updated-endpoint',
1395- description='my updated description',
1396- labels={'key': 'value'},
1397- traffic_split={
1398- '123456': 20,
1399- '234567': 80,
1400- },
1401- )
1399+ my_endpoint = my_endpoint.update(
1400+ display_name='my-updated-endpoint',
1401+ description='my updated description',
1402+ labels={'key': 'value'},
1403+ traffic_split={
1404+ '123456': 20,
1405+ '234567': 80,
1406+ },
1407+ )
14021408
14031409 Args:
14041410 display_name (str):
@@ -1481,6 +1487,7 @@ def predict(
14811487 instances : List ,
14821488 parameters : Optional [Dict ] = None ,
14831489 timeout : Optional [float ] = None ,
1490+ use_raw_predict : Optional [bool ] = False ,
14841491 ) -> Prediction :
14851492 """Make a prediction against this Endpoint.
14861493
@@ -1505,29 +1512,80 @@ def predict(
15051512 [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
15061513 ``parameters_schema_uri``.
15071514 timeout (float): Optional. The timeout for this request in seconds.
1515+ use_raw_predict (bool):
1516+ Optional. Default value is False. If set to True, the underlying prediction call will be made
1517+ against Endpoint.raw_predict(). Note that model version information will
1518+ not be available in the prediciton response using raw_predict.
15081519
15091520 Returns:
15101521 prediction (aiplatform.Prediction):
15111522 Prediction with returned predictions and Model ID.
15121523 """
15131524 self .wait ()
1525+ if use_raw_predict :
1526+ raw_predict_response = self .raw_predict (
1527+ body = json .dumps ({"instances" : instances , "parameters" : parameters }),
1528+ headers = {"Content-Type" : "application/json" },
1529+ )
1530+ json_response = json .loads (raw_predict_response .text )
1531+ return Prediction (
1532+ predictions = json_response ["predictions" ],
1533+ deployed_model_id = raw_predict_response .headers [
1534+ _RAW_PREDICT_DEPLOYED_MODEL_ID_KEY
1535+ ],
1536+ model_resource_name = raw_predict_response .headers [
1537+ _RAW_PREDICT_MODEL_RESOURCE_KEY
1538+ ],
1539+ )
1540+ else :
1541+ prediction_response = self ._prediction_client .predict (
1542+ endpoint = self ._gca_resource .name ,
1543+ instances = instances ,
1544+ parameters = parameters ,
1545+ timeout = timeout ,
1546+ )
15141547
1515- prediction_response = self ._prediction_client .predict (
1516- endpoint = self ._gca_resource .name ,
1517- instances = instances ,
1518- parameters = parameters ,
1519- timeout = timeout ,
1520- )
1548+ return Prediction (
1549+ predictions = [
1550+ json_format .MessageToDict (item )
1551+ for item in prediction_response .predictions .pb
1552+ ],
1553+ deployed_model_id = prediction_response .deployed_model_id ,
1554+ model_version_id = prediction_response .model_version_id ,
1555+ model_resource_name = prediction_response .model ,
1556+ )
15211557
1522- return Prediction (
1523- predictions = [
1524- json_format .MessageToDict (item )
1525- for item in prediction_response .predictions .pb
1526- ],
1527- deployed_model_id = prediction_response .deployed_model_id ,
1528- model_version_id = prediction_response .model_version_id ,
1529- model_resource_name = prediction_response .model ,
1530- )
1558+ def raw_predict (
1559+ self , body : bytes , headers : Dict [str , str ]
1560+ ) -> requests .models .Response :
1561+ """Makes a prediction request using arbitrary headers.
1562+
1563+ Example usage:
1564+ my_endpoint = aiplatform.Endpoint(ENDPOINT_ID)
1565+ response = my_endpoint.raw_predict(
1566+ body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
1567+ headers = {'Content-Type':'application/json'}
1568+ )
1569+ status_code = response.status_code
1570+ results = json.dumps(response.text)
1571+
1572+ Args:
1573+ body (bytes):
1574+ The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
1575+ headers (Dict[str, str]):
1576+ The header of the request as a dictionary. There are no restrictions on the header.
1577+
1578+ Returns:
1579+ A requests.models.Response object containing the status code and prediction results.
1580+ """
1581+ if not self .authorized_session :
1582+ self .credentials ._scopes = constants .base .DEFAULT_AUTHED_SCOPES
1583+ self .authorized_session = google_auth_requests .AuthorizedSession (
1584+ self .credentials
1585+ )
1586+ self .raw_predict_request_url = f"https://{ self .location } -{ constants .base .API_BASE_PATH } /v1/projects/{ self .project } /locations/{ self .location } /endpoints/{ self .name } :rawPredict"
1587+
1588+ return self .authorized_session .post (self .raw_predict_request_url , body , headers )
15311589
15321590 def explain (
15331591 self ,
@@ -2004,7 +2062,7 @@ def _http_request(
20042062 def predict (self , instances : List , parameters : Optional [Dict ] = None ) -> Prediction :
20052063 """Make a prediction against this PrivateEndpoint using a HTTP request.
20062064 This method must be called within the network the PrivateEndpoint is peered to.
2007- The predict() call will fail otherwise . To check, use `PrivateEndpoint.network`.
2065+ Otherwise, the predict() call will fail with error code 404 . To check, use `PrivateEndpoint.network`.
20082066
20092067 Example usage:
20102068 response = my_private_endpoint.predict(instances=[...])
@@ -2062,6 +2120,39 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
20622120 deployed_model_id = self ._gca_resource .deployed_models [0 ].id ,
20632121 )
20642122
2123+ def raw_predict (
2124+ self , body : bytes , headers : Dict [str , str ]
2125+ ) -> requests .models .Response :
2126+ """Make a prediction request using arbitrary headers.
2127+ This method must be called within the network the PrivateEndpoint is peered to.
2128+ Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`.
2129+
2130+ Example usage:
2131+ my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID)
2132+ response = my_endpoint.raw_predict(
2133+ body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
2134+ headers = {'Content-Type':'application/json'}
2135+ )
2136+ status_code = response.status_code
2137+ results = json.dumps(response.text)
2138+
2139+ Args:
2140+ body (bytes):
2141+ The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
2142+ headers (Dict[str, str]):
2143+ The header of the request as a dictionary. There are no restrictions on the header.
2144+
2145+ Returns:
2146+ A requests.models.Response object containing the status code and prediction results.
2147+ """
2148+ self .wait ()
2149+ return self ._http_request (
2150+ method = "POST" ,
2151+ url = self .predict_http_uri ,
2152+ body = body ,
2153+ headers = headers ,
2154+ )
2155+
20652156 def explain (self ):
20662157 raise NotImplementedError (
20672158 f"{ self .__class__ .__name__ } class does not support 'explain' as of now."
0 commit comments