|
23 | 23 | from google.cloud.aiplatform import base
|
24 | 24 | from google.cloud.aiplatform import initializer as aiplatform_initializer
|
25 | 25 | from google.cloud.aiplatform import utils as aiplatform_utils
|
| 26 | +from google.cloud.aiplatform.compat import types as aiplatform_types |
26 | 27 | from google.cloud.aiplatform.utils import gcs_utils
|
27 | 28 | from vertexai._model_garden import _model_garden_models
|
28 | 29 | from vertexai.language_models import (
|
@@ -148,18 +149,24 @@ def tune_model(
|
148 | 149 | self,
|
149 | 150 | training_data: Union[str, "pandas.core.frame.DataFrame"],
|
150 | 151 | *,
|
151 |
| -train_steps: int = 1000, |
| 152 | +train_steps: Optional[int] = None, |
152 | 153 | learning_rate: Optional[float] = None,
|
153 | 154 | learning_rate_multiplier: Optional[float] = None,
|
154 | 155 | tuning_job_location: Optional[str] = None,
|
155 | 156 | tuned_model_location: Optional[str] = None,
|
156 | 157 | model_display_name: Optional[str] = None,
|
157 | 158 | tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
|
158 | 159 | default_context: Optional[str] = None,
|
159 |
| -): |
| 160 | +) -> "_LanguageModelTuningJob": |
160 | 161 | """Tunes a model based on training data.
|
161 | 162 |
|
162 |
| -This method launches a model tuning job that can take some time. |
| 163 | +This method launches and returns an asynchronous model tuning job. |
| 164 | +Usage: |
| 165 | +``` |
| 166 | +tuning_job = model.tune_model(...) |
| 167 | +... do some other work |
| 168 | +tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete |
| 169 | +``` |
163 | 170 |
|
164 | 171 | Args:
|
165 | 172 | training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
|
@@ -303,16 +310,68 @@ def _tune_model(
|
303 | 310 | base_model=self,
|
304 | 311 | job=pipeline_job,
|
305 | 312 | )
|
306 |
| -self._job = job |
307 |
| -tuned_model = job.result() |
308 |
| -# The UXR study attendees preferred to tune model in place |
309 |
| -self._endpoint = tuned_model._endpoint |
310 |
| -self._endpoint_name = tuned_model._endpoint_name |
| 313 | +return job |
311 | 314 |
|
312 | 315 |
|
313 | 316 | class _TunableTextModelMixin(_TunableModelMixin):
|
314 | 317 | """Text model that can be tuned."""
|
315 | 318 |
|
| 319 | +def tune_model( |
| 320 | +self, |
| 321 | +training_data: Union[str, "pandas.core.frame.DataFrame"], |
| 322 | +*, |
| 323 | +train_steps: Optional[int] = None, |
| 324 | +learning_rate_multiplier: Optional[float] = None, |
| 325 | +tuning_job_location: Optional[str] = None, |
| 326 | +tuned_model_location: Optional[str] = None, |
| 327 | +model_display_name: Optional[str] = None, |
| 328 | +tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None, |
| 329 | +) -> "_LanguageModelTuningJob": |
| 330 | +"""Tunes a model based on training data. |
| 331 | +
|
| 332 | +This method launches and returns an asynchronous model tuning job. |
| 333 | +Usage: |
| 334 | +``` |
| 335 | +tuning_job = model.tune_model(...) |
| 336 | +... do some other work |
| 337 | +tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete |
| 338 | +
|
| 339 | +Args: |
| 340 | +training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format. |
| 341 | +The dataset schema is model-specific. |
| 342 | +See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format |
| 343 | +train_steps: Number of training batches to tune on (batch size is 8 samples). |
| 344 | +learning_rate_multiplier: Learning rate multiplier to use in tuning. |
| 345 | +tuning_job_location: GCP location where the tuning job should be run. |
| 346 | +Only "europe-west4" and "us-central1" locations are supported for now. |
| 347 | +tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now. |
| 348 | +model_display_name: Custom display name for the tuned model. |
| 349 | +tuning_evaluation_spec: Specification for the model evaluation during tuning. |
| 350 | +
|
| 351 | +Returns: |
| 352 | +A `LanguageModelTuningJob` object that represents the tuning job. |
| 353 | +Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object. |
| 354 | +
|
| 355 | +Raises: |
| 356 | +ValueError: If the "tuning_job_location" value is not supported |
| 357 | +ValueError: If the "tuned_model_location" value is not supported |
| 358 | +RuntimeError: If the model does not support tuning |
| 359 | +""" |
| 360 | +# Note: Chat models do not support default_context |
| 361 | +return super().tune_model( |
| 362 | +training_data=training_data, |
| 363 | +train_steps=train_steps, |
| 364 | +learning_rate_multiplier=learning_rate_multiplier, |
| 365 | +tuning_job_location=tuning_job_location, |
| 366 | +tuned_model_location=tuned_model_location, |
| 367 | +model_display_name=model_display_name, |
| 368 | +tuning_evaluation_spec=tuning_evaluation_spec, |
| 369 | +) |
| 370 | + |
| 371 | + |
| 372 | +class _PreviewTunableTextModelMixin(_TunableModelMixin): |
| 373 | +"""Text model that can be tuned.""" |
| 374 | + |
316 | 375 | def tune_model(
|
317 | 376 | self,
|
318 | 377 | training_data: Union[str, "pandas.core.frame.DataFrame"],
|
@@ -324,10 +383,20 @@ def tune_model(
|
324 | 383 | tuned_model_location: Optional[str] = None,
|
325 | 384 | model_display_name: Optional[str] = None,
|
326 | 385 | tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
|
327 |
| -): |
| 386 | +) -> "_LanguageModelTuningJob": |
328 | 387 | """Tunes a model based on training data.
|
329 | 388 |
|
330 |
| -This method launches a model tuning job that can take some time. |
| 389 | +This method launches a model tuning job, waits for completion, |
| 390 | +updates the model in-place. This method returns job object for forward |
| 391 | +compatibility. |
| 392 | +In the future (GA), this method will become asynchronous and will stop |
| 393 | +updating the model in-place. |
| 394 | +
|
| 395 | +Usage: |
| 396 | +``` |
| 397 | +tuning_job = model.tune_model(...) # Blocks until tuning is complete |
| 398 | +tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete |
| 399 | +``` |
331 | 400 |
|
332 | 401 | Args:
|
333 | 402 | training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
|
@@ -353,7 +422,7 @@ def tune_model(
|
353 | 422 | RuntimeError: If the model does not support tuning
|
354 | 423 | """
|
355 | 424 | # Note: Chat models do not support default_context
|
356 |
| -return super().tune_model( |
| 425 | +job = super().tune_model( |
357 | 426 | training_data=training_data,
|
358 | 427 | train_steps=train_steps,
|
359 | 428 | learning_rate=learning_rate,
|
@@ -363,11 +432,74 @@ def tune_model(
|
363 | 432 | model_display_name=model_display_name,
|
364 | 433 | tuning_evaluation_spec=tuning_evaluation_spec,
|
365 | 434 | )
|
| 435 | +tuned_model = job.get_tuned_model() |
| 436 | +self._endpoint = tuned_model._endpoint |
| 437 | +self._endpoint_name = tuned_model._endpoint_name |
| 438 | +return job |
366 | 439 |
|
367 | 440 |
|
368 | 441 | class _TunableChatModelMixin(_TunableModelMixin):
|
369 | 442 | """Chat model that can be tuned."""
|
370 | 443 |
|
| 444 | +def tune_model( |
| 445 | +self, |
| 446 | +training_data: Union[str, "pandas.core.frame.DataFrame"], |
| 447 | +*, |
| 448 | +train_steps: Optional[int] = None, |
| 449 | +learning_rate_multiplier: Optional[float] = None, |
| 450 | +tuning_job_location: Optional[str] = None, |
| 451 | +tuned_model_location: Optional[str] = None, |
| 452 | +model_display_name: Optional[str] = None, |
| 453 | +default_context: Optional[str] = None, |
| 454 | +) -> "_LanguageModelTuningJob": |
| 455 | +"""Tunes a model based on training data. |
| 456 | +
|
| 457 | +This method launches and returns an asynchronous model tuning job. |
| 458 | +Usage: |
| 459 | +``` |
| 460 | +tuning_job = model.tune_model(...) |
| 461 | +... do some other work |
| 462 | +tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete |
| 463 | +``` |
| 464 | +
|
| 465 | +Args: |
| 466 | +training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format. |
| 467 | +The dataset schema is model-specific. |
| 468 | +See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format |
| 469 | +train_steps: Number of training batches to tune on (batch size is 8 samples). |
| 470 | +learning_rate: Deprecated. Use learning_rate_multiplier instead. |
| 471 | +Learning rate to use in tuning. |
| 472 | +learning_rate_multiplier: Learning rate multiplier to use in tuning. |
| 473 | +tuning_job_location: GCP location where the tuning job should be run. |
| 474 | +Only "europe-west4" and "us-central1" locations are supported for now. |
| 475 | +tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now. |
| 476 | +model_display_name: Custom display name for the tuned model. |
| 477 | +default_context: The context to use for all training samples by default. |
| 478 | +
|
| 479 | +Returns: |
| 480 | +A `LanguageModelTuningJob` object that represents the tuning job. |
| 481 | +Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object. |
| 482 | +
|
| 483 | +Raises: |
| 484 | +ValueError: If the "tuning_job_location" value is not supported |
| 485 | +ValueError: If the "tuned_model_location" value is not supported |
| 486 | +RuntimeError: If the model does not support tuning |
| 487 | +""" |
| 488 | +# Note: Chat models do not support tuning_evaluation_spec |
| 489 | +return super().tune_model( |
| 490 | +training_data=training_data, |
| 491 | +train_steps=train_steps, |
| 492 | +learning_rate_multiplier=learning_rate_multiplier, |
| 493 | +tuning_job_location=tuning_job_location, |
| 494 | +tuned_model_location=tuned_model_location, |
| 495 | +model_display_name=model_display_name, |
| 496 | +default_context=default_context, |
| 497 | +) |
| 498 | + |
| 499 | + |
| 500 | +class _PreviewTunableChatModelMixin(_TunableModelMixin): |
| 501 | +"""Chat model that can be tuned.""" |
| 502 | + |
371 | 503 | def tune_model(
|
372 | 504 | self,
|
373 | 505 | training_data: Union[str, "pandas.core.frame.DataFrame"],
|
@@ -379,10 +511,20 @@ def tune_model(
|
379 | 511 | tuned_model_location: Optional[str] = None,
|
380 | 512 | model_display_name: Optional[str] = None,
|
381 | 513 | default_context: Optional[str] = None,
|
382 |
| -): |
| 514 | +) -> "_LanguageModelTuningJob": |
383 | 515 | """Tunes a model based on training data.
|
384 | 516 |
|
385 |
| -This method launches a model tuning job that can take some time. |
| 517 | +This method launches a model tuning job, waits for completion, |
| 518 | +updates the model in-place. This method returns job object for forward |
| 519 | +compatibility. |
| 520 | +In the future (GA), this method will become asynchronous and will stop |
| 521 | +updating the model in-place. |
| 522 | +
|
| 523 | +Usage: |
| 524 | +``` |
| 525 | +tuning_job = model.tune_model(...) # Blocks until tuning is complete |
| 526 | +tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete |
| 527 | +``` |
386 | 528 |
|
387 | 529 | Args:
|
388 | 530 | training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
|
@@ -408,7 +550,7 @@ def tune_model(
|
408 | 550 | RuntimeError: If the model does not support tuning
|
409 | 551 | """
|
410 | 552 | # Note: Chat models do not support tuning_evaluation_spec
|
411 |
| -return super().tune_model( |
| 553 | +job = super().tune_model( |
412 | 554 | training_data=training_data,
|
413 | 555 | train_steps=train_steps,
|
414 | 556 | learning_rate=learning_rate,
|
@@ -418,6 +560,10 @@ def tune_model(
|
418 | 560 | model_display_name=model_display_name,
|
419 | 561 | default_context=default_context,
|
420 | 562 | )
|
| 563 | +tuned_model = job.get_tuned_model() |
| 564 | +self._endpoint = tuned_model._endpoint |
| 565 | +self._endpoint_name = tuned_model._endpoint_name |
| 566 | +return job |
421 | 567 |
|
422 | 568 |
|
423 | 569 | @dataclasses.dataclass
|
@@ -746,7 +892,7 @@ class TextGenerationModel(_TextGenerationModel, _ModelWithBatchPredict):
|
746 | 892 |
|
747 | 893 | class _PreviewTextGenerationModel(
|
748 | 894 | _TextGenerationModel,
|
749 |
| -_TunableTextModelMixin, |
| 895 | +_PreviewTunableTextModelMixin, |
750 | 896 | _PreviewModelWithBatchPredict,
|
751 | 897 | _evaluatable_language_models._EvaluatableLanguageModel,
|
752 | 898 | ):
|
@@ -1076,7 +1222,7 @@ class ChatModel(_ChatModelBase):
|
1076 | 1222 | _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
|
1077 | 1223 |
|
1078 | 1224 |
|
1079 |
| -class _PreviewChatModel(ChatModel, _TunableChatModelMixin): |
| 1225 | +class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin): |
1080 | 1226 | _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
|
1081 | 1227 |
|
1082 | 1228 |
|
@@ -1650,11 +1796,12 @@ def __init__(
|
1650 | 1796 | base_model: _LanguageModel,
|
1651 | 1797 | job: aiplatform.PipelineJob,
|
1652 | 1798 | ):
|
| 1799 | +"""Internal constructor. Do not call directly.""" |
1653 | 1800 | self._base_model = base_model
|
1654 | 1801 | self._job = job
|
1655 | 1802 | self._model: Optional[_LanguageModel] = None
|
1656 | 1803 |
|
1657 |
| -def result(self) -> "_LanguageModel": |
| 1804 | +def get_tuned_model(self) -> "_LanguageModel": |
1658 | 1805 | """Blocks until the tuning is complete and returns a `LanguageModel` object."""
|
1659 | 1806 | if self._model:
|
1660 | 1807 | return self._model
|
@@ -1681,11 +1828,12 @@ def result(self) -> "_LanguageModel":
|
1681 | 1828 | return self._model
|
1682 | 1829 |
|
1683 | 1830 | @property
|
1684 |
| -def status(self): |
1685 |
| -"""Job status""" |
| 1831 | +def _status(self) -> Optional[aiplatform_types.pipeline_state.PipelineState]: |
| 1832 | +"""Job status.""" |
1686 | 1833 | return self._job.state
|
1687 | 1834 |
|
1688 |
| -def cancel(self): |
| 1835 | +def _cancel(self): |
| 1836 | +"""Cancels the job.""" |
1689 | 1837 | self._job.cancel()
|
1690 | 1838 |
|
1691 | 1839 |
|
|
0 commit comments