Skip to content

Commit 456249e

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Support bidi stream query in agent engines and ADK template.
PiperOrigin-RevId: 800547989
1 parent 42c3c9c commit 456249e

File tree

4 files changed

+403
-11
lines changed

4 files changed

+403
-11
lines changed

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15+
import asyncio
1516
import base64
1617
import importlib
1718
import json
@@ -180,6 +181,34 @@ async def run_async(self, *args, **kwargs):
180181
}
181182
)
182183

184+
async def run_live(self, *args, **kwargs):
185+
from google.adk.events import event
186+
187+
yield event.Event(
188+
**{
189+
"author": "currency_exchange_agent",
190+
"content": {
191+
"parts": [
192+
{
193+
"thought_signature": b"test_signature",
194+
"function_call": {
195+
"args": {
196+
"currency_date": "2025-04-03",
197+
"currency_from": "USD",
198+
"currency_to": "SEK",
199+
},
200+
"id": "af-c5a57692-9177-4091-a3df-098f834ee849",
201+
"name": "get_exchange_rate",
202+
},
203+
}
204+
],
205+
"role": "model",
206+
},
207+
"id": "9aaItGK9",
208+
"invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7",
209+
}
210+
)
211+
183212

184213
@pytest.mark.usefixtures("google_auth_mock")
185214
class TestAdkApp:
@@ -355,6 +384,29 @@ def test_streaming_agent_run_with_events(self):
355384
events = list(app.streaming_agent_run_with_events(request_json=request_json))
356385
assert len(events) == 1
357386

387+
@pytest.mark.asyncio
388+
async def test_async_bidi_stream_query(self):
389+
app = reasoning_engines.AdkApp(
390+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
391+
)
392+
assert app._tmpl_attrs.get("runner") is None
393+
app.set_up()
394+
app._tmpl_attrs["runner"] = _MockRunner()
395+
request_queue = asyncio.Queue()
396+
request_dict = {
397+
"user_id": _TEST_USER_ID,
398+
"live_request": {
399+
"input": "What is the exchange rate from USD to SEK?",
400+
},
401+
}
402+
403+
await request_queue.put(request_dict)
404+
await request_queue.put(None) # Sentinel to end the stream.
405+
events = []
406+
async for event in app.bidi_stream_query(request_queue):
407+
events.append(event)
408+
assert len(events) == 1
409+
358410
def test_create_session(self):
359411
app = reasoning_engines.AdkApp(
360412
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
@@ -568,3 +620,35 @@ async def test_async_stream_query_invalid_message_type(self):
568620
):
569621
async for _ in app.async_stream_query(user_id=_TEST_USER_ID, message=123):
570622
pass
623+
624+
@pytest.mark.asyncio
625+
async def test_bidi_stream_query_invalid_request_queue(self):
626+
app = reasoning_engines.AdkApp(
627+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
628+
)
629+
request_queue = []
630+
with pytest.raises(
631+
TypeError,
632+
match="request_queue must be an asyncio.Queue instance.",
633+
):
634+
async for _ in app.bidi_stream_query(request_queue):
635+
pass
636+
637+
@pytest.mark.asyncio
638+
async def test_bidi_stream_query_invalid_first_request(self):
639+
app = reasoning_engines.AdkApp(
640+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
641+
)
642+
request_queue = asyncio.Queue()
643+
request_dict = {
644+
"live_request": {
645+
"input": "What is the exchange rate from USD to SEK?",
646+
},
647+
}
648+
await request_queue.put(request_dict)
649+
with pytest.raises(
650+
ValueError,
651+
match="The first request must have a user_id.",
652+
):
653+
async for _ in app.bidi_stream_query(request_queue):
654+
pass

tests/unit/vertex_langchain/test_agent_engines.py

Lines changed: 168 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,20 @@ def clone(self):
108108
return self
109109

110110

111+
class BidiStreamQueryEngine:
112+
"""A sample Agent Engine that implements `bidi_stream_query`."""
113+
114+
def set_up(self):
115+
pass
116+
117+
async def bidi_stream_query(self, unused_request_queue) -> AsyncIterable[Any]:
118+
"""Runs the bidi stream engine."""
119+
raise NotImplementedError()
120+
121+
def clone(self):
122+
return self
123+
124+
111125
class OperationRegistrableEngine:
112126
"""Add a test class that implements OperationRegistrable."""
113127

