File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -2705,6 +2705,9 @@ def run(
27052705
budget_milli_node_hours: int = 1000,
27062706
model_display_name: Optional[str] = None,
27072707
disable_early_stopping: bool = False,
2708+
export_evaluated_data_items: bool = False,
2709+
export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
2710+
export_evaluated_data_items_override_destination: bool = False,
27082711
sync: bool = True,
27092712
) -> models.Model:
27102713
"""Runs the training job and returns a model.
@@ -2777,6 +2780,27 @@ def run(
27772780
that training might stop before the entire training budget has been
27782781
used, if further training does no longer brings significant improvement
27792782
to the model.
2783+
export_evaluated_data_items (bool):
2784+
Whether to export the test set predictions to a BigQuery table.
2785+
If False, then the export is not performed.
2786+
export_evaluated_data_items_bigquery_destination_uri (string):
2787+
Optional. URI of desired destination BigQuery table for exported test set predictions.
2788+
2789+
Expected format:
2790+
``bq://<project_id>:<dataset_id>:<table>``
2791+
2792+
If not specified, then results are exported to the following auto-created BigQuery
2793+
table:
2794+
``<project_id>:export_evaluated_examples_<model_name>_<yyyy_MM_dd'T'HH_mm_ss_SSS'Z'>.evaluated_examples``
2795+
2796+
Applies only if [export_evaluated_data_items] is True.
2797+
export_evaluated_data_items_override_destination (bool):
2798+
Whether to override the contents of [export_evaluated_data_items_bigquery_destination_uri],
2799+
if the table exists, for exported test set predictions. If False, and the
2800+
table exists, then the training job will fail.
2801+
2802+
Applies only if [export_evaluated_data_items] is True and
2803+
[export_evaluated_data_items_bigquery_destination_uri] is specified.
27802804
sync (bool):
27812805
Whether to execute this method synchronously. If False, this method
27822806
will be executed in concurrent Future and any downstream object will
@@ -2806,6 +2830,9 @@ def run(
28062830
budget_milli_node_hours=budget_milli_node_hours,
28072831
model_display_name=model_display_name,
28082832
disable_early_stopping=disable_early_stopping,
2833+
export_evaluated_data_items=export_evaluated_data_items,
2834+
export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
2835+
export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
28092836
sync=sync,
28102837
)
28112838

@@ -2822,6 +2849,9 @@ def _run(
28222849
budget_milli_node_hours: int = 1000,
28232850
model_display_name: Optional[str] = None,
28242851
disable_early_stopping: bool = False,
2852+
export_evaluated_data_items: bool = False,
2853+
export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
2854+
export_evaluated_data_items_override_destination: bool = False,
28252855
sync: bool = True,
28262856
) -> models.Model:
28272857
"""Runs the training job and returns a model.
@@ -2894,6 +2924,27 @@ def _run(
28942924
that training might stop before the entire training budget has been
28952925
used, if further training does no longer brings significant improvement
28962926
to the model.
2927+
export_evaluated_data_items (bool):
2928+
Whether to export the test set predictions to a BigQuery table.
2929+
If False, then the export is not performed.
2930+
export_evaluated_data_items_bigquery_destination_uri (string):
2931+
Optional. URI of desired destination BigQuery table for exported test set predictions.
2932+
2933+
Expected format:
2934+
``bq://<project_id>:<dataset_id>:<table>``
2935+
2936+
If not specified, then results are exported to the following auto-created BigQuery
2937+
table:
2938+
``<project_id>:export_evaluated_examples_<model_name>_<yyyy_MM_dd'T'HH_mm_ss_SSS'Z'>.evaluated_examples``
2939+
2940+
Applies only if [export_evaluated_data_items] is True.
2941+
export_evaluated_data_items_override_destination (bool):
2942+
Whether to override the contents of [export_evaluated_data_items_bigquery_destination_uri],
2943+
if the table exists, for exported test set predictions. If False, and the
2944+
table exists, then the training job will fail.
2945+
2946+
Applies only if [export_evaluated_data_items] is True and
2947+
[export_evaluated_data_items_bigquery_destination_uri] is specified.
28972948
sync (bool):
28982949
Whether to execute this method synchronously. If False, this method
28992950
will be executed in concurrent Future and any downstream object will
@@ -2940,6 +2991,18 @@ def _run(
29402991
"optimizationObjectivePrecisionValue": self._optimization_objective_precision_value,
29412992
}
29422993

2994+
final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
2995+
if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith(
2996+
"bq://"
2997+
):
2998+
final_export_eval_bq_uri = f"bq://{final_export_eval_bq_uri}"
2999+
3000+
if export_evaluated_data_items:
3001+
training_task_inputs_dict["exportEvaluatedDataItemsConfig"] = {
3002+
"destinationBigqueryUri": final_export_eval_bq_uri,
3003+
"overrideExistingTable": export_evaluated_data_items_override_destination,
3004+
}
3005+
29433006
if self._additional_experiments:
29443007
training_task_inputs_dict[
29453008
"additionalExperiments"
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@
7979
_TEST_TRAINING_DISABLE_EARLY_STOPPING = True
8080
_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME = "minimize-log-loss"
8181
_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE = "classification"
82+
_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS = True
83+
_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI = (
84+
"bq://path.to.table"
85+
)
86+
_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION = False
8287
_TEST_ADDITIONAL_EXPERIMENTS = ["exp1", "exp2"]
8388
_TEST_TRAINING_TASK_INPUTS_DICT = {
8489
# required inputs
@@ -117,6 +122,16 @@
117122
},
118123
struct_pb2.Value(),
119124
)
125+
_TEST_TRAINING_TASK_INPUTS_WITH_EXPORT_EVAL_DATA_ITEMS = json_format.ParseDict(
126+
{
127+
**_TEST_TRAINING_TASK_INPUTS_DICT,
128+
"exportEvaluatedDataItemsConfig": {
129+
"destinationBigqueryUri": _TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
130+
"overrideExistingTable": _TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
131+
},
132+
},
133+
struct_pb2.Value(),
134+
)
120135

121136
_TEST_DATASET_NAME = "test-dataset-name"
122137

@@ -366,6 +381,99 @@ def test_run_call_pipeline_service_create(
366381

367382
assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
368383

384+
@pytest.mark.parametrize("sync", [True, False])
385+
def test_run_call_pipeline_service_create_with_export_eval_data_items(
386+
self,
387+
mock_pipeline_service_create,
388+
mock_pipeline_service_get,
389+
mock_dataset_tabular,
390+
mock_model_service_get,
391+
sync,
392+
):
393+
aiplatform.init(
394+
project=_TEST_PROJECT,
395+
staging_bucket=_TEST_BUCKET_NAME,
396+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
397+
)
398+
399+
job = training_jobs.AutoMLTabularTrainingJob(
400+
display_name=_TEST_DISPLAY_NAME,
401+
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
402+
optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE,
403+
column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS,
404+
optimization_objective_recall_value=None,
405+
optimization_objective_precision_value=None,
406+
)
407+
408+
model_from_job = job.run(
409+
dataset=mock_dataset_tabular,
410+
target_column=_TEST_TRAINING_TARGET_COLUMN,
411+
model_display_name=_TEST_MODEL_DISPLAY_NAME,
412+
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
413+
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
414+
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
415+
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
416+
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
417+
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
418+
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
419+
export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS,
420+
export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
421+
export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
422+
sync=sync,
423+
)
424+
425+
job.wait_for_resource_creation()
426+
427+
assert job.resource_name == _TEST_PIPELINE_RESOURCE_NAME
428+
429+
if not sync:
430+
model_from_job.wait()
431+
432+
true_fraction_split = gca_training_pipeline.FractionSplit(
433+
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
434+
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
435+
test_fraction=_TEST_TEST_FRACTION_SPLIT,
436+
)
437+
438+
true_managed_model = gca_model.Model(
439+
display_name=_TEST_MODEL_DISPLAY_NAME,
440+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
441+
)
442+
443+
true_input_data_config = gca_training_pipeline.InputDataConfig(
444+
fraction_split=true_fraction_split,
445+
predefined_split=gca_training_pipeline.PredefinedSplit(
446+
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
447+
),
448+
dataset_id=mock_dataset_tabular.name,
449+
)
450+
451+
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
452+
display_name=_TEST_DISPLAY_NAME,
453+
training_task_definition=schema.training_job.definition.automl_tabular,
454+
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_EXPORT_EVAL_DATA_ITEMS,
455+
model_to_upload=true_managed_model,
456+
input_data_config=true_input_data_config,
457+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
458+
)
459+
460+
mock_pipeline_service_create.assert_called_once_with(
461+
parent=initializer.global_config.common_location_path(),
462+
training_pipeline=true_training_pipeline,
463+
)
464+
465+
assert job._gca_resource is mock_pipeline_service_get.return_value
466+
467+
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
468+
469+
assert model_from_job._gca_resource is mock_model_service_get.return_value
470+
471+
assert job.get_model()._gca_resource is mock_model_service_get.return_value
472+
473+
assert not job.has_failed
474+
475+
assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
476+
369477
@pytest.mark.usefixtures("mock_pipeline_service_get")
370478
@pytest.mark.parametrize("sync", [True, False])
371479
def test_run_call_pipeline_if_no_model_display_name(

0 commit comments

Comments
 (0)