Original file line numberDiff line numberDiff line change
Expand Up@@ -385,6 +385,13 @@ def create(
sync: bool = True,
create_request_timeout: Optional[float] = None,
batch_size: Optional[int] = None,
model_monitoring_objective_config: Optional[
"aiplatform.model_monitoring.ObjectiveConfig"
] = None,
model_monitoring_alert_config: Optional[
"aiplatform.model_monitoring.AlertConfig"
] = None,
analysis_instance_schema_uri: Optional[str] = None,
) -> "BatchPredictionJob":
"""Create a batch prediction job.

Expand DownExpand Up@@ -551,6 +558,23 @@ def create(
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is 64.
model_monitoring_objective_config (aiplatform.model_monitoring.ObjectiveConfig):
Optional. The objective config for model monitoring. Passing this parameter enables
monitoring on the model associated with this batch prediction job.
model_monitoring_alert_config (aiplatform.model_monitoring.EmailAlertConfig):
Optional. Configures how model monitoring alerts are sent to the user. Right now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these docstrings copied from the source at https://.com/googleapis/googleapis/tree/master/google/cloud/aiplatform? Will we be able to remember to update this when/if alerts other than email alert become supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're not directly copied from GAPIC. But Jing's team also confirmed that there's no plans for additional alert configs.

only email alert is supported.
analysis_instance_schema_uri (str):
Optional. Only applicable if model_monitoring_objective_config is also passed.
This parameter specifies the YAML schema file uri describing the format of a single
instance that you want Tensorflow Data Validation (TFDV) to
analyze. If this field is empty, all the feature data types are
inferred from predict_instance_schema_uri, meaning that TFDV
will use the data in the exact format as prediction request/response.
If there are any data type differences between predict instance
and TFDV instance, this field can be used to override the schema.
For models trained with Vertex AI, this field must be set as all the
fields in predict instance formatted as string.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Expand DownExpand Up@@ -601,7 +625,18 @@ def create(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)

# TODO: remove temporary import statements once model monitoring for batch prediction is GA
if model_monitoring_objective_config:
from google.cloud.aiplatform.compat.types import (
io_v1beta1 as gca_io_compat,
batch_prediction_job_v1beta1 as gca_bp_job_compat,
model_monitoring_v1beta1 as gca_model_monitoring_compat,
)
else:
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)
gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()

# Required Fields
Expand DownExpand Up@@ -688,6 +723,28 @@ def create(
)
)

# Model Monitoring
if model_monitoring_objective_config:
if model_monitoring_objective_config.drift_detection_config:
_LOGGER.info(
"Drift detection config is currently not supported for monitoring models associated with batch prediction jobs."
)
if model_monitoring_objective_config.explanation_config:
_LOGGER.info(
"XAI config is currently not supported for monitoring models associated with batch prediction jobs."
)
gapic_batch_prediction_job.model_monitoring_config = (
gca_model_monitoring_compat.ModelMonitoringConfig(
objective_configs=[
model_monitoring_objective_config.as_proto(config_for_bp=True)
],
alert_config=model_monitoring_alert_config.as_proto(
config_for_bp=True
),
analysis_instance_schema_uri=analysis_instance_schema_uri,
)
)

empty_batch_prediction_job = cls._empty_constructor(
project=project,
location=location,
Expand All@@ -702,6 +759,11 @@ def create(
sync=sync,
create_request_timeout=create_request_timeout,
)
# TODO: b/242108750
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)

@classmethod
@base.optional_sync(return_input_arg="empty_batch_prediction_job")
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -24,10 +24,15 @@
from google.api_core import exceptions as core_exceptions
from tests.system.aiplatform import e2e_base

from google.cloud.aiplatform_v1.types import (
io as gca_io,
model_monitoring as gca_model_monitoring,
)

# constants used for testing
USER_EMAIL = ""
MODEL_NAME = "churn"
MODEL_NAME2 = "churn2"
MODEL_DISPLAYNAME_KEY = "churn"
MODEL_DISPLAYNAME_KEY2 = "churn2"
IMAGE = "us-docker.pkg.dev/cloud-aiplatform/prediction/tf2-cpu.2-5:latest"
ENDPOINT = "us-central1-aiplatform.googleapis.com"
CHURN_MODEL_PATH = "gs://mco-mm/churn"
Expand DownExpand Up@@ -139,7 +144,7 @@ def temp_endpoint(self, shared_state):
)

model = aiplatform.Model.upload(
display_name=self._make_display_name(key=MODEL_NAME),
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
artifact_uri=CHURN_MODEL_PATH,
serving_container_image_uri=IMAGE,
)
Expand All@@ -157,19 +162,19 @@ def temp_endpoint_with_two_models(self, shared_state):
)

model1 = aiplatform.Model.upload(
display_name=self._make_display_name(key=MODEL_NAME),
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
artifact_uri=CHURN_MODEL_PATH,
serving_container_image_uri=IMAGE,
)

model2 = aiplatform.Model.upload(
display_name=self._make_display_name(key=MODEL_NAME),
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY2),
artifact_uri=CHURN_MODEL_PATH,
serving_container_image_uri=IMAGE,
)
shared_state["resources"] = [model1, model2]
endpoint = aiplatform.Endpoint.create(
display_name=self._make_display_name(key=MODEL_NAME)
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY)
)
endpoint.deploy(
model=model1, machine_type="n1-standard-2", traffic_percentage=100
Expand DownExpand Up@@ -224,7 +229,14 @@ def test_mdm_one_model_one_valid_config(self, shared_state):
gca_obj_config = gapic_job.model_deployment_monitoring_objective_configs[
0
].objective_config
assert gca_obj_config.training_dataset == skew_config.training_dataset

expected_training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
target_field=TARGET,
)
)
assert gca_obj_config.training_dataset == expected_training_dataset
assert (
gca_obj_config.training_prediction_skew_detection_config
== skew_config.as_proto()
Expand DownExpand Up@@ -297,12 +309,18 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
)
assert gapic_job.model_monitoring_alert_config.enable_logging

expected_training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
target_field=TARGET,
)
)

for config in gapic_job.model_deployment_monitoring_objective_configs:
gca_obj_config = config.objective_config
deployed_model_id = config.deployed_model_id
assert (
gca_obj_config.training_dataset
== all_configs[deployed_model_id].skew_detection_config.training_dataset
gca_obj_config.as_proto().training_dataset == expected_training_dataset
)
assert (
gca_obj_config.training_prediction_skew_detection_config
Expand Down
Loading