Skip to content

Commit ed37e34

Browse files
Prhmmacopybara-github
authored andcommitted
feat(tools): support additional headers for google api toolset #non-breaking
Merge #3194 Allow Google API toolsets to accept optional per-request headers #3105 ## Testing Plan ### Unit Tests - ✅ Added `test_init_with_additional_headers` in `test_google_api_tool.py` to verify headers are passed to RestApiTool - ✅ Added `test_prepare_request_params_merges_default_headers` in `test_rest_api_tool.py` to verify custom headers are merged into requests - ✅ Added `test_prepare_request_params_preserves_existing_headers` in `test_rest_api_tool.py` to verify critical headers (Content-Type, User-Agent) are not overridden by additional_headers - ✅ Updated `test_init` and `test_get_tools` in `test_google_api_toolset.py` to verify the parameter is properly stored and passed through ### Manual Testing Tested with Google Ads API scenario (the original use case from issue #3105): ```python import os from google.adk.tools.google_api_tool import GoogleApiToolset # Create toolset with developer-token header required by Google Ads API google_ads_toolset = GoogleApiToolset( client_id=os.environ["CLIENT_ID"], client_secret=os.environ["CLIENT_SECRET"], api_name="googleads", api_version="v21", additional_headers={"developer-token": os.environ["GOOGLE_ADS_DEV_TOKEN"]} ) # Verify headers are included in API requests tools = await google_ads_toolset.get_tools() # Successfully made requests with the developer-token header COPYBARA_INTEGRATE_REVIEW=#3194 from Prhmma:feature/google-api-toolset-additional-headers-3105 e10489e PiperOrigin-RevId: 822273582
1 parent ce3418a commit ed37e34

File tree

6 files changed

+102
-4
lines changed

6 files changed

+102
-4
lines changed

src/google/adk/tools/google_api_tool/google_api_tool.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ def __init__(
3939
client_id: Optional[str] = None,
4040
client_secret: Optional[str] = None,
4141
service_account: Optional[ServiceAccount] = None,
42+
*,
43+
additional_headers: Optional[Dict[str, str]] = None,
4244
):
4345
super().__init__(
4446
name=rest_api_tool.name,
4547
description=rest_api_tool.description,
4648
is_long_running=rest_api_tool.is_long_running,
4749
)
4850
self._rest_api_tool = rest_api_tool
51+
if additional_headers:
52+
self._rest_api_tool.set_default_headers(additional_headers)
4953
if service_account is not None:
5054
self.configure_sa_auth(service_account)
51-
else:
55+
elif client_id is not None and client_secret is not None:
5256
self.configure_auth(client_id, client_secret)
5357

5458
@override
@@ -57,7 +61,7 @@ def _get_declaration(self) -> FunctionDeclaration:
5761

