Skip to content

Commit d4c63fc

Browse files
csvenjacopybara-github
authored andcommitted
chore: Add model tracking to LiteLlm and introduce a LiteLLM with fallbacks demo
Related: #2292 Co-authored-by: Shan Cao <caoshan@google.com> PiperOrigin-RevId: 828024955
1 parent e25beb4 commit d4c63fc

File tree

7 files changed

+180
-12
lines changed

7 files changed

+180
-12
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# LiteLLM with Fallback Models
2+
3+
This agent is built for resilience using LiteLLM's built-in fallback mechanism. It automatically switches models to guard against common disruptions like token limit errors and connection failures, while ensuring full conversational context is preserved across all model changes.
4+
5+
To run this example, ensure your .env file includes the following variables:
6+
```
7+
GOOGLE_API_KEY=
8+
OPENAI_API_KEY=
9+
ANTHROPIC_API_KEY=
10+
```
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
17+
from google.adk import Agent
18+
from google.adk.models.lite_llm import LiteLlm
19+
from google.adk.tools.tool_context import ToolContext
20+
from google.genai import types
21+
22+
23+
def roll_die(sides: int, tool_context: ToolContext) -> int:
24+
"""Roll a die and return the rolled result.
25+
26+
Args:
27+
sides: The integer number of sides the die has.
28+
tool_context: The tool context to use for the die roll.
29+
30+
Returns:
31+
An integer of the result of rolling the die.
32+
The result is also stored in the tool context for future use.
33+
"""
34+
result = random.randint(1, sides)
35+
if 'rolls' not in tool_context.state:
36+
tool_context.state['rolls'] = []
37+
38+
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
39+
return result
40+
41+
42+
async def before_model_callback(callback_context, llm_request):
43+
print('@before_model_callback')
44+
print(f'Beginning model choice: {llm_request.model}')
45+
callback_context.state['beginning_model_choice'] = llm_request.model
46+
return None
47+
48+
49+
async def after_model_callback(callback_context, llm_response):
50+
print('@after_model_callback')
51+
print(f'Final model choice: {llm_response.model_version}')
52+
callback_context.state['final_model_choice'] = llm_response.model_version
53+
return None
54+
55+
56+
root_agent = Agent(
57+
model=LiteLlm(
58+
model='gemini/gemini-2.5-pro',
59+
fallbacks=[
60+
'anthropic/claude-sonnet-4-5-20250929',
61+
'openai/gpt-4o',
62+
],
63+
),
64+
name='resilient_agent',
65+
description=(
66+
'hello world agent that can roll a dice of given number of sides.'
67+
),
68+
instruction="""
69+
You roll dice and answer questions about the outcome of the dice rolls.
70+
You can roll dice of different sizes.
71+
It is ok to discuss previous dice roles, and comment on the dice rolls.
72+
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
73+
You should never roll a die on your own.
74+
""",
75+
tools=[
76+
roll_die,
77+
],
78+
generate_content_config=types.GenerateContentConfig(
79+
safety_settings=[
80+
types.SafetySetting( # avoid false alarm about rolling dice.
81+
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
82+
threshold=types.HarmBlockThreshold.OFF,
83+
),
84+
]
85+
),
86+
before_model_callback=before_model_callback,
87+
after_model_callback=after_model_callback,
88+
)

