Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions ads/aqua/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import oci
from oci.exceptions import ClientError, ServiceError

from ads.aqua.exception import AquaClientError, AquaServiceError
from ads.common.auth import default_signer
from ads.aqua.exception import AquaServiceError, AquaClientError
from oci.exceptions import ServiceError, ClientError
from ads.common import oci_client as oc
from ads.common.utils import extract_region


Expand All @@ -15,7 +15,9 @@ class AquaApp:

def __init__(self) -> None:
self._auth = default_signer()
self.client = oci.data_science.DataScienceClient(**self._auth)
self.ds_client = oc.OCIClientFactory(**self._auth).data_science
self.logging_client = oc.OCIClientFactory(**self._auth).logging_management
self.identity_client = oc.OCIClientFactory(**self._auth).identity
self.region = extract_region(self._auth)

def list_resource(
Expand Down
80 changes: 49 additions & 31 deletions ads/aqua/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,28 @@
ModelDeploymentContainerRuntime,
ModelDeploymentMode,
)
from ads.common.utils import get_console_link
from ads.common.serializer import DataClassSerializable
from ads.aqua.exception import AquaClientError, AquaServiceError
from ads.config import COMPARTMENT_OCID


# todo: move this to constants or have separate functions
AQUA_SERVICE_MODEL = "aqua_service_model"
CONSOLE_LINK_URL = (
"https://cloud.oracle.com/data-science/model-deployments/{}?region={}"
)

logger = logging.getLogger(__name__)


@dataclass
class ShapeInfo:
instance_shape: str
ocpus: float
memory_in_gbs: float


@dataclass(repr=False)
class AquaDeployment(DataClassSerializable):
"""Represents an Aqua Model Deployment"""

id: str
display_name: str
aqua_service_model: str
Expand All @@ -37,10 +44,8 @@ class AquaDeployment(DataClassSerializable):
created_on: str
created_by: str
endpoint: str
instance_shape: str
ocpus: float
memory_in_gbs: float
console_link: str
shape_info: ShapeInfo


class AquaDeploymentApp(AquaApp):
Expand Down Expand Up @@ -202,7 +207,9 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)

model_deployments = self.list_resource(
self.client.list_model_deployments, compartment_id=compartment_id, **kwargs
self.ds_client.list_model_deployments,
compartment_id=compartment_id,
**kwargs,
)

results = []
Expand Down Expand Up @@ -241,15 +248,16 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeployment":
# add error handler
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be cleaned up?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keeping for now, we have an action item to clean all this up once exception handler PR is finalized.

# if not kwargs.get("model_deployment_id", None):
# raise AquaClientError("Aqua model deployment ocid must be provided to fetch the deployment.")

