File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from google.cloud.aiplatform import compat
2828
from google.cloud.aiplatform import initializer
2929
from google.cloud.aiplatform import utils as aiplatform_utils
30+
from google.cloud.aiplatform.metadata import experiment_resources
3031
from google.cloud.aiplatform_v1.services import gen_ai_tuning_service
3132
from google.cloud.aiplatform_v1.types import job_state
3233
from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job
@@ -35,6 +36,8 @@
3536

3637
import pytest
3738

39+
from unittest.mock import
40+
3841
from google.rpc import status_pb2
3942

4043

@@ -136,7 +139,14 @@ class MockTuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
136139
)
137140

138141

139-
@pytest.mark.usefixtures("google_auth_mock")
142+
@pytest.fixture()
143+
def experiment_init_mock():
144+
with .object(experiment_resources.Experiment, "__init__") as experiment_mock:
145+
experiment_mock.return_value = None
146+
yield experiment_mock
147+
148+
149+
@pytest.mark.usefixtures("google_auth_mock", "experiment_init_mock")
140150
class TestgenerativeModelTuning:
141151
"""Unit tests for generative model tuning."""
142152

Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.cloud.aiplatform import initializer as aiplatform_initializer
2727
from google.cloud.aiplatform import jobs
2828
from google.cloud.aiplatform import utils as aiplatform_utils
29+
from google.cloud.aiplatform.utils import _ipython_utils
2930
from google.cloud.aiplatform_v1.services import gen_ai_tuning_service as gen_ai_tuning_service_v1
3031
from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types
3132
from google.cloud.aiplatform_v1 import types as gca_types
@@ -57,6 +58,7 @@ class TuningJob(aiplatform_base._VertexAiResourceNounPlus):
5758
_parse_resource_name_method = "parse_tuning_job_path"
5859
_format_resource_name_method = "tuning_job_path"
5960
_job_type = "tuning/tuningJob"
61+
_has_displayed_experiments_button = False
6062

6163
client_class = TuningJobClientWithOverride
6264

@@ -74,6 +76,9 @@ def refresh(self) -> "TuningJob":
7476
self._gca_resource: gca_tuning_job_types.TuningJob = (
7577
self._get_gca_resource(resource_name=self.resource_name)
7678
)
79+
if self.experiment and not self._has_displayed_experiments_button:
80+
self._has_displayed_experiments_button = True
81+
_ipython_utils.display_experiment_button(self.experiment)
7782
return self
7883

7984
@property

0 commit comments

Comments
 (0)