Skip to content

Commit a1857ed

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Convert float values with no decimals to integers in FunctionCall.
PiperOrigin-RevId: 695551016
1 parent 8e41265 commit a1857ed

File tree

2 files changed

+144
-8
lines changed

2 files changed

+144
-8
lines changed

tests/unit/vertexai/test_generative_models.py

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# pylint: disable=protected-access,bad-continuation
1919
import io
2020
import pytest
21-
from typing import Iterable, MutableSequence, Optional
21+
from typing import Dict, Iterable, List, MutableSequence, Optional
2222
from unittest import mock
2323

2424
import vertexai
@@ -300,12 +300,31 @@ def mock_generate_content(
300300
if has_function_declarations and not had_any_function_calls:
301301
# response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT
302302
# Workaround for the proto library bug
303-
response_part_struct = dict(
304-
function_call=gapic_tool_types.FunctionCall(
305-
name="get_current_weather",
306-
args={"location": "Boston"},
303+
first_tool_with_function_declarations = next(
304+
tool for tool in request.tools if tool.function_declarations
305+
)
306+
if (
307+
first_tool_with_function_declarations.function_declarations[0].name
308+
== update_weather_data.__name__
309+
):
310+
response_part_struct = dict(
311+
function_call=gapic_tool_types.FunctionCall(
312+
name=update_weather_data.__name__,
313+
args={
314+
"location": "Boston",
315+
"temperature": 60,
316+
"forecasts": [61, 62],
317+
"extra_info": {"humidity": 50},
318+
},
319+
)
320+
)
321+
else:
322+
response_part_struct = dict(
323+
function_call=gapic_tool_types.FunctionCall(
324+
name="get_current_weather",
325+
args={"location": "Boston"},
326+
)
307327
)
308-
)
309328
elif has_function_declarations and latest_user_message_function_responses:
310329
function_response = latest_user_message_function_responses[0]
311330
function_response_dict = type(function_response).to_dict(function_response)
@@ -537,6 +556,19 @@ def get_current_weather(location: str, unit: Optional[str] = "centigrade"):
537556
)
538557

539558

559+
def update_weather_data(
560+
location: str, temperature: int, forecasts: List[int], extra_info: Dict[str, int]
561+
):
562+
"""Updates the weather data in the specified location."""
563+
return dict(
564+
location=location,
565+
temperature=temperature,
566+
forecasts=forecasts,
567+
extra_info=extra_info,
568+
result="Success",
569+
)
570+
571+
540572
@pytest.mark.usefixtures("google_auth_mock")
541573
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
542574
class TestGenerativeModels:
@@ -839,6 +871,46 @@ def test_generate_content_streaming(self, generative_models: generative_models):
839871
for chunk in stream:
840872
assert chunk.text
841873

874+
@patch_genai_services
875+
@pytest.mark.parametrize(
876+
"generative_models",
877+
[generative_models, preview_generative_models],
878+
)
879+
def test_generate_content_with_function_calling_integer_args(
880+
self, generative_models: generative_models
881+
):
882+
model = generative_models.GenerativeModel("gemini-pro")
883+
weather_tool = generative_models.Tool(
884+
function_declarations=[
885+
generative_models.FunctionDeclaration.from_func(update_weather_data)
886+
],
887+
)
888+
response = model.generate_content(
889+
"Please update the weather data in Boston, with temperature 60, "
890+
"forecasts are 61 and 62, and humidity is 50",
891+
tools=[weather_tool],
892+
)
893+
assert (
894+
response.candidates[0].content.parts[0].function_call.name
895+
== "update_weather_data"
896+
)
897+
assert [
898+
function_call.name
899+
for function_call in response.candidates[0].function_calls
900+
] == ["update_weather_data"]
901+
for args in (
902+
response.candidates[0].function_calls[0].args,
903+
response.candidates[0].function_calls[0].to_dict()["args"],
904+
):
905+
assert args["location"] == "Boston"
906+
assert args["temperature"] == 60
907+
assert isinstance(args["temperature"], int)
908+
assert 61 in args["forecasts"] and 62 in args["forecasts"]
909+
assert all(isinstance(forecast, int) for forecast in args["forecasts"])
910+
assert args["extra_info"]["humidity"] == 50
911+
assert isinstance(args["extra_info"]["humidity"], int)
912+
assert args["location"] == "Boston"
913+
842914
@patch_genai_services
843915
@pytest.mark.parametrize(
844916
"generative_models",
@@ -1052,6 +1124,52 @@ def test_chat_function_calling(self, generative_models: generative_models):
10521124
assert "nice" in response2.text
10531125
assert not response2.candidates[0].function_calls
10541126

1127+
@patch_genai_services
1128+
@pytest.mark.parametrize(
1129+
"generative_models",
1130+
[generative_models, preview_generative_models],
1131+
)
1132+
def test_chat_function_calling_with_integer_args(
1133+
self, generative_models: generative_models
1134+
):
1135+
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
1136+
update_weather_data
1137+
)
1138+
weather_tool = generative_models.Tool(
1139+
function_declarations=[get_current_weather_func],
1140+
)
1141+
model = generative_models.GenerativeModel(
1142+
"gemini-pro",
1143+
# Specifying the tools once to avoid specifying them in every request
1144+
tools=[weather_tool],
1145+
)
1146+
chat = model.start_chat()
1147+
1148+
response1 = chat.send_message(
1149+
"Please update the weather data in Boston, with temperature 60, "
1150+
"forecasts are 61 and 62, and humidity is 50"
1151+
)
1152+
assert (
1153+
response1.candidates[0].content.parts[0].function_call.name
1154+
== "update_weather_data"
1155+
)
1156+
assert [
1157+
function_call.name
1158+
for function_call in response1.candidates[0].function_calls
1159+
] == ["update_weather_data"]
1160+
for args in (
1161+
response1.candidates[0].function_calls[0].args,
1162+
response1.candidates[0].function_calls[0].to_dict()["args"],
1163+
):
1164+
assert args["location"] == "Boston"
1165+
assert args["temperature"] == 60
1166+
assert isinstance(args["temperature"], int)
1167+
assert 61 in args["forecasts"] and 62 in args["forecasts"]
1168+
assert all(isinstance(forecast, int) for forecast in args["forecasts"])
1169+
assert args["extra_info"]["humidity"] == 50
1170+
assert isinstance(args["extra_info"]["humidity"], int)
1171+
assert args["location"] == "Boston"
1172+
10551173
@patch_genai_services
10561174
@pytest.mark.parametrize(
10571175
"generative_models",

vertexai/generative_models/_generative_models.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2631,7 +2631,11 @@ def _from_gapic(cls, raw_message: aiplatform_types.FunctionCall) -> "FunctionCal
26312631
return response
26322632

26332633
def to_dict(self) -> Dict[str, Any]:
2634-
return _proto_to_dict(self._raw_message)
2634+
function_call_dict = _proto_to_dict(self._raw_message)
2635+
function_call_dict["args"] = self._convert_number_values(
2636+
function_call_dict["args"]
2637+
)
2638+
return function_call_dict
26352639

26362640
def __repr__(self) -> str:
26372641
return self._raw_message.__repr__()
@@ -2644,7 +2648,21 @@ def name(self) -> str:
26442648
def args(self) -> Dict[str, Any]:
26452649
# We cannot use `type(self.args).to_dict(self.args)`
26462650
# due to: AttributeError: type object 'MapComposite' has no attribute 'to_dict'
2647-
return self.to_dict().get("args")
2651+
args_dict = self.to_dict().get("args")
2652+
return self._convert_number_values(args_dict)
2653+
2654+
def _convert_number_values(self, data: Any) -> Any:
2655+
"""Converts float values with no decimal part to integers."""
2656+
if isinstance(data, float) and data.is_integer():
2657+
return int(data)
2658+
elif isinstance(data, dict):
2659+
return {
2660+
key: self._convert_number_values(value) for key, value in data.items()
2661+
}
2662+
elif isinstance(data, list):
2663+
return [self._convert_number_values(item) for item in data]
2664+
else:
2665+
return data
26482666

26492667

26502668
class SafetySetting:

0 commit comments

Comments
 (0)