|
79 | 79 | _TEST_TRAINING_DISABLE_EARLY_STOPPING = True
|
80 | 80 | _TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME = "minimize-log-loss"
|
81 | 81 | _TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE = "classification"
|
| 82 | +_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS = True |
| 83 | +_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI = ( |
| 84 | +"bq://path.to.table" |
| 85 | +) |
| 86 | +_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION = False |
82 | 87 | _TEST_ADDITIONAL_EXPERIMENTS = ["exp1", "exp2"]
|
83 | 88 | _TEST_TRAINING_TASK_INPUTS_DICT = {
|
84 | 89 | # required inputs
|
|
117 | 122 | },
|
118 | 123 | struct_pb2.Value(),
|
119 | 124 | )
|
| 125 | +_TEST_TRAINING_TASK_INPUTS_WITH_EXPORT_EVAL_DATA_ITEMS = json_format.ParseDict( |
| 126 | +{ |
| 127 | +**_TEST_TRAINING_TASK_INPUTS_DICT, |
| 128 | +"exportEvaluatedDataItemsConfig": { |
| 129 | +"destinationBigqueryUri": _TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, |
| 130 | +"overrideExistingTable": _TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, |
| 131 | +}, |
| 132 | +}, |
| 133 | +struct_pb2.Value(), |
| 134 | +) |
120 | 135 |
|
121 | 136 | _TEST_DATASET_NAME = "test-dataset-name"
|
122 | 137 |
|
@@ -366,6 +381,99 @@ def test_run_call_pipeline_service_create(
|
366 | 381 |
|
367 | 382 | assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
|
368 | 383 |
|
| 384 | +@pytest.mark.parametrize("sync", [True, False]) |
| 385 | +def test_run_call_pipeline_service_create_with_export_eval_data_items( |
| 386 | +self, |
| 387 | +mock_pipeline_service_create, |
| 388 | +mock_pipeline_service_get, |
| 389 | +mock_dataset_tabular, |
| 390 | +mock_model_service_get, |
| 391 | +sync, |
| 392 | +): |
| 393 | +aiplatform.init( |
| 394 | +project=_TEST_PROJECT, |
| 395 | +staging_bucket=_TEST_BUCKET_NAME, |
| 396 | +encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, |
| 397 | +) |
| 398 | + |
| 399 | +job = training_jobs.AutoMLTabularTrainingJob( |
| 400 | +display_name=_TEST_DISPLAY_NAME, |
| 401 | +optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, |
| 402 | +optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, |
| 403 | +column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, |
| 404 | +optimization_objective_recall_value=None, |
| 405 | +optimization_objective_precision_value=None, |
| 406 | +) |
| 407 | + |
| 408 | +model_from_job = job.run( |
| 409 | +dataset=mock_dataset_tabular, |
| 410 | +target_column=_TEST_TRAINING_TARGET_COLUMN, |
| 411 | +model_display_name=_TEST_MODEL_DISPLAY_NAME, |
| 412 | +training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, |
| 413 | +validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, |
| 414 | +test_fraction_split=_TEST_TEST_FRACTION_SPLIT, |
| 415 | +predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, |
| 416 | +weight_column=_TEST_TRAINING_WEIGHT_COLUMN, |
| 417 | +budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, |
| 418 | +disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, |
| 419 | +export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, |
| 420 | +export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, |
| 421 | +export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, |
| 422 | +sync=sync, |
| 423 | +) |
| 424 | + |
| 425 | +job.wait_for_resource_creation() |
| 426 | + |
| 427 | +assert job.resource_name == _TEST_PIPELINE_RESOURCE_NAME |
| 428 | + |
| 429 | +if not sync: |
| 430 | +model_from_job.wait() |
| 431 | + |
| 432 | +true_fraction_split = gca_training_pipeline.FractionSplit( |
| 433 | +training_fraction=_TEST_TRAINING_FRACTION_SPLIT, |
| 434 | +validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, |
| 435 | +test_fraction=_TEST_TEST_FRACTION_SPLIT, |
| 436 | +) |
| 437 | + |
| 438 | +true_managed_model = gca_model.Model( |
| 439 | +display_name=_TEST_MODEL_DISPLAY_NAME, |
| 440 | +encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, |
| 441 | +) |
| 442 | + |
| 443 | +true_input_data_config = gca_training_pipeline.InputDataConfig( |
| 444 | +fraction_split=true_fraction_split, |
| 445 | +predefined_split=gca_training_pipeline.PredefinedSplit( |
| 446 | +key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME |
| 447 | +), |
| 448 | +dataset_id=mock_dataset_tabular.name, |
| 449 | +) |
| 450 | + |
| 451 | +true_training_pipeline = gca_training_pipeline.TrainingPipeline( |
| 452 | +display_name=_TEST_DISPLAY_NAME, |
| 453 | +training_task_definition=schema.training_job.definition.automl_tabular, |
| 454 | +training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_EXPORT_EVAL_DATA_ITEMS, |
| 455 | +model_to_upload=true_managed_model, |
| 456 | +input_data_config=true_input_data_config, |
| 457 | +encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, |
| 458 | +) |
| 459 | + |
| 460 | +mock_pipeline_service_create.assert_called_once_with( |
| 461 | +parent=initializer.global_config.common_location_path(), |
| 462 | +training_pipeline=true_training_pipeline, |
| 463 | +) |
| 464 | + |
| 465 | +assert job._gca_resource is mock_pipeline_service_get.return_value |
| 466 | + |
| 467 | +mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) |
| 468 | + |
| 469 | +assert model_from_job._gca_resource is mock_model_service_get.return_value |
| 470 | + |
| 471 | +assert job.get_model()._gca_resource is mock_model_service_get.return_value |
| 472 | + |
| 473 | +assert not job.has_failed |
| 474 | + |
| 475 | +assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED |
| 476 | + |
369 | 477 | @pytest.mark.usefixtures("mock_pipeline_service_get")
|
370 | 478 | @pytest.mark.parametrize("sync", [True, False])
|
371 | 479 | def test_run_call_pipeline_if_no_model_display_name(
|
|
0 commit comments