|
18 | 18 | # pylint: disable=protected-access,bad-continuation
|
19 | 19 | import io
|
20 | 20 | import pytest
|
21 |
| -from typing import Iterable, MutableSequence, Optional |
| 21 | +from typing import Dict, Iterable, List, MutableSequence, Optional |
22 | 22 | from unittest import mock
|
23 | 23 |
|
24 | 24 | import vertexai
|
@@ -300,12 +300,31 @@ def mock_generate_content(
|
300 | 300 | if has_function_declarations and not had_any_function_calls:
|
301 | 301 | # response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT
|
302 | 302 | # Workaround for the proto library bug
|
303 |
| - response_part_struct = dict( |
304 |
| - function_call=gapic_tool_types.FunctionCall( |
305 |
| - name="get_current_weather", |
306 |
| - args={"location": "Boston"}, |
| 303 | + first_tool_with_function_declarations = next( |
| 304 | + tool for tool in request.tools if tool.function_declarations |
| 305 | + ) |
| 306 | + if ( |
| 307 | + first_tool_with_function_declarations.function_declarations[0].name |
| 308 | + == update_weather_data.__name__ |
| 309 | + ): |
| 310 | + response_part_struct = dict( |
| 311 | + function_call=gapic_tool_types.FunctionCall( |
| 312 | + name=update_weather_data.__name__, |
| 313 | + args={ |
| 314 | + "location": "Boston", |
| 315 | + "temperature": 60, |
| 316 | + "forecasts": [61, 62], |
| 317 | + "extra_info": {"humidity": 50}, |
| 318 | + }, |
| 319 | + ) |
| 320 | + ) |
| 321 | + else: |
| 322 | + response_part_struct = dict( |
| 323 | + function_call=gapic_tool_types.FunctionCall( |
| 324 | + name="get_current_weather", |
| 325 | + args={"location": "Boston"}, |
| 326 | + ) |
307 | 327 | )
|
308 |
| - ) |
309 | 328 | elif has_function_declarations and latest_user_message_function_responses:
|
310 | 329 | function_response = latest_user_message_function_responses[0]
|
311 | 330 | function_response_dict = type(function_response).to_dict(function_response)
|
@@ -537,6 +556,19 @@ def get_current_weather(location: str, unit: Optional[str] = "centigrade"):
|
537 | 556 | )
|
538 | 557 |
|
539 | 558 |
|
| 559 | +def update_weather_data( |
| 560 | + location: str, temperature: int, forecasts: List[int], extra_info: Dict[str, int] |
| 561 | +): |
| 562 | + """Updates the weather data in the specified location.""" |
| 563 | + return dict( |
| 564 | + location=location, |
| 565 | + temperature=temperature, |
| 566 | + forecasts=forecasts, |
| 567 | + extra_info=extra_info, |
| 568 | + result="Success", |
| 569 | + ) |
| 570 | + |
| 571 | + |
540 | 572 | @pytest.mark.usefixtures("google_auth_mock")
|
541 | 573 | @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
|
542 | 574 | class TestGenerativeModels:
|
@@ -839,6 +871,46 @@ def test_generate_content_streaming(self, generative_models: generative_models):
|
839 | 871 | for chunk in stream:
|
840 | 872 | assert chunk.text
|
841 | 873 |
|
| 874 | + @patch_genai_services |
| 875 | + @pytest.mark.parametrize( |
| 876 | + "generative_models", |
| 877 | + [generative_models, preview_generative_models], |
| 878 | + ) |
| 879 | + def test_generate_content_with_function_calling_integer_args( |
| 880 | + self, generative_models: generative_models |
| 881 | + ): |
| 882 | + model = generative_models.GenerativeModel("gemini-pro") |
| 883 | + weather_tool = generative_models.Tool( |
| 884 | + function_declarations=[ |
| 885 | + generative_models.FunctionDeclaration.from_func(update_weather_data) |
| 886 | + ], |
| 887 | + ) |
| 888 | + response = model.generate_content( |
| 889 | + "Please update the weather data in Boston, with temperature 60, " |
| 890 | + "forecasts are 61 and 62, and humidity is 50", |
| 891 | + tools=[weather_tool], |
| 892 | + ) |
| 893 | + assert ( |
| 894 | + response.candidates[0].content.parts[0].function_call.name |
| 895 | + == "update_weather_data" |
| 896 | + ) |
| 897 | + assert [ |
| 898 | + function_call.name |
| 899 | + for function_call in response.candidates[0].function_calls |
| 900 | + ] == ["update_weather_data"] |
| 901 | + for args in ( |
| 902 | + response.candidates[0].function_calls[0].args, |
| 903 | + response.candidates[0].function_calls[0].to_dict()["args"], |
| 904 | + ): |
| 905 | + assert args["location"] == "Boston" |
| 906 | + assert args["temperature"] == 60 |
| 907 | + assert isinstance(args["temperature"], int) |
| 908 | + assert 61 in args["forecasts"] and 62 in args["forecasts"] |
| 909 | + assert all(isinstance(forecast, int) for forecast in args["forecasts"]) |
| 910 | + assert args["extra_info"]["humidity"] == 50 |
| 911 | + assert isinstance(args["extra_info"]["humidity"], int) |
| 912 | + assert args["location"] == "Boston" |
| 913 | + |
842 | 914 | @patch_genai_services
|
843 | 915 | @pytest.mark.parametrize(
|
844 | 916 | "generative_models",
|
@@ -1052,6 +1124,52 @@ def test_chat_function_calling(self, generative_models: generative_models):
|
1052 | 1124 | assert "nice" in response2.text
|
1053 | 1125 | assert not response2.candidates[0].function_calls
|
1054 | 1126 |
|
| 1127 | + @patch_genai_services |
| 1128 | + @pytest.mark.parametrize( |
| 1129 | + "generative_models", |
| 1130 | + [generative_models, preview_generative_models], |
| 1131 | + ) |
| 1132 | + def test_chat_function_calling_with_integer_args( |
| 1133 | + self, generative_models: generative_models |
| 1134 | + ): |
| 1135 | + get_current_weather_func = generative_models.FunctionDeclaration.from_func( |
| 1136 | + update_weather_data |
| 1137 | + ) |
| 1138 | + weather_tool = generative_models.Tool( |
| 1139 | + function_declarations=[get_current_weather_func], |
| 1140 | + ) |
| 1141 | + model = generative_models.GenerativeModel( |
| 1142 | + "gemini-pro", |
| 1143 | + # Specifying the tools once to avoid specifying them in every request |
| 1144 | + tools=[weather_tool], |
| 1145 | + ) |
| 1146 | + chat = model.start_chat() |
| 1147 | + |
| 1148 | + response1 = chat.send_message( |
| 1149 | + "Please update the weather data in Boston, with temperature 60, " |
| 1150 | + "forecasts are 61 and 62, and humidity is 50" |
| 1151 | + ) |
| 1152 | + assert ( |
| 1153 | + response1.candidates[0].content.parts[0].function_call.name |
| 1154 | + == "update_weather_data" |
| 1155 | + ) |
| 1156 | + assert [ |
| 1157 | + function_call.name |
| 1158 | + for function_call in response1.candidates[0].function_calls |
| 1159 | + ] == ["update_weather_data"] |
| 1160 | + for args in ( |
| 1161 | + response1.candidates[0].function_calls[0].args, |
| 1162 | + response1.candidates[0].function_calls[0].to_dict()["args"], |
| 1163 | + ): |
| 1164 | + assert args["location"] == "Boston" |
| 1165 | + assert args["temperature"] == 60 |
| 1166 | + assert isinstance(args["temperature"], int) |
| 1167 | + assert 61 in args["forecasts"] and 62 in args["forecasts"] |
| 1168 | + assert all(isinstance(forecast, int) for forecast in args["forecasts"]) |
| 1169 | + assert args["extra_info"]["humidity"] == 50 |
| 1170 | + assert isinstance(args["extra_info"]["humidity"], int) |
| 1171 | + assert args["location"] == "Boston" |
| 1172 | + |
1055 | 1173 | @patch_genai_services
|
1056 | 1174 | @pytest.mark.parametrize(
|
1057 | 1175 | "generative_models",
|
|
0 commit comments