Skip to content
This repository was archived by the owner on Sep 9, 2023. It is now read-only.
4 changes: 4 additions & 0 deletions google-cloud-aiplatform/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@
<classifier>testlib</classifier>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
</dependency>
</dependencies>

<profiles>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.google.cloud.aiplatform.v1beta1.utility;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;

/**
* Exposes utility methods for converting AI Platform messages to and
* from {@com.google.protobuf.Value} objects.
*/
public class ValueConverter {

/**
* An empty {@com.google.protobuf.Value} message.
*/
public static final Value EMPTY_VALUE = Value.newBuilder().build();

/**
* Converts a message type to a {@com.google.protobuf.Value}.
*
* @param message the message to convert
* @return the message as a {@com.google.protobuf.Value}
* @throws InvalidProtocolBufferException
*/
public static Value toValue(Message message) throws InvalidProtocolBufferException {
String jsonString = JsonFormat.printer().print(message);
Value.Builder value = Value.newBuilder();
JsonFormat.parser().merge(jsonString, value);
return value.build();
}

/**
* Converts a {@com.google.protobuf.Value} to a {@com.google.protobuf.Message}
* of the provided {@com.google.protobuf.Message.Builder}.
*
* @param messageBuilder a builder for the message type
* @param value the Value to convert to a message
* @return the value as a message
* @throws InvalidProtocolBufferException
*/
public static Message fromValue(Message.Builder messageBuilder, Value value)
throws InvalidProtocolBufferException {
String valueString = JsonFormat.printer().print(value);
JsonFormat.parser().merge(valueString, messageBuilder);
return messageBuilder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationMetadata;
import com.google.cloud.aiplatform.v1beta1.utility.ValueConverter;
import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
Expand Down Expand Up @@ -76,11 +80,13 @@ static void createTrainingPipelineImageClassificationSample(
+ "automl_image_classification_1.0.0.yaml";
LocationName locationName = LocationName.of(project, location);

String jsonString =
"{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000,"
+ " \"disableEarlyStopping\": false}";
Value.Builder trainingTaskInputs = Value.newBuilder();
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
AutoMlImageClassificationInputs autoMlImageClassificationInputs =
AutoMlImageClassificationInputs.newBuilder()
.setModelType(ModelType.CLOUD)
.setMultiLabel(false)
.setBudgetMilliNodeHours(8000)
.setDisableEarlyStopping(false)
.build();

InputDataConfig trainingInputDataConfig =
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
Expand All @@ -89,7 +95,7 @@ static void createTrainingPipelineImageClassificationSample(
TrainingPipeline.newBuilder()
.setDisplayName(trainingPipelineDisplayName)
.setTrainingTaskDefinition(trainingTaskDefinition)
.setTrainingTaskInputs(trainingTaskInputs)
.setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
.setInputDataConfig(trainingInputDataConfig)
.setModelToUpload(model)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.ImageClassificationPredictionInstance;
import com.google.cloud.aiplatform.v1beta1.schema.predict.params.ImageClassificationPredictionParams;
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
import com.google.cloud.aiplatform.v1beta1.utility.ValueConverter;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
Expand Down Expand Up @@ -60,23 +64,40 @@ static void predictImageClassification(String project, String fileName, String e
byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName)));
String content = new String(contents, StandardCharsets.UTF_8);

Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();

String contentDict = "{\"content\": \"" + content + "\"}";
Value.Builder instance = Value.newBuilder();
JsonFormat.parser().merge(contentDict, instance);
ImageClassificationPredictionInstance predictionInstance =
ImageClassificationPredictionInstance.newBuilder()
.setContent(content)
.build();

List<Value> instances = new ArrayList<>();
instances.add(instance.build());
instances.add(ValueConverter.toValue(predictionInstance));

ImageClassificationPredictionParams predictionParams =
ImageClassificationPredictionParams.newBuilder()
.setConfidenceThreshold((float) 0.5)
.setMaxPredictions(5)
.build();

PredictResponse predictResponse =
predictionServiceClient.predict(endpointName, instances, parameter);
predictionServiceClient.predict(endpointName, instances, ValueConverter.toValue(predictionParams));
System.out.println("Predict Image Classification Response");
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());

System.out.println("Predictions");
for (Value prediction : predictResponse.getPredictionsList()) {
System.out.format("\tPrediction: %s\n", prediction);

ClassificationPredictionResult.Builder resultBuilder = ClassificationPredictionResult.newBuilder();
// Display names and confidences values correspond to
// IDs in the ID list.
ClassificationPredictionResult result =
(ClassificationPredictionResult)ValueConverter.fromValue(resultBuilder, prediction);
int counter = 0;
for (Long id : result.getIdsList()) {
System.out.printf("Label ID: %d\n", id);
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
counter++;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.TextClassificationPredictionInstance;
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
import com.google.cloud.aiplatform.v1beta1.utility.ValueConverter;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
Expand Down Expand Up @@ -52,25 +55,37 @@ static void predictTextClassificationSingleLabel(
try (PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings)) {
String location = "us-central1";
String jsonString = "{\"content\": \"" + content + "\"}";

EndpointName endpointName = EndpointName.of(project, location, endpointId);

Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
Value.Builder instance = Value.newBuilder();
JsonFormat.parser().merge(jsonString, instance);
TextClassificationPredictionInstance predictionInstance = TextClassificationPredictionInstance
.newBuilder()
.setContent(content)
.build();

List<Value> instances = new ArrayList<>();
instances.add(instance.build());
instances.add(ValueConverter.toValue(predictionInstance));

PredictResponse predictResponse =
predictionServiceClient.predict(endpointName, instances, parameter);
predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE);
System.out.println("Predict Text Classification Response");
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());

System.out.println("Predictions");
System.out.println("Predictions:\n\n");
for (Value prediction : predictResponse.getPredictionsList()) {
System.out.format("\tPrediction: %s\n", prediction);

ClassificationPredictionResult.Builder resultBuilder = ClassificationPredictionResult.newBuilder();

// Display names and confidences values correspond to
// IDs in the ID list.
ClassificationPredictionResult result =
(ClassificationPredictionResult)ValueConverter.fromValue(resultBuilder, prediction);
int counter = 0;
for (Long id : result.getIdsList()) {
System.out.printf("Label ID: %d\n", id);
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
counter++;
}
}
}
}
Expand Down