Skip to content

Commit b64cfad

Browse files
authored
Aqua/improve list model api (#565)
1 parent 253582e commit b64cfad

File tree

2 files changed

+49
-23
lines changed

2 files changed

+49
-23
lines changed

ads/aqua/extension/model_handler.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,8 @@ def read(self, model_id):
3232
@exception_handler
3333
def list(self):
3434
"""List Aqua models."""
35-
# If default is not specified,
36-
# jupyterlab will raise 400 error when argument is not provided by the HTTP request.
37-
compartment_id = self.get_argument(
38-
"compartment_id", default=os.environ.get(AQUA_MODEL_COMPARTMENT)
39-
)
40-
# project_id is optional.
35+
compartment_id = self.get_argument("compartment_id")
36+
# project_id is no needed.
4137
project_id = self.get_argument("project_id", default=None)
4238
return self.finish(AquaModelApp().list(compartment_id, project_id))
4339

ads/aqua/model.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,11 @@ def get(self, model_id) -> "AquaModel":
109109
task=oci_model.freeform_tags.get(Tags.TASK.value, UNKNOWN),
110110
license=oci_model.freeform_tags.get(Tags.LICENSE.value, UNKNOWN),
111111
organization=oci_model.freeform_tags.get(Tags.ORGANIZATION.value, UNKNOWN),
112-
is_fine_tuned_model=True
113-
if oci_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG.value)
114-
else False,
112+
is_fine_tuned_model=(
113+
True
114+
if oci_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG.value)
115+
else False
116+
),
115117
model_card=str(self._read_file(f"{artifact_path}/{README}")),
116118
)
117119

@@ -133,16 +135,28 @@ def list(
133135
List[AquaModelSummary]:
134136
The list of the `ads.aqua.model.AquaModelSummary`.
135137
"""
136-
compartment_id = compartment_id or COMPARTMENT_OCID
137-
138-
service_model = self.list_resource(
139-
self.client.list_models, compartment_id=ODSC_MODEL_COMPARTMENT_OCID
140-
)
141-
fine_tuned_models = self._rqs(compartment_id)
142-
models = fine_tuned_models + service_model
138+
models = []
139+
if compartment_id:
140+
logger.info(f"Fetching custom models from compartment_id={compartment_id}.")
141+
models = self._rqs(compartment_id)
142+
else:
143+
# TODO: remove project_id after policy for service-model compartment has been set.
144+
project_id = os.environ.get("TEST_PROJECT_ID")
145+
logger.info(
146+
f"Fetching service model from compartment_id={ODSC_MODEL_COMPARTMENT_OCID}, project_id={project_id}"
147+
)
148+
models = self.list_resource(
149+
self.client.list_models,
150+
compartment_id=ODSC_MODEL_COMPARTMENT_OCID,
151+
project_id=project_id,
152+
)
143153

144154
if not models:
145-
logger.error(f"No model found in compartment_id={compartment_id}.")
155+
logger.error(
156+
f"No model found in compartment_id={compartment_id or ODSC_MODEL_COMPARTMENT_OCID}."
157+
)
158+
159+
logger.info(f"Successuly fetch {len(models)} models.")
146160

147161
aqua_models = []
148162
# TODO: build index.json for service model as caching if needed.
@@ -168,17 +182,34 @@ def process_model(model):
168182
project_id=project_id or UNKNOWN,
169183
task=model.freeform_tags.get(Tags.TASK.value, UNKNOWN),
170184
time_created=model.time_created,
171-
is_fine_tuned_model=True
172-
if model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG.value)
173-
else False,
185+
is_fine_tuned_model=(
186+
True
187+
if model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG.value)
188+
else False
189+
),
174190
tags=tags,
175191
)
176192

177193
for model in models:
178-
aqua_models.append(process_model(model))
194+
# TODO: remove the check after policy issue resolved
195+
if self._temp_check(model, compartment_id):
196+
aqua_models.append(process_model(model))
179197

180198
return aqua_models
181199

200+
def _temp_check(self, model, compartment_id=None):
201+
# TODO: will remove it later
202+
TARGET_TAGS = model.freeform_tags.keys()
203+
if not Tags.AQUA_TAG.value in TARGET_TAGS:
204+
return False
205+
206+
if compartment_id:
207+
return (
208+
True if Tags.AQUA_FINE_TUNED_MODEL_TAG.value in TARGET_TAGS else False
209+
)
210+
211+
return True if Tags.AQUA_SERVICE_MODEL_TAG.value in TARGET_TAGS else False
212+
182213
def _if_show(self, model: "ModelSummary") -> bool:
183214
"""Determine if the given model should be return by `list`."""
184215
TARGET_TAGS = model.freeform_tags.keys()
@@ -235,10 +266,9 @@ def _rqs(self, compartment_id):
235266
"""Use RQS to fetch models in the user tenancy."""
236267
condition_tags = f"&& (freeformTags.key = '{Tags.AQUA_SERVICE_MODEL_TAG.value}' || freeformTags.key = '{Tags.AQUA_FINE_TUNED_MODEL_TAG.value}')"
237268
condition_lifecycle = "&& lifecycleState = 'ACTIVE'"
238-
# not support filtered by project_id
239269
query = f"query datasciencemodel resources where (compartmentId = '{compartment_id}' {condition_lifecycle} {condition_tags})"
240270
logger.info(query)
241-
271+
logger.info(f"tenant_id={TENANCY_OCID}")
242272
try:
243273
return OCIResource.search(
244274
query,

0 commit comments

Comments
 (0)