Skip to content

Commit 1cc22c3

Browse files
ilai-deutelcopybara-github
authored andcommitted
fix: GenAI - Tuning - Supervised - Fix adapter_size parameter handling to match enum values.
PiperOrigin-RevId: 636608417
1 parent bed3dec commit 1cc22c3

File tree

2 files changed

+53
-28
lines changed

2 files changed

+53
-28
lines changed

tests/unit/vertexai/test_tuning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def test_genai_tuning_service_supervised_tuning_tune_model(self):
172172
validation_dataset="gs://some-bucket/some_dataset.jsonl",
173173
epochs=300,
174174
learning_rate_multiplier=1.0,
175+
adapter_size=8,
175176
)
176177
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_PENDING
177178
assert not sft_tuning_job.has_ended

vertexai/tuning/_supervised_tuning.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
from typing import Dict, Literal, Optional, Union
1717

18-
from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types
18+
from google.cloud.aiplatform_v1.types import (
19+
tuning_job as gca_tuning_job_types,
20+
)
1921
from vertexai import generative_models
2022
from vertexai.tuning import _tuning
2123

@@ -31,44 +33,66 @@ def train(
3133
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
3234
labels: Optional[Dict[str, str]] = None,
3335
) -> "SupervisedTuningJob":
34-
"""Tunes a model using supervised training.
36+
"""Tunes a model using supervised training.
3537
36-
Args:
37-
source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
38-
train_dataset: Cloud Storage path to file containing training dataset for
39-
tuning. The dataset should be in JSONL format.
40-
validation_dataset: Cloud Storage path to file containing validation
41-
dataset for tuning. The dataset should be in JSONL format.
42-
tuned_model_display_name: The display name of the
43-
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
44-
128 characters long and can consist of any UTF-8 characters.
45-
epochs: Number of training epoches for this tuning job.
46-
learning_rate_multiplier: Learning rate multiplier for tuning.
47-
adapter_size: Adapter size for tuning.
48-
labels: User-defined metadata to be associated with trained models
38+
Args:
39+
source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
40+
train_dataset: Cloud Storage path to file containing training dataset for
41+
tuning. The dataset should be in JSONL format.
42+
validation_dataset: Cloud Storage path to file containing validation
43+
dataset for tuning. The dataset should be in JSONL format.
44+
tuned_model_display_name: The display name of the
45+
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
46+
128 characters long and can consist of any UTF-8 characters.
47+
epochs: Number of training epoches for this tuning job.
48+
learning_rate_multiplier: Learning rate multiplier for tuning.
49+
adapter_size: Adapter size for tuning.
50+
labels: User-defined metadata to be associated with trained models
4951
50-
Returns:
51-
A `TuningJob` object.
52-
"""
53-
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
52+
Returns:
53+
A `TuningJob` object.
54+
"""
55+
if adapter_size is None:
56+
adapter_size_value = None
57+
elif adapter_size == 1:
58+
adapter_size_value = (
59+
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_ONE
60+
)
61+
elif adapter_size == 4:
62+
adapter_size_value = (
63+
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_FOUR
64+
)
65+
elif adapter_size == 8:
66+
adapter_size_value = (
67+
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_EIGHT
68+
)
69+
elif adapter_size == 16:
70+
adapter_size_value = (
71+
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_SIXTEEN
72+
)
73+
else:
74+
raise ValueError(
75+
f"Unsupported adapter size: {adapter_size}. The supported sizes are [1, 4, 8, 16]"
76+
)
77+
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
5478
training_dataset_uri=train_dataset,
5579
validation_dataset_uri=validation_dataset,
5680
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
5781
epoch_count=epochs,
5882
learning_rate_multiplier=learning_rate_multiplier,
59-
adapter_size=adapter_size,
83+
adapter_size=adapter_size_value,
6084
),
6185
)
6286

63-
if isinstance(source_model, generative_models.GenerativeModel):
64-
source_model = source_model._prediction_resource_name.rpartition('/')[-1]
87+
if isinstance(source_model, generative_models.GenerativeModel):
88+
source_model = source_model._prediction_resource_name.rpartition("/")[-1]
6589

66-
return SupervisedTuningJob._create( # pylint: disable=protected-access
67-
base_model=source_model,
68-
tuning_spec=supervised_tuning_spec,
69-
tuned_model_display_name=tuned_model_display_name,
70-
labels=labels,
71-
)
90+
return SupervisedTuningJob._create( # pylint: disable=protected-access
91+
base_model=source_model,
92+
tuning_spec=supervised_tuning_spec,
93+
tuned_model_display_name=tuned_model_display_name,
94+
labels=labels,
95+
)
7296

7397

7498
class SupervisedTuningJob(_tuning.TuningJob):

0 commit comments

Comments
 (0)