|
16 | 16 | # |
17 | 17 |
|
18 | 18 | from importlib import reload |
| 19 | +from unittest import mock |
19 | 20 | from unittest.mock import patch, call |
20 | 21 |
|
21 | 22 | import pytest |
22 | 23 | from google.api_core import exceptions |
| 24 | +from google.api_core import operation |
| 25 | +from google.auth import credentials |
23 | 26 |
|
24 | 27 | from google.cloud import aiplatform |
25 | 28 | from google.cloud.aiplatform import initializer |
@@ -106,6 +109,32 @@ def get_metadata_store_mock(): |
106 | 109 | yield get_metadata_store_mock |
107 | 110 |
|
108 | 111 |
|
| 112 | +@pytest.fixture |
| 113 | +def get_metadata_store_mock_raise_not_found_exception(): |
| 114 | + with patch.object( |
| 115 | + MetadataServiceClient, "get_metadata_store" |
| 116 | + ) as get_metadata_store_mock: |
| 117 | + get_metadata_store_mock.side_effect = [ |
| 118 | + exceptions.NotFound("Test store not found."), |
| 119 | + GapicMetadataStore(name=_TEST_METADATASTORE,), |
| 120 | + ] |
| 121 | + |
| 122 | + yield get_metadata_store_mock |
| 123 | + |
| 124 | + |
| 125 | +@pytest.fixture |
| 126 | +def create_metadata_store_mock(): |
| 127 | + with patch.object( |
| 128 | + MetadataServiceClient, "create_metadata_store" |
| 129 | + ) as create_metadata_store_mock: |
| 130 | + create_metadata_store_lro_mock = mock.Mock(operation.Operation) |
| 131 | + create_metadata_store_lro_mock.result.return_value = GapicMetadataStore( |
| 132 | + name=_TEST_METADATASTORE, |
| 133 | + ) |
| 134 | + create_metadata_store_mock.return_value = create_metadata_store_lro_mock |
| 135 | + yield create_metadata_store_mock |
| 136 | + |
| 137 | + |
109 | 138 | @pytest.fixture |
110 | 139 | def get_context_mock(): |
111 | 140 | with patch.object(MetadataServiceClient, "get_context") as get_context_mock: |
@@ -364,6 +393,54 @@ def test_init_experiment_with_existing_metadataStore_and_context( |
364 | 393 | get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) |
365 | 394 | get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) |
366 | 395 |
|
| 396 | + def test_init_experiment_with_credentials( |
| 397 | + self, get_metadata_store_mock, get_context_mock |
| 398 | + ): |
| 399 | + creds = credentials.AnonymousCredentials() |
| 400 | + |
| 401 | + aiplatform.init( |
| 402 | + project=_TEST_PROJECT, |
| 403 | + location=_TEST_LOCATION, |
| 404 | + experiment=_TEST_EXPERIMENT, |
| 405 | + credentials=creds, |
| 406 | + ) |
| 407 | + |
| 408 | + assert ( |
| 409 | + metadata.metadata_service._experiment.api_client._transport._credentials |
| 410 | + == creds |
| 411 | + ) |
| 412 | + |
| 413 | + get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) |
| 414 | + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) |
| 415 | + |
| 416 | + def test_init_and_get_metadata_store_with_credentials( |
| 417 | + self, get_metadata_store_mock |
| 418 | + ): |
| 419 | + creds = credentials.AnonymousCredentials() |
| 420 | + |
| 421 | + aiplatform.init( |
| 422 | + project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds |
| 423 | + ) |
| 424 | + |
| 425 | + store = metadata._MetadataStore.get_or_create() |
| 426 | + |
| 427 | + assert store.api_client._transport._credentials == creds |
| 428 | + |
| 429 | + @pytest.mark.usefixtures( |
| 430 | + "get_metadata_store_mock_raise_not_found_exception", |
| 431 | + "create_metadata_store_mock", |
| 432 | + ) |
| 433 | + def test_init_and_get_then_create_metadata_store_with_credentials(self): |
| 434 | + creds = credentials.AnonymousCredentials() |
| 435 | + |
| 436 | + aiplatform.init( |
| 437 | + project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds |
| 438 | + ) |
| 439 | + |
| 440 | + store = metadata._MetadataStore.get_or_create() |
| 441 | + |
| 442 | + assert store.api_client._transport._credentials == creds |
| 443 | + |
367 | 444 | def test_init_experiment_with_existing_description( |
368 | 445 | self, get_metadata_store_mock, get_context_mock |
369 | 446 | ): |
|
0 commit comments