Skip to content

Commit 051ab20

Browse files
lkawkalkawkagemini-code-assist[bot]holtskinner
authored
feat: custom ID generators (#490)
# Description This PR introduces a new feature allowing server maintainers to provide custom ID generators for system-created objects. I identified 2 key areas requiring ID generation: - [`ResourceContext`](https://github.com/a2aproject/a2a-python/blob/19091262d07b5226becdb177ed4d780b4befb5f4/src/a2a/server/agent_execution/context.py): Manages the creation of new task and context IDs when not supplied by the client. - [`TaskUpdater`](https://github.com/a2aproject/a2a-python/blob/19091262d07b5226becdb177ed4d780b4befb5f4/src/a2a/server/tasks/task_updater.py): Helper responsible for generating IDs for new artifacts and messages during task updates. To address this, a new abstract class, `IDGenerator`, has been introduced. This class enables the creation of specific ID generators for different resource types, which can then be integrated into the relevant SDK components. Alternative approaches were considered: 1. Using a single `IDGenerator` instance for all resource types, with resource type information passed to the generate method. This option was discarded due to potential coupling issues and increased maintenance complexity. 1. Implementing an additional `IDGeneratorRegistry` to manage per-resource-type generators via a `get_id_generator(resource_type)` method. While this would simplify object passing, it introduces additional indirection and overall system complexity. # Prerequisites - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #378 🦕 --------- Co-authored-by: lkawka <lkawka@google.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Holt Skinner <13262395+holtskinner@users.noreply.github.com>
1 parent dca66c3 commit 051ab20

File tree

5 files changed

+149
-9
lines changed

5 files changed

+149
-9
lines changed

src/a2a/server/agent_execution/context.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
import uuid
2-
31
from typing import Any
42

53
from a2a.server.context import ServerCallContext
4+
from a2a.server.id_generator import (
5+
IDGenerator,
6+
IDGeneratorContext,
7+
UUIDGenerator,
8+
)
69
from a2a.types import (
710
InvalidParamsError,
811
Message,
@@ -30,6 +33,8 @@ def __init__( # noqa: PLR0913
3033
task: Task | None = None,
3134
related_tasks: list[Task] | None = None,
3235
call_context: ServerCallContext | None = None,
36+
task_id_generator: IDGenerator | None = None,
37+
context_id_generator: IDGenerator | None = None,
3338
):
3439
"""Initializes the RequestContext.
3540
@@ -40,6 +45,8 @@ def __init__( # noqa: PLR0913
4045
task: The existing `Task` object retrieved from the store, if any.
4146
related_tasks: A list of other tasks related to the current request (e.g., for tool use).
4247
call_context: The server call context associated with this request.
48+
task_id_generator: ID generator for new task IDs. Defaults to UUID generator.
49+
context_id_generator: ID generator for new context IDs. Defaults to UUID generator.
4350
"""
4451
if related_tasks is None:
4552
related_tasks = []
@@ -49,6 +56,12 @@ def __init__( # noqa: PLR0913
4956
self._current_task = task
5057
self._related_tasks = related_tasks
5158
self._call_context = call_context
59+
self._task_id_generator = (
60+
task_id_generator if task_id_generator else UUIDGenerator()
61+
)
62+
self._context_id_generator = (
63+
context_id_generator if context_id_generator else UUIDGenerator()
64+
)
5265
# If the task id and context id were provided, make sure they
5366
# match the request. Otherwise, create them
5467
if self._params:
@@ -163,7 +176,9 @@ def _check_or_generate_task_id(self) -> None:
163176
return
164177

165178
if not self._task_id and not self._params.message.task_id:
166-
self._params.message.task_id = str(uuid.uuid4())
179+
self._params.message.task_id = self._task_id_generator.generate(
180+
IDGeneratorContext(context_id=self._context_id)
181+
)
167182
if self._params.message.task_id:
168183
self._task_id = self._params.message.task_id
169184

@@ -173,6 +188,10 @@ def _check_or_generate_context_id(self) -> None:
173188
return
174189

175190
if not self._context_id and not self._params.message.context_id:
176-
self._params.message.context_id = str(uuid.uuid4())
191+
self._params.message.context_id = (
192+
self._context_id_generator.generate(
193+
IDGeneratorContext(task_id=self._task_id)
194+
)
195+
)
177196
if self._params.message.context_id:
178197
self._context_id = self._params.message.context_id

src/a2a/server/id_generator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import uuid
2+
3+
from abc import ABC, abstractmethod
4+
5+
from pydantic import BaseModel
6+
7+
8+
class IDGeneratorContext(BaseModel):
9+
"""Context for providing additional information to ID generators."""
10+
11+
task_id: str | None = None
12+
context_id: str | None = None
13+
14+
15+
class IDGenerator(ABC):
16+
"""Interface for generating unique identifiers."""
17+
18+
@abstractmethod
19+
def generate(self, context: IDGeneratorContext) -> str:
20+
pass
21+
22+
23+
class UUIDGenerator(IDGenerator):
24+
"""UUID implementation of the IDGenerator interface."""
25+
26+
def generate(self, context: IDGeneratorContext) -> str:
27+
"""Generates a random UUID, ignoring the context."""
28+
return str(uuid.uuid4())

src/a2a/server/tasks/task_updater.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import asyncio
2-
import uuid
32

43
from datetime import datetime, timezone
54
from typing import Any
65

76
from a2a.server.events import EventQueue
7+
from a2a.server.id_generator import (
8+
IDGenerator,
9+
IDGeneratorContext,
10+
UUIDGenerator,
11+
)
812
from a2a.types import (
913
Artifact,
1014
Message,
@@ -23,13 +27,22 @@ class TaskUpdater:
2327
Simplifies the process of creating and enqueueing standard task events.
2428
"""
2529

26-
def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
30+
def __init__(
31+
self,
32+
event_queue: EventQueue,
33+
task_id: str,
34+
context_id: str,
35+
artifact_id_generator: IDGenerator | None = None,
36+
message_id_generator: IDGenerator | None = None,
37+
):
2738
"""Initializes the TaskUpdater.
2839
2940
Args:
3041
event_queue: The `EventQueue` associated with the task.
3142
task_id: The ID of the task.
3243
context_id: The context ID of the task.
44+
artifact_id_generator: ID generator for new artifact IDs. Defaults to UUID generator.
45+
message_id_generator: ID generator for new message IDs. Defaults to UUID generator.
3346
"""
3447
self.event_queue = event_queue
3548
self.task_id = task_id
@@ -42,6 +55,12 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
4255
TaskState.failed,
4356
TaskState.rejected,
4457
}
58+
self._artifact_id_generator = (
59+
artifact_id_generator if artifact_id_generator else UUIDGenerator()
60+
)
61+
self._message_id_generator = (
62+
message_id_generator if message_id_generator else UUIDGenerator()
63+
)
4564

4665
async def update_status(
4766
self,
@@ -110,7 +129,11 @@ async def add_artifact( # noqa: PLR0913
110129
extensions: Optional list of extensions for the artifact.
111130
"""
112131
if not artifact_id:
113-
artifact_id = str(uuid.uuid4())
132+
artifact_id = self._artifact_id_generator.generate(
133+
IDGeneratorContext(
134+
task_id=self.task_id, context_id=self.context_id
135+
)
136+
)
114137

115138
await self.event_queue.enqueue_event(
116139
TaskArtifactUpdateEvent(
@@ -205,7 +228,11 @@ def new_agent_message(
205228
role=Role.agent,
206229
task_id=self.task_id,
207230
context_id=self.context_id,
208-
message_id=str(uuid.uuid4()),
231+
message_id=self._message_id_generator.generate(
232+
IDGeneratorContext(
233+
task_id=self.task_id, context_id=self.context_id
234+
)
235+
),
209236
metadata=metadata,
210237
parts=parts,
211238
)

tests/server/agent_execution/test_context.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from a2a.server.agent_execution import RequestContext
88
from a2a.server.context import ServerCallContext
9+
from a2a.server.id_generator import IDGenerator
910
from a2a.types import (
1011
Message,
1112
MessageSendParams,
@@ -149,6 +150,20 @@ def test_check_or_generate_task_id_with_existing_task_id(self, mock_params):
149150
assert context.task_id == existing_id
150151
assert mock_params.message.task_id == existing_id
151152

153+
def test_check_or_generate_task_id_with_custom_id_generator(
154+
self, mock_params
155+
):
156+
"""Test _check_or_generate_task_id uses custom ID generator when provided."""
157+
id_generator = Mock(spec=IDGenerator)
158+
id_generator.generate.return_value = 'custom-task-id'
159+
160+
context = RequestContext(
161+
request=mock_params, task_id_generator=id_generator
162+
)
163+
# The method is called during initialization
164+
165+
assert context.task_id == 'custom-task-id'
166+
152167
def test_check_or_generate_context_id_no_params(self):
153168
"""Test _check_or_generate_context_id with no params does nothing."""
154169
context = RequestContext()
@@ -168,6 +183,20 @@ def test_check_or_generate_context_id_with_existing_context_id(
168183
assert context.context_id == existing_id
169184
assert mock_params.message.context_id == existing_id
170185

186+
def test_check_or_generate_context_id_with_custom_id_generator(
187+
self, mock_params
188+
):
189+
"""Test _check_or_generate_context_id uses custom ID generator when provided."""
190+
id_generator = Mock(spec=IDGenerator)
191+
id_generator.generate.return_value = 'custom-context-id'
192+
193+
context = RequestContext(
194+
request=mock_params, context_id_generator=id_generator
195+
)
196+
# The method is called during initialization
197+
198+
assert context.context_id == 'custom-context-id'
199+
171200
def test_init_raises_error_on_task_id_mismatch(
172201
self, mock_params, mock_task
173202
):

tests/server/tasks/test_task_updater.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
22
import uuid
33

4-
from unittest.mock import AsyncMock, patch
4+
from unittest.mock import AsyncMock, Mock, patch
55

66
import pytest
77

88
from a2a.server.events import EventQueue
9+
from a2a.server.id_generator import IDGenerator
910
from a2a.server.tasks import TaskUpdater
1011
from a2a.types import (
1112
Message,
@@ -151,6 +152,26 @@ async def test_add_artifact_generates_id(
151152
assert event.last_chunk is None
152153

153154

155+
@pytest.mark.asyncio
156+
async def test_add_artifact_generates_custom_id(event_queue, sample_parts):
157+
"""Test add_artifact uses a custom ID generator when provided."""
158+
artifact_id_generator = Mock(spec=IDGenerator)
159+
artifact_id_generator.generate.return_value = 'custom-artifact-id'
160+
task_updater = TaskUpdater(
161+
event_queue=event_queue,
162+
task_id='test-task-id',
163+
context_id='test-context-id',
164+
artifact_id_generator=artifact_id_generator,
165+
)
166+
167+
await task_updater.add_artifact(parts=sample_parts, artifact_id=None)
168+
169+
event_queue.enqueue_event.assert_called_once()
170+
event = event_queue.enqueue_event.call_args[0][0]
171+
assert isinstance(event, TaskArtifactUpdateEvent)
172+
assert event.artifact.artifact_id == 'custom-artifact-id'
173+
174+
154175
@pytest.mark.asyncio
155176
@pytest.mark.parametrize(
156177
'append_val, last_chunk_val',
@@ -304,6 +325,22 @@ def test_new_agent_message_with_metadata(task_updater, sample_parts):
304325
assert message.metadata == metadata
305326

306327

328+
def test_new_agent_message_with_custom_id_generator(event_queue, sample_parts):
329+
"""Test creating a new agent message with a custom message ID generator."""
330+
message_id_generator = Mock(spec=IDGenerator)
331+
message_id_generator.generate.return_value = 'custom-message-id'
332+
task_updater = TaskUpdater(
333+
event_queue=event_queue,
334+
task_id='test-task-id',
335+
context_id='test-context-id',
336+
message_id_generator=message_id_generator,
337+
)
338+
339+
message = task_updater.new_agent_message(parts=sample_parts)
340+
341+
assert message.message_id == 'custom-message-id'
342+
343+
307344
@pytest.mark.asyncio
308345
async def test_failed_without_message(task_updater, event_queue):
309346
"""Test marking a task as failed without a message."""

0 commit comments

Comments
 (0)