# add error handler
model_deployment = self.client.get_model_deployment(
model_deployment = self.ds_client.get_model_deployment(
model_deployment_id=model_deployment_id, **kwargs
).data

aqua_service_model=(
model_deployment.freeform_tags.get(AQUA_SERVICE_MODEL, None)
if model_deployment.freeform_tags else None

aqua_service_model = (
model_deployment.freeform_tags.get(AQUA_SERVICE_MODEL, None)
if model_deployment.freeform_tags
else None
)

# add error handler
Expand All @@ -259,9 +267,11 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeployment":
return AquaDeploymentApp.from_oci_model_deployment(
model_deployment, self.region
)

@classmethod
def from_oci_model_deployment(cls, oci_model_deployment, region) -> "AquaDeployment":
def from_oci_model_deployment(
cls, oci_model_deployment, region
) -> "AquaDeployment":
"""Converts oci model deployment response to AquaDeployment instance.

Parameters
Expand All @@ -277,31 +287,39 @@ def from_oci_model_deployment(cls, oci_model_deployment, region) -> "AquaDeploym
The instance of the Aqua model deployment.
"""
instance_configuration = (
oci_model_deployment
.model_deployment_configuration_details
.model_configuration_details
.instance_configuration
oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
)
instance_shape_config_details = (
instance_configuration.model_deployment_instance_shape_config_details
)
shape_info = ShapeInfo(
instance_shape=instance_configuration.instance_shape_name,
ocpus=(
instance_shape_config_details.ocpus
if instance_shape_config_details
else None
),
memory_in_gbs=(
instance_shape_config_details.memory_in_gbs
if instance_shape_config_details
else None
),
)
return AquaDeployment(
id=oci_model_deployment.id,
display_name=oci_model_deployment.display_name,
aqua_service_model=oci_model_deployment.freeform_tags.get(AQUA_SERVICE_MODEL),
aqua_service_model=oci_model_deployment.freeform_tags.get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be boolean? aqua_service_model=oci_model_deployment.freeform_tags.get(AQUA_SERVICE_MODEL) is not None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're currently treating aqua_service_model as a string field (something like key=aqua_service_model val=llama2), which is also passed along when deployment is created. This is due to an absence of a tag field in AquaModelSummary. Would it be better to treat this as boolean and have something like tags: Dict[str, str] in the dataclass?

AQUA_SERVICE_MODEL
),
shape_info=shape_info,
state=oci_model_deployment.lifecycle_state,
description=oci_model_deployment.description,
created_on=str(oci_model_deployment.time_created),
created_by=oci_model_deployment.created_by,
endpoint=oci_model_deployment.model_deployment_url,
instance_shape=instance_configuration.instance_shape_name,
ocpus=(
instance_shape_config_details.ocpus
if instance_shape_config_details else None
console_link=get_console_link(
resource="model-deployments",
ocid=oci_model_deployment.id,
region=region,
),
memory_in_gbs=(
instance_shape_config_details.memory_in_gbs
if instance_shape_config_details else None
),
console_link=CONSOLE_LINK_URL.format(oci_model_deployment.id, region)
)
)
2 changes: 2 additions & 0 deletions ads/aqua/extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from ads.aqua.extension.playground_handler import (
__handlers__ as __playground_handlers__,
)
from ads.aqua.extension.ui_handler import __handlers__ as __ui_handlers__

__handlers__ = (
__playground_handlers__
+ __job_handlers__
+ __model_handlers__
+ __common_handlers__
+ __deployment_handlers__
+ __ui_handlers__
)


Expand Down
4 changes: 2 additions & 2 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class AquaDeploymentHandler(AquaAPIhandler):

def get(self, id=""):
"""Handle GET request."""
# todo: handle list, read and logs for model deployment
if not id:
return self.list()
return self.read(id)
Expand Down Expand Up @@ -112,7 +111,8 @@ def post(self, *args, **kwargs):
AquaDeploymentApp().create(
compartment_id=compartment_id,
project_id=project_id,
# todo: replace model_id with aqua_model.id
# todo: replace model_id with aqua_model.id, use current model for deploy
# but replace with model by reference is implemented
model_id=model_id,
aqua_service_model=aqua_service_model,
display_name=display_name,
Expand Down
75 changes: 75 additions & 0 deletions ads/aqua/extension/ui_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from tornado.web import HTTPError
from urllib.parse import urlparse
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.ui import AquaUIApp


class AquaUIHandler(AquaAPIhandler):
"""
Handler for Aqua UI REST APIs.

Methods
-------
get(self, id="")
Routes the request to fetch log groups, log ids details or compartments
list_log_groups(self, id: str)
Reads the AQUA deployment information.
list_logs(self, log_group_id: str, **kwargs)
Lists the specified log group's log objects.
list_compartments(self, **kwargs)
Lists the compartments in a compartment specified by ODSC_MODEL_COMPARTMENT_OCID env variable.

Raises
------
HTTPError: For various failure scenarios such as invalid input format, missing data, etc.
"""

def get(self, id=""):
"""Handle GET request."""
url_parse = urlparse(self.request.path)
paths = url_parse.path.strip("/")
if paths.startswith("aqua/logging"):
if not id:
return self.list_log_groups()
return self.list_logs(id)
elif paths.startswith("aqua/compartments"):
return self.list_compartments()
else:
raise HTTPError(400, f"The request {self.request.path} is invalid.")

def list_log_groups(self, **kwargs):
"""Lists all log groups for the specified compartment or tenancy."""
compartment_id = self.get_argument("compartment_id")
try:
return self.finish(
AquaUIApp().list_log_groups(compartment_id=compartment_id, **kwargs)
)
except Exception as ex:
raise HTTPError(500, str(ex))

def list_logs(self, log_group_id: str, **kwargs):
"""Lists the specified log group's log objects."""
try:
return self.finish(
AquaUIApp().list_logs(log_group_id=log_group_id, **kwargs)
)
except Exception as ex:
raise HTTPError(500, str(ex))

def list_compartments(self, **kwargs):
"""Lists the compartments in a compartment specified by ODSC_MODEL_COMPARTMENT_OCID env variable."""
try:
return self.finish(AquaUIApp().list_compartments(**kwargs))
except Exception as ex:
raise HTTPError(500, str(ex))


__handlers__ = [
("logging/?([^/]*)", AquaUIHandler),
("compartments/?([^/]*)", AquaUIHandler),
]
24 changes: 11 additions & 13 deletions ads/aqua/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from typing import List

import fsspec
import oci
from oci.exceptions import ClientError, ServiceError
from oci.data_science.models import ModelSummary

from ads.aqua import logger
from ads.aqua.base import AquaApp
Expand Down Expand Up @@ -90,11 +89,10 @@ def get(self, model_id) -> "AquaModel":
The instance of the Aqua model.
"""
# add error handler
oci_model = self.client.get_model(model_id).data

oci_model = self.ds_client.get_model(model_id).data
# add error handler
# if not self._if_show(oci_model):
# raise AquaClientError(f"Target model {oci_model.id} is not Aqua model.")
if not self._if_show(oci_model):
raise AquaClientError(f"Target model {oci_model.id} is not Aqua model.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably will not be able to use the AQUA name. I would suggest to have a custom error, something like NotAquaCompatibleError.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we'll update this once exception handler PR is finalized.


custom_metadata_list = oci_model.custom_metadata_list
artifact_path = self._get_artifact_path(custom_metadata_list)
Expand All @@ -113,6 +111,8 @@ def get(self, model_id) -> "AquaModel":
if oci_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG.value)
else False,
model_card=str(self._read_file(f"{artifact_path}/{README}")),
# todo: add proper tags
tags={},
)

def list(
Expand All @@ -136,7 +136,9 @@ def list(
compartment_id = compartment_id or COMPARTMENT_OCID

service_model = self.list_resource(
self.client.list_models, compartment_id=ODSC_MODEL_COMPARTMENT_OCID
self.ds_client.list_models,
compartment_id=ODSC_MODEL_COMPARTMENT_OCID,
project_id=project_id,
)
fine_tuned_models = self._rqs(compartment_id)
models = fine_tuned_models + service_model
Expand All @@ -153,11 +155,7 @@ def process_model(model):
tags.update(model.defined_tags)
tags.update(model.freeform_tags)

model_id = (
model.id
if isinstance(model, oci.data_science.models.ModelSummary)
else model.identifier
)
model_id = model.id if isinstance(model, ModelSummary) else model.identifier
return AquaModelSummary(
compartment_id=model.compartment_id,
icon=icon,
Expand Down Expand Up @@ -215,7 +213,7 @@ def _get_artifact_path(self, custom_metadata_list: List) -> str:

def _read_file(self, file_path: str) -> str:
try:
with fsspec.open(file_path, "rb", **self._auth) as f:
with fsspec.open(file_path, "r", **self._auth) as f:
return f.read()
except Exception as e:
logger.error(f"Failed to retreive model icon. {e}")
Expand Down
Loading