Skip to content

Commit 16ba91c

Browse files
pwwpchecopybara-github
authored andcommitted
feat: Implement PluginService for registering and executing plugins
PluginService takes the registration of plugins, and provide the wrapper utilities to execute all plugins. PiperOrigin-RevId: 781745769
1 parent 4dce9ef commit 16ba91c

File tree

3 files changed

+754
-0
lines changed

3 files changed

+754
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import logging
18+
from typing import Any
19+
from typing import List
20+
from typing import Literal
21+
from typing import Optional
22+
from typing import TYPE_CHECKING
23+
24+
from google.genai import types
25+
26+
from .base_plugin import BasePlugin
27+
28+
if TYPE_CHECKING:
29+
from ..agents.base_agent import BaseAgent
30+
from ..agents.callback_context import CallbackContext
31+
from ..agents.invocation_context import InvocationContext
32+
from ..events.event import Event
33+
from ..models.llm_request import LlmRequest
34+
from ..models.llm_response import LlmResponse
35+
from ..tools.base_tool import BaseTool
36+
from ..tools.tool_context import ToolContext
37+
38+
# A type alias for the names of the available plugin callbacks.
39+
# This helps with static analysis and prevents typos when calling run_callbacks.
40+
PluginCallbackName = Literal[
41+
"on_user_message_callback",
42+
"before_run_callback",
43+
"after_run_callback",
44+
"on_event_callback",
45+
"before_agent_callback",
46+
"after_agent_callback",
47+
"before_tool_callback",
48+
"after_tool_callback",
49+
"before_model_callback",
50+
"after_model_callback",
51+
]
52+
53+
logger = logging.getLogger("google_adk." + __name__)
54+
55+
56+
class PluginManager:
57+
"""Manages the registration and execution of plugins.
58+
59+
The PluginManager is an internal class that orchestrates the invocation of
60+
plugin callbacks at key points in the SDK's execution lifecycle. It maintains
61+
a list of registered plugins and ensures they are called in the order they
62+
were registered.
63+
64+
The core execution logic implements an "early exit" strategy: if any plugin
65+
callback returns a non-`None` value, the execution of subsequent plugins for
66+
that specific event is halted, and the returned value is propagated up the
67+
call stack. This allows plugins to short-circuit operations like agent runs,
68+
tool calls, or model requests.
69+
"""
70+
71+
def __init__(self, plugins: Optional[List[BasePlugin]] = None):
72+
"""Initializes the plugin service.
73+
74+
Args:
75+
plugins: An optional list of plugins to register upon initialization.
76+
"""
77+
self.plugins: List[BasePlugin] = []
78+
if plugins:
79+
for plugin in plugins:
80+
self.register_plugin(plugin)
81+
82+
def register_plugin(self, plugin: BasePlugin) -> None:
83+
"""Registers a new plugin.
84+
85+
Args:
86+
plugin: The plugin instance to register.
87+
88+
Raises:
89+
ValueError: If a plugin with the same name is already registered.
90+
"""
91+
if any(p.name == plugin.name for p in self.plugins):
92+
raise ValueError(f"Plugin with name '{plugin.name}' already registered.")
93+
self.plugins.append(plugin)
94+
logger.info("Plugin '%s' registered.", plugin.name)
95+
96+
def get_plugin(self, plugin_name: str) -> Optional[BasePlugin]:
97+
"""Retrieves a registered plugin by its name.
98+
99+
Args:
100+
plugin_name: The name of the plugin to retrieve.
101+
102+
Returns:
103+
The plugin instance if found, otherwise `None`.
104+
"""
105+
return next((p for p in self.plugins if p.name == plugin_name), None)
106+
107+
async def run_on_user_message_callback(
108+
self,
109+
*,
110+
user_message: types.Content,
111+
invocation_context: InvocationContext,
112+
) -> Optional[types.Content]:
113+
"""Runs the `on_user_message_callback` for all plugins."""
114+
return await self._run_callbacks(
115+
"on_user_message_callback",
116+
user_message=user_message,
117+
invocation_context=invocation_context,
118+
)
119+
120+
async def run_before_run_callback(
121+
self, *, invocation_context: InvocationContext
122+
) -> Optional[types.Content]:
123+
"""Runs the `before_run_callback` for all plugins."""
124+
return await self._run_callbacks(
125+
"before_run_callback", invocation_context=invocation_context
126+
)
127+
128+
async def run_after_run_callback(
129+
self, *, invocation_context: InvocationContext
130+
) -> Optional[None]:
131+
"""Runs the `after_run_callback` for all plugins."""
132+
return await self._run_callbacks(
133+
"after_run_callback", invocation_context=invocation_context
134+
)
135+
136+
async def run_on_event_callback(
137+
self, *, invocation_context: InvocationContext, event: Event
138+
) -> Optional[Event]:
139+
"""Runs the `on_event_callback` for all plugins."""
140+
return await self._run_callbacks(
141+
"on_event_callback",
142+
invocation_context=invocation_context,
143+
event=event,
144+
)
145+
146+
async def run_before_agent_callback(
147+
self, *, agent: BaseAgent, callback_context: CallbackContext
148+
) -> Optional[types.Content]:
149+
"""Runs the `before_agent_callback` for all plugins."""
150+
return await self._run_callbacks(
151+
"before_agent_callback",
152+
agent=agent,
153+
callback_context=callback_context,
154+
)
155+
156+
async def run_after_agent_callback(
157+
self, *, agent: BaseAgent, callback_context: CallbackContext
158+
) -> Optional[types.Content]:
159+
"""Runs the `after_agent_callback` for all plugins."""
160+
return await self._run_callbacks(
161+
"after_agent_callback",
162+
agent=agent,
163+
callback_context=callback_context,
164+
)
165+
166+
async def run_before_tool_callback(
167+
self,
168+
*,
169+
tool: BaseTool,
170+
tool_args: dict[str, Any],
171+
tool_context: ToolContext,
172+
) -> Optional[dict]:
173+
"""Runs the `before_tool_callback` for all plugins."""
174+
return await self._run_callbacks(
175+
"before_tool_callback",
176+
tool=tool,
177+
tool_args=tool_args,
178+
tool_context=tool_context,
179+
)
180+
181+
async def run_after_tool_callback(
182+
self,
183+
*,
184+
tool: BaseTool,
185+
tool_args: dict[str, Any],
186+
tool_context: ToolContext,
187+
result: dict,
188+
) -> Optional[dict]:
189+
"""Runs the `after_tool_callback` for all plugins."""
190+
return await self._run_callbacks(
191+
"after_tool_callback",
192+
tool=tool,
193+
tool_args=tool_args,
194+
tool_context=tool_context,
195+
result=result,
196+
)
197+
198+
async def run_before_model_callback(
199+
self, *, callback_context: CallbackContext, llm_request: LlmRequest
200+
) -> Optional[LlmResponse]:
201+
"""Runs the `before_model_callback` for all plugins."""
202+
return await self._run_callbacks(
203+
"before_model_callback",
204+
callback_context=callback_context,
205+
llm_request=llm_request,
206+
)
207+
208+
async def run_after_model_callback(
209+
self, *, callback_context: CallbackContext, llm_response: LlmResponse
210+
) -> Optional[LlmResponse]:
211+
"""Runs the `after_model_callback` for all plugins."""
212+
return await self._run_callbacks(
213+
"after_model_callback",
214+
callback_context=callback_context,
215+
llm_response=llm_response,
216+
)
217+
218+
async def _run_callbacks(
219+
self, callback_name: PluginCallbackName, **kwargs: Any
220+
) -> Optional[Any]:
221+
"""Executes a specific callback for all registered plugins.
222+
223+
This private method iterates through the plugins and calls the specified
224+
callback method on each one, passing the provided keyword arguments.
225+
226+
The execution stops as soon as a plugin's callback returns a non-`None`
227+
value. This "early exit" value is then returned by this method. If all
228+
plugins are executed and all return `None`, this method also returns `None`.
229+
230+
Args:
231+
callback_name: The name of the callback method to execute.
232+
**kwargs: Keyword arguments to be passed to the callback method.
233+
234+
Returns:
235+
The first non-`None` value returned by a plugin callback, or `None` if
236+
all callbacks return `None`.
237+
238+
Raises:
239+
RuntimeError: If a plugin encounters an unhandled exception during
240+
execution. The original exception is chained.
241+
"""
242+
for plugin in self.plugins:
243+
# Each plugin might not implement all callbacks. The base class provides
244+
# default `pass` implementations, so `getattr` will always succeed.
245+
callback_method = getattr(plugin, callback_name)
246+
try:
247+
result = await callback_method(**kwargs)
248+
if result is not None:
249+
# Early exit: A plugin has returned a value. We stop
250+
# processing further plugins and return this value immediately.
251+
logger.debug(
252+
"Plugin '%s' returned a value for callback '%s', exiting early.",
253+
plugin.name,
254+
callback_name,
255+
)
256+
return result
257+
except Exception as e:
258+
error_message = (
259+
f"Error in plugin '{plugin.name}' during '{callback_name}'"
260+
f" callback: {e}"
261+
)
262+
logger.error(error_message, exc_info=True)
263+
raise RuntimeError(error_message) from e
264+
265+
return None

0 commit comments

Comments
 (0)