Skip to content

Commit 959d798

Browse files
Frances Hubis Thomacopybara-github
authored andcommitted
feat: Enable Vertex Multimodal Dataset as input to supervised fine-tuning.
PiperOrigin-RevId: 772177718
1 parent b4708de commit 959d798

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

vertexai/tuning/_supervised_tuning.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from typing import Dict, Literal, Optional, Union
1717

18+
from google.cloud.aiplatform.preview import datasets
1819
from google.cloud.aiplatform.utils import _ipython_utils
1920
from google.cloud.aiplatform_v1beta1.types import (
2021
tuning_job as gca_tuning_job_types,
@@ -26,8 +27,8 @@
2627
def train(
2728
*,
2829
source_model: Union[str, generative_models.GenerativeModel],
29-
train_dataset: str,
30-
validation_dataset: Optional[str] = None,
30+
train_dataset: Union[str, datasets.MultimodalDataset],
31+
validation_dataset: Optional[Union[str, datasets.MultimodalDataset]] = None,
3132
tuned_model_display_name: Optional[str] = None,
3233
epochs: Optional[int] = None,
3334
learning_rate_multiplier: Optional[float] = None,
@@ -38,8 +39,8 @@ def train(
3839
3940
Args:
4041
source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
41-
train_dataset: Training dataset used for tuning. The dataset can be specified as either a Cloud Storage path to a JSONL file or as the resource name of a Vertex Multimodal Dataset.
42-
validation_dataset: Validation dataset used for tuning. The dataset can be specified as either a Cloud Storage path to a JSONL file or as the resource name of a Vertex Multimodal Dataset.
42+
train_dataset: Training dataset used for tuning. The dataset can be a JSONL file on Google Cloud Storage (specified as its GCS URI) or a Vertex Multimodal Dataset (either as the dataset object itself or as its resource name).
43+
validation_dataset: Validation dataset used for tuning. The dataset can be a JSONL file on Google Cloud Storage (specified as its GCS URI) or a Vertex Multimodal Dataset (either as the dataset object itself or as the resource name).
4344
tuned_model_display_name: The display name of the
4445
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
4546
128 characters long and can consist of any UTF-8 characters.
@@ -73,6 +74,10 @@ def train(
7374
raise ValueError(
7475
f"Unsupported adapter size: {adapter_size}. The supported sizes are [1, 4, 8, 16]"
7576
)
77+
if isinstance(train_dataset, datasets.MultimodalDataset):
78+
train_dataset = train_dataset.resource_name
79+
if isinstance(validation_dataset, datasets.MultimodalDataset):
80+
validation_dataset = validation_dataset.resource_name
7681
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
7782
training_dataset_uri=train_dataset,
7883
validation_dataset_uri=validation_dataset,

0 commit comments

Comments
 (0)