5862
@override
5963
async def run_async(
60-
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
64+
self, *, args: Dict[str, Any], tool_context: Optional[ToolContext]
6165
) -> Dict[str, Any]:
6266
return await self._rest_api_tool.run_async(
6367
args=args, tool_context=tool_context

src/google/adk/tools/google_api_tool/google_api_toolset.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Dict
1718
from typing import List
1819
from typing import Optional
1920
from typing import Union
@@ -45,6 +46,8 @@ class GoogleApiToolset(BaseToolset):
4546
tool_filter: Optional filter to include only specific tools or use a predicate function.
4647
service_account: Optional service account for authentication.
4748
tool_name_prefix: Optional prefix to add to all tool names in this toolset.
49+
additional_headers: Optional dict of HTTP headers to inject into every request
50+
executed by this toolset.
4851
"""
4952

5053
def __init__(
@@ -56,13 +59,16 @@ def __init__(
5659
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
5760
service_account: Optional[ServiceAccount] = None,
5861
tool_name_prefix: Optional[str] = None,
62+
*,
63+
additional_headers: Optional[Dict[str, str]] = None,
5964
):
6065
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
6166
self.api_name = api_name
6267
self.api_version = api_version
6368
self._client_id = client_id
6469
self._client_secret = client_secret
6570
self._service_account = service_account
71+
self._additional_headers = additional_headers
6672
self._openapi_toolset = self._load_toolset_with_oidc_auth()
6773

6874
@override
@@ -72,7 +78,11 @@ async def get_tools(
7278
"""Get all tools in the toolset."""
7379
return [
7480
GoogleApiTool(
75-
tool, self._client_id, self._client_secret, self._service_account
81+
tool,
82+
self._client_id,
83+
self._client_secret,
84+
self._service_account,
85+
additional_headers=self._additional_headers,
7686
)
7787
for tool in await self._openapi_toolset.get_tools(readonly_context)
7888
if self._is_tool_selected(tool, readonly_context)

src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134

135135
# Private properties
136136
self.credential_exchanger = AutoAuthCredentialExchanger()
137+
self._default_headers: Dict[str, str] = {}
137138
if should_parse_operation:
138139
self._operation_parser = OperationParser(self.operation)
139140

@@ -216,6 +217,10 @@ def configure_auth_credential(
216217
auth_credential = AuthCredential.model_validate_json(auth_credential)
217218
self.auth_credential = auth_credential
218219

220+
def set_default_headers(self, headers: Dict[str, str]):
221+
"""Sets default headers that are merged into every request."""
222+
self._default_headers = headers
223+
219224
def _prepare_auth_request_params(
220225
self,
221226
auth_scheme: AuthScheme,
@@ -335,6 +340,9 @@ def _prepare_request_params(
335340
k: v for k, v in query_params.items() if v is not None
336341
}
337342

343+
for key, value in self._default_headers.items():
344+
header_params.setdefault(key, value)
345+
338346
request_params: Dict[str, Any] = {
339347
"method": method,
340348
"url": url,

tests/unittests/tools/google_api_tool/test_google_api_tool.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def test_init(self, mock_rest_api_tool):
5656
assert tool.is_long_running is False
5757
assert tool._rest_api_tool == mock_rest_api_tool
5858

59+
def test_init_with_additional_headers(self, mock_rest_api_tool):
60+
"""Test GoogleApiTool initialization with additional headers."""
61+
headers = {"developer-token": "test-token"}
62+
63+
GoogleApiTool(mock_rest_api_tool, additional_headers=headers)
64+
65+
mock_rest_api_tool.set_default_headers.assert_called_once_with(headers)
66+
5967
def test_get_declaration(self, mock_rest_api_tool):
6068
"""Test _get_declaration method."""
6169
tool = GoogleApiTool(mock_rest_api_tool)

tests/unittests/tools/google_api_tool/test_google_api_toolset.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,14 @@ def test_init(
126126

127127
client_id = "test_client_id"
128128
client_secret = "test_client_secret"
129+
additional_headers = {"developer-token": "abc123"}
129130

130131
tool_set = GoogleApiToolset(
131132
api_name=TEST_API_NAME,
132133
api_version=TEST_API_VERSION,
133134
client_id=client_id,
134135
client_secret=client_secret,
136+
additional_headers=additional_headers,
135137
)
136138

137139
assert tool_set.api_name == TEST_API_NAME
@@ -141,6 +143,7 @@ def test_init(
141143
assert tool_set._service_account is None
142144
assert tool_set.tool_filter is None
143145
assert tool_set._openapi_toolset == mock_openapi_toolset_instance
146+
assert tool_set._additional_headers == additional_headers
144147

145148
mock_converter_class.assert_called_once_with(
146149
TEST_API_NAME, TEST_API_VERSION
@@ -191,13 +194,15 @@ async def test_get_tools(
191194
client_id = "cid"
192195
client_secret = "csecret"
193196
sa_mock = mock.MagicMock(spec=ServiceAccount)
197+
additional_headers = {"developer-token": "token"}
194198

195199
tool_set = GoogleApiToolset(
196200
api_name=TEST_API_NAME,
197201
api_version=TEST_API_VERSION,
198202
client_id=client_id,
199203
client_secret=client_secret,
200204
service_account=sa_mock,
205+
additional_headers=additional_headers,
201206
)
202207

203208
tools = await tool_set.get_tools(mock_readonly_context)
@@ -209,7 +214,11 @@ async def test_get_tools(
209214

210215
for i, rest_tool in enumerate(mock_rest_api_tools):
211216
mock_google_api_tool_class.assert_any_call(
212-
rest_tool, client_id, client_secret, sa_mock
217+
rest_tool,
218+
client_id,
219+
client_secret,
220+
sa_mock,
221+
additional_headers=additional_headers,
213222
)
214223
assert tools[i] is mock_google_api_tool_instances[i]
215224

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,65 @@ def test_prepare_request_params_unknown_parameter(
686686
# Make sure unknown parameters are ignored and do not raise errors.
687687
assert "unknown_param" not in request_params["params"]
688688

689+
def test_prepare_request_params_merges_default_headers(
690+
self,
691+
sample_endpoint,
692+
sample_auth_credential,
693+
sample_auth_scheme,
694+
sample_operation,
695+
):
696+
tool = RestApiTool(
697+
name="test_tool",
698+
description="Test Tool",
699+
endpoint=sample_endpoint,
700+
operation=sample_operation,
701+
auth_credential=sample_auth_credential,
702+
auth_scheme=sample_auth_scheme,
703+
)
704+
tool.set_default_headers({"developer-token": "token"})
705+
706+
request_params = tool._prepare_request_params([], {})
707+
708+
assert request_params["headers"]["developer-token"] == "token"
709+
710+
def test_prepare_request_params_preserves_existing_headers(
711+
self,
712+
sample_endpoint,
713+
sample_auth_credential,
714+
sample_auth_scheme,
715+
sample_operation,
716+
sample_api_parameters,
717+
):
718+
tool = RestApiTool(
719+
name="test_tool",
720+
description="Test Tool",
721+
endpoint=sample_endpoint,
722+
operation=sample_operation,
723+
auth_credential=sample_auth_credential,
724+
auth_scheme=sample_auth_scheme,
725+
)
726+
tool.set_default_headers({
727+
"Content-Type": "text/plain",
728+
"developer-token": "token",
729+
"User-Agent": "custom-default",
730+
})
731+
732+
header_param = ApiParameter(
733+
original_name="User-Agent",
734+
py_name="user_agent",
735+
param_location="header",
736+
param_schema=OpenAPISchema(type="string"),
737+
)
738+
739+
params = sample_api_parameters + [header_param]
740+
kwargs = {"test_body_param": "value", "user_agent": "api-client"}
741+
742+
request_params = tool._prepare_request_params(params, kwargs)
743+
744+
assert request_params["headers"]["Content-Type"] == "application/json"
745+
assert request_params["headers"]["developer-token"] == "token"
746+
assert request_params["headers"]["User-Agent"] == "api-client"
747+
689748
def test_prepare_request_params_base_url_handling(
690749
self, sample_auth_credential, sample_auth_scheme, sample_operation
691750
):

0 commit comments

Comments
 (0)