Skip to content

Commit 162228d

Browse files
pwwpchecopybara-github
authored andcommitted
feat: Integrating Plugin with ADK
This change integrates the plugin system with ADK. PluginManager is attached to the invocation context similar to session/artifact/memory. It includes integrations with following ADK internal callbacks: * App callbacks: Integrated in the BaseRunner class, in run_async and run_live * On Message callbacks: Integrated in the BaseRunner class, triggers on run_async. * Agent callbacks: Integrated in the BaseAgent class. Leveraging the existing *callback functions * Model callbacks: Integrating in the base_llm_flow. * Tool callbacks: Integrated in functions.py, wrapped around the code for agent tool_callbacks Sample code to use plugins: ```python # Add plugins to Runner runner = Runner( app_name="my-app", agent=root_agent, artifact_service=artifact_service, session_service=session_service, memory_service=memory_service, plugins=[ MySamplePlugin(), LoggingPlugin(), ], ) ``` PiperOrigin-RevId: 781746586
1 parent 16ba91c commit 162228d

File tree

11 files changed

+759
-89
lines changed

11 files changed

+759
-89
lines changed

src/google/adk/agents/base_agent.py

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,18 @@ async def run_live(
227227
"""
228228
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
229229
ctx = self._create_invocation_context(parent_context)
230-
# TODO(hangfei): support before/after_agent_callback
230+
231+
if event := await self.__handle_before_agent_callback(ctx):
232+
yield event
233+
if ctx.end_invocation:
234+
return
231235

232236
async for event in self._run_live_impl(ctx):
233237
yield event
234238

239+
if event := await self.__handle_after_agent_callback(ctx):
240+
yield event
241+
235242
async def _run_async_impl(
236243
self, ctx: InvocationContext
237244
) -> AsyncGenerator[Event, None]:
@@ -335,82 +342,117 @@ async def __handle_before_agent_callback(
335342
) -> Optional[Event]:
336343
"""Runs the before_agent_callback if it exists.
337344
345+
Args:
346+
ctx: InvocationContext, the invocation context for this agent.
347+
338348
Returns:
339349
Optional[Event]: an event if callback provides content or changed state.
340350
"""
341-
ret_event = None
342-
343-
if not self.canonical_before_agent_callbacks:
344-
return ret_event
345-
346351
callback_context = CallbackContext(ctx)
347352

348-
for callback in self.canonical_before_agent_callbacks:
349-
before_agent_callback_content = callback(
350-
callback_context=callback_context
351-
)
352-
if inspect.isawaitable(before_agent_callback_content):
353-
before_agent_callback_content = await before_agent_callback_content
354-
if before_agent_callback_content:
355-
ret_event = Event(
356-
invocation_id=ctx.invocation_id,
357-
author=self.name,
358-
branch=ctx.branch,
359-
content=before_agent_callback_content,
360-
actions=callback_context._event_actions,
353+
# Run callbacks from the plugins.
354+
before_agent_callback_content = (
355+
await ctx.plugin_manager.run_before_agent_callback(
356+
agent=self, callback_context=callback_context
361357
)
362-
ctx.end_invocation = True
363-
return ret_event
358+
)
364359

365-
if callback_context.state.has_delta():
360+
# If no overrides are provided from the plugins, further run the canonical
361+
# callbacks.
362+
if (
363+
not before_agent_callback_content
364+
and self.canonical_before_agent_callbacks
365+
):
366+
for callback in self.canonical_before_agent_callbacks:
367+
before_agent_callback_content = callback(
368+
callback_context=callback_context
369+
)
370+
if inspect.isawaitable(before_agent_callback_content):
371+
before_agent_callback_content = await before_agent_callback_content
372+
if before_agent_callback_content:
373+
break
374+
375+
# Process the override content if exists, and further process the state
376+
# change if exists.
377+
if before_agent_callback_content:
366378
ret_event = Event(
379+
invocation_id=ctx.invocation_id,
380+
author=self.name,
381+
branch=ctx.branch,
382+
content=before_agent_callback_content,
383+
actions=callback_context._event_actions,
384+
)
385+
ctx.end_invocation = True
386+
return ret_event
387+
388+
if callback_context.state.has_delta():
389+
return Event(
367390
invocation_id=ctx.invocation_id,
368391
author=self.name,
369392
branch=ctx.branch,
370393
actions=callback_context._event_actions,
371394
)
372395

373-
return ret_event
396+
return None
374397

375398
async def __handle_after_agent_callback(
376399
self, invocation_context: InvocationContext
377400
) -> Optional[Event]:
378401
"""Runs the after_agent_callback if it exists.
379402
403+
Args:
404+
invocation_context: InvocationContext, the invocation context for this
405+
agent.
406+
380407
Returns:
381408
Optional[Event]: an event if callback provides content or changed state.
382409
"""
383-
ret_event = None
384-
385-
if not self.canonical_after_agent_callbacks:
386-
return ret_event
387410

388411
callback_context = CallbackContext(invocation_context)
389412

390-
for callback in self.canonical_after_agent_callbacks:
391-
after_agent_callback_content = callback(callback_context=callback_context)
392-
if inspect.isawaitable(after_agent_callback_content):
393-
after_agent_callback_content = await after_agent_callback_content
394-
if after_agent_callback_content:
395-
ret_event = Event(
396-
invocation_id=invocation_context.invocation_id,
397-
author=self.name,
398-
branch=invocation_context.branch,
399-
content=after_agent_callback_content,
400-
actions=callback_context._event_actions,
413+
# Run callbacks from the plugins.
414+
after_agent_callback_content = (
415+
await invocation_context.plugin_manager.run_after_agent_callback(
416+
agent=self, callback_context=callback_context
401417
)
402-
return ret_event
418+
)
403419

404-
if callback_context.state.has_delta():
420+
# If no overrides are provided from the plugins, further run the canonical
421+
# callbacks.
422+
if (
423+
not after_agent_callback_content
424+
and self.canonical_after_agent_callbacks
425+
):
426+
for callback in self.canonical_after_agent_callbacks:
427+
after_agent_callback_content = callback(
428+
callback_context=callback_context
429+
)
430+
if inspect.isawaitable(after_agent_callback_content):
431+
after_agent_callback_content = await after_agent_callback_content
432+
if after_agent_callback_content:
433+
break
434+
435+
# Process the override content if exists, and further process the state
436+
# change if exists.
437+
if after_agent_callback_content:
405438
ret_event = Event(
406439
invocation_id=invocation_context.invocation_id,
407440
author=self.name,
408441
branch=invocation_context.branch,
409442
content=after_agent_callback_content,
410443
actions=callback_context._event_actions,
411444
)
445+
return ret_event
412446

413-
return ret_event
447+
if callback_context.state.has_delta():
448+
return Event(
449+
invocation_id=invocation_context.invocation_id,
450+
author=self.name,
451+
branch=invocation_context.branch,
452+
content=after_agent_callback_content,
453+
actions=callback_context._event_actions,
454+
)
455+
return None
414456

415457
@override
416458
def model_post_init(self, __context: Any) -> None:

src/google/adk/agents/invocation_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..artifacts.base_artifact_service import BaseArtifactService
2525
from ..auth.credential_service.base_credential_service import BaseCredentialService
2626
from ..memory.base_memory_service import BaseMemoryService
27+
from ..plugins.plugin_manager import PluginManager
2728
from ..sessions.base_session_service import BaseSessionService
2829
from ..sessions.session import Session
2930
from .active_streaming_tool import ActiveStreamingTool
@@ -153,6 +154,9 @@ class InvocationContext(BaseModel):
153154
run_config: Optional[RunConfig] = None
154155
"""Configurations for live agents under this invocation."""
155156

157+
plugin_manager: PluginManager = PluginManager()
158+
"""The manager for keeping track of plugins in this invocation."""
159+
156160
_invocation_cost_manager: _InvocationCostManager = _InvocationCostManager()
157161
"""A container to keep track of different kinds of costs incurred as a part
158162
of this invocation.

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -565,21 +565,32 @@ async def _handle_before_model_callback(
565565
if not isinstance(agent, LlmAgent):
566566
return
567567

568-
if not agent.canonical_before_model_callbacks:
569-
return
570-
571568
callback_context = CallbackContext(
572569
invocation_context, event_actions=model_response_event.actions
573570
)
574571

572+
# First run callbacks from the plugins.
573+
callback_response = (
574+
await invocation_context.plugin_manager.run_before_model_callback(
575+
callback_context=callback_context,
576+
llm_request=llm_request,
577+
)
578+
)
579+
if callback_response:
580+
return callback_response
581+
582+
# If no overrides are provided from the plugins, further run the canonical
583+
# callbacks.
584+
if not agent.canonical_before_model_callbacks:
585+
return
575586
for callback in agent.canonical_before_model_callbacks:
576-
before_model_callback_content = callback(
587+
callback_response = callback(
577588
callback_context=callback_context, llm_request=llm_request
578589
)
579-
if inspect.isawaitable(before_model_callback_content):
580-
before_model_callback_content = await before_model_callback_content
581-
if before_model_callback_content:
582-
return before_model_callback_content
590+
if inspect.isawaitable(callback_response):
591+
callback_response = await callback_response
592+
if callback_response:
593+
return callback_response
583594

584595
async def _handle_after_model_callback(
585596
self,
@@ -593,21 +604,32 @@ async def _handle_after_model_callback(
593604
if not isinstance(agent, LlmAgent):
594605
return
595606

596-
if not agent.canonical_after_model_callbacks:
597-
return
598-
599607
callback_context = CallbackContext(
600608
invocation_context, event_actions=model_response_event.actions
601609
)
602610

611+
# First run callbacks from the plugins.
612+
callback_response = (
613+
await invocation_context.plugin_manager.run_after_model_callback(
614+
callback_context=CallbackContext(invocation_context),
615+
llm_response=llm_response,
616+
)
617+
)
618+
if callback_response:
619+
return callback_response
620+
621+
# If no overrides are provided from the plugins, further run the canonical
622+
# callbacks.
623+
if not agent.canonical_after_model_callbacks:
624+
return
603625
for callback in agent.canonical_after_model_callbacks:
604-
after_model_callback_content = callback(
626+
callback_response = callback(
605627
callback_context=callback_context, llm_response=llm_response
606628
)
607-
if inspect.isawaitable(after_model_callback_content):
608-
after_model_callback_content = await after_model_callback_content
609-
if after_model_callback_content:
610-
return after_model_callback_content
629+
if inspect.isawaitable(callback_response):
630+
callback_response = await callback_response
631+
if callback_response:
632+
return callback_response
611633

612634
def _finalize_model_response_event(
613635
self,

src/google/adk/flows/llm_flows/functions.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -153,37 +153,67 @@ async def handle_function_calls_async(
153153
# do not use "args" as the variable name, because it is a reserved keyword
154154
# in python debugger.
155155
function_args = function_call.args or {}
156-
function_response: Optional[dict] = None
157156

158-
for callback in agent.canonical_before_tool_callbacks:
159-
function_response = callback(
160-
tool=tool, args=function_args, tool_context=tool_context
161-
)
162-
if inspect.isawaitable(function_response):
163-
function_response = await function_response
164-
if function_response:
165-
break
157+
# Step 1: Check if plugin before_tool_callback overrides the function
158+
# response.
159+
function_response = (
160+
await invocation_context.plugin_manager.run_before_tool_callback(
161+
tool=tool, tool_args=function_args, tool_context=tool_context
162+
)
163+
)
166164

167-
if not function_response:
165+
# Step 2: If no overrides are provided from the plugins, further run the
166+
# canonical callback.
167+
if function_response is None:
168+
for callback in agent.canonical_before_tool_callbacks:
169+
function_response = callback(
170+
tool=tool, args=function_args, tool_context=tool_context
171+
)
172+
if inspect.isawaitable(function_response):
173+
function_response = await function_response
174+
if function_response:
175+
break
176+
177+
# Step 3: Otherwise, proceed calling the tool normally.
178+
if function_response is None:
168179
function_response = await __call_tool_async(
169180
tool, args=function_args, tool_context=tool_context
170181
)
171182

172-
for callback in agent.canonical_after_tool_callbacks:
173-
altered_function_response = callback(
174-
tool=tool,
175-
args=function_args,
176-
tool_context=tool_context,
177-
tool_response=function_response,
178-
)
179-
if inspect.isawaitable(altered_function_response):
180-
altered_function_response = await altered_function_response
181-
if altered_function_response is not None:
182-
function_response = altered_function_response
183-
break
183+
# Step 4: Check if plugin after_tool_callback overrides the function
184+
# response.
185+
altered_function_response = (
186+
await invocation_context.plugin_manager.run_after_tool_callback(
187+
tool=tool,
188+
tool_args=function_args,
189+
tool_context=tool_context,
190+
result=function_response,
191+
)
192+
)
193+
194+
# Step 5: If no overrides are provided from the plugins, further run the
195+
# canonical after_tool_callbacks.
196+
if altered_function_response is None:
197+
for callback in agent.canonical_after_tool_callbacks:
198+
altered_function_response = callback(
199+
tool=tool,
200+
args=function_args,
201+
tool_context=tool_context,
202+
tool_response=function_response,
203+
)
204+
if inspect.isawaitable(altered_function_response):
205+
altered_function_response = await altered_function_response
206+
if altered_function_response:
207+
break
208+
209+
# Step 6: If alternative response exists from after_tool_callback, use it
210+
# instead of the original function response.
211+
if altered_function_response is not None:
212+
function_response = altered_function_response
184213

185214
if tool.is_long_running:
186-
# Allow long running function to return None to not provide function response.
215+
# Allow long running function to return None to not provide function
216+
# response.
187217
if not function_response:
188218
continue
189219

@@ -264,6 +294,7 @@ async def handle_function_calls_live(
264294
# )
265295
# if new_response:
266296
# function_response = new_response
297+
altered_function_response = None
267298
if agent.after_tool_callback:
268299
altered_function_response = agent.after_tool_callback(
269300
tool=tool,
@@ -273,8 +304,8 @@ async def handle_function_calls_live(
273304
)
274305
if inspect.isawaitable(altered_function_response):
275306
altered_function_response = await altered_function_response
276-
if altered_function_response is not None:
277-
function_response = altered_function_response
307+
if altered_function_response is not None:
308+
function_response = altered_function_response
278309

279310
if tool.is_long_running:
280311
# Allow async function to return None to not provide function response.

0 commit comments

Comments
 (0)