@@ -141,6 +155,10 @@ async def async_stream_query(
141155
for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE:
142156
yield chunk
143157

158+
async def bidi_stream_query(self, unused_request_queue) -> AsyncIterable[Any]:
159+
"""Runs the bidi stream engine."""
160+
raise NotImplementedError()
161+
144162
# Add a custom method to test the custom stream method registration.
145163
def custom_stream_query(self, unused_arbitrary_string_name: str) -> Iterable[Any]:
146164
"""Runs the stream engine."""
@@ -158,6 +176,12 @@ async def custom_async_stream_method(
158176
for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE:
159177
yield chunk
160178

179+
async def custom_bidi_stream_method(
180+
self, unused_request_queue
181+
) -> AsyncIterable[Any]:
182+
"""Runs the bidi stream engine."""
183+
raise NotImplementedError()
184+
161185
def clone(self):
162186
return self
163187

@@ -890,6 +914,17 @@ async def async_stream_query() -> str:
890914
return "RESPONSE"
891915

892916

917+
class InvalidCapitalizeEngineWithoutBidiStreamQuerySelf:
918+
"""A sample Agent Engine with an invalid bidi_stream_query method."""
919+
920+
def set_up(self):
921+
pass
922+
923+
async def bidi_stream_query() -> AsyncIterable[Any]:
924+
"""Runs the engine."""
925+
raise NotImplementedError()
926+
927+
893928
class InvalidCapitalizeEngineWithoutRegisterOperationsSelf:
894929
"""A sample Agent Engine with an invalid register_operations method."""
895930

@@ -1625,6 +1660,23 @@ def test_get_agent_framework(
16251660
),
16261661
),
16271662
),
1663+
(
1664+
"Update the bidi stream query engine",
1665+
{"agent_engine": BidiStreamQueryEngine()},
1666+
types.reasoning_engine_service.UpdateReasoningEngineRequest(
1667+
reasoning_engine=_generate_agent_engine_with_class_methods_and_agent_framework(
1668+
[],
1669+
_agent_engines._DEFAULT_AGENT_FRAMEWORK,
1670+
),
1671+
update_mask=field_mask_pb2.FieldMask(
1672+
paths=[
1673+
"spec.package_spec.pickle_object_gcs_uri",
1674+
"spec.class_methods",
1675+
"spec.agent_framework",
1676+
]
1677+
),
1678+
),
1679+
),
16281680
(
16291681
"Update the operation registrable engine",
16301682
{"agent_engine": OperationRegistrableEngine()},
@@ -2826,6 +2878,95 @@ async def test_async_stream_query_agent_engine_with_operation_schema(
28262878
)
28272879
)
28282880

2881+
# pytest does not allow absl.testing.parameterized.named_parameters.
2882+
@pytest.mark.parametrize(
2883+
"test_case_name, test_engine, test_class_method_docs, test_class_methods_spec",
2884+
[
2885+
(
2886+
"Default Bidi Stream Queryable (Not Operation Registrable) Engine",
2887+
BidiStreamQueryEngine(),
2888+
{},
2889+
_TEST_ASYNC_STREAM_QUERY_SCHEMAS,
2890+
),
2891+
(
2892+
"Operation Registrable Engine",
2893+
OperationRegistrableEngine(),
2894+
{},
2895+
_TEST_OPERATION_REGISTRABLE_SCHEMAS,
2896+
),
2897+
],
2898+
)
2899+
@pytest.mark.asyncio
2900+
async def test_create_agent_engine_with_bidi_stream_query_operation_schema(
2901+
self,
2902+
test_case_name,
2903+
test_engine,
2904+
test_class_method_docs,
2905+
test_class_methods_spec,
2906+
):
2907+
with mock.patch.object(
2908+
base.VertexAiResourceNoun,
2909+
"_get_gca_resource",
2910+
) as get_gca_resource_mock:
2911+
test_spec = types.ReasoningEngineSpec()
2912+
test_spec.class_methods.extend(test_class_methods_spec)
2913+
get_gca_resource_mock.return_value = types.ReasoningEngine(
2914+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2915+
spec=test_spec,
2916+
)
2917+
agent_engines.create(test_engine)
2918+
2919+
# pytest does not allow absl.testing.parameterized.named_parameters.
2920+
@pytest.mark.parametrize(
2921+
"test_case_name, test_engine, test_class_methods, test_class_methods_spec",
2922+
[
2923+
(
2924+
"Default Bidi Stream Queryable (Not Operation Registrable) Engine",
2925+
BidiStreamQueryEngine(),
2926+
[],
2927+
[],
2928+
),
2929+
(
2930+
"Operation Registrable Engine",
2931+
OperationRegistrableEngine(),
2932+
[],
2933+
_TEST_OPERATION_REGISTRABLE_SCHEMAS,
2934+
),
2935+
],
2936+
)
2937+
@pytest.mark.asyncio
2938+
async def test_update_agent_engine_with_bidi_stream_query_operation_schema(
2939+
self,
2940+
test_case_name,
2941+
test_engine,
2942+
test_class_methods,
2943+
test_class_methods_spec,
2944+
update_agent_engine_mock,
2945+
):
2946+
with mock.patch.object(
2947+
base.VertexAiResourceNoun,
2948+
"_get_gca_resource",
2949+
) as get_gca_resource_mock:
2950+
test_spec = types.ReasoningEngineSpec()
2951+
test_spec.class_methods.append(_TEST_METHOD_TO_BE_UNREGISTERED_SCHEMA)
2952+
get_gca_resource_mock.return_value = types.ReasoningEngine(
2953+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME, spec=test_spec
2954+
)
2955+
test_agent_engine = agent_engines.create(MethodToBeUnregisteredEngine())
2956+
assert hasattr(test_agent_engine, _TEST_METHOD_TO_BE_UNREGISTERED_NAME)
2957+
2958+
with mock.patch.object(
2959+
base.VertexAiResourceNoun,
2960+
"_get_gca_resource",
2961+
) as get_gca_resource_mock:
2962+
test_spec = types.ReasoningEngineSpec()
2963+
test_spec.class_methods.extend(test_class_methods_spec)
2964+
get_gca_resource_mock.return_value = types.ReasoningEngine(
2965+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2966+
spec=test_spec,
2967+
)
2968+
test_agent_engine.update(agent_engine=test_engine)
2969+
28292970

