Skip to content
27 changes: 27 additions & 0 deletions ads/aqua/common/task_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python

# Copyright (c) 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from dataclasses import dataclass

from ads.common.extended_enum import ExtendedEnum
from ads.common.serializer import DataClassSerializable


class TaskStatusEnum(ExtendedEnum):
MODEL_VALIDATION_SUCCESSFUL = "MODEL_VALIDATION_SUCCESSFUL"
MODEL_DOWNLOAD_STARTED = "MODEL_DOWNLOAD_STARTED"
MODEL_DOWNLOAD_SUCCESSFUL = "MODEL_DOWNLOAD_SUCCESSFUL"
MODEL_UPLOAD_STARTED = "MODEL_UPLOAD_STARTED"
MODEL_UPLOAD_SUCCESSFUL = "MODEL_UPLOAD_SUCCESSFUL"
DATASCIENCE_MODEL_CREATED = "DATASCIENCE_MODEL_CREATED"
MODEL_REGISTRATION_SUCCESSFUL = "MODEL_REGISTRATION_SUCCESSFUL"
REGISTRATION_FAILED = "REGISTRATION_FAILED"
MODEL_DOWNLOAD_INPROGRESS = "MODEL_DOWNLOAD_INPROGRESS"


@dataclass
class TaskStatus(DataClassSerializable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not pydantic?

state: TaskStatusEnum
message: str
4 changes: 4 additions & 0 deletions ads/aqua/extension/aqua_ws_msg_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List

from tornado.web import HTTPError
from tornado.websocket import WebSocketHandler

from ads.aqua import logger
from ads.aqua.common.decorator import handle_exceptions
Expand Down Expand Up @@ -53,6 +54,9 @@ def process(self) -> BaseResponse:
"""
pass

def set_ws_connection(self, con: WebSocketHandler):
self.ws_connection = con

def write_error(self, status_code, **kwargs):
"""AquaWSMSGhandler errors are JSON, not human pages."""
reason = kwargs.get("reason")
Expand Down
93 changes: 73 additions & 20 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import threading
from logging import getLogger
from typing import Optional
from urllib.parse import urlparse
from uuid import uuid4

from tornado.web import HTTPError

Expand All @@ -12,17 +15,29 @@
CustomInferenceContainerTypeFamily,
)
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.task_status import TaskStatus, TaskStatusEnum
from ads.aqua.common.utils import (
get_hf_model_info,
is_valid_ocid,
list_hf_models,
)
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.extension.status_manager import (
RegistrationStatus,
StatusTracker,
TaskNameEnum,
)
from ads.aqua.model import AquaModelApp
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
from ads.aqua.model.entities import (
AquaModel,
AquaModelSummary,
HFModelSummary,
)
from ads.aqua.ui import ModelFormat

logger = getLogger(__name__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the other places we use - from ads.aqua import logger.



class AquaModelHandler(AquaAPIhandler):
"""Handler for Aqua Model REST APIs."""
Expand Down Expand Up @@ -108,6 +123,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
HTTPError
Raises HTTPError if inputs are missing or are invalid
"""
task_id = str(uuid4())
try:
input_data = self.get_json_body()
except Exception as ex:
Expand Down Expand Up @@ -145,27 +161,64 @@ def post(self, *args, **kwargs): # noqa: ARG002
str(input_data.get("ignore_model_artifact_check", "false")).lower()
== "true"
)
async_mode = input_data.get("async_mode", False)

