|
15 | 15 | # limitations under the License. |
16 | 16 | # |
17 | 17 |
|
| 18 | +import copy |
18 | 19 | import pytest |
19 | 20 |
|
20 | 21 | from unittest import mock |
|
56 | 57 |
|
57 | 58 | _TEST_DISPLAY_NAME = "test-display-name" |
58 | 59 | _TEST_DISPLAY_NAME_2 = "test-display-name-2" |
| 60 | +_TEST_DISPLAY_NAME_3 = "test-display-name-3" |
59 | 61 | _TEST_ID = "1028944691210842416" |
60 | 62 | _TEST_ID_2 = "4366591682456584192" |
| 63 | +_TEST_ID_3 = "5820582938582924817" |
61 | 64 | _TEST_DESCRIPTION = "test-description" |
62 | 65 |
|
63 | 66 | _TEST_ENDPOINT_NAME = ( |
|
80 | 83 | _TEST_DEPLOYED_MODELS = [ |
81 | 84 | gca_endpoint.DeployedModel(id=_TEST_ID, display_name=_TEST_DISPLAY_NAME), |
82 | 85 | gca_endpoint.DeployedModel(id=_TEST_ID_2, display_name=_TEST_DISPLAY_NAME_2), |
| 86 | + gca_endpoint.DeployedModel(id=_TEST_ID_3, display_name=_TEST_DISPLAY_NAME_3), |
| 87 | +] |
| 88 | + |
| 89 | +_TEST_TRAFFIC_SPLIT = {_TEST_ID: 0, _TEST_ID_2: 100, _TEST_ID_3: 0} |
| 90 | + |
| 91 | +_TEST_LONG_TRAFFIC_SPLIT = { |
| 92 | + "m1": 40, |
| 93 | + "m2": 10, |
| 94 | + "m3": 30, |
| 95 | + "m4": 0, |
| 96 | + "m5": 5, |
| 97 | + "m6": 8, |
| 98 | + "m7": 7, |
| 99 | +} |
| 100 | +_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS = ["m4", "m5", "m7", "m6", "m2", "m3", "m1"] |
| 101 | +_TEST_LONG_DEPLOYED_MODELS = [ |
| 102 | + gca_endpoint.DeployedModel(id=id, display_name=f"{id}_display_name") |
| 103 | + for id in _TEST_LONG_TRAFFIC_SPLIT.keys() |
83 | 104 | ] |
84 | 105 |
|
85 | 106 | _TEST_MACHINE_TYPE = "n1-standard-32" |
@@ -200,6 +221,21 @@ def get_endpoint_with_models_mock(): |
200 | 221 | display_name=_TEST_DISPLAY_NAME, |
201 | 222 | name=_TEST_ENDPOINT_NAME, |
202 | 223 | deployed_models=_TEST_DEPLOYED_MODELS, |
| 224 | + traffic_split=_TEST_TRAFFIC_SPLIT, |
| 225 | + ) |
| 226 | + yield get_endpoint_mock |
| 227 | + |
| 228 | + |
| 229 | +@pytest.fixture |
| 230 | +def get_endpoint_with_many_models_mock(): |
| 231 | + with mock.patch.object( |
| 232 | + endpoint_service_client.EndpointServiceClient, "get_endpoint" |
| 233 | + ) as get_endpoint_mock: |
| 234 | + get_endpoint_mock.return_value = gca_endpoint.Endpoint( |
| 235 | + display_name=_TEST_DISPLAY_NAME, |
| 236 | + name=_TEST_ENDPOINT_NAME, |
| 237 | + deployed_models=_TEST_LONG_DEPLOYED_MODELS, |
| 238 | + traffic_split=_TEST_LONG_TRAFFIC_SPLIT, |
203 | 239 | ) |
204 | 240 | yield get_endpoint_mock |
205 | 241 |
|
@@ -990,23 +1026,84 @@ def test_undeploy_with_traffic_split(self, undeploy_model_mock, sync): |
990 | 1026 | @pytest.mark.usefixtures("get_endpoint_mock") |
991 | 1027 | @pytest.mark.parametrize("sync", [True, False]) |
992 | 1028 | def test_undeploy_raise_error_traffic_split_total(self, sync): |
993 | | - with pytest.raises(ValueError): |
| 1029 | + with pytest.raises(ValueError) as e: |
994 | 1030 | test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
995 | 1031 | test_endpoint.undeploy( |
996 | 1032 | deployed_model_id="model1", traffic_split={"model2": 99}, sync=sync |
997 | 1033 | ) |
998 | 1034 |
|
| 1035 | + assert e.match("Sum of all traffic within traffic split needs to be 100.") |
| 1036 | + |
999 | 1037 | @pytest.mark.usefixtures("get_endpoint_mock") |
1000 | 1038 | @pytest.mark.parametrize("sync", [True, False]) |
1001 | 1039 | def test_undeploy_raise_error_undeployed_model_traffic(self, sync): |
1002 | | - with pytest.raises(ValueError): |
| 1040 | + with pytest.raises(ValueError) as e: |
1003 | 1041 | test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
1004 | 1042 | test_endpoint.undeploy( |
1005 | 1043 | deployed_model_id="model1", |
1006 | 1044 | traffic_split={"model1": 50, "model2": 50}, |
1007 | 1045 | sync=sync, |
1008 | 1046 | ) |
1009 | 1047 |
|
| 1048 | + assert e.match("Model being undeployed should have 0 traffic.") |
| 1049 | + |
| 1050 | + @pytest.mark.usefixtures("get_endpoint_with_models_mock") |
| 1051 | + @pytest.mark.parametrize("sync", [True, False]) |
| 1052 | + def test_undeploy_raises_error_on_zero_leftover_traffic(self, sync): |
| 1053 | + """ |
| 1054 | + Attempting to undeploy model with 100% traffic on an Endpoint with |
| 1055 | + multiple models deployed without an updated traffic_split should |
| 1056 | + raise an informative error. |
| 1057 | + """ |
| 1058 | + |
| 1059 | + traffic_remaining = _TEST_TRAFFIC_SPLIT[_TEST_ID_2] |
| 1060 | + |
| 1061 | + assert traffic_remaining == 100 # Confirm this model has all traffic |
| 1062 | + assert sum(_TEST_TRAFFIC_SPLIT.values()) == 100 # Mock traffic sums to 100% |
| 1063 | + |
| 1064 | + with pytest.raises(ValueError) as e: |
| 1065 | + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
| 1066 | + test_endpoint.undeploy( |
| 1067 | + deployed_model_id=_TEST_ID_2, sync=sync, |
| 1068 | + ) |
| 1069 | + |
| 1070 | + assert e.match( |
| 1071 | + f"Undeploying deployed model '{_TEST_ID_2}' would leave the remaining " |
| 1072 | + f"traffic split at 0%." |
| 1073 | + ) |
| 1074 | + |
| 1075 | + @pytest.mark.usefixtures("get_endpoint_with_models_mock") |
| 1076 | + @pytest.mark.parametrize("sync", [True, False]) |
| 1077 | + def test_undeploy_zero_traffic_model_without_new_traffic_split( |
| 1078 | + self, undeploy_model_mock, sync |
| 1079 | + ): |
| 1080 | + """ |
| 1081 | + Attempting to undeploy model with zero traffic without providing |
| 1082 | + a new traffic split should not raise any errors. |
| 1083 | + """ |
| 1084 | + |
| 1085 | + traffic_remaining = _TEST_TRAFFIC_SPLIT[_TEST_ID_3] |
| 1086 | + |
| 1087 | + assert not traffic_remaining # Confirm there is zero traffic |
| 1088 | + |
| 1089 | + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) |
| 1090 | + test_endpoint.undeploy( |
| 1091 | + deployed_model_id=_TEST_ID_3, sync=sync, |
| 1092 | + ) |
| 1093 | + |
| 1094 | + if not sync: |
| 1095 | + test_endpoint.wait() |
| 1096 | + |
| 1097 | + expected_new_traffic_split = copy.deepcopy(_TEST_TRAFFIC_SPLIT) |
| 1098 | + expected_new_traffic_split.pop(_TEST_ID_3) |
| 1099 | + |
| 1100 | + undeploy_model_mock.assert_called_once_with( |
| 1101 | + endpoint=test_endpoint.resource_name, |
| 1102 | + deployed_model_id=_TEST_ID_3, |
| 1103 | + traffic_split=expected_new_traffic_split, |
| 1104 | + metadata=(), |
| 1105 | + ) |
| 1106 | + |
1010 | 1107 | def test_predict(self, get_endpoint_mock, predict_client_predict_mock): |
1011 | 1108 |
|
1012 | 1109 | test_endpoint = models.Endpoint(_TEST_ID) |
@@ -1057,23 +1154,28 @@ def test_list_models(self, get_endpoint_with_models_mock): |
1057 | 1154 |
|
1058 | 1155 | assert my_models == _TEST_DEPLOYED_MODELS |
1059 | 1156 |
|
1060 | | - @pytest.mark.usefixtures("get_endpoint_with_models_mock") |
| 1157 | + @pytest.mark.usefixtures("get_endpoint_with_many_models_mock") |
1061 | 1158 | @pytest.mark.parametrize("sync", [True, False]) |
1062 | 1159 | def test_undeploy_all(self, sdk_private_undeploy_mock, sync): |
1063 | 1160 |
|
| 1161 | + # Ensure mock traffic split deployed model IDs are same as expected IDs |
| 1162 | + assert set(_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS) == set( |
| 1163 | + _TEST_LONG_TRAFFIC_SPLIT.keys() |
| 1164 | + ) |
| 1165 | + |
1064 | 1166 | ept = aiplatform.Endpoint(_TEST_ID) |
1065 | 1167 | ept.undeploy_all(sync=sync) |
1066 | 1168 |
|
1067 | 1169 | if not sync: |
1068 | 1170 | ept.wait() |
1069 | 1171 |
|
1070 | 1172 | # undeploy_all() results in an undeploy() call for each deployed_model |
| 1173 | + # Models are undeployed in ascending order of traffic percentage |
1071 | 1174 | sdk_private_undeploy_mock.assert_has_calls( |
1072 | 1175 | [ |
1073 | | - mock.call(deployed_model_id=deployed_model.id, sync=sync) |
1074 | | - for deployed_model in _TEST_DEPLOYED_MODELS |
| 1176 | + mock.call(deployed_model_id=deployed_model_id, sync=sync) |
| 1177 | + for deployed_model_id in _TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS |
1075 | 1178 | ], |
1076 | | - any_order=True, |
1077 | 1179 | ) |
1078 | 1180 |
|
1079 | 1181 | @pytest.mark.usefixtures("list_endpoints_mock") |
|
0 commit comments