Skip to content

Commit 4d10e47

Browse files
authored
Fixed model deployment failed return value (#421)
2 parents c8ed767 + fabd11e commit 4d10e47

File tree

6 files changed

+29
-104
lines changed

6 files changed

+29
-104
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class LogNotConfiguredError(Exception): # pragma: no cover
8585
pass
8686

8787

88-
class ModelDeploymentFailedError(Exception): # pragma: no cover
88+
class ModelDeploymentPredictError(Exception): # pragma: no cover
8989
pass
9090

9191

@@ -607,11 +607,6 @@ def deploy(
607607
-------
608608
ModelDeployment
609609
The instance of ModelDeployment.
610-
611-
Raises
612-
------
613-
ModelDeploymentFailedError
614-
If model deployment fails to deploy
615610
"""
616611
create_model_deployment_details = (
617612
self._build_model_deployment_details()
@@ -626,11 +621,6 @@ def deploy(
626621
poll_interval=poll_interval,
627622
)
628623

629-
if response.lifecycle_state == State.FAILED.name:
630-
raise ModelDeploymentFailedError(
631-
f"Model deployment {response.id} failed to deploy: {response.lifecycle_details}"
632-
)
633-
634624
return self._update_from_oci_model(response)
635625

636626
def delete(
@@ -662,6 +652,7 @@ def delete(
662652
max_wait_time=max_wait_time,
663653
poll_interval=poll_interval,
664654
)
655+
665656
return self._update_from_oci_model(response)
666657

667658
def update(
@@ -890,6 +881,12 @@ def predict(
890881
Prediction results.
891882
892883
"""
884+
current_state = self.sync().lifecycle_state
885+
if current_state != State.ACTIVE.name:
886+
raise ModelDeploymentPredictError(
887+
"This model deployment is not in active state, you will not be able to use predict end point. "
888+
f"Current model deployment state: {current_state} "
889+
)
893890
endpoint = f"{self.url}/predict"
894891
signer = authutil.default_signer()["signer"]
895892
header = {
@@ -953,7 +950,7 @@ def predict(
953950
except oci.exceptions.ServiceError as ex:
954951
# When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend.
955952
if ex.status == 429:
956-
bandwidth_mbps = self.infrastructure.bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
953+
bandwidth_mbps = self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS
957954
utils.get_logger().warning(
958955
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
959956
"To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024."

ads/model/generic_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1574,7 +1574,10 @@ def from_model_deployment(
15741574

15751575
current_state = model_deployment.state.name.upper()
15761576
if current_state != ModelDeploymentState.ACTIVE.name:
1577-
raise NotActiveDeploymentError(current_state)
1577+
logger.warning(
1578+
"This model deployment is not in active state, you will not be able to use predict end point. "
1579+
f"Current model deployment state: `{current_state}`"
1580+
)
15781581

15791582
model = cls.from_model_catalog(
15801583
model_id=model_deployment.properties.model_id,

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,8 @@ def activate(
209209
)
210210
except Exception as e:
211211
logger.error(
212-
f"Error while trying to activate model deployment: {self.id}"
212+
"Error while trying to activate model deployment: " + str(e)
213213
)
214-
raise e
215214

216215
return self.sync()
217216
else:
@@ -261,9 +260,8 @@ def create(
261260
)
262261
except Exception as e:
263262
logger.error(
264-
f"Error while trying to create model deployment: {self.id}"
263+
"Error while trying to create model deployment: " + str(e)
265264
)
266-
raise e
267265

268266
return self.sync()
269267

@@ -325,9 +323,8 @@ def deactivate(
325323
)
326324
except Exception as e:
327325
logger.error(
328-
f"Error while trying to deactivate model deployment: {self.id}"
326+
"Error while trying to deactivate model deployment: " + str(e)
329327
)
330-
raise e
331328

332329
return self.sync()
333330
else:
@@ -396,9 +393,8 @@ def delete(
396393
)
397394
except Exception as e:
398395
logger.error(
399-
f"Error while trying to delete model deployment: {self.id}"
396+
"Error while trying to delete model deployment: " + str(e)
400397
)
401-
raise e
402398

403399
return self.sync()
404400

@@ -452,8 +448,9 @@ def update(
452448
)
453449
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
454450
except Exception as e:
455-
logger.error(f"Error while trying to update model deployment: {self.id}")
456-
raise e
451+
logger.error(
452+
"Error while trying to update model deployment: " + str(e)
453+
)
457454

458455
return self.sync()
459456

tests/unitary/default_setup/model_deployment/test_model_deployment.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ class ModelDeploymentTestCase(unittest.TestCase):
2828
)
2929

3030
@patch("requests.post")
31-
def test_predict(self, mock_post):
31+
@patch("ads.model.deployment.model_deployment.ModelDeployment.sync")
32+
def test_predict(self, mock_sync, mock_post):
3233
"""Ensures predict model passes with valid input parameters."""
34+
mock_sync.return_value = Mock(lifecycle_state="ACTIVE")
3335
mock_post.return_value = Mock(
3436
status_code=200, json=lambda: {"result": "result"}
3537
)
@@ -50,8 +52,10 @@ def test_predict(self, mock_post):
5052
self.test_model_deployment.predict(data=np.array([1, 2, 3]))
5153

5254
@patch("requests.post")
53-
def test_predict_with_bytes(self, mock_post):
55+
@patch("ads.model.deployment.model_deployment.ModelDeployment.sync")
56+
def test_predict_with_bytes(self, mock_sync, mock_post):
5457
"""Ensures predict model passes with bytes input."""
58+
mock_sync.return_value = Mock(lifecycle_state="ACTIVE")
5559
byte_data = b"[[1,2,3,4]]"
5660
with patch.object(authutil, "default_signer") as mock_auth:
5761
auth = MagicMock()
@@ -66,8 +70,10 @@ def test_predict_with_bytes(self, mock_post):
6670
)
6771

6872
@patch("requests.post")
69-
def test_predict_with_auto_serialize_data(self, mock_post):
73+
@patch("ads.model.deployment.model_deployment.ModelDeployment.sync")
74+
def test_predict_with_auto_serialize_data(self, mock_sync, mock_post):
7075
"""Ensures predict model passes with valid input parameters."""
76+
mock_sync.return_value = Mock(lifecycle_state="ACTIVE")
7177
mock_post.return_value = Mock(
7278
status_code=200, json=lambda: {"result": "result"}
7379
)

tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from ads.model.deployment.model_deployment import (
2222
ModelDeployment,
2323
ModelDeploymentLogType,
24-
ModelDeploymentFailedError,
2524
)
2625
from ads.model.deployment.model_deployment_infrastructure import (
2726
ModelDeploymentInfrastructure,
@@ -1148,44 +1147,6 @@ def test_deploy(
11481147
mock_create_model_deployment.assert_called_with(create_model_deployment_details)
11491148
mock_sync.assert_called()
11501149

1151-
@patch.object(OCIDataScienceMixin, "sync")
1152-
@patch.object(
1153-
oci.data_science.DataScienceClient,
1154-
"create_model_deployment",
1155-
)
1156-
@patch.object(DataScienceModel, "create")
1157-
def test_deploy_failed(
1158-
self, mock_create, mock_create_model_deployment, mock_sync
1159-
):
1160-
dsc_model = MagicMock()
1161-
dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx"
1162-
mock_create.return_value = dsc_model
1163-
response = oci.response.Response(
1164-
status=MagicMock(),
1165-
headers=MagicMock(),
1166-
request=MagicMock(),
1167-
data=oci.data_science.models.ModelDeployment(
1168-
id="test_model_deployment_id",
1169-
lifecycle_state="FAILED",
1170-
lifecycle_details="The specified log object is not found or user is not authorized.",
1171-
),
1172-
)
1173-
mock_sync.return_value = response.data
1174-
model_deployment = self.initialize_model_deployment()
1175-
create_model_deployment_details = (
1176-
model_deployment._build_model_deployment_details()
1177-
)
1178-
with pytest.raises(
1179-
ModelDeploymentFailedError,
1180-
match=f"Model deployment {response.data.id} failed to deploy: {response.data.lifecycle_details}",
1181-
):
1182-
model_deployment.deploy(wait_for_completion=False)
1183-
mock_create.assert_called()
1184-
mock_create_model_deployment.assert_called_with(
1185-
create_model_deployment_details
1186-
)
1187-
mock_sync.assert_called()
1188-
11891150
@patch.object(
11901151
OCIDataScienceModelDeployment,
11911152
"activate",

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,45 +1051,6 @@ def test_from_model_deployment(
10511051

10521052
assert test_result == test_model
10531053

1054-
@patch.object(
1055-
ModelDeployment,
1056-
"state",
1057-
new_callable=PropertyMock,
1058-
return_value=ModelDeploymentState.FAILED,
1059-
)
1060-
@patch.object(ModelDeployment, "from_id")
1061-
@patch("ads.common.auth.default_signer")
1062-
@patch("ads.common.oci_client.OCIClientFactory")
1063-
def test_from_model_deployment_fail(
1064-
self,
1065-
mock_client,
1066-
mock_default_signer,
1067-
mock_from_id,
1068-
mock_model_deployment_state,
1069-
):
1070-
"""Tests loading model from model deployment."""
1071-
test_auth_config = {"signer": {"config": "value"}}
1072-
mock_default_signer.return_value = test_auth_config
1073-
test_model_deployment_id = "md_ocid"
1074-
test_model_id = "model_ocid"
1075-
md_props = ModelDeploymentProperties(model_id=test_model_id)
1076-
md = ModelDeployment(properties=md_props)
1077-
mock_from_id.return_value = md
1078-
1079-
with pytest.raises(NotActiveDeploymentError):
1080-
GenericModel.from_model_deployment(
1081-
model_deployment_id=test_model_deployment_id,
1082-
model_file_name="test.pkl",
1083-
artifact_dir="test_dir",
1084-
auth=test_auth_config,
1085-
force_overwrite=True,
1086-
properties=None,
1087-
bucket_uri="test_bucket_uri",
1088-
remove_existing_artifact=True,
1089-
compartment_id="test_compartment_id",
1090-
)
1091-
mock_from_id.assert_called_with(test_model_deployment_id)
1092-
10931054
@patch.object(ModelDeployment, "update")
10941055
@patch.object(ModelDeployment, "from_id")
10951056
@patch("ads.common.auth.default_signer")

0 commit comments

Comments
 (0)