@@ -782,6 +782,7 @@ def create(
782782 enable_request_response_logging = False ,
783783 request_response_logging_sampling_rate : Optional [float ] = None ,
784784 request_response_logging_bq_destination_table : Optional [str ] = None ,
785+ dedicated_endpoint_enabled = False ,
785786 ) -> "Endpoint" :
786787 """Creates a new endpoint.
787788
@@ -849,6 +850,10 @@ def create(
849850 request_response_logging_bq_destination_table (str):
850851 Optional. The request response logging bigquery destination. If not set, will create a table with name:
851852 ``bq://{project_id}.logging_{endpoint_display_name}_{endpoint_id}.request_response_logging``.
853+ dedicated_endpoint_enabled (bool):
854+ Optional. If enabled, a dedicated dns will be created and your
855+ traffic will be fully isolated from other customers' traffic and
856+ latency will be reduced.
852857
853858 Returns:
854859 endpoint (aiplatform.Endpoint):
@@ -893,6 +898,7 @@ def create(
893898 create_request_timeout = create_request_timeout ,
894899 endpoint_id = endpoint_id ,
895900 predict_request_response_logging_config = predict_request_response_logging_config ,
901+ dedicated_endpoint_enabled = dedicated_endpoint_enabled ,
896902 )
897903
898904 @classmethod
@@ -918,6 +924,7 @@ def _create(
918924 private_service_connect_config : Optional [
919925 gca_service_networking .PrivateServiceConnectConfig
920926 ] = None ,
927+ dedicated_endpoint_enabled = False ,
921928 ) -> "Endpoint" :
922929 """Creates a new endpoint by calling the API client.
923930
@@ -984,6 +991,10 @@ def _create(
984991 private_service_connect_config (aiplatform.service_network.PrivateServiceConnectConfig):
985992 If enabled, the endpoint can be accessible via [Private Service Connect](https://cloud.google.com/vpc/docs/private-service-connect).
986993 Cannot be enabled when network is specified.
994+ dedicated_endpoint_enabled (bool):
995+ Optional. If enabled, a dedicated dns will be created and your
996+ traffic will be fully isolated from other customers' traffic and
997+ latency will be reduced.
987998
988999 Returns:
9891000 endpoint (aiplatform.Endpoint):
@@ -1002,6 +1013,7 @@ def _create(
10021013 network = network ,
10031014 predict_request_response_logging_config = predict_request_response_logging_config ,
10041015 private_service_connect_config = private_service_connect_config ,
1016+ dedicated_endpoint_enabled = dedicated_endpoint_enabled ,
10051017 )
10061018
10071019 operation_future = api_client .create_endpoint (
@@ -2167,9 +2179,18 @@ def predict(
21672179 parameters : Optional [Dict ] = None ,
21682180 timeout : Optional [float ] = None ,
21692181 use_raw_predict : Optional [bool ] = False ,
2182+ * ,
2183+ use_dedicated_endpoint : Optional [bool ] = False ,
21702184 ) -> Prediction :
21712185 """Make a prediction against this Endpoint.
21722186
2187+ For dedicated endpoint, set use_dedicated_endpoint = True:
2188+ ```
2189+ response = my_endpoint.predict(instances=[...],
2190+ use_dedicated_endpoint=True)
2191+ my_predictions = response.predictions
2192+ ```
2193+
21732194 Args:
21742195 instances (List):
21752196 Required. The instances that are the input to the
@@ -2194,6 +2215,9 @@ def predict(
21942215 use_raw_predict (bool):
21952216 Optional. Default value is False. If set to True, the underlying prediction call will be made
21962217 against Endpoint.raw_predict().
2218+ use_dedicated_endpoint (bool):
2219+ Optional. Default value is False. If set to True, the underlying prediction call will be made
2220+ using the dedicated endpoint dns.
21972221
21982222 Returns:
21992223 prediction (aiplatform.Prediction):
@@ -2204,6 +2228,7 @@ def predict(
22042228 raw_predict_response = self .raw_predict (
22052229 body = json .dumps ({"instances" : instances , "parameters" : parameters }),
22062230 headers = {"Content-Type" : "application/json" },
2231+ use_dedicated_endpoint = use_dedicated_endpoint ,
22072232 )
22082233 json_response = raw_predict_response .json ()
22092234 return Prediction (
@@ -2219,6 +2244,51 @@ def predict(
22192244 _RAW_PREDICT_MODEL_VERSION_ID_KEY , None
22202245 ),
22212246 )
2247+
2248+ if use_dedicated_endpoint :
2249+ self ._sync_gca_resource_if_skipped ()
2250+ if (
2251+ not self ._gca_resource .dedicated_endpoint_enabled
2252+ or self ._gca_resource .dedicated_endpoint_dns is None
2253+ ):
2254+ raise ValueError (
2255+ "Dedicated endpoint is not enabled or DNS is empty."
2256+ "Please make sure endpoint has dedicated endpoint enabled"
2257+ "and model are ready before making a prediction."
2258+ )
2259+
2260+ if not self .authorized_session :
2261+ self .credentials ._scopes = constants .base .DEFAULT_AUTHED_SCOPES
2262+ self .authorized_session = google_auth_requests .AuthorizedSession (
2263+ self .credentials
2264+ )
2265+
2266+ headers = {
2267+ "Content-Type" : "application/json" ,
2268+ }
2269+
2270+ url = f"https://{ self ._gca_resource .dedicated_endpoint_dns } /v1/{ self .resource_name } :predict"
2271+ response = self .authorized_session .post (
2272+ url = url ,
2273+ data = json .dumps (
2274+ {
2275+ "instances" : instances ,
2276+ "parameters" : parameters ,
2277+ }
2278+ ),
2279+ headers = headers ,
2280+ )
2281+
2282+ prediction_response = json .loads (response .text )
2283+
2284+ return Prediction (
2285+ predictions = prediction_response .get ("predictions" ),
2286+ metadata = prediction_response .get ("metadata" ),
2287+ deployed_model_id = prediction_response .get ("deployedModelId" ),
2288+ model_resource_name = prediction_response .get ("model" ),
2289+ model_version_id = prediction_response .get ("modelVersionId" ),
2290+ )
2291+
22222292 else :
22232293 prediction_response = self ._prediction_client .predict (
22242294 endpoint = self ._gca_resource .name ,
@@ -2307,7 +2377,11 @@ async def predict_async(
23072377 )
23082378
23092379 def raw_predict (
2310- self , body : bytes , headers : Dict [str , str ]
2380+ self ,
2381+ body : bytes ,
2382+ headers : Dict [str , str ],
2383+ * ,
2384+ use_dedicated_endpoint : Optional [bool ] = False ,
23112385 ) -> requests .models .Response :
23122386 """Makes a prediction request using arbitrary headers.
23132387
@@ -2317,6 +2391,12 @@ def raw_predict(
23172391 body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
23182392 headers = {'Content-Type':'application/json'}
23192393 )
2394+ # For dedicated endpoint:
2395+ response = my_endpoint.raw_predict(
2396+ body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
2397+ headers = {'Content-Type':'application/json'},
2398+ dedicated_endpoint=True,
2399+ )
23202400 status_code = response.status_code
23212401 results = json.dumps(response.text)
23222402
@@ -2325,6 +2405,9 @@ def raw_predict(
23252405 The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
23262406 headers (Dict[str, str]):
23272407 The header of the request as a dictionary. There are no restrictions on the header.
2408+ use_dedicated_endpoint (bool):
2409+ Optional. Default value is False. If set to True, the underlying prediction call will be made
2410+ using the dedicated endpoint dns.
23282411
23292412 Returns:
23302413 A requests.models.Response object containing the status code and prediction results.
@@ -2338,12 +2421,29 @@ def raw_predict(
23382421 if self .raw_predict_request_url is None :
23392422 self .raw_predict_request_url = f"https://{ self .location } -{ constants .base .API_BASE_PATH } /v1/projects/{ self .project } /locations/{ self .location } /endpoints/{ self .name } :rawPredict"
23402423
2341- return self .authorized_session .post (
2342- url = self .raw_predict_request_url , data = body , headers = headers
2343- )
2424+ url = self .raw_predict_request_url
2425+
2426+ if use_dedicated_endpoint :
2427+ self ._sync_gca_resource_if_skipped ()
2428+ if (
2429+ not self ._gca_resource .dedicated_endpoint_enabled
2430+ or self ._gca_resource .dedicated_endpoint_dns is None
2431+ ):
2432+ raise ValueError (
2433+ "Dedicated endpoint is not enabled or DNS is empty."
2434+ "Please make sure endpoint has dedicated endpoint enabled"
2435+ "and model are ready before making a prediction."
2436+ )
2437+ url = f"https://{ self ._gca_resource .dedicated_endpoint_dns } /v1/{ self .resource_name } :rawPredict"
2438+
2439+ return self .authorized_session .post (url = url , data = body , headers = headers )
23442440
23452441 def stream_raw_predict (
2346- self , body : bytes , headers : Dict [str , str ]
2442+ self ,
2443+ body : bytes ,
2444+ headers : Dict [str , str ],
2445+ * ,
2446+ use_dedicated_endpoint : Optional [bool ] = False ,
23472447 ) -> Iterator [requests .models .Response ]:
23482448 """Makes a streaming prediction request using arbitrary headers.
23492449
@@ -2358,13 +2458,28 @@ def stream_raw_predict(
23582458 stream_result = json.dumps(response.text)
23592459 ```
23602460
2461+ For dedicated endpoint:
2462+ ```
2463+ my_endpoint = aiplatform.Endpoint(ENDPOINT_ID)
2464+ for stream_response in my_endpoint.stream_raw_predict(
2465+ body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
2466+ headers = {'Content-Type':'application/json'},
2467+ use_dedicated_endpoint=True,
2468+ ):
2469+ status_code = response.status_code
2470+ stream_result = json.dumps(response.text)
2471+ ```
2472+
23612473 Args:
23622474 body (bytes):
23632475 The body of the prediction request in bytes. This must not
23642476 exceed 10 mb per request.
23652477 headers (Dict[str, str]):
23662478 The header of the request as a dictionary. There are no
23672479 restrictions on the header.
2480+ use_dedicated_endpoint (bool):
2481+ Optional. Default value is False. If set to True, the underlying prediction call will be made
2482+ using the dedicated endpoint dns.
23682483
23692484 Yields:
23702485 predictions (Iterator[requests.models.Response]):
@@ -2379,8 +2494,23 @@ def stream_raw_predict(
23792494 if self .stream_raw_predict_request_url is None :
23802495 self .stream_raw_predict_request_url = f"https://{ self .location } -{ constants .base .API_BASE_PATH } /v1/projects/{ self .project } /locations/{ self .location } /endpoints/{ self .name } :streamRawPredict"
23812496
2497+ url = self .raw_predict_request_url
2498+
2499+ if use_dedicated_endpoint :
2500+ self ._sync_gca_resource_if_skipped ()
2501+ if (
2502+ not self ._gca_resource .dedicated_endpoint_enabled
2503+ or self ._gca_resource .dedicated_endpoint_dns is None
2504+ ):
2505+ raise ValueError (
2506+ "Dedicated endpoint is not enabled or DNS is empty."
2507+ "Please make sure endpoint has dedicated endpoint enabled"
2508+ "and model are ready before making a prediction."
2509+ )
2510+ url = f"https://{ self ._gca_resource .dedicated_endpoint_dns } /v1/{ self .resource_name } :streamRawPredict"
2511+
23822512 with self .authorized_session .post (
2383- url = self . stream_raw_predict_request_url ,
2513+ url = url ,
23842514 data = body ,
23852515 headers = headers ,
23862516 stream = True ,
0 commit comments