@@ -99,6 +99,44 @@ def _sanitize_schema_type(schema: dict[str, Any]) -> dict[str, Any]:
9999 return schema
100100
101101
102+ def _dereference_schema (schema : dict [str , Any ]) -> dict [str , Any ]:
103+ """Resolves $ref pointers in a JSON schema."""
104+
105+ defs = schema .get ("$defs" , {})
106+
107+ def _resolve_refs (sub_schema : Any ) -> Any :
108+ if isinstance (sub_schema , dict ):
109+ if "$ref" in sub_schema :
110+ ref_key = sub_schema ["$ref" ].split ("/" )[- 1 ]
111+ if ref_key in defs :
112+ # Found the reference, replace it with the definition.
113+ resolved = defs [ref_key ].copy ()
114+ # Merge properties from the reference, allowing overrides.
115+ sub_schema_copy = sub_schema .copy ()
116+ del sub_schema_copy ["$ref" ]
117+ resolved .update (sub_schema_copy )
118+ # Recursively resolve refs in the newly inserted part.
119+ return _resolve_refs (resolved )
120+ else :
121+ # Reference not found, return as is.
122+ return sub_schema
123+ else :
124+ # No $ref, so traverse deeper into the dictionary.
125+ return {key : _resolve_refs (value ) for key , value in sub_schema .items ()}
126+ elif isinstance (sub_schema , list ):
127+ # Traverse into lists.
128+ return [_resolve_refs (item ) for item in sub_schema ]
129+ else :
130+ # Not a dict or list, return as is.
131+ return sub_schema
132+
133+ dereferenced_schema = _resolve_refs (schema )
134+ # Remove the definitions block after resolving.
135+ if "$defs" in dereferenced_schema :
136+ del dereferenced_schema ["$defs" ]
137+ return dereferenced_schema
138+
139+
102140def _sanitize_schema_formats_for_gemini (
103141 schema : dict [str , Any ],
104142) -> dict [str , Any ]:
@@ -109,7 +147,10 @@ def _sanitize_schema_formats_for_gemini(
109147 "any_of" , # 'one_of', 'all_of', 'not' to come
110148 }
111149 snake_case_schema = {}
112- dict_schema_field_names : tuple [str ] = ("properties" ,) # 'defs' to come
150+ dict_schema_field_names : tuple [str , ...] = (
151+ "properties" ,
152+ "defs" ,
153+ )
113154 for field_name , field_value in schema .items ():
114155 field_name = _to_snake_case (field_name )
115156 if field_name in schema_field_names :
@@ -151,8 +192,9 @@ def _to_gemini_schema(openapi_schema: dict[str, Any]) -> Schema:
151192 if not isinstance (openapi_schema , dict ):
152193 raise TypeError ("openapi_schema must be a dictionary" )
153194
154- openapi_schema = _sanitize_schema_formats_for_gemini (openapi_schema )
195+ dereferenced_schema = _dereference_schema (openapi_schema )
196+ sanitized_schema = _sanitize_schema_formats_for_gemini (dereferenced_schema )
155197 return Schema .from_json_schema (
156- json_schema = _ExtendedJSONSchema .model_validate (openapi_schema ),
198+ json_schema = _ExtendedJSONSchema .model_validate (sanitized_schema ),
157199 api_option = get_google_llm_variant (),
158200 )
0 commit comments