Skip to content

Commit 4fec136

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Add live/bidi streaming support for Agent Engine
PiperOrigin-RevId: 799712899
1 parent c55f906 commit 4fec136

File tree

5 files changed

+364
-1
lines changed

5 files changed

+364
-1
lines changed

tests/unit/vertexai/genai/test_genai_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ async def test_async_client(self):
5656
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
5757
assert isinstance(test_client.aio, vertexai._genai.client.AsyncClient)
5858

59+
@pytest.mark.usefixtures("google_auth_mock")
60+
def test_live_client(self):
61+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
62+
test_async_client = test_client.aio
63+
assert isinstance(test_async_client.live, vertexai._genai.live.AsyncLive)
64+
5965
@pytest.mark.usefixtures("google_auth_mock")
6066
def test_types(self):
6167
assert vertexai.types is not None
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
import importlib
16+
import json
17+
from unittest import mock
18+
19+
import google.auth
20+
import google.auth.credentials
21+
from google.cloud import aiplatform
22+
import vertexai
23+
from google.cloud.aiplatform import initializer as aiplatform_initializer
24+
from vertexai._genai import live_agent_engines
25+
import pytest
26+
27+
28+
_TEST_PROJECT = "test-project"
29+
_TEST_LOCATION = "us-central1"
30+
pytestmark = pytest.mark.usefixtures("google_auth_mock")
31+
32+
33+
class TestLiveAgentEngines:
34+
"""Unit tests for the GenAI client."""
35+
36+
def setup_method(self):
37+
importlib.reload(aiplatform_initializer)
38+
importlib.reload(aiplatform)
39+
importlib.reload(vertexai)
40+
vertexai.init(
41+
project=_TEST_PROJECT,
42+
location=_TEST_LOCATION,
43+
)
44+
45+
@pytest.mark.asyncio
46+
@pytest.mark.usefixtures("google_auth_mock")
47+
@mock.patch.object(live_agent_engines, "ws_connect")
48+
@mock.patch.object(google.auth, "default")
49+
async def test_async_live_agent_engines_connect(
50+
self, mock_auth_default, mock_ws_connect
51+
):
52+
"""Tests the AsyncLiveAgentEngines.connect method, as well as the AsyncLiveAgentEngineSession methods."""
53+
# Mock credentials to avoid refresh issues
54+
mock_creds = mock.Mock(spec=google.auth.credentials.Credentials)
55+
mock_creds.token = "test-token"
56+
mock_creds.valid = True
57+
mock_auth_default.return_value = (mock_creds, None)
58+
59+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
60+
mock_ws = mock.AsyncMock()
61+
mock_ws_connect.return_value.__aenter__.return_value = mock_ws
62+
63+
mock_ws.recv.side_effect = [
64+
json.dumps({"output": "HELLO"}).encode("utf-8"),
65+
json.dumps({"output": "WORLD"}).encode("utf-8"),
66+
]
67+
68+
async with test_client.aio.live.agent_engines.connect(
69+
agent_engine="test-agent-engine",
70+
config={"class_method": "bidi_stream_query", "input": {"query": "hello"}},
71+
) as session:
72+
assert session is not None
73+
74+
# Send additional messages
75+
await session.send({"query": "world"})
76+
77+
# Receive responses
78+
responses = []
79+
response = await session.receive()
80+
responses.append(response)
81+
response = await session.receive()
82+
responses.append(response)
83+
84+
await session.close()
85+
86+
assert responses == [{"output": "HELLO"}, {"output": "WORLD"}]
87+
88+
mock_ws.send.assert_has_calls(
89+
[
90+
mock.call(
91+
json.dumps(
92+
{
93+
"setup": {
94+
"name": (
95+
f"projects/{_TEST_PROJECT}/locations/"
96+
f"{_TEST_LOCATION}/reasoningEngines/"
97+
"test-agent-engine"
98+
),
99+
"class_method": "bidi_stream_query",
100+
"input": {"query": "hello"},
101+
}
102+
}
103+
)
104+
),
105+
mock.call(json.dumps({"bidi_stream_input": {"query": "world"}})),
106+
]
107+
)
108+
mock_ws.close.assert_called_once()
109+
mock_auth_default.assert_called_once_with(
110+
scopes=["https://www.googleapis.com/auth/cloud-platform"]
111+
)
112+
mock_creds.refresh.assert_not_called()

