Skip to content

Commit 4540f9d

Browse files
speedstorm1copybara-github
authored andcommitted
feat: Add support for preference optimization tuning in the SDK.
PiperOrigin-RevId: 825682657
1 parent 64cab58 commit 4540f9d

File tree

3 files changed

+332
-70
lines changed

3 files changed

+332
-70
lines changed

google/genai/tests/tunings/test_tune.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,37 @@
6565
),
6666
exception_if_mldev="is not supported in Gemini API",
6767
),
68+
pytest_helper.TestTableItem(
69+
name="test_tune_simple_dpo",
70+
parameters=genai_types.CreateTuningJobParameters(
71+
base_model="gemini-2.5-flash",
72+
training_dataset=genai_types.TuningDataset(
73+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
74+
),
75+
config=genai_types.CreateTuningJobConfig(
76+
tuned_model_display_name="Model display name",
77+
epoch_count=1,
78+
method="PREFERENCE_TUNING",
79+
),
80+
),
81+
exception_if_mldev="parameter is not supported in Gemini API.",
82+
),
83+
pytest_helper.TestTableItem(
84+
name="test_tune_dpo_with_beta",
85+
parameters=genai_types.CreateTuningJobParameters(
86+
base_model="gemini-2.5-flash",
87+
training_dataset=genai_types.TuningDataset(
88+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
89+
),
90+
config=genai_types.CreateTuningJobConfig(
91+
tuned_model_display_name="Model display name",
92+
epoch_count=1,
93+
method=genai_types.TuningMethod.PREFERENCE_TUNING,
94+
beta=0.5,
95+
),
96+
),
97+
exception_if_mldev="parameter is not supported in Gemini API.",
98+
),
6899
pytest_helper.TestTableItem(
69100
name="test_non_pretuned_model_with_checkpoint_id",
70101
parameters=genai_types.CreateTuningJobParameters(

google/genai/tunings.py

Lines changed: 180 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def _CreateTuningJobConfig_to_mldev(
128128
if getv(from_object, ['labels']) is not None:
129129
raise ValueError('labels parameter is not supported in Gemini API.')
130130

131+
if getv(from_object, ['beta']) is not None:
132+
raise ValueError('beta parameter is not supported in Gemini API.')
133+
131134
return to_object
132135

133136

@@ -138,14 +141,28 @@ def _CreateTuningJobConfig_to_vertex(
138141
) -> dict[str, Any]:
139142
to_object: dict[str, Any] = {}
140143

141-
if getv(from_object, ['validation_dataset']) is not None:
142-
setv(
143-
parent_object,
144-
['supervisedTuningSpec'],
145-
_TuningValidationDataset_to_vertex(
146-
getv(from_object, ['validation_dataset']), to_object, root_object
147-
),
148-
)
144+
discriminator = getv(root_object, ['config', 'method'])
145+
if discriminator is None:
146+
discriminator = 'SUPERVISED_FINE_TUNING'
147+
if discriminator == 'SUPERVISED_FINE_TUNING':
148+
if getv(from_object, ['validation_dataset']) is not None:
149+
setv(
150+
parent_object,
151+
['supervisedTuningSpec'],
152+
_TuningValidationDataset_to_vertex(
153+
getv(from_object, ['validation_dataset']), to_object, root_object
154+
),
155+
)
156+
157+
elif discriminator == 'PREFERENCE_TUNING':
158+
if getv(from_object, ['validation_dataset']) is not None:
159+
setv(
160+
parent_object,
161+
['preferenceOptimizationSpec'],
162+
_TuningValidationDataset_to_vertex(
163+
getv(from_object, ['validation_dataset']), to_object, root_object
164+
),
165+
)
149166

150167
if getv(from_object, ['tuned_model_display_name']) is not None:
151168
setv(
@@ -157,52 +174,125 @@ def _CreateTuningJobConfig_to_vertex(
157174
if getv(from_object, ['description']) is not None:
158175
setv(parent_object, ['description'], getv(from_object, ['description']))
159176

160-
if getv(from_object, ['epoch_count']) is not None:
161-
setv(
162-
parent_object,
163-
['supervisedTuningSpec', 'hyperParameters', 'epochCount'],
164-
getv(from_object, ['epoch_count']),
165-
)
177+
discriminator = getv(root_object, ['config', 'method'])
178+
if discriminator is None:
179+
discriminator = 'SUPERVISED_FINE_TUNING'
180+
if discriminator == 'SUPERVISED_FINE_TUNING':
181+
if getv(from_object, ['epoch_count']) is not None:
182+
setv(
183+
parent_object,
184+
['supervisedTuningSpec', 'hyperParameters', 'epochCount'],
185+
getv(from_object, ['epoch_count']),
186+
)
166187

167-
if getv(from_object, ['learning_rate_multiplier']) is not None:
168-
setv(
169-
parent_object,
170-
['supervisedTuningSpec', 'hyperParameters', 'learningRateMultiplier'],
171-
getv(from_object, ['learning_rate_multiplier']),
172-
)
188+
elif discriminator == 'PREFERENCE_TUNING':
189+
if getv(from_object, ['epoch_count']) is not None:
190+
setv(
191+
parent_object,
192+
['preferenceOptimizationSpec', 'hyperParameters', 'epochCount'],
193+
getv(from_object, ['epoch_count']),
194+
)
173195

174-
if getv(from_object, ['export_last_checkpoint_only']) is not None:
175-
setv(
176-
parent_object,
177-
['supervisedTuningSpec', 'exportLastCheckpointOnly'],
178-
getv(from_object, ['export_last_checkpoint_only']),
179-
)
196+
discriminator = getv(root_object, ['config', 'method'])
197+
if discriminator is None:
198+
discriminator = 'SUPERVISED_FINE_TUNING'
199+
if discriminator == 'SUPERVISED_FINE_TUNING':
200+
if getv(from_object, ['learning_rate_multiplier']) is not None:
201+
setv(
202+
parent_object,
203+
['supervisedTuningSpec', 'hyperParameters', 'learningRateMultiplier'],
204+
getv(from_object, ['learning_rate_multiplier']),
205+
)
180206

181-
if getv(from_object, ['adapter_size']) is not None:
182-
setv(
183-
parent_object,
184-
['supervisedTuningSpec', 'hyperParameters', 'adapterSize'],
185-
getv(from_object, ['adapter_size']),
186-
)
207+
elif discriminator == 'PREFERENCE_TUNING':
208+
if getv(from_object, ['learning_rate_multiplier']) is not None:
209+
setv(
210+
parent_object,
211+
[
212+
'preferenceOptimizationSpec',
213+
'hyperParameters',
214+
'learningRateMultiplier',
215+
],
216+
getv(from_object, ['learning_rate_multiplier']),
217+
)
218+
219+
discriminator = getv(root_object, ['config', 'method'])
220+
if discriminator is None:
221+
discriminator = 'SUPERVISED_FINE_TUNING'
222+
if discriminator == 'SUPERVISED_FINE_TUNING':
223+
if getv(from_object, ['export_last_checkpoint_only']) is not None:
224+
setv(
225+
parent_object,
226+
['supervisedTuningSpec', 'exportLastCheckpointOnly'],
227+
getv(from_object, ['export_last_checkpoint_only']),
228+
)
229+
230+
elif discriminator == 'PREFERENCE_TUNING':
231+
if getv(from_object, ['export_last_checkpoint_only']) is not None:
232+
setv(
233+
parent_object,
234+
['preferenceOptimizationSpec', 'exportLastCheckpointOnly'],
235+
getv(from_object, ['export_last_checkpoint_only']),
236+
)
237+
238+
discriminator = getv(root_object, ['config', 'method'])
239+
if discriminator is None:
240+
discriminator = 'SUPERVISED_FINE_TUNING'
241+
if discriminator == 'SUPERVISED_FINE_TUNING':
242+
if getv(from_object, ['adapter_size']) is not None:
243+
setv(
244+
parent_object,
245+
['supervisedTuningSpec', 'hyperParameters', 'adapterSize'],
246+
getv(from_object, ['adapter_size']),
247+
)
248+
249+
elif discriminator == 'PREFERENCE_TUNING':
250+
if getv(from_object, ['adapter_size']) is not None:
251+
setv(
252+
parent_object,
253+
['preferenceOptimizationSpec', 'hyperParameters', 'adapterSize'],
254+
getv(from_object, ['adapter_size']),
255+
)
187256

188257
if getv(from_object, ['batch_size']) is not None:
189258
raise ValueError('batch_size parameter is not supported in Vertex AI.')
190259

191260
if getv(from_object, ['learning_rate']) is not None:
192261
raise ValueError('learning_rate parameter is not supported in Vertex AI.')
193262

194-
if getv(from_object, ['evaluation_config']) is not None:
195-
setv(
196-
parent_object,
197-
['supervisedTuningSpec', 'evaluationConfig'],
198-
_EvaluationConfig_to_vertex(
199-
getv(from_object, ['evaluation_config']), to_object, root_object
200-
),
201-
)
263+
discriminator = getv(root_object, ['config', 'method'])
264+
if discriminator is None:
265+
discriminator = 'SUPERVISED_FINE_TUNING'
266+
if discriminator == 'SUPERVISED_FINE_TUNING':
267+
if getv(from_object, ['evaluation_config']) is not None:
268+
setv(
269+
parent_object,
270+
['supervisedTuningSpec', 'evaluationConfig'],
271+
_EvaluationConfig_to_vertex(
272+
getv(from_object, ['evaluation_config']), to_object, root_object
273+
),
274+
)
275+
276+
elif discriminator == 'PREFERENCE_TUNING':
277+
if getv(from_object, ['evaluation_config']) is not None:
278+
setv(
279+
parent_object,
280+
['preferenceOptimizationSpec', 'evaluationConfig'],
281+
_EvaluationConfig_to_vertex(
282+
getv(from_object, ['evaluation_config']), to_object, root_object
283+
),
284+
)
202285

203286
if getv(from_object, ['labels']) is not None:
204287
setv(parent_object, ['labels'], getv(from_object, ['labels']))
205288

289+
if getv(from_object, ['beta']) is not None:
290+
setv(
291+
parent_object,
292+
['preferenceOptimizationSpec', 'hyperParameters', 'beta'],
293+
getv(from_object, ['beta']),
294+
)
295+
206296
return to_object
207297

208298

@@ -219,12 +309,8 @@ def _CreateTuningJobParametersPrivate_to_mldev(
219309
setv(to_object, ['preTunedModel'], getv(from_object, ['pre_tuned_model']))
220310

221311
if getv(from_object, ['training_dataset']) is not None:
222-
setv(
223-
to_object,
224-
['tuningTask', 'trainingData'],
225-
_TuningDataset_to_mldev(
226-
getv(from_object, ['training_dataset']), to_object, root_object
227-
),
312+
_TuningDataset_to_mldev(
313+
getv(from_object, ['training_dataset']), to_object, root_object
228314
)
229315

230316
if getv(from_object, ['config']) is not None:
@@ -501,19 +587,44 @@ def _TuningDataset_to_vertex(
501587
root_object: Optional[Union[dict[str, Any], object]] = None,
502588
) -> dict[str, Any]:
503589
to_object: dict[str, Any] = {}
504-
if getv(from_object, ['gcs_uri']) is not None:
505-
setv(
506-
parent_object,
507-
['supervisedTuningSpec', 'trainingDatasetUri'],
508-
getv(from_object, ['gcs_uri']),
509-
)
510590

511-
if getv(from_object, ['vertex_dataset_resource']) is not None:
512-
setv(
513-
parent_object,
514-
['supervisedTuningSpec', 'trainingDatasetUri'],
515-
getv(from_object, ['vertex_dataset_resource']),
516-
)
591+
discriminator = getv(root_object, ['config', 'method'])
592+
if discriminator is None:
593+
discriminator = 'SUPERVISED_FINE_TUNING'
594+
if discriminator == 'SUPERVISED_FINE_TUNING':
595+
if getv(from_object, ['gcs_uri']) is not None:
596+
setv(
597+
parent_object,
598+
['supervisedTuningSpec', 'trainingDatasetUri'],
599+
getv(from_object, ['gcs_uri']),
600+
)
601+
602+
elif discriminator == 'PREFERENCE_TUNING':
603+
if getv(from_object, ['gcs_uri']) is not None:
604+
setv(
605+
parent_object,
606+
['preferenceOptimizationSpec', 'trainingDatasetUri'],
607+
getv(from_object, ['gcs_uri']),
608+
)
609+
610+
discriminator = getv(root_object, ['config', 'method'])
611+
if discriminator is None:
612+
discriminator = 'SUPERVISED_FINE_TUNING'
613+
if discriminator == 'SUPERVISED_FINE_TUNING':
614+
if getv(from_object, ['vertex_dataset_resource']) is not None:
615+
setv(
616+
parent_object,
617+
['supervisedTuningSpec', 'trainingDatasetUri'],
618+
getv(from_object, ['vertex_dataset_resource']),
619+
)
620+
621+
elif discriminator == 'PREFERENCE_TUNING':
622+
if getv(from_object, ['vertex_dataset_resource']) is not None:
623+
setv(
624+
parent_object,
625+
['preferenceOptimizationSpec', 'trainingDatasetUri'],
626+
getv(from_object, ['vertex_dataset_resource']),
627+
)
517628

518629
if getv(from_object, ['examples']) is not None:
519630
raise ValueError('examples parameter is not supported in Vertex AI.')
@@ -635,6 +746,13 @@ def _TuningJob_from_vertex(
635746
getv(from_object, ['supervisedTuningSpec']),
636747
)
637748

749+
if getv(from_object, ['preferenceOptimizationSpec']) is not None:
750+
setv(
751+
to_object,
752+
['preference_optimization_spec'],
753+
getv(from_object, ['preferenceOptimizationSpec']),
754+
)
755+
638756
if getv(from_object, ['tuningDataStats']) is not None:
639757
setv(
640758
to_object, ['tuning_data_stats'], getv(from_object, ['tuningDataStats'])
@@ -950,7 +1068,7 @@ def _tune(
9501068
training_dataset: types.TuningDatasetOrDict,
9511069
config: Optional[types.CreateTuningJobConfigOrDict] = None,
9521070
) -> types.TuningJob:
953-
"""Creates a supervised fine-tuning job and returns the TuningJob object.
1071+
"""Creates a tuning job and returns the TuningJob object.
9541072
9551073
Args:
9561074
base_model: The name of the model to tune.
@@ -1023,7 +1141,7 @@ def _tune_mldev(
10231141
training_dataset: types.TuningDatasetOrDict,
10241142
config: Optional[types.CreateTuningJobConfigOrDict] = None,
10251143
) -> types.TuningOperation:
1026-
"""Creates a supervised fine-tuning job and returns the TuningJob object.
1144+
"""Creates a tuning job and returns the TuningJob object.
10271145
10281146
Args:
10291147
base_model: The name of the model to tune.
@@ -1419,7 +1537,7 @@ async def _tune(
14191537
training_dataset: types.TuningDatasetOrDict,
14201538
config: Optional[types.CreateTuningJobConfigOrDict] = None,
14211539
) -> types.TuningJob:
1422-
"""Creates a supervised fine-tuning job and returns the TuningJob object.
1540+
"""Creates a tuning job and returns the TuningJob object.
14231541
14241542
Args:
14251543
base_model: The name of the model to tune.
@@ -1492,7 +1610,7 @@ async def _tune_mldev(
14921610
training_dataset: types.TuningDatasetOrDict,
14931611
config: Optional[types.CreateTuningJobConfigOrDict] = None,
14941612
) -> types.TuningOperation:
1495-
"""Creates a supervised fine-tuning job and returns the TuningJob object.
1613+
"""Creates a tuning job and returns the TuningJob object.
14961614
14971615
Args:
14981616
base_model: The name of the model to tune.

0 commit comments

Comments
 (0)