Skip to content

Commit 67284fc

Browse files
calvingilescopybara-github
authored andcommitted
feat: History Management Sample
Merge #891 This creates a sample relating to discussion #826 - how to manage context windows. COPYBARA_INTEGRATE_REVIEW=#891 from calvingiles:history-management-sample 2827817 PiperOrigin-RevId: 784920438
1 parent 0ec69d0 commit 67284fc

File tree

3 files changed

+205
-0
lines changed

3 files changed

+205
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
17+
from google.adk import Agent
18+
from google.adk.agents.callback_context import CallbackContext
19+
from google.adk.models import LlmRequest
20+
from google.adk.tools.tool_context import ToolContext
21+
22+
23+
def roll_die(sides: int, tool_context: ToolContext) -> int:
24+
"""Roll a die and return the rolled result.
25+
26+
Args:
27+
sides: The integer number of sides the die has.
28+
29+
Returns:
30+
An integer of the result of rolling the die.
31+
"""
32+
result = random.randint(1, sides)
33+
if not 'rolls' in tool_context.state:
34+
tool_context.state['rolls'] = []
35+
36+
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
37+
return result
38+
39+
40+
async def check_prime(nums: list[int]) -> str:
41+
"""Check if a given list of numbers are prime.
42+
43+
Args:
44+
nums: The list of numbers to check.
45+
46+
Returns:
47+
A str indicating which number is prime.
48+
"""
49+
primes = set()
50+
for number in nums:
51+
number = int(number)
52+
if number <= 1:
53+
continue
54+
is_prime = True
55+
for i in range(2, int(number**0.5) + 1):
56+
if number % i == 0:
57+
is_prime = False
58+
break
59+
if is_prime:
60+
primes.add(number)
61+
return (
62+
'No prime numbers found.'
63+
if not primes
64+
else f"{', '.join(str(num) for num in primes)} are prime numbers."
65+
)
66+
67+
68+
def create_slice_history_callback(n_recent_turns):
69+
async def before_model_callback(callback_context: CallbackContext, llm_request: LlmRequest):
70+
if n_recent_turns < 1:
71+
return
72+
73+
user_indexes = [i for i, content in enumerate(llm_request.contents) if content.role == "user"]
74+
75+
if n_recent_turns > len(user_indexes):
76+
return
77+
78+
suffix_idx = user_indexes[-n_recent_turns]
79+
llm_request.contents = llm_request.contents[suffix_idx:]
80+
81+
return before_model_callback
82+
83+
84+
root_agent = Agent(
85+
model='gemini-2.0-flash',
86+
name='short_history_agent',
87+
description=(
88+
'an agent that maintains only the last turn in its context window.'
89+
' numbers.'
90+
),
91+
instruction="""
92+
You roll dice and answer questions about the outcome of the dice rolls.
93+
You can roll dice of different sizes.
94+
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
95+
It is ok to discuss previous dice roles, and comment on the dice rolls.
96+
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
97+
You should never roll a die on your own.
98+
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
99+
You should not check prime numbers before calling the tool.
100+
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
101+
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
102+
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
103+
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
104+
3. When you respond, you must include the roll_die result from step 1.
105+
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
106+
You should not rely on the previous history on prime results.
107+
""",
108+
tools=[roll_die, check_prime],
109+
before_model_callback=create_slice_history_callback(n_recent_turns=2),
110+
)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import time
17+
import warnings
18+
19+
import agent
20+
from dotenv import load_dotenv
21+
from google.adk import Runner
22+
from google.adk.artifacts import InMemoryArtifactService
23+
from google.adk.cli.utils import logs
24+
from google.adk.sessions import InMemorySessionService
25+
from google.adk.sessions import Session
26+
from google.genai import types
27+
28+
load_dotenv(override=True)
29+
warnings.filterwarnings('ignore', category=UserWarning)
30+
logs.log_to_tmp_folder()
31+
32+
33+
async def main():
34+
app_name = 'my_app'
35+
user_id_1 = 'user1'
36+
session_service = InMemorySessionService()
37+
artifact_service = InMemoryArtifactService()
38+
runner = Runner(
39+
app_name=app_name,
40+
agent=agent.root_agent,
41+
artifact_service=artifact_service,
42+
session_service=session_service,
43+
)
44+
session_11 = await session_service.create_session(
45+
app_name=app_name, user_id=user_id_1
46+
)
47+
48+
async def run_prompt(session: Session, new_message: str):
49+
content = types.Content(
50+
role='user', parts=[types.Part.from_text(text=new_message)]
51+
)
52+
print('** User says:', content.model_dump(exclude_none=True))
53+
async for event in runner.run_async(
54+
user_id=user_id_1,
55+
session_id=session.id,
56+
new_message=content,
57+
):
58+
if event.content.parts and event.content.parts[0].text:
59+
print(f'** {event.author}: {event.content.parts[0].text}')
60+
61+
start_time = time.time()
62+
print('Start time:', start_time)
63+
print('------------------------------------')
64+
await run_prompt(session_11, 'Hi')
65+
await run_prompt(session_11, 'Roll a die with 100 sides')
66+
await run_prompt(session_11, 'Roll a die again with 100 sides.')
67+
await run_prompt(session_11, 'What numbers did I got?')
68+
print(
69+
await artifact_service.list_artifact_keys(
70+
app_name=app_name, user_id=user_id_1, session_id=session_11.id
71+
)
72+
)
73+
end_time = time.time()
74+
print('------------------------------------')
75+
print('End time:', end_time)
76+
print('Total time:', end_time - start_time)
77+
78+
79+
if __name__ == '__main__':
80+
asyncio.run(main())

0 commit comments

Comments
 (0)