Skip to content

Commit 9ef8d05

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: bring instrumentation in ADK preview template to parity with GA one
PiperOrigin-RevId: 825088080
1 parent 67f9099 commit 9ef8d05

File tree

2 files changed

+390
-51
lines changed

2 files changed

+390
-51
lines changed

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 175 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
import base64
1717
import importlib
1818
import json
19+
import dataclasses
20+
import os
1921
from unittest import mock
22+
from typing import Optional
2023

2124
from google import auth
2225
import vertexai
@@ -25,6 +28,7 @@
2528
from vertexai.preview import reasoning_engines
2629
from google.genai import types
2730
import pytest
31+
import uuid
2832

2933

3034
try:
@@ -44,6 +48,7 @@ def __init__(self, name: str, model: str):
4448
_TEST_MODEL = "gemini-2.0-flash"
4549
_TEST_USER_ID = "test_user_id"
4650
_TEST_AGENT_NAME = "test_agent"
51+
_TEST_AGENT = Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
4752
_TEST_SESSION = {
4853
"id": "ca18c25a-644b-4e13-9b24-78c150ec3eb9",
4954
"app_name": "default-app-name",
@@ -92,15 +97,6 @@ def vertexai_init_mock():
9297
yield vertexai_init_mock
9398

9499

95-
@pytest.fixture
96-
def cloud_trace_exporter_mock():
97-
with mock.patch.object(
98-
_utils,
99-
"_import_cloud_trace_exporter_or_warn",
100-
) as cloud_trace_exporter_mock:
101-
yield cloud_trace_exporter_mock
102-
103-
104100
@pytest.fixture
105101
def tracer_provider_mock():
106102
with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock:
@@ -116,12 +112,53 @@ def simple_span_processor_mock():
116112

117113

118114
@pytest.fixture
119-
def mock_adk_version():
115+
def cloud_trace_exporter_mock():
116+
import sys
117+
import opentelemetry
118+
119+
mock_cloud_trace_exporter = mock.Mock()
120+
121+
opentelemetry.exporter = type(sys)("exporter")
122+
opentelemetry.exporter.cloud_trace = type(sys)("cloud_trace")
123+
opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter = (
124+
mock_cloud_trace_exporter
125+
)
126+
127+
sys.modules["opentelemetry.exporter"] = opentelemetry.exporter
128+
sys.modules["opentelemetry.exporter.cloud_trace"] = (
129+
opentelemetry.exporter.cloud_trace
130+
)
131+
132+
yield mock_cloud_trace_exporter
133+
134+
del sys.modules["opentelemetry.exporter.cloud_trace"]
135+
del sys.modules["opentelemetry.exporter"]
136+
137+
138+
@pytest.fixture
139+
def trace_provider_mock():
140+
import opentelemetry.sdk.trace
141+
142+
with mock.patch.object(
143+
opentelemetry.sdk.trace, "TracerProvider"
144+
) as tracer_provider_mock:
145+
yield tracer_provider_mock
146+
147+
148+
@pytest.fixture
149+
def default_instrumentor_builder_mock():
120150
with mock.patch(
121-
"google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version",
122-
return_value="1.5.0",
123-
):
124-
yield
151+
"google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk._default_instrumentor_builder"
152+
) as default_instrumentor_builder_mock:
153+
yield default_instrumentor_builder_mock
154+
155+
156+
@pytest.fixture
157+
def adk_version_mock():
158+
with mock.patch(
159+
"google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version"
160+
) as adk_version_mock:
161+
yield adk_version_mock
125162

126163

127164
class _MockRunner:
@@ -520,6 +557,130 @@ async def test_async_search_memory(self):
520557
)
521558
assert len(response.memories) >= 1
522559

