Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions google/auth/_agent_identity_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers for Agent Identity credentials."""

import base64
import hashlib
import os
import re

from google.auth import _exponential_backoff
from google.auth import environment_vars
from google.auth import exceptions

# SPIFFE trust domain patterns for Agent Identities.
_AGENT_IDENTITY_SPIFFE_TRUST_DOMAIN_PATTERNS = [
r"^agents\.global\.org-\d+\.system\.id\.goog$",
r"^agents\.global\.proj-\d+\.system\.id\.goog$",
]


def get_agent_identity_certificate_path():
"""Gets the certificate path from the certificate config file.
The path to the certificate config file is read from the
GOOGLE_API_CERTIFICATE_CONFIG environment variable. This function
implements a retry mechanism to handle cases where the environment
variable is set before the file is available on the filesystem.
Returns:
str: The path to the leaf certificate file.
Raises:
google.auth.exceptions.RefreshError: If the certificate config file
or the certificate file cannot be found after retries.
"""
import json

cert_config_path = os.environ.get(environment_vars.GOOGLE_API_CERTIFICATE_CONFIG)
if not cert_config_path:
return None

# Use exponential backoff to retry loading the certificate config file.
# This is to handle the race condition where the env var is set before the file exists.
backoff = _exponential_backoff.ExponentialBackoff(total_attempts=5)

for _ in backoff:
if os.path.exists(cert_config_path):
with open(cert_config_path, "r") as f:
cert_config = json.load(f)
cert_path = (
cert_config.get("cert_configs", {})
.get("workload", {})
.get("cert_path")
)
if cert_path and os.path.exists(cert_path):
return cert_path

raise exceptions.RefreshError(
"Certificate config or certificate file not found after multiple retries."
)


def _is_agent_identity_certificate(cert_bytes):
"""Checks if a certificate is an Agent Identity certificate.
This is determined by checking the Subject Alternative Name (SAN) for a
SPIFFE ID with a trust domain matching Agent Identity patterns.
Args:
cert_bytes (bytes): The PEM-encoded certificate bytes.
Returns:
bool: True if the certificate is an Agent Identity certificate,
False otherwise.
"""
from cryptography import x509
from cryptography.x509.oid import ExtensionOID

cert = x509.load_pem_x509_certificate(cert_bytes)
try:
ext = cert.extensions.get_extension_for_oid(
ExtensionOID.SUBJECT_ALTERNATIVE_NAME
)
except x509.ExtensionNotFound:
return False
uris = ext.value.get_values_for_type(x509.UniformResourceIdentifier)

for uri in uris:
if uri.startswith("spiffe://"):
spiffe_id = uri[len("spiffe://") :]
trust_domain = spiffe_id.split("/", 1)[0]
for pattern in _AGENT_IDENTITY_SPIFFE_TRUST_DOMAIN_PATTERNS:
if re.match(pattern, trust_domain):
return True
return False


def calculate_certificate_fingerprint(cert_bytes):
"""Calculates the base64-encoded SHA256 hash of a DER-encoded certificate.
Args:
cert_bytes (bytes): The PEM-encoded certificate bytes.
Returns:
str: The base64-encoded SHA256 fingerprint.
"""
from cryptography import x509
from cryptography.hazmat.primitives import serialization

cert = x509.load_pem_x509_certificate(cert_bytes)
der_cert = cert.public_bytes(serialization.Encoding.DER)
fingerprint = hashlib.sha256(der_cert).digest()
return base64.b64encode(fingerprint).decode("utf-8")


def should_request_bound_token(cert_bytes):
"""Determines if a bound token should be requested.
This is based on the GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES
environment variable and whether the certificate is an agent identity cert.
Returns:
bool: True if a bound token should be requested, False otherwise.
"""
is_agent_cert = _is_agent_identity_certificate(cert_bytes)
is_opted_in = (
os.environ.get(
"GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES", "true"
).lower()
== "true"
)
return is_agent_cert and is_opted_in
20 changes: 17 additions & 3 deletions google/auth/compute_engine/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,26 @@ def get_service_account_token(request, service_account="default", scopes=None):
google.auth.exceptions.TransportError: if an error occurred while
retrieving metadata.
"""
from google.auth import _agent_identity_utils

