1313# limitations under the License.
1414#
1515
16- from typing import Literal , Optional , Union
16+ from typing import Dict , Literal , Optional , Union
1717
1818from google .cloud .aiplatform_v1 .types import tuning_job as gca_tuning_job_types
19-
2019from vertexai import generative_models
2120from vertexai .tuning import _tuning
2221
@@ -30,27 +29,28 @@ def train(
3029 epochs : Optional [int ] = None ,
3130 learning_rate_multiplier : Optional [float ] = None ,
3231 adapter_size : Optional [Literal [1 , 4 , 8 , 16 ]] = None ,
32+ labels : Optional [Dict [str , str ]] = None ,
3333) -> "SupervisedTuningJob" :
34- """Tunes a model using supervised training.
34+ """Tunes a model using supervised training.
3535
36- Args:
37- source_model (str):
38- Model name for tuning, e.g., "gemini-1.0-pro-002".
39- train_dataset: Cloud Storage path to file containing training dataset for tuning .
40- The dataset should be in JSONL format.
41- validation_dataset: Cloud Storage path to file containing validation dataset for tuning .
42- The dataset should be in JSONL format.
43- tuned_model_display_name: The display name of the
44- [TunedModel][google.cloud.aiplatform.v1.Model]. The name can
45- be up to 128 characters long and can consist of any UTF-8 characters .
46- epochs: Number of training epoches for this tuning job .
47- learning_rate_multiplier: Learning rate multiplier for tuning.
48- adapter_size: Adapter size for tuning.
36+ Args:
37+ source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
38+ train_dataset: Cloud Storage path to file containing training dataset for
39+ tuning. The dataset should be in JSONL format .
40+ validation_dataset: Cloud Storage path to file containing validation
41+ dataset for tuning. The dataset should be in JSONL format .
42+ tuned_model_display_name: The display name of the
43+ [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
44+ 128 characters long and can consist of any UTF-8 characters.
45+ epochs: Number of training epoches for this tuning job .
46+ learning_rate_multiplier: Learning rate multiplier for tuning.
47+ adapter_size: Adapter size for tuning.
48+ labels: User-defined metadata to be associated with trained models
4949
50- Returns:
51- A `TuningJob` object.
52- """
53- supervised_tuning_spec = gca_tuning_job_types .SupervisedTuningSpec (
50+ Returns:
51+ A `TuningJob` object.
52+ """
53+ supervised_tuning_spec = gca_tuning_job_types .SupervisedTuningSpec (
5454 training_dataset_uri = train_dataset ,
5555 validation_dataset_uri = validation_dataset ,
5656 hyper_parameters = gca_tuning_job_types .SupervisedHyperParameters (
@@ -60,14 +60,15 @@ def train(
6060 ),
6161 )
6262
63- if isinstance (source_model , generative_models .GenerativeModel ):
64- source_model = source_model ._prediction_resource_name .rpartition ('/' )[- 1 ]
63+ if isinstance (source_model , generative_models .GenerativeModel ):
64+ source_model = source_model ._prediction_resource_name .rpartition ('/' )[- 1 ]
6565
66- return SupervisedTuningJob ._create ( # pylint: disable=protected-access
67- base_model = source_model ,
68- tuning_spec = supervised_tuning_spec ,
69- tuned_model_display_name = tuned_model_display_name ,
70- )
66+ return SupervisedTuningJob ._create ( # pylint: disable=protected-access
67+ base_model = source_model ,
68+ tuning_spec = supervised_tuning_spec ,
69+ tuned_model_display_name = tuned_model_display_name ,
70+ labels = labels ,
71+ )
7172
7273
7374class SupervisedTuningJob (_tuning .TuningJob ):
0 commit comments