- Notifications
You must be signed in to change notification settings - Fork 59
Async register API support #1083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: aqua_apiserver
Are you sure you want to change the base?
Changes from all commits
53305b8 ee8dc5e 8be1a97 8180950 2d4e990 de82312 f6c8c65 ebbd175 b6cf7ba a244dc5 99871a1 7ca41c5 722595c fd3051e 876f826 16658b5 4813195 8de4ad8 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| #!/usr/bin/env python | ||
| | ||
| # Copyright (c) 2025 Oracle and/or its affiliates. | ||
| # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ | ||
| | ||
| from dataclasses import dataclass | ||
| | ||
| from ads.common.extended_enum import ExtendedEnum | ||
| from ads.common.serializer import DataClassSerializable | ||
| | ||
| | ||
| class TaskStatusEnum(ExtendedEnum): | ||
| MODEL_VALIDATION_SUCCESSFUL = "MODEL_VALIDATION_SUCCESSFUL" | ||
| MODEL_DOWNLOAD_STARTED = "MODEL_DOWNLOAD_STARTED" | ||
| MODEL_DOWNLOAD_SUCCESSFUL = "MODEL_DOWNLOAD_SUCCESSFUL" | ||
| MODEL_UPLOAD_STARTED = "MODEL_UPLOAD_STARTED" | ||
| MODEL_UPLOAD_SUCCESSFUL = "MODEL_UPLOAD_SUCCESSFUL" | ||
| DATASCIENCE_MODEL_CREATED = "DATASCIENCE_MODEL_CREATED" | ||
| MODEL_REGISTRATION_SUCCESSFUL = "MODEL_REGISTRATION_SUCCESSFUL" | ||
| REGISTRATION_FAILED = "REGISTRATION_FAILED" | ||
| MODEL_DOWNLOAD_INPROGRESS = "MODEL_DOWNLOAD_INPROGRESS" | ||
| | ||
| | ||
| @dataclass | ||
| class TaskStatus(DataClassSerializable): | ||
| state: TaskStatusEnum | ||
| message: str | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -2,8 +2,11 @@ | |
| # Copyright (c) 2024, 2025 Oracle and/or its affiliates. | ||
| # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ | ||
| | ||
| import threading | ||
| from logging import getLogger | ||
| from typing import Optional | ||
| from urllib.parse import urlparse | ||
| from uuid import uuid4 | ||
| | ||
| from tornado.web import HTTPError | ||
| | ||
| | @@ -12,17 +15,29 @@ | |
| CustomInferenceContainerTypeFamily, | ||
| ) | ||
| from ads.aqua.common.errors import AquaRuntimeError, AquaValueError | ||
| from ads.aqua.common.task_status import TaskStatus, TaskStatusEnum | ||
| from ads.aqua.common.utils import ( | ||
| get_hf_model_info, | ||
| is_valid_ocid, | ||
| list_hf_models, | ||
| ) | ||
| from ads.aqua.extension.base_handler import AquaAPIhandler | ||
| from ads.aqua.extension.errors import Errors | ||
| from ads.aqua.extension.status_manager import ( | ||
| RegistrationStatus, | ||
| StatusTracker, | ||
| TaskNameEnum, | ||
| ) | ||
| from ads.aqua.model import AquaModelApp | ||
| from ads.aqua.model.entities import AquaModelSummary, HFModelSummary | ||
| from ads.aqua.model.entities import ( | ||
| AquaModel, | ||
| AquaModelSummary, | ||
| HFModelSummary, | ||
| ) | ||
| from ads.aqua.ui import ModelFormat | ||
| | ||
| logger = getLogger(__name__) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in the other places we use - | ||
| | ||
| | ||
| class AquaModelHandler(AquaAPIhandler): | ||
| """Handler for Aqua Model REST APIs.""" | ||
| | @@ -108,6 +123,7 @@ def post(self, *args, **kwargs): # noqa: ARG002 | |
| HTTPError | ||
| Raises HTTPError if inputs are missing or are invalid | ||
| """ | ||
| task_id = str(uuid4()) | ||
| try: | ||
| input_data = self.get_json_body() | ||
| except Exception as ex: | ||
| | @@ -145,27 +161,64 @@ def post(self, *args, **kwargs): # noqa: ARG002 | |
| str(input_data.get("ignore_model_artifact_check", "false")).lower() | ||
| == "true" | ||
| ) | ||
| async_mode = input_data.get("async_mode", False) | ||
| | ||
| return self.finish( | ||
| AquaModelApp().register( | ||
| model=model, | ||
| os_path=os_path, | ||
| download_from_hf=download_from_hf, | ||
| local_dir=local_dir, | ||
| cleanup_model_cache=cleanup_model_cache, | ||
| inference_container=inference_container, | ||
| finetuning_container=finetuning_container, | ||
| compartment_id=compartment_id, | ||
| project_id=project_id, | ||
| model_file=model_file, | ||
| inference_container_uri=inference_container_uri, | ||
| allow_patterns=allow_patterns, | ||
| ignore_patterns=ignore_patterns, | ||
| freeform_tags=freeform_tags, | ||
| defined_tags=defined_tags, | ||
| ignore_model_artifact_check=ignore_model_artifact_check, | ||
| def register_model(callback=None) -> AquaModel: | ||
| """Wrapper method to help initialize callback in case of async mode""" | ||
| try: | ||
| registered_model = AquaModelApp().register( | ||
| model=model, | ||
| os_path=os_path, | ||
| download_from_hf=download_from_hf, | ||
| local_dir=local_dir, | ||
| cleanup_model_cache=cleanup_model_cache, | ||
| inference_container=inference_container, | ||
| finetuning_container=finetuning_container, | ||
| compartment_id=compartment_id, | ||
| project_id=project_id, | ||
| model_file=model_file, | ||
| inference_container_uri=inference_container_uri, | ||
| allow_patterns=allow_patterns, | ||
| ignore_patterns=ignore_patterns, | ||
| freeform_tags=freeform_tags, | ||
| defined_tags=defined_tags, | ||
| ignore_model_artifact_check=ignore_model_artifact_check, | ||
| callback=callback, | ||
| ) | ||
| except Exception as e: | ||
| if async_mode: | ||
| StatusTracker.add_status( | ||
| RegistrationStatus( | ||
| task_id=task_id, | ||
| task_status=TaskStatus( | ||
| state=TaskStatusEnum.REGISTRATION_FAILED, message=str(e) | ||
| ), | ||
| ) | ||
| ) | ||
| raise | ||
| else: | ||
| raise | ||
| return registered_model | ||
| | ||
| if async_mode: | ||
| t = threading.Thread( | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be better to use a ThreadPool instead to control the number of the potential threads? Something like: Maybe we can introduce some global ThreadPoolExecutor for this? I think the decorator could also take care of the In this case we just mark any desired function with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mrDzurb Does threadpool allow for daemon threads? I need daemon threads here. | ||
| target=register_model, | ||
| args=( | ||
| StatusTracker.prepare_status_callback( | ||
| TaskNameEnum.REGISTRATION_STATUS, task_id=task_id | ||
| ), | ||
| ), | ||
| daemon=True, | ||
| ) | ||
| ) | ||
| t.start() | ||
| output = { | ||
| "state": "ACCEPTED", | ||
| "task_id": task_id, | ||
| "progress_url": f"ws://host:port/aqua/ws/{task_id}", | ||
| } | ||
| else: | ||
| output = register_model() | ||
| return self.finish(output) | ||
| | ||
| @handle_exceptions | ||
| def put(self, id): | ||
| | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -23,6 +23,7 @@ class RequestResponseType(ExtendedEnum): | |
| AdsVersion = "AdsVersion" | ||
| CompatibilityCheck = "CompatibilityCheck" | ||
| Error = "Error" | ||
| RegisterModelStatus = "RegisterModelStatus" | ||
| | ||
| | ||
| @dataclass | ||
| | @@ -141,3 +142,16 @@ class AquaWsError(DataClassSerializable): | |
| class ErrorResponse(BaseResponse): | ||
| data: AquaWsError | ||
| kind = RequestResponseType.Error | ||
| | ||
| | ||
| @dataclass | ||
| class RequestStatus(DataClassSerializable): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: For the new code, pydantic would be better to use? | ||
| status: str | ||
| message: str | ||
| | ||
| | ||
| @dataclass | ||
| class ModelRegisterRequest(DataClassSerializable): | ||
| status: str | ||
| task_id: str | ||
| message: str = "" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not pydantic?