28302971
@pytest.mark.usefixtures("google_auth_mock")
28312972
class TestAgentEngineErrors:
@@ -2887,8 +3028,8 @@ def test_create_agent_engine_no_query_method(
28873028
TypeError,
28883029
match=(
28893030
"agent_engine has none of the following callable methods: "
2890-
"`query`, `async_query`, `stream_query`, `async_stream_query` "
2891-
"or `register_operations`."
3031+
"`query`, `async_query`, `stream_query`, `async_stream_query`, "
3032+
"`bidi_stream_query` or `register_operations`."
28923033
),
28933034
):
28943035
agent_engines.create(
@@ -2911,8 +3052,8 @@ def test_create_agent_engine_noncallable_query_attribute(
29113052
TypeError,
29123053
match=(
29133054
"agent_engine has none of the following callable methods: "
2914-
"`query`, `async_query`, `stream_query`, `async_stream_query` "
2915-
"or `register_operations`."
3055+
"`query`, `async_query`, `stream_query`, `async_stream_query`, "
3056+
"`bidi_stream_query` or `register_operations`."
29163057
),
29173058
):
29183059
agent_engines.create(
@@ -3024,6 +3165,23 @@ def test_create_agent_engine_with_invalid_async_stream_query_method(
30243165
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
30253166
)
30263167

3168+
def test_create_agent_engine_with_invalid_bidi_stream_query_method(
3169+
self,
3170+
create_agent_engine_mock,
3171+
cloud_storage_create_bucket_mock,
3172+
tarfile_open_mock,
3173+
cloudpickle_dump_mock,
3174+
cloudpickle_load_mock,
3175+
importlib_metadata_version_mock,
3176+
get_agent_engine_mock,
3177+
):
3178+
with pytest.raises(ValueError, match="Invalid bidi_stream_query signature"):
3179+
agent_engines.create(
3180+
InvalidCapitalizeEngineWithoutBidiStreamQuerySelf(),
3181+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
3182+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
3183+
)
3184+
30273185
def test_create_agent_engine_with_invalid_register_operations_method(
30283186
self,
30293187
create_agent_engine_mock,
@@ -3158,8 +3316,8 @@ def test_update_agent_engine_no_query_method(
31583316
TypeError,
31593317
match=(
31603318
"agent_engine has none of the following callable methods: "
3161-
"`query`, `async_query`, `stream_query`, `async_stream_query` "
3162-
"or `register_operations`."
3319+
"`query`, `async_query`, `stream_query`, `async_stream_query`, "
3320+
"`bidi_stream_query` or `register_operations`."
31633321
),
31643322
):
31653323
test_agent_engine = _generate_agent_engine_to_update()
@@ -3181,8 +3339,8 @@ def test_update_agent_engine_noncallable_query_attribute(
31813339
TypeError,
31823340
match=(
31833341
"agent_engine has none of the following callable methods: "
3184-
"`query`, `async_query`, `stream_query`, `async_stream_query` "
3185-
"or `register_operations`."
3342+
"`query`, `async_query`, `stream_query`, `async_stream_query`, "
3343+
"`bidi_stream_query` or `register_operations`."
31863344
),
31873345
):
31883346
test_agent_engine = _generate_agent_engine_to_update()
@@ -3324,7 +3482,8 @@ def test_update_class_methods_spec_with_registered_operation_not_found(self):
33243482
"register the API methods: "
33253483
"https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. "
33263484
"Error: {Unsupported api mode: `UNKNOWN_API_MODE`, "
3327-
"Supported modes are: ``, `async`, `async_stream`, `stream`.}"
3485+
"Supported modes are: ``, `async`, `async_stream`, "
3486+
"`bidi_stream`, `stream`.}"
33283487
),
33293488
),
33303489
],

0 commit comments

Comments
 (0)