Skip to content

Commit ca2a460

Browse files
authored
Merge branch 'main' into fix/async_stream_append_chunk
2 parents 997bf1c + ce13aef commit ca2a460

File tree

4 files changed

+436
-0
lines changed

4 files changed

+436
-0
lines changed

google/genai/_extra_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,35 @@ def format_destination(
117117
return config
118118

119119

120+
def find_afc_incompatible_tool_indexes(
121+
config: Optional[types.GenerateContentConfigOrDict] = None,
122+
) -> list[int]:
123+
"""Checks if the config contains any AFC incompatible tools.
124+
125+
A `types.Tool` object that contains `function_declarations` is considered a
126+
non-AFC tool for this execution path.
127+
128+
Args:
129+
config: The GenerateContentConfig to check for incompatible tools.
130+
131+
Returns:
132+
A list of indexes of the incompatible tools in the config.
133+
"""
134+
if not config:
135+
return []
136+
config_model = _create_generate_content_config_model(config)
137+
incompatible_tools_indexes: list[int] = []
138+
139+
if not config_model or not config_model.tools:
140+
return incompatible_tools_indexes
141+
142+
for index, tool in enumerate(config_model.tools):
143+
if isinstance(tool, types.Tool) and tool.function_declarations:
144+
incompatible_tools_indexes.append(index)
145+
146+
return incompatible_tools_indexes
147+
148+
120149
def get_function_map(
121150
config: Optional[types.GenerateContentConfigOrDict] = None,
122151
mcp_to_genai_tool_adapters: Optional[

google/genai/models.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4993,6 +4993,9 @@ def generate_content(
49934993
# scones.
49944994
"""
49954995

4996+
incompatible_tools_indexes = (
4997+
_extra_utils.find_afc_incompatible_tool_indexes(config)
4998+
)
49964999
parsed_config = _extra_utils.parse_config_for_mcp_usage(config)
49975000
if (
49985001
parsed_config
@@ -5006,6 +5009,28 @@ def generate_content(
50065009
return self._generate_content(
50075010
model=model, contents=contents, config=parsed_config
50085011
)
5012+
if incompatible_tools_indexes:
5013+
original_tools_length = 0
5014+
if isinstance(config, types.GenerateContentConfig):
5015+
if config.tools:
5016+
original_tools_length = len(config.tools)
5017+
elif isinstance(config, dict):
5018+
tools = config.get('tools', [])
5019+
if tools:
5020+
original_tools_length = len(tools)
5021+
if len(incompatible_tools_indexes) != original_tools_length:
5022+
indices_str = ', '.join(map(str, incompatible_tools_indexes))
5023+
logger.warning(
5024+
'Tools at indices [%s] are not compatible with automatic function '
5025+
'calling (AFC). AFC is disabled. If AFC is intended, please '
5026+
'include python callables in the tool list, and do not include '
5027+
'function declaration in the tool list.',
5028+
indices_str,
5029+
)
5030+
return self._generate_content(
5031+
model=model, contents=contents, config=parsed_config
5032+
)
5033+
50095034
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(
50105035
parsed_config
50115036
)
@@ -5129,6 +5154,9 @@ def generate_content_stream(
51295154
# scones.
51305155
"""
51315156

5157+
incompatible_tools_indexes = (
5158+
_extra_utils.find_afc_incompatible_tool_indexes(config)
5159+
)
51325160
parsed_config = _extra_utils.parse_config_for_mcp_usage(config)
51335161
if (
51345162
parsed_config
@@ -5144,6 +5172,27 @@ def generate_content_stream(
51445172
)
51455173
return
51465174

5175+
if incompatible_tools_indexes:
5176+
original_tools_length = 0
5177+
if isinstance(config, types.GenerateContentConfig):
5178+
if config.tools:
5179+
original_tools_length = len(config.tools)
5180+
elif isinstance(config, dict):
5181+
tools = config.get('tools', [])
5182+
if tools:
5183+
original_tools_length = len(tools)
5184+
if len(incompatible_tools_indexes) != original_tools_length:
5185+
indices_str = ', '.join(map(str, incompatible_tools_indexes))
5186+
logger.warning(
5187+
'Tools at indices [%s] are not compatible with automatic function '
5188+
'calling. AFC will be disabled.',
5189+
indices_str,
5190+
)
5191+
yield from self._generate_content_stream(
5192+
model=model, contents=contents, config=parsed_config
5193+
)
5194+
return
5195+
51475196
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(
51485197
parsed_config
51495198
)
@@ -6759,13 +6808,37 @@ async def generate_content(
67596808
# J'aime les bagels.
67606809
"""
67616810
# Retrieve and cache any MCP sessions if provided.
6811+
incompatible_tools_indexes = (
6812+
_extra_utils.find_afc_incompatible_tool_indexes(config)
6813+
)
67626814
parsed_config, mcp_to_genai_tool_adapters = (
67636815
await _extra_utils.parse_config_for_mcp_sessions(config)
67646816
)
67656817
if _extra_utils.should_disable_afc(parsed_config):
67666818
return await self._generate_content(
67676819
model=model, contents=contents, config=parsed_config
67686820
)
6821+
if incompatible_tools_indexes:
6822+
original_tools_length = 0
6823+
if isinstance(config, types.GenerateContentConfig):
6824+
if config.tools:
6825+
original_tools_length = len(config.tools)
6826+
elif isinstance(config, dict):
6827+
tools = config.get('tools', [])
6828+
if tools:
6829+
original_tools_length = len(tools)
6830+
if len(incompatible_tools_indexes) != original_tools_length:
6831+
indices_str = ', '.join(map(str, incompatible_tools_indexes))
6832+
logger.warning(
6833+
'Tools at indices [%s] are not compatible with automatic function '
6834+
'calling (AFC). AFC is disabled. If AFC is intended, please '
6835+
'include python callables in the tool list, and do not include '
6836+
'function declaration in the tool list.',
6837+
indices_str,
6838+
)
6839+
return await self._generate_content(
6840+
model=model, contents=contents, config=parsed_config
6841+
)
67696842
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(
67706843
parsed_config
67716844
)
@@ -6890,6 +6963,10 @@ async def generate_content_stream(
68906963
# scones.
68916964
"""
68926965

6966+
# Retrieve and cache any MCP sessions if provided.
6967+
incompatible_tools_indexes = (
6968+
_extra_utils.find_afc_incompatible_tool_indexes(config)
6969+
)
68936970
# Retrieve and cache any MCP sessions if provided.
68946971
parsed_config, mcp_to_genai_tool_adapters = (
68956972
await _extra_utils.parse_config_for_mcp_sessions(config)
@@ -6905,6 +6982,34 @@ async def base_async_generator(model, contents, config): # type: ignore[no-unty
69056982

69066983
return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
69076984

6985+
if incompatible_tools_indexes:
6986+
original_tools_length = 0
6987+
if isinstance(config, types.GenerateContentConfig):
6988+
if config.tools:
6989+
original_tools_length = len(config.tools)
6990+
elif isinstance(config, dict):
6991+
tools = config.get('tools', [])
6992+
if tools:
6993+
original_tools_length = len(tools)
6994+
if len(incompatible_tools_indexes) != original_tools_length:
6995+
indices_str = ', '.join(map(str, incompatible_tools_indexes))
6996+
logger.warning(
6997+
'Tools at indices [%s] are not compatible with automatic function '
6998+
'calling (AFC). AFC is disabled. If AFC is intended, please '
6999+
'include python callables in the tool list, and do not include '
7000+
'function declaration in the tool list.',
7001+
indices_str,
7002+
)
7003+
response = await self._generate_content_stream(
7004+
model=model, contents=contents, config=parsed_config
7005+
)
7006+
7007+
async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def]
7008+
async for chunk in response: # type: ignore[attr-defined]
7009+
yield chunk
7010+
7011+
return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
7012+
69087013
async def async_generator(model, contents, config): # type: ignore[no-untyped-def]
69097014
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config)
69107015
logger.info(

0 commit comments

Comments
 (0)