return self.finish(
AquaModelApp().register(
model=model,
os_path=os_path,
download_from_hf=download_from_hf,
local_dir=local_dir,
cleanup_model_cache=cleanup_model_cache,
inference_container=inference_container,
finetuning_container=finetuning_container,
compartment_id=compartment_id,
project_id=project_id,
model_file=model_file,
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
ignore_model_artifact_check=ignore_model_artifact_check,
def register_model(callback=None) -> AquaModel:
"""Wrapper method to help initialize callback in case of async mode"""
try:
registered_model = AquaModelApp().register(
model=model,
os_path=os_path,
download_from_hf=download_from_hf,
local_dir=local_dir,
cleanup_model_cache=cleanup_model_cache,
inference_container=inference_container,
finetuning_container=finetuning_container,
compartment_id=compartment_id,
project_id=project_id,
model_file=model_file,
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
ignore_model_artifact_check=ignore_model_artifact_check,
callback=callback,
)
except Exception as e:
if async_mode:
StatusTracker.add_status(
RegistrationStatus(
task_id=task_id,
task_status=TaskStatus(
state=TaskStatusEnum.REGISTRATION_FAILED, message=str(e)
),
)
)
raise
else:
raise
return registered_model

if async_mode:
t = threading.Thread(
Copy link
Member

@mrDzurb mrDzurb Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to use a ThreadPool instead to control the number of the potential threads?

Something like:

THREAD_POOL_EXECUTOR = ThreadPoolExecutor(max_workers=10) if async_mode: # Submit the registration task to a thread pool. THREAD_POOL_EXECUTOR.submit(self._register_model, task_id, input_data, async_mode) output = { "state": "ACCEPTED", "task_id": task_id, "progress_url": f"ws://host:port/aqua/ws/{task_id}", } else: output = self._register_model(task_id, input_data, async_mode) 

Maybe we can introduce some global ThreadPoolExecutor for this?
I'm wondering if we can use a decorator for this, something similar that we do for in @threaded decorator.

THREAD_POOL_EXECUTOR = ThreadPoolExecutor(max_workers=10) def run_in_thread_if_async(func): """Decorator to run the function in a thread if async_mode is True.""" @wraps(func) def wrapper(self, async_mode, *args, **kwargs): if async_mode: task_id = str(uuid4()) future = THREAD_POOL_EXECUTOR.submit(func, self, task_id, *args, **kwargs) return { "state": "ACCEPTED", "task_id": task_id, "progress_url": f"ws://host:port/aqua/ws/{task_id}", } else: return func(self, None, *args, **kwargs) return wrapper 

I think the decorator could also take care of the StatusTracker.

In this case we just mark any desired function with the @run_in_thread_if_async decorator which will do all the related work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrDzurb Does threadpool allow for daemon threads? I need daemon threads here.

target=register_model,
args=(
StatusTracker.prepare_status_callback(
TaskNameEnum.REGISTRATION_STATUS, task_id=task_id
),
),
daemon=True,
)
)
t.start()
output = {
"state": "ACCEPTED",
"task_id": task_id,
"progress_url": f"ws://host:port/aqua/ws/{task_id}",
}
else:
output = register_model()
return self.finish(output)

@handle_exceptions
def put(self, id):
Expand Down
14 changes: 14 additions & 0 deletions ads/aqua/extension/models/ws_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class RequestResponseType(ExtendedEnum):
AdsVersion = "AdsVersion"
CompatibilityCheck = "CompatibilityCheck"
Error = "Error"
RegisterModelStatus = "RegisterModelStatus"


@dataclass
Expand Down Expand Up @@ -141,3 +142,16 @@ class AquaWsError(DataClassSerializable):
class ErrorResponse(BaseResponse):
data: AquaWsError
kind = RequestResponseType.Error


@dataclass
class RequestStatus(DataClassSerializable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: For the new code, pydantic would be better to use?

status: str
message: str


@dataclass
class ModelRegisterRequest(DataClassSerializable):
status: str
task_id: str
message: str = ""
45 changes: 43 additions & 2 deletions ads/aqua/extension/models_ws_msg_handler.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
#!/usr/bin/env python

# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
from logging import getLogger
from typing import List, Union

from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
from ads.aqua.extension.models.ws_models import (
ListModelsResponse,
ModelDetailsResponse,
ModelRegisterRequest,
RequestResponseType,
)
from ads.aqua.extension.status_manager import (
RegistrationSubscriber,
StatusTracker,
TaskNameEnum,
)
from ads.aqua.model import AquaModelApp

logger = getLogger(__name__)

REGISTRATION_STATUS = "registration_status"


class AquaModelWSMsgHandler(AquaWSMsgHandler):
status_subscriber = {}
register_status = {} # Not threadsafe

def __init__(self, message: Union[str, bytes]):
super().__init__(message)

@staticmethod
def get_message_types() -> List[RequestResponseType]:
return [RequestResponseType.ListModels, RequestResponseType.ModelDetails]
return [
RequestResponseType.ListModels,
RequestResponseType.ModelDetails,
RequestResponseType.RegisterModelStatus,
]

@handle_exceptions
def process(self) -> Union[ListModelsResponse, ModelDetailsResponse]:
Expand All @@ -47,3 +65,26 @@ def process(self) -> Union[ListModelsResponse, ModelDetailsResponse]:
kind=RequestResponseType.ModelDetails,
data=response,
)
elif request.get("kind") == "RegisterModelStatus":
task_id = request.get("task_id")
StatusTracker.add_subscriber(
subscriber=RegistrationSubscriber(
task_id=task_id, subscriber=self.ws_connection
),
notify_latest_status=False,
)

latest_status = StatusTracker.get_latest_status(
TaskNameEnum.REGISTRATION_STATUS, task_id=task_id
)
logger.info(latest_status)
if latest_status:
return ModelRegisterRequest(
status=latest_status.state,
message=latest_status.message,
task_id=task_id,
)
else:
return ModelRegisterRequest(
status="SUBSCRIBED", task_id=task_id, message=""
)
Loading