560+
@pytest.mark.parametrize(
561+
"adk_version,enable_tracing,enable_telemetry,want_tracing_setup,want_logging_setup",
562+
[
563+
("1.16.0", False, False, False, False),
564+
("1.16.0", False, True, False, True),
565+
("1.16.0", False, None, False, False),
566+
("1.16.0", True, False, False, False),
567+
("1.16.0", True, True, True, True),
568+
("1.16.0", True, None, True, False),
569+
("1.16.0", None, False, False, False),
570+
("1.16.0", None, True, False, True),
571+
("1.16.0", None, None, False, False),
572+
("1.17.0", False, False, False, False),
573+
("1.17.0", False, True, False, True),
574+
("1.17.0", False, None, False, False),
575+
("1.17.0", True, False, False, False),
576+
("1.17.0", True, True, True, True),
577+
("1.17.0", True, None, True, False),
578+
("1.17.0", None, False, False, False),
579+
("1.17.0", None, True, True, True),
580+
("1.17.0", None, None, False, False),
581+
],
582+
)
583+
@mock.patch.dict(os.environ)
584+
def test_default_instrumentor_enablement(
585+
self,
586+
adk_version: str,
587+
enable_tracing: Optional[bool],
588+
enable_telemetry: Optional[bool],
589+
want_tracing_setup: bool,
590+
want_logging_setup: bool,
591+
default_instrumentor_builder_mock: mock.Mock,
592+
adk_version_mock: mock.Mock,
593+
):
594+
# Arrange
595+
adk_version_mock.return_value = adk_version
596+
if enable_telemetry is not None:
597+
os.environ["GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"] = str(
598+
enable_telemetry
599+
)
600+
601+
app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=enable_tracing)
602+
603+
# Act
604+
app.set_up()
605+
606+
# Assert
607+
default_instrumentor_builder_mock.assert_called_once_with(
608+
_TEST_PROJECT,
609+
enable_tracing=want_tracing_setup,
610+
enable_logging=want_logging_setup,
611+
)
612+
613+
@mock.patch.dict(
614+
os.environ,
615+
{
616+
"GOOGLE_CLOUD_AGENT_ENGINE_ID": "test_agent_id",
617+
"OTEL_RESOURCE_ATTRIBUTES": "some-attribute=some-value",
618+
},
619+
)
620+
def test_tracing_setup(
621+
self,
622+
trace_provider_mock: mock.Mock,
623+
cloud_trace_exporter_mock: mock.Mock,
624+
monkeypatch: pytest.MonkeyPatch,
625+
):
626+
monkeypatch.setattr(
627+
"uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678")
628+
)
629+
monkeypatch.setattr("os.getpid", lambda: 123123123)
630+
app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
631+
app.set_up()
632+
633+
expected_attributes = {
634+
"telemetry.sdk.language": "python",
635+
"telemetry.sdk.name": "opentelemetry",
636+
"telemetry.sdk.version": "1.36.0",
637+
"gcp.project_id": "test-project",
638+
"cloud.account.id": "test-project",
639+
"service.name": "test_agent_id",
640+
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project/locations/us-central1/reasoningEngines/test_agent_id",
641+
"service.instance.id": "12345678123456781234567812345678-123123123",
642+
"cloud.region": "us-central1",
643+
"some-attribute": "some-value",
644+
}
645+
646+
@dataclasses.dataclass
647+
class RegexMatchingAll:
648+
keys: set[str]
649+
650+
def __eq__(self, regex: object) -> bool:
651+
return isinstance(regex, str) and set(regex.split("|")) == self.keys
652+
653+
cloud_trace_exporter_mock.assert_called_once_with(
654+
project_id=_TEST_PROJECT,
655+
client=mock.ANY,
656+
resource_regex=RegexMatchingAll(keys=set(expected_attributes.keys())),
657+
)
658+
659+
assert (
660+
trace_provider_mock.call_args.kwargs["resource"].attributes
661+
== expected_attributes
662+
)
663+
664+
@mock.patch.dict(os.environ)
665+
def test_span_content_capture_disabled_by_default(self):
666+
app = reasoning_engines.AdkApp(agent=_TEST_AGENT)
667+
app.set_up()
668+
assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "false"
669+
670+
@mock.patch.dict(
671+
os.environ, {"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT": "true"}
672+
)
673+
def test_span_content_capture_disabled_with_env_var(self):
674+
app = reasoning_engines.AdkApp(agent=_TEST_AGENT)
675+
app.set_up()
676+
assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "false"
677+
678+
@mock.patch.dict(os.environ)
679+
def test_span_content_capture_enabled_with_tracing(self):
680+
app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True)
681+
app.set_up()
682+
assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "true"
683+
523684
@pytest.mark.usefixtures("caplog")
524685
def test_enable_tracing(
525686
self,
@@ -584,7 +745,6 @@ def test_dump_event_for_json():
584745
assert base64.b64decode(part["thought_signature"]) == raw_signature
585746

586747

587-
@pytest.mark.usefixtures("mock_adk_version")
588748
class TestAdkAppErrors:
589749
def test_raise_get_session_not_found_error(self):
590750
with pytest.raises(

0 commit comments

Comments
 (0)