File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,12 @@ def run(
249249

250250
_LOGGER.log_create_with_lro(self.__class__)
251251

252+
# PipelineJob.name is not used by pipeline service
253+
pipeline_job_id = self._gca_resource.name.split("/")[-1]
252254
self._gca_resource = self.api_client.create_pipeline_job(
253-
parent=self._parent, pipeline_job=self._gca_resource
255+
parent=self._parent,
256+
pipeline_job=self._gca_resource,
257+
pipeline_job_id=pipeline_job_id,
254258
)
255259

256260
_LOGGER.log_create_complete_with_getter(
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
_TEST_PROJECT = "test-project"
4242
_TEST_LOCATION = "us-central1"
43+
_TEST_PIPELINE_JOB_DISPLAY_NAME = "sample-pipeline-job-display-name"
4344
_TEST_PIPELINE_JOB_ID = "sample-test-pipeline-202111111"
4445
_TEST_GCS_BUCKET_NAME = "my-bucket"
4546
_TEST_CREDENTIALS = auth_credentials.AnonymousCredentials()
@@ -199,7 +200,7 @@ def test_run_call_pipeline_service_create(
199200
)
200201

201202
job = pipeline_jobs.PipelineJob(
202-
display_name=_TEST_PIPELINE_JOB_ID,
203+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
203204
template_path=_TEST_TEMPLATE_PATH,
204205
job_id=_TEST_PIPELINE_JOB_ID,
205206
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
@@ -222,7 +223,7 @@ def test_run_call_pipeline_service_create(
222223

223224
# Construct expected request
224225
expected_gapic_pipeline_job = gca_pipeline_job_v1beta1.PipelineJob(
225-
display_name=_TEST_PIPELINE_JOB_ID,
226+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
226227
name=_TEST_PIPELINE_JOB_NAME,
227228
pipeline_spec={
228229
"components": {},
@@ -233,7 +234,9 @@ def test_run_call_pipeline_service_create(
233234
)
234235

235236
mock_pipeline_service_create.assert_called_once_with(
236-
parent=_TEST_PARENT, pipeline_job=expected_gapic_pipeline_job,
237+
parent=_TEST_PARENT,
238+
pipeline_job=expected_gapic_pipeline_job,
239+
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
237240
)
238241

239242
mock_pipeline_service_get.assert_called_with(name=_TEST_PIPELINE_JOB_NAME)
@@ -242,6 +245,14 @@ def test_run_call_pipeline_service_create(
242245
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
243246
)
244247

248+
@pytest.mark.usefixtures("mock_pipeline_service_get")
249+
def test_get_pipeline_job(self, mock_pipeline_service_get):
250+
aiplatform.init(project=_TEST_PROJECT)
251+
job = pipeline_jobs.PipelineJob.get(resource_name=_TEST_PIPELINE_JOB_ID)
252+
253+
mock_pipeline_service_get.assert_called_once_with(name=_TEST_PIPELINE_JOB_NAME)
254+
assert isinstance(job, pipeline_jobs.PipelineJob)
255+
245256
@pytest.mark.usefixtures(
246257
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
247258
)
@@ -255,7 +266,7 @@ def test_cancel_pipeline_job(
255266
)
256267

257268
job = pipeline_jobs.PipelineJob(
258-
display_name=_TEST_PIPELINE_JOB_ID,
269+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
259270
template_path=_TEST_TEMPLATE_PATH,
260271
job_id=_TEST_PIPELINE_JOB_ID,
261272
)
@@ -267,14 +278,6 @@ def test_cancel_pipeline_job(
267278
name=_TEST_PIPELINE_JOB_NAME
268279
)
269280

270-
@pytest.mark.usefixtures("mock_pipeline_service_get")
271-
def test_get_training_job(self, mock_pipeline_service_get):
272-
aiplatform.init(project=_TEST_PROJECT)
273-
job = pipeline_jobs.PipelineJob.get(resource_name=_TEST_PIPELINE_JOB_ID)
274-
275-
mock_pipeline_service_get.assert_called_once_with(name=_TEST_PIPELINE_JOB_NAME)
276-
assert isinstance(job, pipeline_jobs.PipelineJob)
277-
278281
@pytest.mark.usefixtures(
279282
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
280283
)
@@ -288,7 +291,7 @@ def test_cancel_pipeline_job_without_running(
288291
)
289292

290293
job = pipeline_jobs.PipelineJob(
291-
display_name=_TEST_PIPELINE_JOB_ID,
294+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
292295
template_path=_TEST_TEMPLATE_PATH,
293296
job_id=_TEST_PIPELINE_JOB_ID,
294297
)
@@ -313,7 +316,7 @@ def test_pipeline_failure_raises(self, sync):
313316
)
314317

315318
job = pipeline_jobs.PipelineJob(
316-
display_name=_TEST_PIPELINE_JOB_ID,
319+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
317320
template_path=_TEST_TEMPLATE_PATH,
318321
job_id=_TEST_PIPELINE_JOB_ID,
319322
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,

0 commit comments

Comments
 (0)