1515
1616from typing import Dict , Literal , Optional , Union
1717
18- from google .cloud .aiplatform_v1 .types import tuning_job as gca_tuning_job_types
18+ from google .cloud .aiplatform_v1 .types import (
19+ tuning_job as gca_tuning_job_types ,
20+ )
1921from vertexai import generative_models
2022from vertexai .tuning import _tuning
2123
@@ -31,44 +33,66 @@ def train(
3133 adapter_size : Optional [Literal [1 , 4 , 8 , 16 ]] = None ,
3234 labels : Optional [Dict [str , str ]] = None ,
3335) -> "SupervisedTuningJob" :
34- """Tunes a model using supervised training.
36+ """Tunes a model using supervised training.
3537
36- Args:
37- source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
38- train_dataset: Cloud Storage path to file containing training dataset for
39- tuning. The dataset should be in JSONL format.
40- validation_dataset: Cloud Storage path to file containing validation
41- dataset for tuning. The dataset should be in JSONL format.
42- tuned_model_display_name: The display name of the
43- [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
44- 128 characters long and can consist of any UTF-8 characters.
45- epochs: Number of training epoches for this tuning job.
46- learning_rate_multiplier: Learning rate multiplier for tuning.
47- adapter_size: Adapter size for tuning.
48- labels: User-defined metadata to be associated with trained models
38+ Args:
39+ source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
40+ train_dataset: Cloud Storage path to file containing training dataset for
41+ tuning. The dataset should be in JSONL format.
42+ validation_dataset: Cloud Storage path to file containing validation
43+ dataset for tuning. The dataset should be in JSONL format.
44+ tuned_model_display_name: The display name of the
45+ [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
46+ 128 characters long and can consist of any UTF-8 characters.
47+ epochs: Number of training epoches for this tuning job.
48+ learning_rate_multiplier: Learning rate multiplier for tuning.
49+ adapter_size: Adapter size for tuning.
50+ labels: User-defined metadata to be associated with trained models
4951
50- Returns:
51- A `TuningJob` object.
52- """
53- supervised_tuning_spec = gca_tuning_job_types .SupervisedTuningSpec (
52+ Returns:
53+ A `TuningJob` object.
54+ """
55+ if adapter_size is None :
56+ adapter_size_value = None
57+ elif adapter_size == 1 :
58+ adapter_size_value = (
59+ gca_tuning_job_types .SupervisedHyperParameters .AdapterSize .ADAPTER_SIZE_ONE
60+ )
61+ elif adapter_size == 4 :
62+ adapter_size_value = (
63+ gca_tuning_job_types .SupervisedHyperParameters .AdapterSize .ADAPTER_SIZE_FOUR
64+ )
65+ elif adapter_size == 8 :
66+ adapter_size_value = (
67+ gca_tuning_job_types .SupervisedHyperParameters .AdapterSize .ADAPTER_SIZE_EIGHT
68+ )
69+ elif adapter_size == 16 :
70+ adapter_size_value = (
71+ gca_tuning_job_types .SupervisedHyperParameters .AdapterSize .ADAPTER_SIZE_SIXTEEN
72+ )
73+ else :
74+ raise ValueError (
75+ f"Unsupported adapter size: { adapter_size } . The supported sizes are [1, 4, 8, 16]"
76+ )
77+ supervised_tuning_spec = gca_tuning_job_types .SupervisedTuningSpec (
5478 training_dataset_uri = train_dataset ,
5579 validation_dataset_uri = validation_dataset ,
5680 hyper_parameters = gca_tuning_job_types .SupervisedHyperParameters (
5781 epoch_count = epochs ,
5882 learning_rate_multiplier = learning_rate_multiplier ,
59- adapter_size = adapter_size ,
83+ adapter_size = adapter_size_value ,
6084 ),
6185 )
6286
63- if isinstance (source_model , generative_models .GenerativeModel ):
64- source_model = source_model ._prediction_resource_name .rpartition ('/' )[- 1 ]
87+ if isinstance (source_model , generative_models .GenerativeModel ):
88+ source_model = source_model ._prediction_resource_name .rpartition ("/" )[- 1 ]
6589
66- return SupervisedTuningJob ._create ( # pylint: disable=protected-access
67- base_model = source_model ,
68- tuning_spec = supervised_tuning_spec ,
69- tuned_model_display_name = tuned_model_display_name ,
70- labels = labels ,
71- )
90+ return SupervisedTuningJob ._create ( # pylint: disable=protected-access
91+ base_model = source_model ,
92+ tuning_spec = supervised_tuning_spec ,
93+ tuned_model_display_name = tuned_model_display_name ,
94+ labels = labels ,
95+ )
7296
7397
7498class SupervisedTuningJob (_tuning .TuningJob ):
0 commit comments