src/google/adk/models/lite_llm.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,9 @@ def _model_response_to_generate_content_response(
558558
if not message:
559559
raise ValueError("No message in response")
560560

561-
llm_response = _message_to_generate_content_response(message)
561+
llm_response = _message_to_generate_content_response(
562+
message, model_version=response.model
563+
)
562564
if finish_reason:
563565
# If LiteLLM already provides a FinishReason enum (e.g., for Gemini), use
564566
# it directly. Otherwise, map the finish_reason string to the enum.
@@ -579,13 +581,14 @@ def _model_response_to_generate_content_response(
579581

580582

581583
def _message_to_generate_content_response(
582-
message: Message, is_partial: bool = False
584+
message: Message, *, is_partial: bool = False, model_version: str = None
583585
) -> LlmResponse:
584586
"""Converts a litellm message to LlmResponse.
585587
586588
Args:
587589
message: The message to convert.
588590
is_partial: Whether the message is partial.
591+
model_version: The model version used to generate the response.
589592
590593
Returns:
591594
The LlmResponse.
@@ -606,7 +609,9 @@ def _message_to_generate_content_response(
606609
parts.append(part)
607610

608611
return LlmResponse(
609-
content=types.Content(role="model", parts=parts), partial=is_partial
612+
content=types.Content(role="model", parts=parts),
613+
partial=is_partial,
614+
model_version=model_version,
610615
)
611616

612617

@@ -950,6 +955,7 @@ async def generate_content_async(
950955
content=chunk.text,
951956
),
952957
is_partial=True,
958+
model_version=part.model,
953959
)
954960
elif isinstance(chunk, UsageMetadataChunk):
955961
usage_metadata = types.GenerateContentResponseUsageMetadata(
@@ -981,14 +987,16 @@ async def generate_content_async(
981987
role="assistant",
982988
content=text,
983989
tool_calls=tool_calls,
984-
)
990+
),
991+
model_version=part.model,
985992
)
986993
)
987994
text = ""
988995
function_calls.clear()
989996
elif finish_reason == "stop" and text:
990997
aggregated_llm_response = _message_to_generate_content_response(
991-
ChatCompletionAssistantMessage(role="assistant", content=text)
998+
ChatCompletionAssistantMessage(role="assistant", content=text),
999+
model_version=part.model,
9921000
)
9931001
text = ""
9941002

src/google/adk/models/llm_response.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class LlmResponse(BaseModel):
5555
)
5656
"""The pydantic model config."""
5757

58+
model_version: Optional[str] = None
59+
"""Output only. The model version used to generate the response."""
60+
5861
content: Optional[types.Content] = None
5962
"""The generative content of the response.
6063
@@ -159,6 +162,7 @@ def create(
159162
citation_metadata=candidate.citation_metadata,
160163
avg_logprobs=candidate.avg_logprobs,
161164
logprobs_result=candidate.logprobs_result,
165+
model_version=generate_content_response.model_version,
162166
)
163167
else:
164168
return LlmResponse(
@@ -169,6 +173,7 @@ def create(
169173
finish_reason=candidate.finish_reason,
170174
avg_logprobs=candidate.avg_logprobs,
171175
logprobs_result=candidate.logprobs_result,
176+
model_version=generate_content_response.model_version,
172177
)
173178
else:
174179
if generate_content_response.prompt_feedback:
@@ -177,10 +182,12 @@ def create(
177182
error_code=prompt_feedback.block_reason,
178183
error_message=prompt_feedback.block_reason_message,
179184
usage_metadata=usage_metadata,
185+
model_version=generate_content_response.model_version,
180186
)
181187
else:
182188
return LlmResponse(
183189
error_code='UNKNOWN_ERROR',
184190
error_message='Unknown error.',
185191
usage_metadata=usage_metadata,
192+
model_version=generate_content_response.model_version,
186193
)

tests/unittests/models/test_litellm.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090

9191
STREAMING_MODEL_RESPONSE = [
9292
ModelResponse(
93+
model="test_model",
9394
choices=[
9495
StreamingChoices(
9596
finish_reason=None,
@@ -98,9 +99,10 @@
9899
content="zero, ",
99100
),
100101
)
101-
]
102+
],
102103
),
103104
ModelResponse(
105+
model="test_model",
104106
choices=[
105107
StreamingChoices(
106108
finish_reason=None,
@@ -109,9 +111,10 @@
109111
content="one, ",
110112
),
111113
)
112-
]
114+
],
113115
),
114116
ModelResponse(
117+
model="test_model",
115118
choices=[
116119
StreamingChoices(
117120
finish_reason=None,
@@ -120,9 +123,10 @@
120123
content="two:",
121124
),
122125
)
123-
]
126+
],
124127
),
125128
ModelResponse(
129+
model="test_model",
126130
choices=[
127131
StreamingChoices(
128132
finish_reason=None,
@@ -141,9 +145,10 @@
141145
],
142146
),
143147
)
144-
]
148+
],
145149
),
146150
ModelResponse(
151+
model="test_model",
147152
choices=[
148153
StreamingChoices(
149154
finish_reason=None,
@@ -162,14 +167,15 @@
162167
],
163168
),
164169
)
165-
]
170+
],
166171
),
167172
ModelResponse(
173+
model="test_model",
168174
choices=[
169175
StreamingChoices(
170176
finish_reason="tool_use",
171177
)
172-
]
178+
],
173179
),
174180
]
175181

