Skip to content

Commit 8a8a4fa

Browse files
authored
fix: Improve handling of undeploying model without redistributing remaining traffic (#898)
- Add informative error when undeploying a model with traffic from an Endpoint with multiple deployed models, without providing a new traffic split. - Improve accuracy of docstring for `Endpoint.undeploy()` - Add tests to cover CUJs Fixes [b/198290421](http://b/198290421) 🦕
1 parent 321cf9e commit 8a8a4fa

File tree

2 files changed

+137
-16
lines changed

2 files changed

+137
-16
lines changed

google/cloud/aiplatform/models.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -997,21 +997,22 @@ def undeploy(
997997
) -> None:
998998
"""Undeploys a deployed model.
999999
1000-
Proportionally adjusts the traffic_split among the remaining deployed
1001-
models of the endpoint.
1000+
The model to be undeployed should have no traffic or user must provide
1001+
a new traffic_split with the remaining deployed models. Refer
1002+
to `Endpoint.traffic_split` for the current traffic split mapping.
10021003
10031004
Args:
10041005
deployed_model_id (str):
10051006
Required. The ID of the DeployedModel to be undeployed from the
10061007
Endpoint.
10071008
traffic_split (Dict[str, int]):
1008-
Optional. A map from a DeployedModel's ID to the percentage of
1009+
Optional. A map of DeployedModel IDs to the percentage of
10091010
this Endpoint's traffic that should be forwarded to that DeployedModel.
1010-
If a DeployedModel's ID is not listed in this map, then it receives
1011-
no traffic. The traffic percentage values must add up to 100, or
1012-
map must be empty if the Endpoint is to not accept any traffic at
1013-
the moment. Key for model being deployed is "0". Should not be
1014-
provided if traffic_percentage is provided.
1011+
Required if undeploying a model with non-zero traffic from an Endpoint
1012+
with multiple deployed models. The traffic percentage values must add
1013+
up to 100, or map must be empty if the Endpoint is to not accept any traffic
1014+
at the moment. If a DeployedModel's ID is not listed in this map, then it
1015+
receives no traffic.
10151016
metadata (Sequence[Tuple[str, str]]):
10161017
Optional. Strings which should be sent along with the request as
10171018
metadata.
@@ -1026,6 +1027,19 @@ def undeploy(
10261027
"Sum of all traffic within traffic split needs to be 100."
10271028
)
10281029

1030+
# Two or more models deployed to Endpoint and remaining traffic will be zero
1031+
elif (
1032+
len(self.traffic_split) > 1
1033+
and deployed_model_id in self._gca_resource.traffic_split
1034+
and self._gca_resource.traffic_split[deployed_model_id] == 100
1035+
):
1036+
raise ValueError(
1037+
f"Undeploying deployed model '{deployed_model_id}' would leave the remaining "
1038+
"traffic split at 0%. Traffic split must add up to 100% when models are "
1039+
"deployed. Please undeploy the other models first or provide an updated "
1040+
"traffic_split."
1041+
)
1042+
10291043
self._undeploy(
10301044
deployed_model_id=deployed_model_id,
10311045
traffic_split=traffic_split,
@@ -1282,8 +1296,13 @@ def undeploy_all(self, sync: bool = True) -> "Endpoint":
12821296
"""
12831297
self._sync_gca_resource()
12841298

1285-
for deployed_model in self._gca_resource.deployed_models:
1286-
self._undeploy(deployed_model_id=deployed_model.id, sync=sync)
1299+
models_to_undeploy = sorted( # Undeploy zero traffic models first
1300+
self._gca_resource.traffic_split.keys(),
1301+
key=lambda id: self._gca_resource.traffic_split[id],
1302+
)
1303+
1304+
for deployed_model in models_to_undeploy:
1305+
self._undeploy(deployed_model_id=deployed_model, sync=sync)
12871306

12881307
return self
12891308

tests/unit/aiplatform/test_endpoints.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import copy
1819
import pytest
1920

2021
from unittest import mock
@@ -56,8 +57,10 @@
5657

5758
_TEST_DISPLAY_NAME = "test-display-name"
5859
_TEST_DISPLAY_NAME_2 = "test-display-name-2"
60+
_TEST_DISPLAY_NAME_3 = "test-display-name-3"
5961
_TEST_ID = "1028944691210842416"
6062
_TEST_ID_2 = "4366591682456584192"
63+
_TEST_ID_3 = "5820582938582924817"
6164
_TEST_DESCRIPTION = "test-description"
6265

6366
_TEST_ENDPOINT_NAME = (
@@ -80,6 +83,24 @@
8083
_TEST_DEPLOYED_MODELS = [
8184
gca_endpoint.DeployedModel(id=_TEST_ID, display_name=_TEST_DISPLAY_NAME),
8285
gca_endpoint.DeployedModel(id=_TEST_ID_2, display_name=_TEST_DISPLAY_NAME_2),
86+
gca_endpoint.DeployedModel(id=_TEST_ID_3, display_name=_TEST_DISPLAY_NAME_3),
87+
]
88+
89+
_TEST_TRAFFIC_SPLIT = {_TEST_ID: 0, _TEST_ID_2: 100, _TEST_ID_3: 0}
90+
91+
_TEST_LONG_TRAFFIC_SPLIT = {
92+
"m1": 40,
93+
"m2": 10,
94+
"m3": 30,
95+
"m4": 0,
96+
"m5": 5,
97+
"m6": 8,
98+
"m7": 7,
99+
}
100+
_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS = ["m4", "m5", "m7", "m6", "m2", "m3", "m1"]
101+
_TEST_LONG_DEPLOYED_MODELS = [
102+
gca_endpoint.DeployedModel(id=id, display_name=f"{id}_display_name")
103+
for id in _TEST_LONG_TRAFFIC_SPLIT.keys()
83104
]
84105

85106
_TEST_MACHINE_TYPE = "n1-standard-32"
@@ -200,6 +221,21 @@ def get_endpoint_with_models_mock():
200221
display_name=_TEST_DISPLAY_NAME,
201222
name=_TEST_ENDPOINT_NAME,
202223
deployed_models=_TEST_DEPLOYED_MODELS,
224+
traffic_split=_TEST_TRAFFIC_SPLIT,
225+
)
226+
yield get_endpoint_mock
227+
228+
229+
@pytest.fixture
230+
def get_endpoint_with_many_models_mock():
231+
with mock.patch.object(
232+
endpoint_service_client.EndpointServiceClient, "get_endpoint"
233+
) as get_endpoint_mock:
234+
get_endpoint_mock.return_value = gca_endpoint.Endpoint(
235+
display_name=_TEST_DISPLAY_NAME,
236+
name=_TEST_ENDPOINT_NAME,
237+
deployed_models=_TEST_LONG_DEPLOYED_MODELS,
238+
traffic_split=_TEST_LONG_TRAFFIC_SPLIT,
203239
)
204240
yield get_endpoint_mock
205241

@@ -990,23 +1026,84 @@ def test_undeploy_with_traffic_split(self, undeploy_model_mock, sync):
9901026
@pytest.mark.usefixtures("get_endpoint_mock")
9911027
@pytest.mark.parametrize("sync", [True, False])
9921028
def test_undeploy_raise_error_traffic_split_total(self, sync):
993-
with pytest.raises(ValueError):
1029+
with pytest.raises(ValueError) as e:
9941030
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
9951031
test_endpoint.undeploy(
9961032
deployed_model_id="model1", traffic_split={"model2": 99}, sync=sync
9971033
)
9981034

1035+
assert e.match("Sum of all traffic within traffic split needs to be 100.")
1036+
9991037
@pytest.mark.usefixtures("get_endpoint_mock")
10001038
@pytest.mark.parametrize("sync", [True, False])
10011039
def test_undeploy_raise_error_undeployed_model_traffic(self, sync):
1002-
with pytest.raises(ValueError):
1040+
with pytest.raises(ValueError) as e:
10031041
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
10041042
test_endpoint.undeploy(
10051043
deployed_model_id="model1",
10061044
traffic_split={"model1": 50, "model2": 50},
10071045
sync=sync,
10081046
)
10091047

1048+
assert e.match("Model being undeployed should have 0 traffic.")
1049+
1050+
@pytest.mark.usefixtures("get_endpoint_with_models_mock")
1051+
@pytest.mark.parametrize("sync", [True, False])
1052+
def test_undeploy_raises_error_on_zero_leftover_traffic(self, sync):
1053+
"""
1054+
Attempting to undeploy model with 100% traffic on an Endpoint with
1055+
multiple models deployed without an updated traffic_split should
1056+
raise an informative error.
1057+
"""
1058+
1059+
traffic_remaining = _TEST_TRAFFIC_SPLIT[_TEST_ID_2]
1060+
1061+
assert traffic_remaining == 100 # Confirm this model has all traffic
1062+
assert sum(_TEST_TRAFFIC_SPLIT.values()) == 100 # Mock traffic sums to 100%
1063+
1064+
with pytest.raises(ValueError) as e:
1065+
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
1066+
test_endpoint.undeploy(
1067+
deployed_model_id=_TEST_ID_2, sync=sync,
1068+
)
1069+
1070+
assert e.match(
1071+
f"Undeploying deployed model '{_TEST_ID_2}' would leave the remaining "
1072+
f"traffic split at 0%."
1073+
)
1074+
1075+
@pytest.mark.usefixtures("get_endpoint_with_models_mock")
1076+
@pytest.mark.parametrize("sync", [True, False])
1077+
def test_undeploy_zero_traffic_model_without_new_traffic_split(
1078+
self, undeploy_model_mock, sync
1079+
):
1080+
"""
1081+
Attempting to undeploy model with zero traffic without providing
1082+
a new traffic split should not raise any errors.
1083+
"""
1084+
1085+
traffic_remaining = _TEST_TRAFFIC_SPLIT[_TEST_ID_3]
1086+
1087+
assert not traffic_remaining # Confirm there is zero traffic
1088+
1089+
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
1090+
test_endpoint.undeploy(
1091+
deployed_model_id=_TEST_ID_3, sync=sync,
1092+
)
1093+
1094+
if not sync:
1095+
test_endpoint.wait()
1096+
1097+
expected_new_traffic_split = copy.deepcopy(_TEST_TRAFFIC_SPLIT)
1098+
expected_new_traffic_split.pop(_TEST_ID_3)
1099+
1100+
undeploy_model_mock.assert_called_once_with(
1101+
endpoint=test_endpoint.resource_name,
1102+
deployed_model_id=_TEST_ID_3,
1103+
traffic_split=expected_new_traffic_split,
1104+
metadata=(),
1105+
)
1106+
10101107
def test_predict(self, get_endpoint_mock, predict_client_predict_mock):
10111108

10121109
test_endpoint = models.Endpoint(_TEST_ID)
@@ -1057,23 +1154,28 @@ def test_list_models(self, get_endpoint_with_models_mock):
10571154

10581155
assert my_models == _TEST_DEPLOYED_MODELS
10591156

1060-
@pytest.mark.usefixtures("get_endpoint_with_models_mock")
1157+
@pytest.mark.usefixtures("get_endpoint_with_many_models_mock")
10611158
@pytest.mark.parametrize("sync", [True, False])
10621159
def test_undeploy_all(self, sdk_private_undeploy_mock, sync):
10631160

1161+
# Ensure mock traffic split deployed model IDs are same as expected IDs
1162+
assert set(_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS) == set(
1163+
_TEST_LONG_TRAFFIC_SPLIT.keys()
1164+
)
1165+
10641166
ept = aiplatform.Endpoint(_TEST_ID)
10651167
ept.undeploy_all(sync=sync)
10661168

10671169
if not sync:
10681170
ept.wait()
10691171

10701172
# undeploy_all() results in an undeploy() call for each deployed_model
1173+
# Models are undeployed in ascending order of traffic percentage
10711174
sdk_private_undeploy_mock.assert_has_calls(
10721175
[
1073-
mock.call(deployed_model_id=deployed_model.id, sync=sync)
1074-
for deployed_model in _TEST_DEPLOYED_MODELS
1176+
mock.call(deployed_model_id=deployed_model_id, sync=sync)
1177+
for deployed_model_id in _TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS
10751178
],
1076-
any_order=True,
10771179
)
10781180

10791181
@pytest.mark.usefixtures("list_endpoints_mock")

0 commit comments

Comments
 (0)