Skip to content
15 changes: 5 additions & 10 deletions ads/model/deployment/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,14 @@ def send_request(
Returns:
A JSON representive of a requests.Response object.
"""
headers = dict()
if is_json_payload:
headers["Content-Type"] = (
header.get("content_type") or DEFAULT_CONTENT_TYPE_JSON
)
header["Content-Type"] = header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON) or DEFAULT_CONTENT_TYPE_JSON
request_kwargs = {"json": data}
else:
headers["Content-Type"] = (
header.get("content_type") or DEFAULT_CONTENT_TYPE_BYTES
)
header["Content-Type"] = header.pop("content_type", DEFAULT_CONTENT_TYPE_BYTES) or DEFAULT_CONTENT_TYPE_BYTES
request_kwargs = {"data": data} # should pass bytes when using data

request_kwargs["headers"] = headers
request_kwargs["headers"] = header

if dry_run:
request_kwargs["headers"]["Accept"] = "*/*"
Expand All @@ -140,7 +135,7 @@ def send_request(
return json.loads(req.body)
return req.body
else:
request_kwargs["auth"] = header.get("signer")
request_kwargs["auth"] = header.pop("signer")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we pop signer for dry_run as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return requests.post(endpoint, **request_kwargs).json()


Expand Down
24 changes: 21 additions & 3 deletions ads/model/deployment/model_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ads.common import auth as authutil
import pandas as pd
from ads.model.serde.model_input import JsonModelInputSERDE
from ads.common import auth, oci_client
from ads.common.oci_logging import (
LOG_INTERVAL,
LOG_RECORDS_LIMIT,
Expand Down Expand Up @@ -63,6 +62,7 @@

MODEL_DEPLOYMENT_KIND = "deployment"
MODEL_DEPLOYMENT_TYPE = "modelDeployment"
MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON = "TRITON"

MODEL_DEPLOYMENT_INSTANCE_SHAPE = "VM.Standard.E4.Flex"
MODEL_DEPLOYMENT_INSTANCE_OCPUS = 1
Expand Down Expand Up @@ -828,6 +828,8 @@ def predict(
data: Any = None,
serializer: "ads.model.ModelInputSerializer" = model_input_serializer,
auto_serialize_data: bool = False,
model_name: str = None,
model_version: str = None,
**kwargs,
) -> dict:
"""Returns prediction of input data run against the model deployment endpoint.
Expand Down Expand Up @@ -860,6 +862,10 @@ def predict(
If `auto_serialize_data=False`, `data` required to be bytes or json serializable
and `json_input` required to be json serializable. If `auto_serialize_data` set
to True, data will be serialized before sending to model deployment endpoint.
model_name: str
Defaults to None. When the `Inference_server="triton"`, the name of the model to invoke.
model_version: str
Defaults to None. When the `Inference_server="triton"`, the version of the model to invoke.
kwargs:
content_type: str
Used to indicate the media type of the resource.
Expand All @@ -878,6 +884,7 @@ def predict(
"signer": signer,
"content_type": kwargs.get("content_type", None),
}
header.update(kwargs.pop("headers", {}))

if data is None and json_input is None:
raise AttributeError(
Expand Down Expand Up @@ -916,9 +923,13 @@ def predict(
raise TypeError(
"`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
)

if model_name and model_version:
header['model-name'] = model_name
header['model-version'] = model_version
elif bool(model_version) ^ bool(model_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can rearrange this if/elif -

if bool(model_version) ^ bool(model_name): raise ValueError("`model_name` and `model_version` have to be provided together.") header['model-name'] = model_name header['model-version'] = model_version 
raise ValueError("`model_name` and `model_version` have to be provided together.")
prediction = send_request(
data=data, endpoint=endpoint, is_json_payload=is_json_payload, header=header
data=data, endpoint=endpoint, is_json_payload=is_json_payload, header=header,
)
return prediction

Expand Down Expand Up @@ -1390,6 +1401,10 @@ def _update_from_oci_model(self, oci_model_instance) -> "ModelDeployment":
infrastructure.CONST_WEB_CONCURRENCY,
runtime.env.get("WEB_CONCURRENCY", None),
)
if runtime.env.get("CONTAINER_TYPE", None) == MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON:
runtime.set_spec(
runtime.CONST_INFERENCE_SERVER, MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON.lower()
)

self.set_spec(self.CONST_INFRASTRUCTURE, infrastructure)
self.set_spec(self.CONST_RUNTIME, runtime)
Expand Down Expand Up @@ -1566,6 +1581,9 @@ def _build_model_deployment_configuration_details(self) -> Dict:
infrastructure.web_concurrency
)
runtime.set_spec(runtime.CONST_ENV, environment_variables)
if hasattr(runtime, "inference_server") and runtime.inference_server and runtime.inference_server.upper() == MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON:
environment_variables["CONTAINER_TYPE"] = MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON
runtime.set_spec(runtime.CONST_ENV, environment_variables)
environment_configuration_details = {
runtime.CONST_ENVIRONMENT_CONFIG_TYPE: runtime.environment_config_type,
runtime.CONST_ENVIRONMENT_VARIABLES: runtime.env,
Expand Down
56 changes: 56 additions & 0 deletions ads/model/deployment/model_deployment_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ class ModelDeploymentContainerRuntime(ModelDeploymentRuntime):
CONST_ENTRYPOINT = "entrypoint"
CONST_SERVER_PORT = "serverPort"
CONST_HEALTH_CHECK_PORT = "healthCheckPort"
CONST_INFERENCE_SERVER = "inferenceServer"

attribute_map = {
**ModelDeploymentRuntime.attribute_map,
Expand All @@ -339,6 +340,7 @@ class ModelDeploymentContainerRuntime(ModelDeploymentRuntime):
CONST_ENTRYPOINT: "entrypoint",
CONST_SERVER_PORT: "server_port",
CONST_HEALTH_CHECK_PORT: "health_check_port",
CONST_INFERENCE_SERVER: "inference_server"
}

payload_attribute_map = {
Expand Down Expand Up @@ -532,3 +534,57 @@ def with_health_check_port(
The ModelDeploymentContainerRuntime instance (self).
"""
return self.set_spec(self.CONST_HEALTH_CHECK_PORT, health_check_port)

@property
def inference_server(self) -> str:
"""Returns the inference server.

Returns
-------
str
The inference server.
"""
return self.get_spec(self.CONST_INFERENCE_SERVER, None)

def with_inference_server(self, inference_server: str = "triton") -> "ModelDeploymentRuntime":
"""Sets the inference server. Current supported inference server is "triton".
Note if you are using byoc, you do not need to set the inference server.

Parameters
----------
inference_server: str
Set the inference server.

Returns
-------
ModelDeploymentRuntime
The ModelDeploymentRuntime instance (self).

Example
-------
>>> from ads.model.deployment import ModelDeployment, ModelDeploymentContainerRuntime, ModelDeploymentInfrastructure
>>> import ads
>>> ads.set_auth("resource_principal")
>>> infrastructure = ModelDeploymentInfrastructure()\
... .with_project_id(<project_id>)\
... .with_compartment_id(<comparment_id>)\
... .with_shape_name("VM.Standard.E4.Flex")\
... .with_replica(2)\
... .with_bandwidth_mbps(10)\
... .with_access_log(log_group_id=<deployment_log_group_id>, log_id=<deployment_access_log_id>)\
... .with_predict_log(log_group_id=<deployment_log_group_id>, log_id=<deployment_predict_log_id>)

>>> runtime = ModelDeploymentContainerRuntime()\
... .with_image(<container_image>)\
... .with_server_port(<server_port>)\
... .with_health_check_port(<health_check_port>)\
... .with_model_uri(<model_id>)\
... .with_env({"key":"value", "key2":"value2"})\
... .with_inference_server("triton")
>>> deployment = ModelDeployment()\
... .with_display_name("Triton Example")\
... .with_infrastructure(infrastructure)\
... .with_runtime(runtime)
>>> deployment.deploy()
"""
return self.set_spec(self.CONST_INFERENCE_SERVER, inference_server.lower())
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,62 @@ def initialize_model_deployment_from_spec(self):
"runtime": runtime,
}
)

def initialize_model_deployment_triton_builder(self):
infrastructure = ModelDeploymentInfrastructure()\
.with_compartment_id("fakeid.compartment.oc1..xxx")\
.with_project_id("fakeid.datascienceproject.oc1.iad.xxx")\
.with_shape_name("VM.Standard.E4.Flex")\
.with_replica(2)\
.with_bandwidth_mbps(10)\

runtime = ModelDeploymentContainerRuntime()\
.with_image("fake_image")\
.with_server_port(5000)\
.with_health_check_port(5000)\
.with_model_uri("fake_model_id")\
.with_env({"key":"value", "key2":"value2"})\
.with_inference_server("triton")

deployment = ModelDeployment()\
.with_display_name("triton case")\
.with_infrastructure(infrastructure)\
.with_runtime(runtime)
return deployment

def initialize_model_deployment_triton_yaml(self):
yaml_string = """
kind: deployment
spec:
displayName: triton
infrastructure:
kind: infrastructure
spec:
bandwidthMbps: 10
compartmentId: fake_compartment_id
deploymentType: SINGLE_MODEL
policyType: FIXED_SIZE
replica: 2
shapeConfigDetails:
memoryInGBs: 16.0
ocpus: 1.0
shapeName: VM.Standard.E4.Flex
type: datascienceModelDeployment
runtime:
kind: runtime
spec:
env:
key: value
key2: value2
inference_server: triton
healthCheckPort: 8000
image: fake_image
modelUri: fake_model_id
serverPort: 8000
type: container
"""
deployment_from_yaml = ModelDeployment.from_yaml(yaml_string)
return deployment_from_yaml

def initialize_model_deployment_from_kwargs(self):
infrastructure = (
Expand Down Expand Up @@ -435,11 +491,34 @@ def test_initialize_model_deployment_with_error(self):
},
)


def test_initialize_model_deployment_with_spec_kwargs(self):
model_deployment_kwargs = self.initialize_model_deployment_from_kwargs()
model_deployment_builder = self.initialize_model_deployment()

assert model_deployment_kwargs.to_dict() == model_deployment_builder.to_dict()


def test_initialize_model_deployment_triton_builder(self):
temp_model_deployment = self.initialize_model_deployment_triton_builder()
assert isinstance(
temp_model_deployment.runtime, ModelDeploymentContainerRuntime
)
assert isinstance(
temp_model_deployment.infrastructure, ModelDeploymentInfrastructure
)
assert temp_model_deployment.runtime.inference_server == "triton"

def test_initialize_model_deployment_triton_yaml(self):
temp_model_deployment = self.initialize_model_deployment_triton_yaml()
assert isinstance(
temp_model_deployment.runtime, ModelDeploymentContainerRuntime
)
assert isinstance(
temp_model_deployment.infrastructure, ModelDeploymentInfrastructure
)
assert temp_model_deployment.runtime.inference_server == "triton"


def test_model_deployment_to_dict(self):
model_deployment = self.initialize_model_deployment()
Expand Down