|
6 | 6 | from dotenv import load_dotenv |
7 | 7 | from fastapi import Request |
8 | 8 | from datetime import datetime |
9 | | -from unittest.mock import AsyncMock, patch |
| 9 | +from unittest.mock import AsyncMock, patch, MagicMock |
10 | 10 |
|
11 | 11 | sys.path.insert( |
12 | 12 | 0, os.path.abspath("../..") |
@@ -553,9 +553,67 @@ def test_initialize_router_endpoints(): |
553 | 553 | assert hasattr(router, "aanthropic_messages") |
554 | 554 | assert hasattr(router, "aresponses") |
555 | 555 | assert hasattr(router, "responses") |
556 | | - |
| 556 | + assert hasattr(router, "aget_responses") |
| 557 | + assert hasattr(router, "adelete_responses") |
557 | 558 | # Verify the endpoints are callable |
558 | 559 | assert callable(router.amoderation) |
559 | 560 | assert callable(router.aanthropic_messages) |
560 | 561 | assert callable(router.aresponses) |
561 | 562 | assert callable(router.responses) |
| 563 | + assert callable(router.aget_responses) |
| 564 | + assert callable(router.adelete_responses) |
| 565 | + |
| 566 | + |
| 567 | +@pytest.mark.asyncio |
| 568 | +async def test_init_responses_api_endpoints(): |
| 569 | + """ |
| 570 | + A simpler test for _init_responses_api_endpoints that focuses on the basic functionality |
| 571 | + """ |
| 572 | + from litellm.responses.utils import ResponsesAPIRequestUtils |
| 573 | + # Create a router with a basic model |
| 574 | + router = Router( |
| 575 | + model_list=[ |
| 576 | + { |
| 577 | + "model_name": "test-model", |
| 578 | + "litellm_params": { |
| 579 | + "model": "openai/test-model", |
| 580 | + "api_key": "fake-api-key", |
| 581 | + }, |
| 582 | + } |
| 583 | + ] |
| 584 | + ) |
| 585 | + |
| 586 | + # Just mock the _ageneric_api_call_with_fallbacks method |
| 587 | + router._ageneric_api_call_with_fallbacks = AsyncMock() |
| 588 | + |
| 589 | + # Add a mock implementation of _get_model_id_from_response_id to the Router instance |
| 590 | + ResponsesAPIRequestUtils.get_model_id_from_response_id = MagicMock(return_value=None) |
| 591 | + |
| 592 | + # Call without a response_id (no model extraction should happen) |
| 593 | + await router._init_responses_api_endpoints( |
| 594 | + original_function=AsyncMock(), |
| 595 | + thread_id="thread_xyz" |
| 596 | + ) |
| 597 | + |
| 598 | + # Verify _ageneric_api_call_with_fallbacks was called but model wasn't changed |
| 599 | + first_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs |
| 600 | + assert "model" not in first_call_kwargs |
| 601 | + assert first_call_kwargs["thread_id"] == "thread_xyz" |
| 602 | + |
| 603 | + # Reset the mock |
| 604 | + router._ageneric_api_call_with_fallbacks.reset_mock() |
| 605 | + |
| 606 | + # Change the return value for the second call |
| 607 | + ResponsesAPIRequestUtils.get_model_id_from_response_id.return_value = "claude-3-sonnet" |
| 608 | + |
| 609 | + # Call with a response_id |
| 610 | + await router._init_responses_api_endpoints( |
| 611 | + original_function=AsyncMock(), |
| 612 | + response_id="resp_claude_123" |
| 613 | + ) |
| 614 | + |
| 615 | + # Verify model was updated in the kwargs |
| 616 | + second_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs |
| 617 | + assert second_call_kwargs["model"] == "claude-3-sonnet" |
| 618 | + assert second_call_kwargs["response_id"] == "resp_claude_123" |
| 619 | + |
0 commit comments