File tree

3 files changed

+86
-7
lines changed

3 files changed

+86
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,19 @@ def init(
9292
if metadata.metadata_service.experiment_name:
9393
logging.info("project/location updated, reset Metadata config.")
9494
metadata.metadata_service.reset()
95+
9596
if project:
9697
self._project = project
9798
if location:
9899
utils.validate_region(location)
99100
self._location = location
101+
if staging_bucket:
102+
self._staging_bucket = staging_bucket
103+
if credentials:
104+
self._credentials = credentials
105+
if encryption_spec_key_name:
106+
self._encryption_spec_key_name = encryption_spec_key_name
107+
100108
if experiment:
101109
metadata.metadata_service.set_experiment(
102110
experiment=experiment, description=experiment_description
@@ -105,12 +113,6 @@ def init(
105113
raise ValueError(
106114
"Experiment name needs to be set in `init` in order to add experiment descriptions."
107115
)
108-
if staging_bucket:
109-
self._staging_bucket = staging_bucket
110-
if credentials:
111-
self._credentials = credentials
112-
if encryption_spec_key_name:
113-
self._encryption_spec_key_name = encryption_spec_key_name
114116

115117
def get_encryption_spec(
116118
self,
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _get(
205205
project: Optional[str] = None,
206206
location: Optional[str] = None,
207207
credentials: Optional[auth_credentials.Credentials] = None,
208-
) -> "Optional[_MetadataStore]":
208+
) -> Optional["_MetadataStore"]:
209209
"""Returns a MetadataStore resource.
210210
211211
Args:
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
#
1717

1818
from importlib import reload
19+
from unittest import mock
1920
from unittest.mock import , call
2021

2122
import pytest
2223
from google.api_core import exceptions
24+
from google.api_core import operation
25+
from google.auth import credentials
2326

2427
from google.cloud import aiplatform
2528
from google.cloud.aiplatform import initializer
@@ -106,6 +109,32 @@ def get_metadata_store_mock():
106109
yield get_metadata_store_mock
107110

108111

112+
@pytest.fixture
113+
def get_metadata_store_mock_raise_not_found_exception():
114+
with .object(
115+
MetadataServiceClient, "get_metadata_store"
116+
) as get_metadata_store_mock:
117+
get_metadata_store_mock.side_effect = [
118+
exceptions.NotFound("Test store not found."),
119+
GapicMetadataStore(name=_TEST_METADATASTORE,),
120+
]
121+
122+
yield get_metadata_store_mock
123+
124+
125+
@pytest.fixture
126+
def create_metadata_store_mock():
127+
with .object(
128+
MetadataServiceClient, "create_metadata_store"
129+
) as create_metadata_store_mock:
130+
create_metadata_store_lro_mock = mock.Mock(operation.Operation)
131+
create_metadata_store_lro_mock.result.return_value = GapicMetadataStore(
132+
name=_TEST_METADATASTORE,
133+
)
134+
create_metadata_store_mock.return_value = create_metadata_store_lro_mock
135+
yield create_metadata_store_mock
136+
137+
109138
@pytest.fixture
110139
def get_context_mock():
111140
with .object(MetadataServiceClient, "get_context") as get_context_mock:
@@ -364,6 +393,54 @@ def test_init_experiment_with_existing_metadataStore_and_context(
364393
get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE)
365394
get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME)
366395

396+
def test_init_experiment_with_credentials(
397+
self, get_metadata_store_mock, get_context_mock
398+
):
399+
creds = credentials.AnonymousCredentials()
400+
401+
aiplatform.init(
402+
project=_TEST_PROJECT,
403+
location=_TEST_LOCATION,
404+
experiment=_TEST_EXPERIMENT,
405+
credentials=creds,
406+
)
407+
408+
assert (
409+
metadata.metadata_service._experiment.api_client._transport._credentials
410+
== creds
411+
)
412+
413+
get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE)
414+
get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME)
415+
416+
def test_init_and_get_metadata_store_with_credentials(
417+
self, get_metadata_store_mock
418+
):
419+
creds = credentials.AnonymousCredentials()
420+
421+
aiplatform.init(
422+
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds
423+
)
424+
425+
store = metadata._MetadataStore.get_or_create()
426+
427+
assert store.api_client._transport._credentials == creds
428+
429+
@pytest.mark.usefixtures(
430+
"get_metadata_store_mock_raise_not_found_exception",
431+
"create_metadata_store_mock",
432+
)
433+
def test_init_and_get_then_create_metadata_store_with_credentials(self):
434+
creds = credentials.AnonymousCredentials()
435+
436+
aiplatform.init(
437+
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds
438+
)
439+
440+
store = metadata._MetadataStore.get_or_create()
441+
442+
assert store.api_client._transport._credentials == creds
443+
367444
def test_init_experiment_with_existing_description(
368445
self, get_metadata_store_mock, get_context_mock
369446
):

0 commit comments

Comments
 (0)