|
4 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
5 | 5 | from unicodedata import category |
6 | 6 | from unittest import TestCase |
7 | | -from unittest.mock import MagicMock, patch |
| 7 | +from unittest.mock import MagicMock, patch, ANY |
8 | 8 |
|
9 | 9 | import pytest |
10 | 10 | from huggingface_hub.hf_api import HfApi, ModelInfo |
|
14 | 14 |
|
15 | 15 | from ads.aqua.common.errors import AquaRuntimeError |
16 | 16 | from ads.aqua.common.utils import get_hf_model_info |
17 | | -from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES |
| 17 | +from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES, AQUA_CHAT_TEMPLATE_METADATA_KEY |
18 | 18 | from ads.aqua.extension.errors import ReplyDetails |
19 | 19 | from ads.aqua.extension.model_handler import ( |
20 | 20 | AquaHuggingFaceHandler, |
21 | 21 | AquaModelHandler, |
22 | 22 | AquaModelLicenseHandler, |
23 | | - AquaModelTokenizerConfigHandler, |
| 23 | + AquaModelChatTemplateHandler |
24 | 24 | ) |
25 | 25 | from ads.aqua.model import AquaModelApp |
26 | 26 | from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary |
@@ -254,39 +254,114 @@ def test_get(self, mock_load_license): |
254 | 254 | mock_load_license.assert_called_with("test_model_id") |
255 | 255 |
|
256 | 256 |
|
257 | | -class ModelTokenizerConfigHandlerTestCase(TestCase): |
| 257 | +class AquaModelChatTemplateHandlerTestCase(TestCase): |
258 | 258 | @patch.object(IPythonHandler, "__init__") |
259 | 259 | def setUp(self, ipython_init_mock) -> None: |
260 | 260 | ipython_init_mock.return_value = None |
261 | | - self.model_tokenizer_config_handler = AquaModelTokenizerConfigHandler( |
| 261 | + self.model_chat_template_handler = AquaModelChatTemplateHandler( |
262 | 262 | MagicMock(), MagicMock() |
263 | 263 | ) |
264 | | - self.model_tokenizer_config_handler.finish = MagicMock() |
265 | | - self.model_tokenizer_config_handler.request = MagicMock() |
| 264 | + self.model_chat_template_handler.finish = MagicMock() |
| 265 | + self.model_chat_template_handler.request = MagicMock() |
| 266 | + self.model_chat_template_handler._headers = {} |
266 | 267 |
|
267 | | - @patch.object(AquaModelApp, "get_hf_tokenizer_config") |
| 268 | + @patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id") |
268 | 269 | @patch("ads.aqua.extension.model_handler.urlparse") |
269 | | - def test_get(self, mock_urlparse, mock_get_hf_tokenizer_config): |
270 | | - request_path = MagicMock(path="aqua/model/ocid1.xx./tokenizer") |
| 270 | + def test_get_valid_path(self, mock_urlparse, mock_from_id): |
| 271 | + request_path = MagicMock(path="/aqua/models/ocid1.xx./chat-template") |
271 | 272 | mock_urlparse.return_value = request_path |
272 | | - self.model_tokenizer_config_handler.get(model_id="test_model_id") |
273 | | - self.model_tokenizer_config_handler.finish.assert_called_with( |
274 | | - mock_get_hf_tokenizer_config.return_value |
275 | | - ) |
276 | | - mock_get_hf_tokenizer_config.assert_called_with("test_model_id") |
277 | 273 |
|
278 | | - @patch.object(AquaModelApp, "get_hf_tokenizer_config") |
| 274 | + model_mock = MagicMock() |
| 275 | + model_mock.get_custom_metadata_artifact.return_value = "chat_template_string" |
| 276 | + mock_from_id.return_value = model_mock |
| 277 | + |
| 278 | + self.model_chat_template_handler.get(model_id="test_model_id") |
| 279 | + self.model_chat_template_handler.finish.assert_called_with("chat_template_string") |
| 280 | + model_mock.get_custom_metadata_artifact.assert_called_with("chat_template") |
| 281 | + |
279 | 282 | @patch("ads.aqua.extension.model_handler.urlparse") |
280 | | - def test_get_invalid_path(self, mock_urlparse, mock_get_hf_tokenizer_config): |
281 | | - """Test invalid request path should raise HTTPError(400)""" |
282 | | - request_path = MagicMock(path="/invalid/path") |
| 283 | + def test_get_invalid_path(self, mock_urlparse): |
| 284 | + request_path = MagicMock(path="/wrong/path") |
283 | 285 | mock_urlparse.return_value = request_path |
284 | 286 |
|
285 | 287 | with self.assertRaises(HTTPError) as context: |
286 | | - self.model_tokenizer_config_handler.get(model_id="test_model_id") |
| 288 | + self.model_chat_template_handler.get("ocid1.test.chat") |
287 | 289 | self.assertEqual(context.exception.status_code, 400) |
288 | | - self.model_tokenizer_config_handler.finish.assert_not_called() |
289 | | - mock_get_hf_tokenizer_config.assert_not_called() |
| 290 | + |
| 291 | + @patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id", side_effect=Exception("Not found")) |
| 292 | + @patch("ads.aqua.extension.model_handler.urlparse") |
| 293 | + def test_get_model_not_found(self, mock_urlparse, mock_from_id): |
| 294 | + request_path = MagicMock(path="/aqua/models/ocid1.invalid/chat-template") |
| 295 | + mock_urlparse.return_value = request_path |
| 296 | + |
| 297 | + with self.assertRaises(HTTPError) as context: |
| 298 | + self.model_chat_template_handler.get("ocid1.invalid") |
| 299 | + self.assertEqual(context.exception.status_code, 404) |
| 300 | + |
| 301 | + @patch("ads.aqua.extension.model_handler.DataScienceModel.from_id") |
| 302 | + def test_post_valid(self, mock_from_id): |
| 303 | + model_mock = MagicMock() |
| 304 | + model_mock.create_custom_metadata_artifact.return_value = {"result": "success"} |
| 305 | + mock_from_id.return_value = model_mock |
| 306 | + |
| 307 | + self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "Hello <|user|>"}) |
| 308 | + result = self.model_chat_template_handler.post("ocid1.valid") |
| 309 | + self.model_chat_template_handler.finish.assert_called_with({"result": "success"}) |
| 310 | + |
| 311 | + model_mock.create_custom_metadata_artifact.assert_called_with( |
| 312 | + metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY, |
| 313 | + path_type=ANY, |
| 314 | + artifact_path_or_content=b"Hello <|user|>" |
| 315 | + ) |
| 316 | + |
| 317 | + @patch.object(AquaModelChatTemplateHandler, "write_error") |
| 318 | + def test_post_invalid_json(self, mock_write_error): |
| 319 | + self.model_chat_template_handler.get_json_body = MagicMock(side_effect=Exception("Invalid JSON")) |
| 320 | + self.model_chat_template_handler._headers = {} |
| 321 | + self.model_chat_template_handler.post("ocid1.test.invalidjson") |
| 322 | + |
| 323 | + mock_write_error.assert_called_once() |
| 324 | + |
| 325 | + kwargs = mock_write_error.call_args.kwargs |
| 326 | + exc_info = kwargs.get("exc_info") |
| 327 | + |
| 328 | + assert exc_info is not None |
| 329 | + exc_type, exc_instance, _ = exc_info |
| 330 | + |
| 331 | + assert isinstance(exc_instance, HTTPError) |
| 332 | + assert exc_instance.status_code == 400 |
| 333 | + assert "Invalid JSON body" in str(exc_instance) |
| 334 | + |
| 335 | + @patch.object(AquaModelChatTemplateHandler, "write_error") |
| 336 | + def test_post_missing_chat_template(self, mock_write_error): |
| 337 | + self.model_chat_template_handler.get_json_body = MagicMock(return_value={}) |
| 338 | + self.model_chat_template_handler._headers = {} |
| 339 | + |
| 340 | + self.model_chat_template_handler.post("ocid1.test.model") |
| 341 | + |
| 342 | + mock_write_error.assert_called_once() |
| 343 | + exc_info = mock_write_error.call_args.kwargs.get("exc_info") |
| 344 | + assert exc_info is not None |
| 345 | + _, exc_instance, _ = exc_info |
| 346 | + assert isinstance(exc_instance, HTTPError) |
| 347 | + assert exc_instance.status_code == 400 |
| 348 | + assert "Missing required field: 'chat_template'" in str(exc_instance) |
| 349 | + |
| 350 | + @patch("ads.aqua.extension.model_handler.DataScienceModel.from_id", side_effect=Exception("Not found")) |
| 351 | + @patch.object(AquaModelChatTemplateHandler, "write_error") |
| 352 | + def test_post_model_not_found(self, mock_write_error, mock_from_id): |
| 353 | + self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "test template"}) |
| 354 | + self.model_chat_template_handler._headers = {} |
| 355 | + |
| 356 | + self.model_chat_template_handler.post("ocid1.invalid.model") |
| 357 | + |
| 358 | + mock_write_error.assert_called_once() |
| 359 | + exc_info = mock_write_error.call_args.kwargs.get("exc_info") |
| 360 | + assert exc_info is not None |
| 361 | + _, exc_instance, _ = exc_info |
| 362 | + assert isinstance(exc_instance, HTTPError) |
| 363 | + assert exc_instance.status_code == 404 |
| 364 | + assert "Model not found" in str(exc_instance) |
290 | 365 |
|
291 | 366 |
|
292 | 367 | class TestAquaHuggingFaceHandler: |
|
0 commit comments