@@ -342,6 +348,7 @@
342348
@pytest.fixture
343349
def mock_response():
344350
return ModelResponse(
351+
model="test_model",
345352
choices=[
346353
Choices(
347354
message=ChatCompletionAssistantMessage(
@@ -359,7 +366,7 @@ def mock_response():
359366
],
360367
)
361368
)
362-
]
369+
],
363370
)
364371

365372

@@ -529,6 +536,7 @@ async def test_generate_content_async(mock_acompletion, lite_llm_instance):
529536
"test_arg": "test_value"
530537
}
531538
assert response.content.parts[1].function_call.id == "test_tool_call_id"
539+
assert response.model_version == "test_model"
532540

533541
mock_acompletion.assert_called_once()
534542

@@ -1262,6 +1270,19 @@ def test_message_to_generate_content_response_tool_call():
12621270
assert response.content.parts[0].function_call.id == "test_tool_call_id"
12631271

12641272

1273+
def test_message_to_generate_content_response_with_model():
1274+
message = ChatCompletionAssistantMessage(
1275+
role="assistant",
1276+
content="Test response",
1277+
)
1278+
response = _message_to_generate_content_response(
1279+
message, model_version="gemini-2.5-pro"
1280+
)
1281+
assert response.content.role == "model"
1282+
assert response.content.parts[0].text == "Test response"
1283+
assert response.model_version == "gemini-2.5-pro"
1284+
1285+
12651286
def test_get_content_text():
12661287
parts = [types.Part.from_text(text="Test text")]
12671288
content = _get_content(parts)
@@ -1556,16 +1577,20 @@ async def test_generate_content_async_stream(
15561577
assert len(responses) == 4
15571578
assert responses[0].content.role == "model"
15581579
assert responses[0].content.parts[0].text == "zero, "
1580+
assert responses[0].model_version == "test_model"
15591581
assert responses[1].content.role == "model"
15601582
assert responses[1].content.parts[0].text == "one, "
1583+
assert responses[1].model_version == "test_model"
15611584
assert responses[2].content.role == "model"
15621585
assert responses[2].content.parts[0].text == "two:"
1586+
assert responses[2].model_version == "test_model"
15631587
assert responses[3].content.role == "model"
15641588
assert responses[3].content.parts[-1].function_call.name == "test_function"
15651589
assert responses[3].content.parts[-1].function_call.args == {
15661590
"test_arg": "test_value"
15671591
}
15681592
assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id"
1593+
assert responses[3].model_version == "test_model"
15691594
mock_completion.assert_called_once()
15701595

15711596
_, kwargs = mock_completion.call_args

tests/unittests/models/test_llm_response.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,18 @@ def test_llm_response_create_empty_content_with_stop_reason():
334334

335335
assert response.error_code is None
336336
assert response.content is not None
337+
338+
339+
def test_llm_response_create_includes_model_version():
340+
"""Test LlmResponse.create() includes model version."""
341+
generate_content_response = types.GenerateContentResponse(
342+
model_version='gemini-2.0-flash',
343+
candidates=[
344+
types.Candidate(
345+
content=types.Content(parts=[types.Part(text='Response text')]),
346+
finish_reason=types.FinishReason.STOP,
347+
)
348+
],
349+
)
350+
response = LlmResponse.create(generate_content_response)
351+
assert response.model_version == 'gemini-2.0-flash'

0 commit comments

Comments
 (0)