|
16 | 16 | import importlib |
17 | 17 | import json |
18 | 18 | import os |
19 | | -import cloudpickle |
20 | | -import sys |
21 | 19 | from unittest import mock |
22 | 20 | from typing import Optional |
23 | 21 | import dataclasses |
24 | 22 |
|
25 | 23 | from google import auth |
26 | | -from google.auth import credentials as auth_credentials |
27 | | -from google.cloud import storage |
28 | 24 | import vertexai |
29 | | -from google.cloud import aiplatform |
30 | | -from google.cloud.aiplatform_v1 import types as aip_types |
31 | | -from google.cloud.aiplatform_v1.services import reasoning_engine_service |
32 | | -from google.cloud.aiplatform import base |
33 | 25 | from google.cloud.aiplatform import initializer |
34 | 26 | from vertexai.agent_engines import _utils |
35 | 27 | from vertexai import agent_engines |
36 | | -from vertexai.agent_engines.templates import adk as adk_template |
37 | | -from vertexai.agent_engines import _agent_engines |
38 | | -from google.api_core import operation as ga_operation |
39 | 28 | from google.genai import types |
40 | 29 | import pytest |
41 | 30 | import uuid |
@@ -87,52 +76,6 @@ def __init__(self, name: str, model: str): |
87 | 76 | "streaming_mode": "sse", |
88 | 77 | "max_llm_calls": 500, |
89 | 78 | } |
90 | | -_TEST_STAGING_BUCKET = "gs://test-bucket" |
91 | | -_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) |
92 | | -_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" |
93 | | -_TEST_RESOURCE_ID = "1028944691210842416" |
94 | | -_TEST_AGENT_ENGINE_RESOURCE_NAME = ( |
95 | | - f"{_TEST_PARENT}/reasoningEngines/{_TEST_RESOURCE_ID}" |
96 | | -) |
97 | | -_TEST_AGENT_ENGINE_DISPLAY_NAME = "Agent Engine Display Name" |
98 | | -_TEST_GCS_DIR_NAME = _agent_engines._DEFAULT_GCS_DIR_NAME |
99 | | -_TEST_BLOB_FILENAME = _agent_engines._BLOB_FILENAME |
100 | | -_TEST_REQUIREMENTS_FILE = _agent_engines._REQUIREMENTS_FILE |
101 | | -_TEST_EXTRA_PACKAGES_FILE = _agent_engines._EXTRA_PACKAGES_FILE |
102 | | -_TEST_AGENT_ENGINE_GCS_URI = "{}/{}/{}".format( |
103 | | - _TEST_STAGING_BUCKET, |
104 | | - _TEST_GCS_DIR_NAME, |
105 | | - _TEST_BLOB_FILENAME, |
106 | | -) |
107 | | -_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI = "{}/{}/{}".format( |
108 | | - _TEST_STAGING_BUCKET, |
109 | | - _TEST_GCS_DIR_NAME, |
110 | | - _TEST_EXTRA_PACKAGES_FILE, |
111 | | -) |
112 | | -_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI = "{}/{}/{}".format( |
113 | | - _TEST_STAGING_BUCKET, |
114 | | - _TEST_GCS_DIR_NAME, |
115 | | - _TEST_REQUIREMENTS_FILE, |
116 | | -) |
117 | | -_TEST_AGENT_ENGINE_PACKAGE_SPEC = aip_types.ReasoningEngineSpec.PackageSpec( |
118 | | - python_version=f"{sys.version_info.major}.{sys.version_info.minor}", |
119 | | - pickle_object_gcs_uri=_TEST_AGENT_ENGINE_GCS_URI, |
120 | | - dependency_files_gcs_uri=_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI, |
121 | | - requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, |
122 | | -) |
123 | | -_ADK_AGENT_FRAMEWORK = adk_template.AdkApp.agent_framework |
124 | | -_TEST_AGENT_ENGINE_OBJ = aip_types.ReasoningEngine( |
125 | | - name=_TEST_AGENT_ENGINE_RESOURCE_NAME, |
126 | | - display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, |
127 | | - spec=aip_types.ReasoningEngineSpec( |
128 | | - package_spec=_TEST_AGENT_ENGINE_PACKAGE_SPEC, |
129 | | - agent_framework=_ADK_AGENT_FRAMEWORK, |
130 | | - ), |
131 | | -) |
132 | | - |
133 | | -GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( |
134 | | - "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" |
135 | | -) |
136 | 79 |
|
137 | 80 |
|
138 | 81 | @pytest.fixture(scope="module") |
@@ -809,174 +752,3 @@ async def test_async_stream_query_invalid_message_type(self): |
809 | 752 | ): |
810 | 753 | async for _ in app.async_stream_query(user_id=_TEST_USER_ID, message=123): |
811 | 754 | pass |
812 | | - |
813 | | - |
814 | | -@pytest.fixture(scope="module") |
815 | | -def create_agent_engine_mock(): |
816 | | - with mock.patch.object( |
817 | | - reasoning_engine_service.ReasoningEngineServiceClient, |
818 | | - "create_reasoning_engine", |
819 | | - ) as create_agent_engine_mock: |
820 | | - create_agent_engine_lro_mock = mock.Mock(ga_operation.Operation) |
821 | | - create_agent_engine_lro_mock.result.return_value = _TEST_AGENT_ENGINE_OBJ |
822 | | - create_agent_engine_mock.return_value = create_agent_engine_lro_mock |
823 | | - yield create_agent_engine_mock |
824 | | - |
825 | | - |
826 | | -@pytest.fixture(scope="module") |
827 | | -def get_agent_engine_mock(): |
828 | | - with mock.patch.object( |
829 | | - reasoning_engine_service.ReasoningEngineServiceClient, |
830 | | - "get_reasoning_engine", |
831 | | - ) as get_agent_engine_mock: |
832 | | - api_client_mock = mock.Mock() |
833 | | - api_client_mock.get_reasoning_engine.return_value = _TEST_AGENT_ENGINE_OBJ |
834 | | - get_agent_engine_mock.return_value = api_client_mock |
835 | | - yield get_agent_engine_mock |
836 | | - |
837 | | - |
838 | | -@pytest.fixture(scope="module") |
839 | | -def cloud_storage_create_bucket_mock(): |
840 | | - with mock.patch.object(storage, "Client") as cloud_storage_mock: |
841 | | - bucket_mock = mock.Mock(spec=storage.Bucket) |
842 | | - bucket_mock.blob.return_value.open.return_value = "blob_file" |
843 | | - bucket_mock.blob.return_value.upload_from_filename.return_value = None |
844 | | - bucket_mock.blob.return_value.upload_from_string.return_value = None |
845 | | - |
846 | | - cloud_storage_mock.get_bucket = mock.Mock( |
847 | | - side_effect=ValueError("bucket not found") |
848 | | - ) |
849 | | - cloud_storage_mock.bucket.return_value = bucket_mock |
850 | | - cloud_storage_mock.create_bucket.return_value = bucket_mock |
851 | | - |
852 | | - yield cloud_storage_mock |
853 | | - |
854 | | - |
855 | | -@pytest.fixture(scope="module") |
856 | | -def cloudpickle_dump_mock(): |
857 | | - with mock.patch.object(cloudpickle, "dump") as cloudpickle_dump_mock: |
858 | | - yield cloudpickle_dump_mock |
859 | | - |
860 | | - |
861 | | -@pytest.fixture(scope="module") |
862 | | -def cloudpickle_load_mock(): |
863 | | - with mock.patch.object(cloudpickle, "load") as cloudpickle_load_mock: |
864 | | - yield cloudpickle_load_mock |
865 | | - |
866 | | - |
867 | | -@pytest.fixture(scope="function") |
868 | | -def get_gca_resource_mock(): |
869 | | - with mock.patch.object( |
870 | | - base.VertexAiResourceNoun, |
871 | | - "_get_gca_resource", |
872 | | - ) as get_gca_resource_mock: |
873 | | - get_gca_resource_mock.return_value = _TEST_AGENT_ENGINE_OBJ |
874 | | - yield get_gca_resource_mock |
875 | | - |
876 | | - |
877 | | -# Function scope is required for the pytest parameterized tests. |
878 | | -@pytest.fixture(scope="function") |
879 | | -def update_agent_engine_mock(): |
880 | | - with mock.patch.object( |
881 | | - reasoning_engine_service.ReasoningEngineServiceClient, |
882 | | - "update_reasoning_engine", |
883 | | - ) as update_agent_engine_mock: |
884 | | - yield update_agent_engine_mock |
885 | | - |
886 | | - |
887 | | -@pytest.mark.usefixtures("google_auth_mock") |
888 | | -class TestAgentEngines: |
889 | | - def setup_method(self): |
890 | | - importlib.reload(initializer) |
891 | | - importlib.reload(aiplatform) |
892 | | - aiplatform.init( |
893 | | - project=_TEST_PROJECT, |
894 | | - location=_TEST_LOCATION, |
895 | | - credentials=_TEST_CREDENTIALS, |
896 | | - staging_bucket=_TEST_STAGING_BUCKET, |
897 | | - ) |
898 | | - |
899 | | - def teardown_method(self): |
900 | | - initializer.global_pool.shutdown(wait=True) |
901 | | - |
902 | | - @pytest.mark.parametrize( |
903 | | - "env_vars,expected_env_vars", |
904 | | - [ |
905 | | - ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
906 | | - (None, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
907 | | - ( |
908 | | - {"some_env": "some_val"}, |
909 | | - { |
910 | | - "some_env": "some_val", |
911 | | - GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true", |
912 | | - }, |
913 | | - ), |
914 | | - ( |
915 | | - {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
916 | | - {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
917 | | - ), |
918 | | - ], |
919 | | - ) |
920 | | - def test_create_default_telemetry_enablement( |
921 | | - self, |
922 | | - create_agent_engine_mock: mock.Mock, |
923 | | - cloud_storage_create_bucket_mock: mock.Mock, |
924 | | - cloudpickle_dump_mock: mock.Mock, |
925 | | - cloudpickle_load_mock: mock.Mock, |
926 | | - get_gca_resource_mock: mock.Mock, |
927 | | - env_vars: dict[str, str], |
928 | | - expected_env_vars: dict[str, str], |
929 | | - ): |
930 | | - agent_engines.create( |
931 | | - agent_engine=agent_engines.AdkApp(agent=_TEST_AGENT), |
932 | | - env_vars=env_vars, |
933 | | - ) |
934 | | - create_agent_engine_mock.assert_called_once() |
935 | | - deployment_spec = create_agent_engine_mock.call_args.kwargs[ |
936 | | - "reasoning_engine" |
937 | | - ].spec.deployment_spec |
938 | | - assert _utils.to_dict(deployment_spec)["env"] == [ |
939 | | - {"name": key, "value": value} for key, value in expected_env_vars.items() |
940 | | - ] |
941 | | - |
942 | | - @pytest.mark.parametrize( |
943 | | - "env_vars,expected_env_vars", |
944 | | - [ |
945 | | - ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
946 | | - (None, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
947 | | - ( |
948 | | - {"some_env": "some_val"}, |
949 | | - { |
950 | | - "some_env": "some_val", |
951 | | - GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true", |
952 | | - }, |
953 | | - ), |
954 | | - ( |
955 | | - {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
956 | | - {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
957 | | - ), |
958 | | - ], |
959 | | - ) |
960 | | - def test_update_default_telemetry_enablement( |
961 | | - self, |
962 | | - update_agent_engine_mock: mock.Mock, |
963 | | - cloud_storage_create_bucket_mock: mock.Mock, |
964 | | - cloudpickle_dump_mock: mock.Mock, |
965 | | - cloudpickle_load_mock: mock.Mock, |
966 | | - get_gca_resource_mock: mock.Mock, |
967 | | - get_agent_engine_mock: mock.Mock, |
968 | | - env_vars: dict[str, str], |
969 | | - expected_env_vars: dict[str, str], |
970 | | - ): |
971 | | - agent_engines.update( |
972 | | - resource_name=_TEST_AGENT_ENGINE_RESOURCE_NAME, |
973 | | - description="foobar", # avoid "At least one of ... must be specified" errors. |
974 | | - env_vars=env_vars, |
975 | | - ) |
976 | | - update_agent_engine_mock.assert_called_once() |
977 | | - deployment_spec = update_agent_engine_mock.call_args.kwargs[ |
978 | | - "request" |
979 | | - ].reasoning_engine.spec.deployment_spec |
980 | | - assert _utils.to_dict(deployment_spec)["env"] == [ |
981 | | - {"name": key, "value": value} for key, value in expected_env_vars.items() |
982 | | - ] |
0 commit comments