File tree

3 files changed

+200
-26
lines changed

3 files changed

+200
-26
lines changed
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_tuning(self, shared_state):
189189
df=training_data, upload_gcs_path=dataset_uri
190190
)
191191

192-
model.tune_model(
192+
tuning_job = model.tune_model(
193193
training_data=training_data,
194194
train_steps=1,
195195
tuning_job_location="europe-west4",
@@ -211,6 +211,18 @@ def test_tuning(self, shared_state):
211211
)
212212
# Deleting the Endpoint is a little less bad since the LLM SDK will recreate it, but it's not advised for the same reason.
213213

214+
# Testing the new model returned by the `tuning_job.get_tuned_model` method
215+
tuned_model1 = tuning_job.get_tuned_model()
216+
response1 = tuned_model1.predict(
217+
"What is the best recipe for banana bread? Recipe:",
218+
max_output_tokens=128,
219+
temperature=0,
220+
top_p=1,
221+
top_k=5,
222+
)
223+
assert response1.text
224+
225+
# Testing the model updated in-place (Deprecated. Preview only)
214226
response = model.predict(
215227
"What is the best recipe for banana bread? Recipe:",
216228
max_output_tokens=128,
Original file line numberDiff line numberDiff line change
@@ -1039,13 +1039,13 @@ def mock_get_tuned_model(get_endpoint_mock):
10391039
with mock..object(
10401040
_language_models._TunableModelMixin, "get_tuned_model"
10411041
) as mock_text_generation_model:
1042-
mock_text_generation_model._model_id = (
1042+
mock_text_generation_model.return_value._model_id = (
10431043
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
10441044
)
1045-
mock_text_generation_model._endpoint_name = (
1045+
mock_text_generation_model.return_value._endpoint_name = (
10461046
test_constants.EndpointConstants._TEST_ENDPOINT_NAME
10471047
)
1048-
mock_text_generation_model._endpoint = get_endpoint_mock
1048+
mock_text_generation_model.return_value._endpoint = get_endpoint_mock
10491049
yield mock_text_generation_model
10501050

10511051

@@ -1344,7 +1344,7 @@ def test_tune_text_generation_model(
13441344
enable_early_stopping = True
13451345
tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"
13461346

1347-
model.tune_model(
1347+
tuning_job = model.tune_model(
13481348
training_data=_TEST_TEXT_BISON_TRAINING_DF,
13491349
tuning_job_location=tuning_job_location,
13501350
tuned_model_location="us-central1",
@@ -1375,6 +1375,13 @@ def test_tune_text_generation_model(
13751375
== _TEST_ENCRYPTION_KEY_NAME
13761376
)
13771377

1378+
# Testing the tuned model
1379+
tuned_model = tuning_job.get_tuned_model()
1380+
assert (
1381+
tuned_model._endpoint_name
1382+
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
1383+
)
1384+
13781385
@pytest.mark.parametrize(
13791386
"job_spec",
13801387
[_TEST_PIPELINE_SPEC_JSON],
@@ -1408,7 +1415,7 @@ def test_tune_chat_model(
14081415
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")
14091416

14101417
default_context = "Default context"
1411-
model.tune_model(
1418+
tuning_job = model.tune_model(
14121419
training_data=_TEST_TEXT_BISON_TRAINING_DF,
14131420
tuning_job_location="europe-west4",
14141421
tuned_model_location="us-central1",
@@ -1421,6 +1428,13 @@ def test_tune_chat_model(
14211428
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"
14221429
assert pipeline_arguments["default_context"] == default_context
14231430

1431+
# Testing the tuned model
1432+
tuned_model = tuning_job.get_tuned_model()
1433+
assert (
1434+
tuned_model._endpoint_name
1435+
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
1436+
)
1437+
14241438
@pytest.mark.parametrize(
14251439
"job_spec",
14261440
[_TEST_PIPELINE_SPEC_JSON],
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from google.cloud.aiplatform import base
2424
from google.cloud.aiplatform import initializer as aiplatform_initializer
2525
from google.cloud.aiplatform import utils as aiplatform_utils
26+
from google.cloud.aiplatform.compat import types as aiplatform_types
2627
from google.cloud.aiplatform.utils import gcs_utils
2728
from vertexai._model_garden import _model_garden_models
2829
from vertexai.language_models import (
@@ -148,18 +149,24 @@ def tune_model(
148149
self,
149150
training_data: Union[str, "pandas.core.frame.DataFrame"],
150151
*,
151-
train_steps: int = 1000,
152+
train_steps: Optional[int] = None,
152153
learning_rate: Optional[float] = None,
153154
learning_rate_multiplier: Optional[float] = None,
154155
tuning_job_location: Optional[str] = None,
155156
tuned_model_location: Optional[str] = None,
156157
model_display_name: Optional[str] = None,
157158
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
158159
default_context: Optional[str] = None,
159-
):
160+
) -> "_LanguageModelTuningJob":
160161
"""Tunes a model based on training data.
161162
162-
This method launches a model tuning job that can take some time.
163+
This method launches and returns an asynchronous model tuning job.
164+
Usage:
165+
```
166+
tuning_job = model.tune_model(...)
167+
... do some other work
168+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
169+
```
163170
164171
Args:
165172
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -303,16 +310,68 @@ def _tune_model(
303310
base_model=self,
304311
job=pipeline_job,
305312
)
306-
self._job = job
307-
tuned_model = job.result()
308-
# The UXR study attendees preferred to tune model in place
309-
self._endpoint = tuned_model._endpoint
310-
self._endpoint_name = tuned_model._endpoint_name
313+
return job
311314

312315

313316
class _TunableTextModelMixin(_TunableModelMixin):
314317
"""Text model that can be tuned."""
315318

319+
def tune_model(
320+
self,
321+
training_data: Union[str, "pandas.core.frame.DataFrame"],
322+
*,
323+
train_steps: Optional[int] = None,
324+
learning_rate_multiplier: Optional[float] = None,
325+
tuning_job_location: Optional[str] = None,
326+
tuned_model_location: Optional[str] = None,
327+
model_display_name: Optional[str] = None,
328+
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
329+
) -> "_LanguageModelTuningJob":
330+
"""Tunes a model based on training data.
331+
332+
This method launches and returns an asynchronous model tuning job.
333+
Usage:
334+
```
335+
tuning_job = model.tune_model(...)
336+
... do some other work
337+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
338+
339+
Args:
340+
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
341+
The dataset schema is model-specific.
342+
See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
343+
train_steps: Number of training batches to tune on (batch size is 8 samples).
344+
learning_rate_multiplier: Learning rate multiplier to use in tuning.
345+
tuning_job_location: GCP location where the tuning job should be run.
346+
Only "europe-west4" and "us-central1" locations are supported for now.
347+
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
348+
model_display_name: Custom display name for the tuned model.
349+
tuning_evaluation_spec: Specification for the model evaluation during tuning.
350+
351+
Returns:
352+
A `LanguageModelTuningJob` object that represents the tuning job.
353+
Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
354+
355+
Raises:
356+
ValueError: If the "tuning_job_location" value is not supported
357+
ValueError: If the "tuned_model_location" value is not supported
358+
RuntimeError: If the model does not support tuning
359+
"""
360+
# Note: Chat models do not support default_context
361+
return super().tune_model(
362+
training_data=training_data,
363+
train_steps=train_steps,
364+
learning_rate_multiplier=learning_rate_multiplier,
365+
tuning_job_location=tuning_job_location,
366+
tuned_model_location=tuned_model_location,
367+
model_display_name=model_display_name,
368+
tuning_evaluation_spec=tuning_evaluation_spec,
369+
)
370+
371+
372+
class _PreviewTunableTextModelMixin(_TunableModelMixin):
373+
"""Text model that can be tuned."""
374+
316375
def tune_model(
317376
self,
318377
training_data: Union[str, "pandas.core.frame.DataFrame"],
@@ -324,10 +383,20 @@ def tune_model(
324383
tuned_model_location: Optional[str] = None,
325384
model_display_name: Optional[str] = None,
326385
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
327-
):
386+
) -> "_LanguageModelTuningJob":
328387
"""Tunes a model based on training data.
329388
330-
This method launches a model tuning job that can take some time.
389+
This method launches a model tuning job, waits for completion,
390+
updates the model in-place. This method returns job object for forward
391+
compatibility.
392+
In the future (GA), this method will become asynchronous and will stop
393+
updating the model in-place.
394+
395+
Usage:
396+
```
397+
tuning_job = model.tune_model(...) # Blocks until tuning is complete
398+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
399+
```
331400
332401
Args:
333402
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -353,7 +422,7 @@ def tune_model(
353422
RuntimeError: If the model does not support tuning
354423
"""
355424
# Note: Chat models do not support default_context
356-
return super().tune_model(
425+
job = super().tune_model(
357426
training_data=training_data,
358427
train_steps=train_steps,
359428
learning_rate=learning_rate,
@@ -363,11 +432,74 @@ def tune_model(
363432
model_display_name=model_display_name,
364433
tuning_evaluation_spec=tuning_evaluation_spec,
365434
)
435+
tuned_model = job.get_tuned_model()
436+
self._endpoint = tuned_model._endpoint
437+
self._endpoint_name = tuned_model._endpoint_name
438+
return job
366439

367440

368441
class _TunableChatModelMixin(_TunableModelMixin):
369442
"""Chat model that can be tuned."""
370443

444+
def tune_model(
445+
self,
446+
training_data: Union[str, "pandas.core.frame.DataFrame"],
447+
*,
448+
train_steps: Optional[int] = None,
449+
learning_rate_multiplier: Optional[float] = None,
450+
tuning_job_location: Optional[str] = None,
451+
tuned_model_location: Optional[str] = None,
452+
model_display_name: Optional[str] = None,
453+
default_context: Optional[str] = None,
454+
) -> "_LanguageModelTuningJob":
455+
"""Tunes a model based on training data.
456+
457+
This method launches and returns an asynchronous model tuning job.
458+
Usage:
459+
```
460+
tuning_job = model.tune_model(...)
461+
... do some other work
462+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
463+
```
464+
465+
Args:
466+
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
467+
The dataset schema is model-specific.
468+
See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
469+
train_steps: Number of training batches to tune on (batch size is 8 samples).
470+
learning_rate: Deprecated. Use learning_rate_multiplier instead.
471+
Learning rate to use in tuning.
472+
learning_rate_multiplier: Learning rate multiplier to use in tuning.
473+
tuning_job_location: GCP location where the tuning job should be run.
474+
Only "europe-west4" and "us-central1" locations are supported for now.
475+
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
476+
model_display_name: Custom display name for the tuned model.
477+
default_context: The context to use for all training samples by default.
478+
479+
Returns:
480+
A `LanguageModelTuningJob` object that represents the tuning job.
481+
Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
482+
483+
Raises:
484+
ValueError: If the "tuning_job_location" value is not supported
485+
ValueError: If the "tuned_model_location" value is not supported
486+
RuntimeError: If the model does not support tuning
487+
"""
488+
# Note: Chat models do not support tuning_evaluation_spec
489+
return super().tune_model(
490+
training_data=training_data,
491+
train_steps=train_steps,
492+
learning_rate_multiplier=learning_rate_multiplier,
493+
tuning_job_location=tuning_job_location,
494+
tuned_model_location=tuned_model_location,
495+
model_display_name=model_display_name,
496+
default_context=default_context,
497+
)
498+
499+
500+
class _PreviewTunableChatModelMixin(_TunableModelMixin):
501+
"""Chat model that can be tuned."""
502+
371503
def tune_model(
372504
self,
373505
training_data: Union[str, "pandas.core.frame.DataFrame"],
@@ -379,10 +511,20 @@ def tune_model(
379511
tuned_model_location: Optional[str] = None,
380512
model_display_name: Optional[str] = None,
381513
default_context: Optional[str] = None,
382-
):
514+
) -> "_LanguageModelTuningJob":
383515
"""Tunes a model based on training data.
384516
385-
This method launches a model tuning job that can take some time.
517+
This method launches a model tuning job, waits for completion,
518+
updates the model in-place. This method returns job object for forward
519+
compatibility.
520+
In the future (GA), this method will become asynchronous and will stop
521+
updating the model in-place.
522+
523+
Usage:
524+
```
525+
tuning_job = model.tune_model(...) # Blocks until tuning is complete
526+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
527+
```
386528
387529
Args:
388530
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -408,7 +550,7 @@ def tune_model(
408550
RuntimeError: If the model does not support tuning
409551
"""
410552
# Note: Chat models do not support tuning_evaluation_spec
411-
return super().tune_model(
553+
job = super().tune_model(
412554
training_data=training_data,
413555
train_steps=train_steps,
414556
learning_rate=learning_rate,
@@ -418,6 +560,10 @@ def tune_model(
418560
model_display_name=model_display_name,
419561
default_context=default_context,
420562
)
563+
tuned_model = job.get_tuned_model()
564+
self._endpoint = tuned_model._endpoint
565+
self._endpoint_name = tuned_model._endpoint_name
566+
return job
421567

422568

423569
@dataclasses.dataclass
@@ -746,7 +892,7 @@ class TextGenerationModel(_TextGenerationModel, _ModelWithBatchPredict):
746892

747893
class _PreviewTextGenerationModel(
748894
_TextGenerationModel,
749-
_TunableTextModelMixin,
895+
_PreviewTunableTextModelMixin,
750896
_PreviewModelWithBatchPredict,
751897
_evaluatable_language_models._EvaluatableLanguageModel,
752898
):
@@ -1076,7 +1222,7 @@ class ChatModel(_ChatModelBase):
10761222
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
10771223

10781224

1079-
class _PreviewChatModel(ChatModel, _TunableChatModelMixin):
1225+
class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):
10801226
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
10811227

10821228

@@ -1650,11 +1796,12 @@ def __init__(
16501796
base_model: _LanguageModel,
16511797
job: aiplatform.PipelineJob,
16521798
):
1799+
"""Internal constructor. Do not call directly."""
16531800
self._base_model = base_model
16541801
self._job = job
16551802
self._model: Optional[_LanguageModel] = None
16561803

1657-
def result(self) -> "_LanguageModel":
1804+
def get_tuned_model(self) -> "_LanguageModel":
16581805
"""Blocks until the tuning is complete and returns a `LanguageModel` object."""
16591806
if self._model:
16601807
return self._model
@@ -1681,11 +1828,12 @@ def result(self) -> "_LanguageModel":
16811828
return self._model
16821829

16831830
@property
1684-
def status(self):
1685-
"""Job status"""
1831+
def _status(self) -> Optional[aiplatform_types.pipeline_state.PipelineState]:
1832+
"""Job status."""
16861833
return self._job.state
16871834

1688-
def cancel(self):
1835+
def _cancel(self):
1836+
"""Cancels the job."""
16891837
self._job.cancel()
16901838

16911839

0 commit comments

Comments
 (0)