Skip to content

Commit 3643b4a

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: Allow toolset to process llm_request before tools returned by it
PiperOrigin-RevId: 785480813
1 parent cec400a commit 3643b4a

File tree

4 files changed

+299
-5
lines changed

4 files changed

+299
-5
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ...telemetry import trace_call_llm
4343
from ...telemetry import trace_send_data
4444
from ...telemetry import tracer
45+
from ...tools.base_toolset import BaseToolset
4546
from ...tools.tool_context import ToolContext
4647

4748
if TYPE_CHECKING:
@@ -341,13 +342,25 @@ async def _preprocess_async(
341342
yield event
342343

343344
# Run processors for tools.
344-
for tool in await agent.canonical_tools(
345-
ReadonlyContext(invocation_context)
346-
):
345+
for tool_union in agent.tools:
347346
tool_context = ToolContext(invocation_context)
348-
await tool.process_llm_request(
349-
tool_context=tool_context, llm_request=llm_request
347+
348+
# If it's a toolset, process it first
349+
if isinstance(tool_union, BaseToolset):
350+
await tool_union.process_llm_request(
351+
tool_context=tool_context, llm_request=llm_request
352+
)
353+
354+
from ...agents.llm_agent import _convert_tool_union_to_tools
355+
356+
# Then process all tools from this tool union
357+
tools = await _convert_tool_union_to_tools(
358+
tool_union, ReadonlyContext(invocation_context)
350359
)
360+
for tool in tools:
361+
await tool.process_llm_request(
362+
tool_context=tool_context, llm_request=llm_request
363+
)
351364

352365
async def _postprocess_async(
353366
self,

src/google/adk/tools/base_toolset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,16 @@
2020
from typing import Optional
2121
from typing import Protocol
2222
from typing import runtime_checkable
23+
from typing import TYPE_CHECKING
2324
from typing import Union
2425

2526
from ..agents.readonly_context import ReadonlyContext
2627
from .base_tool import BaseTool
2728

29+
if TYPE_CHECKING:
30+
from ..models.llm_request import LlmRequest
31+
from .tool_context import ToolContext
32+
2833

2934
@runtime_checkable
3035
class ToolPredicate(Protocol):
@@ -96,3 +101,20 @@ def _is_tool_selected(
96101
return tool.name in self.tool_filter
97102

98103
return False
104+
105+
async def process_llm_request(
106+
self, *, tool_context: ToolContext, llm_request: LlmRequest
107+
) -> None:
108+
"""Processes the outgoing LLM request for this toolset. This method will be
109+
called before each tool processes the llm request.
110+
111+
Use cases:
112+
- Instead of let each tool process the llm request, we can let the toolset
113+
process the llm request. e.g. ComputerUseToolset can add computer use
114+
tool to the llm request.
115+
116+
Args:
117+
tool_context: The context of the tool.
118+
llm_request: The outgoing LLM request, mutable this method.
119+
"""
120+
pass
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for BaseLlmFlow toolset integration."""
16+
17+
from unittest.mock import AsyncMock
18+
19+
from google.adk.agents import Agent
20+
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
21+
from google.adk.models.llm_request import LlmRequest
22+
from google.adk.models.llm_response import LlmResponse
23+
from google.adk.tools.base_toolset import BaseToolset
24+
from google.genai import types
25+
import pytest
26+
27+
from ... import testing_utils
28+
29+
30+
class BaseLlmFlowForTesting(BaseLlmFlow):
31+
"""Test implementation of BaseLlmFlow for testing purposes."""
32+
33+
pass
34+
35+
36+
@pytest.mark.asyncio
37+
async def test_preprocess_calls_toolset_process_llm_request():
38+
"""Test that _preprocess_async calls process_llm_request on toolsets."""
39+
40+
# Create a mock toolset that tracks if process_llm_request was called
41+
class _MockToolset(BaseToolset):
42+
43+
def __init__(self):
44+
super().__init__()
45+
self.process_llm_request_called = False
46+
self.process_llm_request = AsyncMock(side_effect=self._track_call)
47+
48+
async def _track_call(self, **kwargs):
49+
self.process_llm_request_called = True
50+
51+
async def get_tools(self, readonly_context=None):
52+
return []
53+
54+
async def close(self):
55+
pass
56+
57+
mock_toolset = _MockToolset()
58+
59+
# Create a mock model that returns a simple response
60+
mock_response = LlmResponse(
61+
content=types.Content(
62+
role='model', parts=[types.Part.from_text(text='Test response')]
63+
),
64+
partial=False,
65+
)
66+
67+
mock_model = testing_utils.MockModel.create(responses=[mock_response])
68+
69+
# Create agent with the mock toolset
70+
agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset])
71+
invocation_context = await testing_utils.create_invocation_context(
72+
agent=agent, user_content='test message'
73+
)
74+
75+
flow = BaseLlmFlowForTesting()
76+
77+
# Call _preprocess_async
78+
llm_request = LlmRequest()
79+
events = []
80+
async for event in flow._preprocess_async(invocation_context, llm_request):
81+
events.append(event)
82+
83+
# Verify that process_llm_request was called on the toolset
84+
assert mock_toolset.process_llm_request_called
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_preprocess_handles_mixed_tools_and_toolsets():
89+
"""Test that _preprocess_async properly handles both tools and toolsets."""
90+
from google.adk.tools.base_tool import BaseTool
91+
from google.adk.tools.function_tool import FunctionTool
92+
93+
# Create a mock tool
94+
class _MockTool(BaseTool):
95+
96+
def __init__(self):
97+
super().__init__(name='mock_tool', description='Mock tool')
98+
self.process_llm_request_called = False
99+
self.process_llm_request = AsyncMock(side_effect=self._track_call)
100+
101+
async def _track_call(self, **kwargs):
102+
self.process_llm_request_called = True
103+
104+
async def call(self, **kwargs):
105+
return 'mock result'
106+
107+
# Create a mock toolset
108+
class _MockToolset(BaseToolset):
109+
110+
def __init__(self):
111+
super().__init__()
112+
self.process_llm_request_called = False
113+
self.process_llm_request = AsyncMock(side_effect=self._track_call)
114+
115+
async def _track_call(self, **kwargs):
116+
self.process_llm_request_called = True
117+
118+
async def get_tools(self, readonly_context=None):
119+
return []
120+
121+
async def close(self):
122+
pass
123+
124+
def _test_function():
125+
"""Test function tool."""
126+
return 'function result'
127+
128+
mock_tool = _MockTool()
129+
mock_toolset = _MockToolset()
130+
131+
# Create agent with mixed tools and toolsets
132+
agent = Agent(
133+
name='test_agent', tools=[mock_tool, _test_function, mock_toolset]
134+
)
135+
136+
invocation_context = await testing_utils.create_invocation_context(
137+
agent=agent, user_content='test message'
138+
)
139+
140+
flow = BaseLlmFlowForTesting()
141+
142+
# Call _preprocess_async
143+
llm_request = LlmRequest()
144+
events = []
145+
async for event in flow._preprocess_async(invocation_context, llm_request):
146+
events.append(event)
147+
148+
# Verify that process_llm_request was called on both tools and toolsets
149+
assert mock_tool.process_llm_request_called
150+
assert mock_toolset.process_llm_request_called
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for BaseToolset."""
16+
17+
from typing import Optional
18+
19+
from google.adk.agents.invocation_context import InvocationContext
20+
from google.adk.agents.readonly_context import ReadonlyContext
21+
from google.adk.agents.sequential_agent import SequentialAgent
22+
from google.adk.models.llm_request import LlmRequest
23+
from google.adk.sessions.in_memory_session_service import InMemorySessionService
24+
from google.adk.tools.base_tool import BaseTool
25+
from google.adk.tools.base_toolset import BaseToolset
26+
from google.adk.tools.tool_context import ToolContext
27+
import pytest
28+
29+
30+
class _TestingToolset(BaseToolset):
31+
"""A test implementation of BaseToolset."""
32+
33+
async def get_tools(
34+
self, readonly_context: Optional[ReadonlyContext] = None
35+
) -> list[BaseTool]:
36+
return []
37+
38+
async def close(self) -> None:
39+
pass
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_process_llm_request_default_implementation():
44+
"""Test that the default process_llm_request implementation does nothing."""
45+
toolset = _TestingToolset()
46+
47+
# Create test objects
48+
session_service = InMemorySessionService()
49+
session = await session_service.create_session(
50+
app_name='test_app', user_id='test_user'
51+
)
52+
agent = SequentialAgent(name='test_agent')
53+
invocation_context = InvocationContext(
54+
invocation_id='test_id',
55+
agent=agent,
56+
session=session,
57+
session_service=session_service,
58+
)
59+
tool_context = ToolContext(invocation_context)
60+
llm_request = LlmRequest()
61+
62+
# The default implementation should not modify the request
63+
original_request = LlmRequest.model_validate(llm_request.model_dump())
64+
65+
await toolset.process_llm_request(
66+
tool_context=tool_context, llm_request=llm_request
67+
)
68+
69+
# Verify the request was not modified
70+
assert llm_request.model_dump() == original_request.model_dump()
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_process_llm_request_can_be_overridden():
75+
"""Test that process_llm_request can be overridden by subclasses."""
76+
77+
class _CustomToolset(_TestingToolset):
78+
79+
async def process_llm_request(
80+
self, *, tool_context: ToolContext, llm_request: LlmRequest
81+
) -> None:
82+
# Add some custom processing
83+
if not llm_request.contents:
84+
llm_request.contents = []
85+
llm_request.contents.append('Custom processing applied')
86+
87+
toolset = _CustomToolset()
88+
89+
# Create test objects
90+
session_service = InMemorySessionService()
91+
session = await session_service.create_session(
92+
app_name='test_app', user_id='test_user'
93+
)
94+
agent = SequentialAgent(name='test_agent')
95+
invocation_context = InvocationContext(
96+
invocation_id='test_id',
97+
agent=agent,
98+
session=session,
99+
session_service=session_service,
100+
)
101+
tool_context = ToolContext(invocation_context)
102+
llm_request = LlmRequest()
103+
104+
await toolset.process_llm_request(
105+
tool_context=tool_context, llm_request=llm_request
106+
)
107+
108+
# Verify the custom processing was applied
109+
assert llm_request.contents == ['Custom processing applied']

0 commit comments

Comments
 (0)