|
16 | 16 | #
|
17 | 17 |
|
18 | 18 | from importlib import reload
|
| 19 | +from unittest import mock |
19 | 20 | from unittest.mock import , call
|
20 | 21 |
|
21 | 22 | import pytest
|
22 | 23 | from google.api_core import exceptions
|
| 24 | +from google.api_core import operation |
| 25 | +from google.auth import credentials |
23 | 26 |
|
24 | 27 | from google.cloud import aiplatform
|
25 | 28 | from google.cloud.aiplatform import initializer
|
@@ -106,6 +109,32 @@ def get_metadata_store_mock():
|
106 | 109 | yield get_metadata_store_mock
|
107 | 110 |
|
108 | 111 |
|
| 112 | +@pytest.fixture |
| 113 | +def get_metadata_store_mock_raise_not_found_exception(): |
| 114 | +with .object( |
| 115 | +MetadataServiceClient, "get_metadata_store" |
| 116 | +) as get_metadata_store_mock: |
| 117 | +get_metadata_store_mock.side_effect = [ |
| 118 | +exceptions.NotFound("Test store not found."), |
| 119 | +GapicMetadataStore(name=_TEST_METADATASTORE,), |
| 120 | +] |
| 121 | + |
| 122 | +yield get_metadata_store_mock |
| 123 | + |
| 124 | + |
| 125 | +@pytest.fixture |
| 126 | +def create_metadata_store_mock(): |
| 127 | +with .object( |
| 128 | +MetadataServiceClient, "create_metadata_store" |
| 129 | +) as create_metadata_store_mock: |
| 130 | +create_metadata_store_lro_mock = mock.Mock(operation.Operation) |
| 131 | +create_metadata_store_lro_mock.result.return_value = GapicMetadataStore( |
| 132 | +name=_TEST_METADATASTORE, |
| 133 | +) |
| 134 | +create_metadata_store_mock.return_value = create_metadata_store_lro_mock |
| 135 | +yield create_metadata_store_mock |
| 136 | + |
| 137 | + |
109 | 138 | @pytest.fixture
|
110 | 139 | def get_context_mock():
|
111 | 140 | with .object(MetadataServiceClient, "get_context") as get_context_mock:
|
@@ -364,6 +393,54 @@ def test_init_experiment_with_existing_metadataStore_and_context(
|
364 | 393 | get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE)
|
365 | 394 | get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME)
|
366 | 395 |
|
| 396 | +def test_init_experiment_with_credentials( |
| 397 | +self, get_metadata_store_mock, get_context_mock |
| 398 | +): |
| 399 | +creds = credentials.AnonymousCredentials() |
| 400 | + |
| 401 | +aiplatform.init( |
| 402 | +project=_TEST_PROJECT, |
| 403 | +location=_TEST_LOCATION, |
| 404 | +experiment=_TEST_EXPERIMENT, |
| 405 | +credentials=creds, |
| 406 | +) |
| 407 | + |
| 408 | +assert ( |
| 409 | +metadata.metadata_service._experiment.api_client._transport._credentials |
| 410 | +== creds |
| 411 | +) |
| 412 | + |
| 413 | +get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) |
| 414 | +get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) |
| 415 | + |
| 416 | +def test_init_and_get_metadata_store_with_credentials( |
| 417 | +self, get_metadata_store_mock |
| 418 | +): |
| 419 | +creds = credentials.AnonymousCredentials() |
| 420 | + |
| 421 | +aiplatform.init( |
| 422 | +project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds |
| 423 | +) |
| 424 | + |
| 425 | +store = metadata._MetadataStore.get_or_create() |
| 426 | + |
| 427 | +assert store.api_client._transport._credentials == creds |
| 428 | + |
| 429 | +@pytest.mark.usefixtures( |
| 430 | +"get_metadata_store_mock_raise_not_found_exception", |
| 431 | +"create_metadata_store_mock", |
| 432 | +) |
| 433 | +def test_init_and_get_then_create_metadata_store_with_credentials(self): |
| 434 | +creds = credentials.AnonymousCredentials() |
| 435 | + |
| 436 | +aiplatform.init( |
| 437 | +project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds |
| 438 | +) |
| 439 | + |
| 440 | +store = metadata._MetadataStore.get_or_create() |
| 441 | + |
| 442 | +assert store.api_client._transport._credentials == creds |
| 443 | + |
367 | 444 | def test_init_experiment_with_existing_description(
|
368 | 445 | self, get_metadata_store_mock, get_context_mock
|
369 | 446 | ):
|
|
0 commit comments