@@ -108,6 +108,20 @@ def clone(self):
108108 return self
109109
110110
111+ class BidiStreamQueryEngine :
112+ """A sample Agent Engine that implements `bidi_stream_query`."""
113+
114+ def set_up (self ):
115+ pass
116+
117+ async def bidi_stream_query (self , unused_request_queue ) -> AsyncIterable [Any ]:
118+ """Runs the bidi stream engine."""
119+ raise NotImplementedError ()
120+
121+ def clone (self ):
122+ return self
123+
124+
111125class OperationRegistrableEngine :
112126 """Add a test class that implements OperationRegistrable."""
113127
@@ -141,6 +155,10 @@ async def async_stream_query(
141155 for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE :
142156 yield chunk
143157
158+ async def bidi_stream_query (self , unused_request_queue ) -> AsyncIterable [Any ]:
159+ """Runs the bidi stream engine."""
160+ raise NotImplementedError ()
161+
144162 # Add a custom method to test the custom stream method registration.
145163 def custom_stream_query (self , unused_arbitrary_string_name : str ) -> Iterable [Any ]:
146164 """Runs the stream engine."""
@@ -158,6 +176,12 @@ async def custom_async_stream_method(
158176 for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE :
159177 yield chunk
160178
179+ async def custom_bidi_stream_method (
180+ self , unused_request_queue
181+ ) -> AsyncIterable [Any ]:
182+ """Runs the bidi stream engine."""
183+ raise NotImplementedError ()
184+
161185 def clone (self ):
162186 return self
163187
@@ -890,6 +914,17 @@ async def async_stream_query() -> str:
890914 return "RESPONSE"
891915
892916
917+ class InvalidCapitalizeEngineWithoutBidiStreamQuerySelf :
918+ """A sample Agent Engine with an invalid bidi_stream_query method."""
919+
920+ def set_up (self ):
921+ pass
922+
923+ async def bidi_stream_query () -> AsyncIterable [Any ]:
924+ """Runs the engine."""
925+ raise NotImplementedError ()
926+
927+
893928class InvalidCapitalizeEngineWithoutRegisterOperationsSelf :
894929 """A sample Agent Engine with an invalid register_operations method."""
895930
@@ -1625,6 +1660,23 @@ def test_get_agent_framework(
16251660 ),
16261661 ),
16271662 ),
1663+ (
1664+ "Update the bidi stream query engine" ,
1665+ {"agent_engine" : BidiStreamQueryEngine ()},
1666+ types .reasoning_engine_service .UpdateReasoningEngineRequest (
1667+ reasoning_engine = _generate_agent_engine_with_class_methods_and_agent_framework (
1668+ [],
1669+ _agent_engines ._DEFAULT_AGENT_FRAMEWORK ,
1670+ ),
1671+ update_mask = field_mask_pb2 .FieldMask (
1672+ paths = [
1673+ "spec.package_spec.pickle_object_gcs_uri" ,
1674+ "spec.class_methods" ,
1675+ "spec.agent_framework" ,
1676+ ]
1677+ ),
1678+ ),
1679+ ),
16281680 (
16291681 "Update the operation registrable engine" ,
16301682 {"agent_engine" : OperationRegistrableEngine ()},
@@ -2826,6 +2878,95 @@ async def test_async_stream_query_agent_engine_with_operation_schema(
28262878 )
28272879 )
28282880
2881+ # pytest does not allow absl.testing.parameterized.named_parameters.
2882+ @pytest .mark .parametrize (
2883+ "test_case_name, test_engine, test_class_method_docs, test_class_methods_spec" ,
2884+ [
2885+ (
2886+ "Default Bidi Stream Queryable (Not Operation Registrable) Engine" ,
2887+ BidiStreamQueryEngine (),
2888+ {},
2889+ _TEST_ASYNC_STREAM_QUERY_SCHEMAS ,
2890+ ),
2891+ (
2892+ "Operation Registrable Engine" ,
2893+ OperationRegistrableEngine (),
2894+ {},
2895+ _TEST_OPERATION_REGISTRABLE_SCHEMAS ,
2896+ ),
2897+ ],
2898+ )
2899+ @pytest .mark .asyncio
2900+ async def test_create_agent_engine_with_bidi_stream_query_operation_schema (
2901+ self ,
2902+ test_case_name ,
2903+ test_engine ,
2904+ test_class_method_docs ,
2905+ test_class_methods_spec ,
2906+ ):
2907+ with mock .patch .object (
2908+ base .VertexAiResourceNoun ,
2909+ "_get_gca_resource" ,
2910+ ) as get_gca_resource_mock :
2911+ test_spec = types .ReasoningEngineSpec ()
2912+ test_spec .class_methods .extend (test_class_methods_spec )
2913+ get_gca_resource_mock .return_value = types .ReasoningEngine (
2914+ name = _TEST_AGENT_ENGINE_RESOURCE_NAME ,
2915+ spec = test_spec ,
2916+ )
2917+ agent_engines .create (test_engine )
2918+
2919+ # pytest does not allow absl.testing.parameterized.named_parameters.
2920+ @pytest .mark .parametrize (
2921+ "test_case_name, test_engine, test_class_methods, test_class_methods_spec" ,
2922+ [
2923+ (
2924+ "Default Bidi Stream Queryable (Not Operation Registrable) Engine" ,
2925+ BidiStreamQueryEngine (),
2926+ [],
2927+ [],
2928+ ),
2929+ (
2930+ "Operation Registrable Engine" ,
2931+ OperationRegistrableEngine (),
2932+ [],
2933+ _TEST_OPERATION_REGISTRABLE_SCHEMAS ,
2934+ ),
2935+ ],
2936+ )
2937+ @pytest .mark .asyncio
2938+ async def test_update_agent_engine_with_bidi_stream_query_operation_schema (
2939+ self ,
2940+ test_case_name ,
2941+ test_engine ,
2942+ test_class_methods ,
2943+ test_class_methods_spec ,
2944+ update_agent_engine_mock ,
2945+ ):
2946+ with mock .patch .object (
2947+ base .VertexAiResourceNoun ,
2948+ "_get_gca_resource" ,
2949+ ) as get_gca_resource_mock :
2950+ test_spec = types .ReasoningEngineSpec ()
2951+ test_spec .class_methods .append (_TEST_METHOD_TO_BE_UNREGISTERED_SCHEMA )
2952+ get_gca_resource_mock .return_value = types .ReasoningEngine (
2953+ name = _TEST_AGENT_ENGINE_RESOURCE_NAME , spec = test_spec
2954+ )
2955+ test_agent_engine = agent_engines .create (MethodToBeUnregisteredEngine ())
2956+ assert hasattr (test_agent_engine , _TEST_METHOD_TO_BE_UNREGISTERED_NAME )
2957+
2958+ with mock .patch .object (
2959+ base .VertexAiResourceNoun ,
2960+ "_get_gca_resource" ,
2961+ ) as get_gca_resource_mock :
2962+ test_spec = types .ReasoningEngineSpec ()
2963+ test_spec .class_methods .extend (test_class_methods_spec )
2964+ get_gca_resource_mock .return_value = types .ReasoningEngine (
2965+ name = _TEST_AGENT_ENGINE_RESOURCE_NAME ,
2966+ spec = test_spec ,
2967+ )
2968+ test_agent_engine .update (agent_engine = test_engine )
2969+
28292970
28302971@pytest .mark .usefixtures ("google_auth_mock" )
28312972class TestAgentEngineErrors :
@@ -2887,8 +3028,8 @@ def test_create_agent_engine_no_query_method(
28873028 TypeError ,
28883029 match = (
28893030 "agent_engine has none of the following callable methods: "
2890- "`query`, `async_query`, `stream_query`, `async_stream_query` "
2891- "or `register_operations`."
3031+ "`query`, `async_query`, `stream_query`, `async_stream_query`, "
3032+ "`bidi_stream_query` or `register_operations`."
28923033 ),
28933034 ):
28943035 agent_engines .create (
@@ -2911,8 +3052,8 @@ def test_create_agent_engine_noncallable_query_attribute(
29113052 TypeError ,
29123053 match = (
29133054 "agent_engine has none of the following callable methods: "
2914- "`query`, `async_query`, `stream_query`, `async_stream_query` "
2915- "or `register_operations`."
3055+ "`query`, `async_query`, `stream_query`, `async_stream_query`, "
3056+ "`bidi_stream_query` or `register_operations`."
29163057 ),
29173058 ):
29183059 agent_engines .create (
@@ -3024,6 +3165,23 @@ def test_create_agent_engine_with_invalid_async_stream_query_method(
30243165 requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
30253166 )
30263167
3168+ def test_create_agent_engine_with_invalid_bidi_stream_query_method (
3169+ self ,
3170+ create_agent_engine_mock ,
3171+ cloud_storage_create_bucket_mock ,
3172+ tarfile_open_mock ,
3173+ cloudpickle_dump_mock ,
3174+ cloudpickle_load_mock ,
3175+ importlib_metadata_version_mock ,
3176+ get_agent_engine_mock ,
3177+ ):
3178+ with pytest .raises (ValueError , match = "Invalid bidi_stream_query signature" ):
3179+ agent_engines .create (
3180+ InvalidCapitalizeEngineWithoutBidiStreamQuerySelf (),
3181+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
3182+ requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
3183+ )
3184+
30273185 def test_create_agent_engine_with_invalid_register_operations_method (
30283186 self ,
30293187 create_agent_engine_mock ,
@@ -3158,8 +3316,8 @@ def test_update_agent_engine_no_query_method(
31583316 TypeError ,
31593317 match = (
31603318 "agent_engine has none of the following callable methods: "
3161- "`query`, `async_query`, `stream_query`, `async_stream_query` "
3162- "or `register_operations`."
3319+ "`query`, `async_query`, `stream_query`, `async_stream_query`, "
3320+ "`bidi_stream_query` or `register_operations`."
31633321 ),
31643322 ):
31653323 test_agent_engine = _generate_agent_engine_to_update ()
@@ -3181,8 +3339,8 @@ def test_update_agent_engine_noncallable_query_attribute(
31813339 TypeError ,
31823340 match = (
31833341 "agent_engine has none of the following callable methods: "
3184- "`query`, `async_query`, `stream_query`, `async_stream_query` "
3185- "or `register_operations`."
3342+ "`query`, `async_query`, `stream_query`, `async_stream_query`, "
3343+ "`bidi_stream_query` or `register_operations`."
31863344 ),
31873345 ):
31883346 test_agent_engine = _generate_agent_engine_to_update ()
@@ -3324,7 +3482,8 @@ def test_update_class_methods_spec_with_registered_operation_not_found(self):
33243482 "register the API methods: "
33253483 "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. "
33263484 "Error: {Unsupported api mode: `UNKNOWN_API_MODE`, "
3327- "Supported modes are: ``, `async`, `async_stream`, `stream`.}"
3485+ "Supported modes are: ``, `async`, `async_stream`, "
3486+ "`bidi_stream`, `stream`.}"
33283487 ),
33293488 ),
33303489 ],
0 commit comments