Skip to content

Commit 88032cf

Browse files
wukathcopybara-github
authored andcommitted
feat: Support MCP prompts
Add support for MCP prompts via the McpInstructionProvider class, which can be specified as an agent's instruction. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 828166051
1 parent 11571c3 commit 88032cf

File tree

5 files changed

+351
-9
lines changed

5 files changed

+351
-9
lines changed

contributing/samples/mcp_sse_agent/agent.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,27 @@
1616
import os
1717

1818
from google.adk.agents.llm_agent import LlmAgent
19+
from google.adk.agents.mcp_instruction_provider import McpInstructionProvider
1920
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
2021
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
2122

2223
_allowed_path = os.path.dirname(os.path.abspath(__file__))
2324

25+
connection_params = SseConnectionParams(
26+
url='http://localhost:3000/sse',
27+
headers={'Accept': 'text/event-stream'},
28+
)
29+
2430
root_agent = LlmAgent(
2531
model='gemini-2.0-flash',
2632
name='enterprise_assistant',
27-
instruction=f"""\
28-
Help user accessing their file systems.
29-
30-
Allowed directory: {_allowed_path}
31-
""",
33+
instruction=McpInstructionProvider(
34+
connection_params=connection_params,
35+
prompt_name='file_system_prompt',
36+
),
3237
tools=[
3338
MCPToolset(
34-
connection_params=SseConnectionParams(
35-
url='http://localhost:3000/sse',
36-
headers={'Accept': 'text/event-stream'},
37-
),
39+
connection_params=connection_params,
3840
# don't want agent to do write operation
3941
# you can also do below
4042
# tool_filter=lambda tool, ctx=None: tool.name

contributing/samples/mcp_sse_agent/filesystem_server.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def get_cwd() -> str:
4545
return str(Path.cwd())
4646

4747

48+
# Add a prompt for accessing file systems
49+
@mcp.prompt(name="file_system_prompt")
50+
def file_system_prompt() -> str:
51+
return f"""\
52+
Help the user access their file systems."""
53+
54+
4855
# Graceful shutdown handler
4956
async def shutdown(signal, loop):
5057
"""Cleanup tasks tied to the service's shutdown."""

src/google/adk/agents/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
16+
import sys
17+
1518
from .base_agent import BaseAgent
1619
from .invocation_context import InvocationContext
1720
from .live_request_queue import LiveRequest
@@ -35,3 +38,16 @@
3538
'LiveRequestQueue',
3639
'RunConfig',
3740
]
41+
42+
if sys.version_info < (3, 10):
43+
logger = logging.getLogger('google_adk.' + __name__)
44+
logger.warning(
45+
'MCP requires Python 3.10 or above. Please upgrade your Python'
46+
' version in order to use it.'
47+
)
48+
else:
49+
from .mcp_instruction_provider import McpInstructionProvider
50+
51+
__all__.extend([
52+
'McpInstructionProvider',
53+
])
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Provides instructions to an agent by fetching prompts from an MCP server."""
16+
17+
from __future__ import annotations
18+
19+
import logging
20+
import sys
21+
from typing import Any
22+
from typing import Dict
23+
from typing import TextIO
24+
25+
from .llm_agent import InstructionProvider
26+
from .readonly_context import ReadonlyContext
27+
28+
# Attempt to import MCP Session Manager from the MCP library, and hints user to
29+
# upgrade their Python version to 3.10 if it fails.
30+
try:
31+
from mcp import types
32+
33+
from ..tools.mcp_tool.mcp_session_manager import MCPSessionManager
34+
except ImportError as e:
35+
if sys.version_info < (3, 10):
36+
raise ImportError(
37+
"MCP Session Manager requires Python 3.10 or above. Please upgrade"
38+
" your Python version."
39+
) from e
40+
else:
41+
raise e
42+
43+
44+
class McpInstructionProvider(InstructionProvider):
45+
"""Fetches agent instructions from an MCP server."""
46+
47+
def __init__(
48+
self,
49+
connection_params: Any,
50+
prompt_name: str,
51+
errlog: TextIO = sys.stderr,
52+
):
53+
"""Initializes the McpInstructionProvider.
54+
55+
Args:
56+
connection_params: Parameters for connecting to the MCP server.
57+
prompt_name: The name of the MCP Prompt to fetch.
58+
errlog: TextIO stream for error logging.
59+
"""
60+
self._connection_params = connection_params
61+
self._errlog = errlog or logging.getLogger(__name__)
62+
self._mcp_session_manager = MCPSessionManager(
63+
connection_params=self._connection_params,
64+
errlog=self._errlog,
65+
)
66+
self.prompt_name = prompt_name
67+
68+
async def __call__(self, context: ReadonlyContext) -> str:
69+
"""Fetches the instruction from the MCP server.
70+
71+
Args:
72+
context: The read-only context of the agent.
73+
74+
Returns:
75+
The instruction string.
76+
"""
77+
session = await self._mcp_session_manager.create_session()
78+
# Fetch prompt definition to get the required argument names
79+
prompt_definitions = await session.list_prompts()
80+
prompt_definition = next(
81+
(p for p in prompt_definitions.prompts if p.name == self.prompt_name),
82+
None,
83+
)
84+
85+
# Fetch arguments from context state if the prompt requires them
86+
prompt_args: Dict[str, Any] = {}
87+
if prompt_definition and prompt_definition.arguments:
88+
arg_names = {arg.name for arg in prompt_definition.arguments}
89+
prompt_args = {
90+
k: v for k, v in (context.state or {}).items() if k in arg_names
91+
}
92+
93+
# Fetch the specific prompt by name with arguments from context state
94+
prompt_result: types.GetPromptResult = await session.get_prompt(
95+
self.prompt_name, arguments=prompt_args
96+
)
97+
98+
if prompt_result and prompt_result.messages:
99+
# Concatenate content of all messages to form the instruction.
100+
instruction = "".join(
101+
message.content.text
102+
for message in prompt_result.messages
103+
if message.content.type == "text"
104+
)
105+
return instruction
106+
else:
107+
raise ValueError(f"Failed to load MCP prompt '{self.prompt_name}'.")
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for McpInstructionProvider."""
16+
import sys
17+
from unittest.mock import AsyncMock
18+
from unittest.mock import MagicMock
19+
from unittest.mock import patch
20+
21+
from google.adk.agents.readonly_context import ReadonlyContext
22+
import pytest
23+
24+
# Skip all tests in this module if Python version is less than 3.10
25+
pytestmark = pytest.mark.skipif(
26+
sys.version_info < (3, 10),
27+
reason="MCP instruction provider requires Python 3.10+",
28+
)
29+
30+
# Import dependencies with version checking
31+
try:
32+
from google.adk.agents.mcp_instruction_provider import McpInstructionProvider
33+
except ImportError as e:
34+
if sys.version_info < (3, 10):
35+
# Create dummy classes to prevent NameError during test collection
36+
# Tests will be skipped anyway due to pytestmark
37+
class DummyClass:
38+
pass
39+
40+
McpInstructionProvider = DummyClass
41+
else:
42+
raise e
43+
44+
45+
class TestMcpInstructionProvider:
46+
"""Unit tests for McpInstructionProvider."""
47+
48+
def setup_method(self):
49+
"""Sets up the test environment."""
50+
self.connection_params = {"host": "localhost", "port": 8000}
51+
self.prompt_name = "test_prompt"
52+
self.mock_mcp_session_manager_cls = patch(
53+
"google.adk.agents.mcp_instruction_provider.MCPSessionManager"
54+
).start()
55+
self.mock_mcp_session_manager = (
56+
self.mock_mcp_session_manager_cls.return_value
57+
)
58+
self.mock_session = MagicMock()
59+
self.mock_session.list_prompts = AsyncMock()
60+
self.mock_session.get_prompt = AsyncMock()
61+
self.mock_mcp_session_manager.create_session = AsyncMock(
62+
return_value=self.mock_session
63+
)
64+
self.provider = McpInstructionProvider(
65+
self.connection_params, self.prompt_name
66+
)
67+
68+
@pytest.mark.asyncio
69+
async def test_call_success_no_args(self):
70+
"""Tests __call__ with a prompt that has no arguments."""
71+
mock_prompt = MagicMock()
72+
mock_prompt.name = self.prompt_name
73+
mock_prompt.arguments = None
74+
self.mock_session.list_prompts.return_value = MagicMock(
75+
prompts=[mock_prompt]
76+
)
77+
78+
mock_msg1 = MagicMock()
79+
mock_msg1.content.type = "text"
80+
mock_msg1.content.text = "instruction part 1. "
81+
mock_msg2 = MagicMock()
82+
mock_msg2.content.type = "text"
83+
mock_msg2.content.text = "instruction part 2"
84+
self.mock_session.get_prompt.return_value = MagicMock(
85+
messages=[mock_msg1, mock_msg2]
86+
)
87+
88+
mock_invocation_context = MagicMock()
89+
mock_invocation_context.session.state = {}
90+
context = ReadonlyContext(mock_invocation_context)
91+
92+
# Call
93+
instruction = await self.provider(context)
94+
95+
# Assert
96+
assert instruction == "instruction part 1. instruction part 2"
97+
self.mock_session.get_prompt.assert_called_once_with(
98+
self.prompt_name, arguments={}
99+
)
100+
101+
@pytest.mark.asyncio
102+
async def test_call_success_with_args(self):
103+
"""Tests __call__ with a prompt that has arguments."""
104+
mock_arg1 = MagicMock()
105+
mock_arg1.name = "arg1"
106+
mock_prompt = MagicMock()
107+
mock_prompt.name = self.prompt_name
108+
mock_prompt.arguments = [mock_arg1]
109+
self.mock_session.list_prompts.return_value = MagicMock(
110+
prompts=[mock_prompt]
111+
)
112+
113+
mock_msg = MagicMock()
114+
mock_msg.content.type = "text"
115+
mock_msg.content.text = "instruction with arg1"
116+
self.mock_session.get_prompt.return_value = MagicMock(messages=[mock_msg])
117+
118+
mock_invocation_context = MagicMock()
119+
mock_invocation_context.session.state = {"arg1": "value1", "arg2": "value2"}
120+
context = ReadonlyContext(mock_invocation_context)
121+
122+
instruction = await self.provider(context)
123+
124+
assert instruction == "instruction with arg1"
125+
self.mock_session.get_prompt.assert_called_once_with(
126+
self.prompt_name, arguments={"arg1": "value1"}
127+
)
128+
129+
@pytest.mark.asyncio
130+
async def test_call_prompt_not_found_in_list_prompts(self):
131+
"""Tests __call__ when list_prompts doesn't return the prompt."""
132+
self.mock_session.list_prompts.return_value = MagicMock(prompts=[])
133+
134+
mock_msg = MagicMock()
135+
mock_msg.content.type = "text"
136+
mock_msg.content.text = "instruction"
137+
self.mock_session.get_prompt.return_value = MagicMock(messages=[mock_msg])
138+
139+
mock_invocation_context = MagicMock()
140+
mock_invocation_context.session.state = {"arg1": "value1"}
141+
context = ReadonlyContext(mock_invocation_context)
142+
143+
instruction = await self.provider(context)
144+
145+
assert instruction == "instruction"
146+
self.mock_session.get_prompt.assert_called_once_with(
147+
self.prompt_name, arguments={}
148+
)
149+
150+
@pytest.mark.asyncio
151+
async def test_call_get_prompt_returns_no_messages(self):
152+
"""Tests __call__ when get_prompt returns no messages."""
153+
# Setup mocks
154+
self.mock_session.list_prompts.return_value = MagicMock(prompts=[])
155+
self.mock_session.get_prompt.return_value = MagicMock(messages=[])
156+
157+
mock_invocation_context = MagicMock()
158+
mock_invocation_context.session.state = {}
159+
context = ReadonlyContext(mock_invocation_context)
160+
161+
# Call and assert
162+
with pytest.raises(
163+
ValueError, match="Failed to load MCP prompt 'test_prompt'."
164+
):
165+
await self.provider(context)
166+
167+
# Assert
168+
self.mock_session.get_prompt.assert_called_once_with(
169+
self.prompt_name, arguments={}
170+
)
171+
172+
@pytest.mark.asyncio
173+
async def test_call_ignore_non_text_messages(self):
174+
"""Tests __call__ ignores non-text messages."""
175+
# Setup mocks
176+
mock_prompt = MagicMock()
177+
mock_prompt.name = self.prompt_name
178+
mock_prompt.arguments = None
179+
self.mock_session.list_prompts.return_value = MagicMock(
180+
prompts=[mock_prompt]
181+
)
182+
183+
mock_msg1 = MagicMock()
184+
mock_msg1.content.type = "text"
185+
mock_msg1.content.text = "instruction part 1. "
186+
187+
mock_msg2 = MagicMock()
188+
mock_msg2.content.type = "image"
189+
mock_msg2.content.text = "ignored"
190+
191+
mock_msg3 = MagicMock()
192+
mock_msg3.content.type = "text"
193+
mock_msg3.content.text = "instruction part 2"
194+
195+
self.mock_session.get_prompt.return_value = MagicMock(
196+
messages=[mock_msg1, mock_msg2, mock_msg3]
197+
)
198+
199+
mock_invocation_context = MagicMock()
200+
mock_invocation_context.session.state = {}
201+
context = ReadonlyContext(mock_invocation_context)
202+
203+
# Call
204+
instruction = await self.provider(context)
205+
206+
# Assert
207+
assert instruction == "instruction part 1. instruction part 2"
208+
self.mock_session.get_prompt.assert_called_once_with(
209+
self.prompt_name, arguments={}
210+
)

0 commit comments

Comments
 (0)