vertexai/_genai/client.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.genai import _common
2323
from google.genai import client as genai_client
2424
from google.genai import types
25+
from . import live
2526

2627

2728
_GENAI_MODULES_TELEMETRY_HEADER = "vertex-genai-modules"
@@ -45,15 +46,23 @@ def _add_tracking_headers(headers: dict[str, str]) -> None:
4546

4647

4748
class AsyncClient:
48-
4949
"""Async Client for the GenAI SDK."""
5050

5151
def __init__(self, api_client: genai_client.Client):
5252
self._api_client = api_client
53+
self._live = live.AsyncLive(self._api_client)
5354
self._evals = None
5455
self._agent_engines = None
5556
self._prompt_optimizer = None
5657

58+
@property
59+
@_common.experimental_warning(
60+
"The Vertex SDK GenAI live module is experimental, and may change in future "
61+
"versions."
62+
)
63+
def live(self) -> live.AsyncLive:
64+
return self._live
65+
5766
@property
5867
@_common.experimental_warning(
5968
"The Vertex SDK GenAI evals module is experimental, and may change in future "

vertexai/_genai/live.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
"""[Preview] Live API client."""
17+
18+
import importlib
19+
import logging
20+
21+
22+
from google.genai import _api_module
23+
from google.genai import _common
24+
from google.genai._api_client import BaseApiClient
25+
26+
logger = logging.getLogger("google_genai.live")
27+
28+
29+
class AsyncLive(_api_module.BaseModule):
30+
"""[Preview] AsyncLive."""
31+
32+
def __init__(self, api_client: BaseApiClient):
33+
super().__init__(api_client)
34+
self._agent_engines = None
35+
36+
@property
37+
@_common.experimental_warning(
38+
"The Vertex SDK GenAI agent engines module is experimental, "
39+
"and may change in future versions."
40+
)
41+
def agent_engines(self):
42+
if self._agent_engines is None:
43+
try:
44+
# We need to lazy load the live_agent_engines module to handle
45+
# the possibility of ImportError when dependencies are not
46+
# installed.
47+
self._agent_engines = importlib.import_module(
48+
".live_agent_engines",
49+
__package__,
50+
)
51+
except ImportError as e:
52+
raise ImportError(
53+
"The 'agent_engines' module requires 'additional packages'. "
54+
"Please install them using pip install "
55+
"google-cloud-aiplatform[agent_engines]"
56+
) from e
57+
return self._agent_engines.AsyncLiveAgentEngines(self._api_client)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
"""Live AgentEngine API client."""
17+
18+
import contextlib
19+
import json
20+
from typing import Any, AsyncIterator, Dict, Optional
21+
import google.auth
22+
23+
from google.genai import _api_module
24+
from .types import QueryAgentEngineConfig, QueryAgentEngineConfigOrDict
25+
26+
27+
try:
28+
from websockets.asyncio.client import ClientConnection
29+
from websockets.asyncio.client import connect as ws_connect
30+
except ModuleNotFoundError:
31+
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
32+
from websockets.client import ClientConnection # type: ignore
33+
from websockets.client import connect as ws_connect # type: ignore
34+
35+
36+
class AsyncLiveAgentEngineSession:
37+
"""AsyncLiveAgentEngineSession."""
38+
39+
def __init__(self, websocket: ClientConnection):
40+
self._ws = websocket
41+
42+
async def send(self, query_input: Dict[str, Any]) -> None:
43+
"""Send a query input to the Agent.
44+
45+
Args:
46+
query_input: A JSON serializable Python Dict to be send to the Agent.
47+
"""
48+
49+
try:
50+
json_request = json.dumps({"bidi_stream_input": query_input})
51+
except json.JSONEncoderError as exc:
52+
raise ValueError(
53+
"Failed to encode query input to JSON in live_agent_engines: "
54+
f"{str(query_input)}"
55+
) from exc
56+
await self._ws.send(json_request)
57+
58+
async def receive(self) -> Dict[str, Any]:
59+
"""Receive one response from the Agent.
60+
61+
Returns:
62+
A response from the Agent.
63+
64+
Raises:
65+
websockets.exceptions.ConnectionClosed: If the connection is closed.
66+
"""
67+
68+
response = await self._ws.recv()
69+
try:
70+
return json.loads(response)
71+
except json.decoder.JSONDecodeError as exc:
72+
raise ValueError(
73+
"Failed to parse response to JSON in live_agent_engines: "
74+
f"{str(response)}"
75+
) from exc
76+
77+
async def close(self) -> None:
78+
"""Close the connection."""
79+
await self._ws.close()
80+
81+
82+
class AsyncLiveAgentEngines(_api_module.BaseModule):
83+
"""AsyncLiveAgentEngines.
84+
85+
Example usage:
86+
87+
.. code-block:: python
88+
89+
from pathlib import Path
90+
91+
from google import genai
92+
from google.genai import types
93+
94+
class MyAgentEngine(client):
95+
def bidi_stream_query(self, input_queue: asyncio.Queue):
96+
while True:
97+
input = await input_queue.get()
98+
yield {"output": f"Agent received {input}!"}
99+
100+
client = vertexai.Client(project="my-project", location="us-central1")
101+
agent_engine = client.agent_engines.create(agent)
102+
103+
async with client.aio.live.agent_engines.connect(
104+
agent_engine=agent_engine.api_resource.name,
105+
setup={"class_method": "bidi_stream_query"},
106+
) as session:
107+
await session.send(input={"input": "Hello world"})
108+
109+
response = await session.receive()
110+
# {"output": "Agent received Hello world!"}
111+
...
112+
"""
113+
114+
@contextlib.asynccontextmanager
115+
async def connect(
116+
self,
117+
*,
118+
agent_engine: str,
119+
config: Optional[QueryAgentEngineConfigOrDict] = None,
120+
) -> AsyncIterator[AsyncLiveAgentEngineSession]:
121+
"""Connect to the agent deployed to Agent Engine in a live (bidirectional streaming) session.
122+
123+
Args:
124+
agent_engine: The resource name of the Agent Engine to use for the
125+
live session.
126+
config: The optional configuration for starting the live Agent Engine
127+
session. Custom class_method and an optional initial input could be
128+
provided. If no class_method is provided, the default class_method
129+
"bidi_stream_query" will be used by the Agent Engine.
130+
131+
Yields:
132+
An AsyncLiveAgentEngineSession object.
133+
"""
134+
if isinstance(config, dict):
135+
config = QueryAgentEngineConfig(**config)
136+
137+
agent_engine_resource_name = agent_engine
138+
if not agent_engine_resource_name.startswith("projects/"):
139+
agent_engine_resource_name = f"projects/{self._api_client.project}/locations/{self._api_client.location}/reasoningEngines/{agent_engine}"
140+
request_dict = {"setup": {"name": agent_engine_resource_name}}
141+
if config.class_method:
142+
request_dict["setup"]["class_method"] = config.class_method
143+
if config.input:
144+
request_dict["setup"]["input"] = config.input
145+
146+
request = json.dumps(request_dict)
147+
148+
if not self._api_client._credentials:
149+
# Get bearer token through Application Default Credentials.
150+
creds, _ = google.auth.default( # type: ignore
151+
scopes=["https://www.googleapis.com/auth/cloud-platform"]
152+
)
153+
else:
154+
creds = self._api_client._credentials
155+
# creds.valid is False, and creds.token is None
156+
# Need to refresh credentials to populate those
157+
if not (creds.token and creds.valid):
158+
auth_req = google.auth.transport.requests.Request() # type: ignore
159+
creds.refresh(auth_req)
160+
bearer_token = creds.token
161+
162+
original_headers = self._api_client._http_options.headers
163+
headers = original_headers.copy() if original_headers is not None else {}
164+
headers["Authorization"] = f"Bearer {bearer_token}"
165+
166+
base_url = self._api_client._websocket_base_url()
167+
if isinstance(base_url, bytes):
168+
base_url = base_url.decode("utf-8")
169+
uri = (
170+
f"{base_url}/ws/google.cloud.aiplatform."
171+
f"{self._api_client._http_options.api_version}"
172+
".ReasoningEngineExecutionService/BidiQueryReasoningEngine"
173+
)
174+
175+
async with ws_connect(
176+
uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx
177+
) as ws:
178+
await ws.send(request)
179+
yield AsyncLiveAgentEngineSession(websocket=ws)

0 commit comments

Comments
 (0)