params = {}
if scopes:
if not isinstance(scopes, str):
scopes = ",".join(scopes)
params = {"scopes": scopes}
else:
params = None
params["scopes"] = scopes

try:
cert_path = _agent_identity_utils.get_agent_identity_certificate_path()
if cert_path:
with open(cert_path, "rb") as cert_file:
cert_bytes = cert_file.read()
if _agent_identity_utils.should_request_bound_token(cert_bytes):
fingerprint = _agent_identity_utils.calculate_certificate_fingerprint(
cert_bytes
)
params["bindCertificateFingerprint"] = fingerprint
except exceptions.RefreshError as e:
_LOGGER.warning("Could not load agent identity certificate: %s", e)

metrics_header = {
metrics.API_CLIENT_HEADER: metrics.token_request_access_token_mds()
Expand Down
2 changes: 1 addition & 1 deletion google/auth/compute_engine/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def _refresh_token(self, request):
service can't be reached if if the instance has not
credentials.
"""
scopes = self._scopes if self._scopes is not None else self._default_scopes
try:
self._retrieve_info(request)
scopes = self._scopes if self._scopes is not None else self._default_scopes
# Always fetch token with default service account email.
self.token, self.expiry = _metadata.get_service_account_token(
request, service_account="default", scopes=scopes
Expand Down
4 changes: 4 additions & 0 deletions google/auth/environment_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@
GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED = "GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED"
"""Environment variable controlling whether to enable trust boundary feature.
The default value is false. Users have to explicitly set this value to true."""

GOOGLE_API_CERTIFICATE_CONFIG = "GOOGLE_API_CERTIFICATE_CONFIG"
"""Environment variable defining the location of Google API certificate config
file."""
15 changes: 11 additions & 4 deletions google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def refresh(self, request):
credentials, it will refresh the access token and the trust boundary.
"""
self._refresh_token(request)
self._handle_trust_boundary(request)

def _handle_trust_boundary(self, request):
# If we are impersonating, the trust boundary is handled by the
# impersonated credentials object. We need to get it from there.
if self._service_account_impersonation_url:
Expand All @@ -428,7 +431,7 @@ def refresh(self, request):
# Otherwise, refresh the trust boundary for the external account.
self._refresh_trust_boundary(request)

def _refresh_token(self, request):
def _refresh_token(self, request, cert_fingerprint=None):
scopes = self._scopes if self._scopes is not None else self._default_scopes

# Inject client certificate into request.
Expand All @@ -446,11 +449,15 @@ def _refresh_token(self, request):
self.expiry = self._impersonated_credentials.expiry
else:
now = _helpers.utcnow()
additional_options = None
additional_options = {}
# Do not pass workforce_pool_user_project when client authentication
# is used. The client ID is sufficient for determining the user project.
if self._workforce_pool_user_project and not self._client_id:
additional_options = {"userProject": self._workforce_pool_user_project}
additional_options["userProject"] = self._workforce_pool_user_project

if cert_fingerprint:
additional_options["bindCertificateFingerprint"] = cert_fingerprint

additional_headers = {
metrics.API_CLIENT_HEADER: metrics.byoid_metrics_header(
self._metrics_options
Expand All @@ -464,7 +471,7 @@ def _refresh_token(self, request):
audience=self._audience,
scopes=scopes,
requested_token_type=_STS_REQUESTED_TOKEN_TYPE,
additional_options=additional_options,
additional_options=additional_options if additional_options else None,
additional_headers=additional_headers,
)
self.token = response_data.get("access_token")
Expand Down
21 changes: 21 additions & 0 deletions google/auth/identity_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,24 @@ def from_file(cls, filename, **kwargs):
credentials.
"""
return super(Credentials, cls).from_file(filename, **kwargs)

def refresh(self, request):
"""Refreshes the access token.
Args:
request (google.auth.transport.Request): The object used to make
HTTP requests.
"""
from google.auth import _agent_identity_utils

cert_fingerprint = None
# Check if the credential is X.509 based.
if self._credential_source_certificate is not None:
cert_bytes = self._get_cert_bytes()
if _agent_identity_utils.should_request_bound_token(cert_bytes):
cert_fingerprint = _agent_identity_utils.calculate_certificate_fingerprint(
cert_bytes
)

self._refresh_token(request, cert_fingerprint=cert_fingerprint)
self._handle_trust_boundary(request)
4 changes: 2 additions & 2 deletions google/auth/transport/_mtls_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import re
import subprocess

from google.auth import environment_vars
from google.auth import exceptions

CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json"
CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json"
_CERTIFICATE_CONFIGURATION_ENV = "GOOGLE_API_CERTIFICATE_CONFIG"
_CERT_PROVIDER_COMMAND = "cert_provider_command"
_CERT_REGEX = re.compile(
b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL
Expand Down Expand Up @@ -132,7 +132,7 @@ def _get_cert_config_path(certificate_config_path=None):
"""

if certificate_config_path is None:
env_path = environ.get(_CERTIFICATE_CONFIGURATION_ENV, None)
env_path = environ.get(environment_vars.GOOGLE_API_CERTIFICATE_CONFIG, None)
if env_path is not None and env_path != "":
certificate_config_path = env_path
else:
Expand Down
62 changes: 62 additions & 0 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,68 @@ def test_build_trust_boundary_lookup_url_no_email(

assert excinfo.match(r"missing 'email' field")

@mock.patch("google.auth.compute_engine._metadata.get")
@mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path")
@mock.patch(
"google.auth._agent_identity_utils.should_request_bound_token",
return_value=True,
)
@mock.patch(
"google.auth._agent_identity_utils.calculate_certificate_fingerprint",
return_value="fingerprint",
)
def test_refresh_with_agent_identity(
self,
mock_calculate_fingerprint,
mock_should_request,
mock_get_path,
mock_metadata_get,
tmpdir,
):
cert_path = tmpdir.join("cert.pem")
cert_path.write(b"cert_content")
mock_get_path.return_value = str(cert_path)

mock_metadata_get.side_effect = [
{"email": "service-account@example.com", "scopes": ["one", "two"]},
{"access_token": "token", "expires_in": 500},
]

self.credentials.refresh(None)

assert self.credentials.token == "token"
mock_should_request.assert_called_once_with(b"cert_content")
kwargs = mock_metadata_get.call_args[1]
assert kwargs["params"] == {
"scopes": "one,two",
"bindCertificateFingerprint": "fingerprint",
}

@mock.patch("google.auth.compute_engine._metadata.get")
@mock.patch("google.auth._agent_identity_utils.get_agent_identity_certificate_path")
@mock.patch(
"google.auth._agent_identity_utils.should_request_bound_token",
return_value=False,
)
def test_refresh_with_agent_identity_opt_out_or_not_agent(
self, mock_should_request, mock_get_path, mock_metadata_get, tmpdir
):
cert_path = tmpdir.join("cert.pem")
cert_path.write(b"cert_content")
mock_get_path.return_value = str(cert_path)

mock_metadata_get.side_effect = [
{"email": "service-account@example.com", "scopes": ["one", "two"]},
{"access_token": "token", "expires_in": 500},
]

self.credentials.refresh(None)

assert self.credentials.token == "token"
mock_should_request.assert_called_once_with(b"cert_content")
kwargs = mock_metadata_get.call_args[1]
assert "bindCertificateFingerprint" not in kwargs.get("params", {})


class TestIDTokenCredentials(object):
credentials = None
Expand Down
Loading