Skip to content

Commit d585635

Browse files
mikeas1holtskinner
andauthored
feat: Add a ClientFactory.connect() method for easy client creation (#509)
# Description This PR adds a convenience method for constructing a Client from either an AgentCard URL or an AgentCard directly. The goal is to reduce the number of lines of code required for simple client creation, but still enabling more advanced handled of client construction. Usage example: ```python my_agent_url = 'https://travel-agent.example.com' client = await ClientFactory.connect(my_agent_url) await client.send_message(...) ``` Release-As: 0.3.10 --------- Co-authored-by: Holt Skinner <13262395+holtskinner@users.noreply.github.com>
1 parent 317df0a commit d585635

File tree

3 files changed

+225
-0
lines changed

3 files changed

+225
-0
lines changed

.github/actions/spelling/allow.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
AAgent
12
ACard
23
AClient
34
ACMRTUXB

src/a2a/client/client_factory.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import logging
44

55
from collections.abc import Callable
6+
from typing import Any
67

78
import httpx
89

910
from a2a.client.base_client import BaseClient
11+
from a2a.client.card_resolver import A2ACardResolver
1012
from a2a.client.client import Client, ClientConfig, Consumer
1113
from a2a.client.middleware import ClientCallInterceptor
1214
from a2a.client.transports.base import ClientTransport
@@ -101,6 +103,71 @@ def _register_defaults(
101103
GrpcTransport.create,
102104
)
103105

106+
@classmethod
107+
async def connect( # noqa: PLR0913
108+
cls,
109+
agent: str | AgentCard,
110+
client_config: ClientConfig | None = None,
111+
consumers: list[Consumer] | None = None,
112+
interceptors: list[ClientCallInterceptor] | None = None,
113+
relative_card_path: str | None = None,
114+
resolver_http_kwargs: dict[str, Any] | None = None,
115+
extra_transports: dict[str, TransportProducer] | None = None,
116+
) -> Client:
117+
"""Convenience method for constructing a client.
118+
119+
Constructs a client that connects to the specified agent. Note that
120+
creating multiple clients via this method is less efficient than
121+
constructing an instance of ClientFactory and reusing that.
122+
123+
.. code-block:: python
124+
125+
# This will search for an AgentCard at /.well-known/agent-card.json
126+
my_agent_url = 'https://travel.agents.example.com'
127+
client = await ClientFactory.connect(my_agent_url)
128+
129+
130+
Args:
131+
agent: The base URL of the agent, or the AgentCard to connect to.
132+
client_config: The ClientConfig to use when connecting to the agent.
133+
consumers: A list of `Consumer` methods to pass responses to.
134+
interceptors: A list of interceptors to use for each request. These
135+
are used for things like attaching credentials or http headers
136+
to all outbound requests.
137+
relative_card_path: If the agent field is a URL, this value is used as
138+
the relative path when resolving the agent card. See
139+
A2AAgentCardResolver.get_agent_card for more details.
140+
resolver_http_kwargs: Dictionary of arguments to provide to the httpx
141+
client when resolving the agent card. This value is provided to
142+
A2AAgentCardResolver.get_agent_card as the http_kwargs parameter.
143+
extra_transports: Additional transport protocols to enable when
144+
constructing the client.
145+
146+
Returns:
147+
A `Client` object.
148+
"""
149+
client_config = client_config or ClientConfig()
150+
if isinstance(agent, str):
151+
if not client_config.httpx_client:
152+
async with httpx.AsyncClient() as client:
153+
resolver = A2ACardResolver(client, agent)
154+
card = await resolver.get_agent_card(
155+
relative_card_path=relative_card_path,
156+
http_kwargs=resolver_http_kwargs,
157+
)
158+
else:
159+
resolver = A2ACardResolver(client_config.httpx_client, agent)
160+
card = await resolver.get_agent_card(
161+
relative_card_path=relative_card_path,
162+
http_kwargs=resolver_http_kwargs,
163+
)
164+
else:
165+
card = agent
166+
factory = cls(client_config)
167+
for label, generator in (extra_transports or {}).items():
168+
factory.register(label, generator)
169+
return factory.create(card, consumers, interceptors)
170+
104171
def register(self, label: str, generator: TransportProducer) -> None:
105172
"""Register a new transport producer for a given transport label."""
106173
self._registry[label] = generator

tests/client/test_client_factory.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Tests for the ClientFactory."""
22

3+
from unittest.mock import AsyncMock, MagicMock, patch
4+
35
import httpx
46
import pytest
57

@@ -103,3 +105,158 @@ def test_client_factory_no_compatible_transport(base_agent_card: AgentCard):
103105
factory = ClientFactory(config)
104106
with pytest.raises(ValueError, match='no compatible transports found'):
105107
factory.create(base_agent_card)
108+
109+
110+
@pytest.mark.asyncio
111+
async def test_client_factory_connect_with_agent_card(
112+
base_agent_card: AgentCard,
113+
):
114+
"""Verify that connect works correctly when provided with an AgentCard."""
115+
client = await ClientFactory.connect(base_agent_card)
116+
assert isinstance(client._transport, JsonRpcTransport)
117+
assert client._transport.url == 'http://primary-url.com'
118+
119+
120+
@pytest.mark.asyncio
121+
async def test_client_factory_connect_with_url(base_agent_card: AgentCard):
122+
"""Verify that connect works correctly when provided with a URL."""
123+
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
124+
mock_resolver.return_value.get_agent_card = AsyncMock(
125+
return_value=base_agent_card
126+
)
127+
128+
agent_url = 'http://example.com'
129+
client = await ClientFactory.connect(agent_url)
130+
131+
mock_resolver.assert_called_once()
132+
assert mock_resolver.call_args[0][1] == agent_url
133+
mock_resolver.return_value.get_agent_card.assert_awaited_once()
134+
135+
assert isinstance(client._transport, JsonRpcTransport)
136+
assert client._transport.url == 'http://primary-url.com'
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_client_factory_connect_with_url_and_client_config(
141+
base_agent_card: AgentCard,
142+
):
143+
"""Verify connect with a URL and a pre-configured httpx client."""
144+
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
145+
mock_resolver.return_value.get_agent_card = AsyncMock(
146+
return_value=base_agent_card
147+
)
148+
149+
agent_url = 'http://example.com'
150+
mock_httpx_client = httpx.AsyncClient()
151+
config = ClientConfig(httpx_client=mock_httpx_client)
152+
153+
client = await ClientFactory.connect(agent_url, client_config=config)
154+
155+
mock_resolver.assert_called_once_with(mock_httpx_client, agent_url)
156+
mock_resolver.return_value.get_agent_card.assert_awaited_once()
157+
158+
assert isinstance(client._transport, JsonRpcTransport)
159+
assert client._transport.url == 'http://primary-url.com'
160+
161+
162+
@pytest.mark.asyncio
163+
async def test_client_factory_connect_with_resolver_args(
164+
base_agent_card: AgentCard,
165+
):
166+
"""Verify connect passes resolver arguments correctly."""
167+
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
168+
mock_resolver.return_value.get_agent_card = AsyncMock(
169+
return_value=base_agent_card
170+
)
171+
172+
agent_url = 'http://example.com'
173+
relative_path = '/card'
174+
http_kwargs = {'headers': {'X-Test': 'true'}}
175+
176+
# The resolver args are only passed if an httpx_client is provided in config
177+
config = ClientConfig(httpx_client=httpx.AsyncClient())
178+
179+
await ClientFactory.connect(
180+
agent_url,
181+
client_config=config,
182+
relative_card_path=relative_path,
183+
resolver_http_kwargs=http_kwargs,
184+
)
185+
186+
mock_resolver.return_value.get_agent_card.assert_awaited_once_with(
187+
relative_card_path=relative_path,
188+
http_kwargs=http_kwargs,
189+
)
190+
191+
192+
@pytest.mark.asyncio
193+
async def test_client_factory_connect_resolver_args_without_client(
194+
base_agent_card: AgentCard,
195+
):
196+
"""Verify resolver args are ignored if no httpx_client is provided."""
197+
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
198+
mock_resolver.return_value.get_agent_card = AsyncMock(
199+
return_value=base_agent_card
200+
)
201+
202+
agent_url = 'http://example.com'
203+
relative_path = '/card'
204+
http_kwargs = {'headers': {'X-Test': 'true'}}
205+
206+
await ClientFactory.connect(
207+
agent_url,
208+
relative_card_path=relative_path,
209+
resolver_http_kwargs=http_kwargs,
210+
)
211+
212+
mock_resolver.return_value.get_agent_card.assert_awaited_once_with(
213+
relative_card_path=relative_path,
214+
http_kwargs=http_kwargs,
215+
)
216+
217+
218+
@pytest.mark.asyncio
219+
async def test_client_factory_connect_with_extra_transports(
220+
base_agent_card: AgentCard,
221+
):
222+
"""Verify that connect can register and use extra transports."""
223+
224+
class CustomTransport:
225+
pass
226+
227+
def custom_transport_producer(*args, **kwargs):
228+
return CustomTransport()
229+
230+
base_agent_card.preferred_transport = 'custom'
231+
base_agent_card.url = 'custom://foo'
232+
233+
config = ClientConfig(supported_transports=['custom'])
234+
235+
client = await ClientFactory.connect(
236+
base_agent_card,
237+
client_config=config,
238+
extra_transports={'custom': custom_transport_producer},
239+
)
240+
241+
assert isinstance(client._transport, CustomTransport)
242+
243+
244+
@pytest.mark.asyncio
245+
async def test_client_factory_connect_with_consumers_and_interceptors(
246+
base_agent_card: AgentCard,
247+
):
248+
"""Verify consumers and interceptors are passed through correctly."""
249+
consumer1 = MagicMock()
250+
interceptor1 = MagicMock()
251+
252+
with patch('a2a.client.client_factory.BaseClient') as mock_base_client:
253+
await ClientFactory.connect(
254+
base_agent_card,
255+
consumers=[consumer1],
256+
interceptors=[interceptor1],
257+
)
258+
259+
mock_base_client.assert_called_once()
260+
call_args = mock_base_client.call_args[0]
261+
assert call_args[3] == [consumer1]
262+
assert call_args[4] == [interceptor1]

0 commit comments

Comments
 (0)