Perform batch code prediction using a pre-trained code generation model.
Code sample
Go
Before trying this sample, follow the Go setup instructions in the Vertex AI quickstart using client libraries. For more information, see the Vertex AI Go API reference documentation.
To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.
import ( "context" "fmt" "io" aiplatform "cloud.google.com/go/aiplatform/apiv1" aiplatformpb "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" "google.golang.org/api/option" "google.golang.org/protobuf/types/known/structpb" ) // batchCodePredict perform batch code prediction using a pre-trained code generation model func batchCodePredict(w io.Writer, projectID, location, name, outputURI string, inputURIs []string) error { // inputURI := []string{"gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl"} // outputURI: existing template path. Following formats are allowed: // - gs://BUCKET_NAME/DIRECTORY/ // - bq://project_name.llm_dataset ctx := context.Background() apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) // Pretrained code model model := "publishers/google/models/code-bison" parameters := map[string]interface{}{ "temperature": 0.2, "maxOutputTokens": 200, } parametersValue, err := structpb.NewValue(parameters) if err != nil { fmt.Fprintf(w, "unable to convert parameters to Value: %v", err) return err } client, err := aiplatform.NewJobClient(ctx, option.WithEndpoint(apiEndpoint)) if err != nil { return err } defer client.Close() req := &aiplatformpb.CreateBatchPredictionJobRequest{ Parent: fmt.Sprintf("projects/%s/locations/%s", projectID, location), BatchPredictionJob: &aiplatformpb.BatchPredictionJob{ DisplayName: name, Model: model, ModelParameters: parametersValue, InputConfig: &aiplatformpb.BatchPredictionJob_InputConfig{ Source: &aiplatformpb.BatchPredictionJob_InputConfig_GcsSource{ GcsSource: &aiplatformpb.GcsSource{ Uris: inputURIs, }, }, // List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model InstancesFormat: "jsonl", }, OutputConfig: &aiplatformpb.BatchPredictionJob_OutputConfig{ Destination: &aiplatformpb.BatchPredictionJob_OutputConfig_GcsDestination{ GcsDestination: &aiplatformpb.GcsDestination{ OutputUriPrefix: outputURI, }, }, // List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model PredictionsFormat: "jsonl", }, }, } job, err := client.CreateBatchPredictionJob(ctx, req) if err != nil { return err } fmt.Fprint(w, job.GetDisplayName()) return nil }
Java
Before trying this sample, follow the Java setup instructions in the Vertex AI quickstart using client libraries. For more information, see the Vertex AI Java API reference documentation.
To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.
import com.google.cloud.aiplatform.v1.BatchPredictionJob; import com.google.cloud.aiplatform.v1.GcsDestination; import com.google.cloud.aiplatform.v1.GcsSource; import com.google.cloud.aiplatform.v1.JobServiceClient; import com.google.cloud.aiplatform.v1.JobServiceSettings; import com.google.cloud.aiplatform.v1.LocationName; import com.google.gson.Gson; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import java.io.IOException; import java.util.HashMap; import java.util.Map; public class BatchCodePredictionSample { public static void main(String[] args) throws IOException, InterruptedException { // TODO(developer): Replace these variables before running the sample. String project = "YOUR_PROJECT_ID"; String location = "us-central1"; // inputUri: URI of the input dataset. // Could be a BigQuery table or a Google Cloud Storage file. // E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]" String inputUri = "gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl"; // outputUri: URI where the output will be stored. // Could be a BigQuery table or a Google Cloud Storage file. // E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]" String outputUri = "gs://YOUR_BUCKET/batch_code_predict_output"; String codeModel = "code-bison"; batchCodePredictionSample(project, location, inputUri, outputUri, codeModel); } // Perform batch code prediction using a pre-trained code generation model. // Example of using Google Cloud Storage bucket as the input and output data source public static BatchPredictionJob batchCodePredictionSample( String project, String location, String inputUri, String outputUri, String codeModel) throws IOException { BatchPredictionJob response; JobServiceSettings jobServiceSettings = JobServiceSettings.newBuilder() .setEndpoint("us-central1-aiplatform.googleapis.com:443").build(); LocationName parent = LocationName.of(project, location); String modelName = String.format( "projects/%s/locations/%s/publishers/google/models/%s", project, location, codeModel); // Construct your modelParameters Map<String, String> modelParameters = new HashMap<>(); modelParameters.put("maxOutputTokens", "200"); modelParameters.put("temperature", "0.2"); modelParameters.put("topP", "0.95"); modelParameters.put("topK", "40"); Value parameterValue = mapToValue(modelParameters); // Initialize client that will be used to send requests. This client only needs to be created // once, and can be reused for multiple requests. try (JobServiceClient client = JobServiceClient.create(jobServiceSettings)) { BatchPredictionJob batchPredictionJob = BatchPredictionJob.newBuilder() .setDisplayName("my batch code prediction job " + System.currentTimeMillis()) .setModel(modelName) .setInputConfig( BatchPredictionJob.InputConfig.newBuilder() .setGcsSource(GcsSource.newBuilder().addUris(inputUri).build()) .setInstancesFormat("jsonl") .build()) .setOutputConfig( BatchPredictionJob.OutputConfig.newBuilder() .setGcsDestination(GcsDestination.newBuilder() .setOutputUriPrefix(outputUri).build()) .setPredictionsFormat("jsonl") .build()) .setModelParameters(parameterValue) .build(); response = client.createBatchPredictionJob(parent, batchPredictionJob); System.out.format("response: %s\n", response); System.out.format("\tName: %s\n", response.getName()); } return response; } private static Value mapToValue(Map<String, String> map) throws InvalidProtocolBufferException { Gson gson = new Gson(); String json = gson.toJson(map); Value.Builder builder = Value.newBuilder(); JsonFormat.parser().merge(json, builder); return builder.build(); } }
What's next
To search and filter code samples for other Google Cloud products, see the Google Cloud sample browser.