Skip to content
This repository was archived by the owner on Sep 9, 2023. It is now read-only.
4 changes: 4 additions & 0 deletions google-cloud-aiplatform/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@
<classifier>testlib</classifier>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
</dependency>
</dependencies>

<profiles>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.aiplatform.utility;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: most other utility packages use util

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;

/**
* Exposes utility methods for converting AI Platform messages to and from
* {@com.google.protobuf.Value} objects.
*/
public class ValueConverter {

/** An empty {@com.google.protobuf.Value} message. */
public static final Value EMPTY_VALUE = Value.newBuilder().build();

/**
* Converts a message type to a {@com.google.protobuf.Value}.
*
* @param message the message to convert
* @return the message as a {@com.google.protobuf.Value}
* @throws InvalidProtocolBufferException
*/
public static Value toValue(Message message) throws InvalidProtocolBufferException {
String jsonString = JsonFormat.printer().print(message);
Value.Builder value = Value.newBuilder();
JsonFormat.parser().merge(jsonString, value);
return value.build();
}

/**
* Converts a {@com.google.protobuf.Value} to a {@com.google.protobuf.Message} of the provided
* {@com.google.protobuf.Message.Builder}.
*
* @param messageBuilder a builder for the message type
* @param value the Value to convert to a message
* @return the value as a message
* @throws InvalidProtocolBufferException
*/
public static Message fromValue(Message.Builder messageBuilder, Value value)
throws InvalidProtocolBufferException {
String valueString = JsonFormat.printer().print(value);
JsonFormat.parser().merge(valueString, messageBuilder);
return messageBuilder.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.aiplatform.utility;

import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.gson.JsonObject;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MapEntry;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.util.Collection;
import java.util.stream.Collectors;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class ValueConverterTest {

private static JsonObject testJsonInputs;
private static Value testValueInputs;
private static AutoMlImageClassificationInputs testObjectInputs;

@Before
public void setUp() throws InvalidProtocolBufferException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything in this setup is only used in the one test - you can move it inside there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

testJsonInputs = new JsonObject();
testJsonInputs.addProperty("multi_label", true);
testJsonInputs.addProperty("model_type", "CLOUD");
testJsonInputs.addProperty("budget_milli_node_hours", 8000);

Value.Builder valueBuilder = Value.newBuilder();
JsonFormat.parser().merge(testJsonInputs.toString(), valueBuilder);
testValueInputs = valueBuilder.build();

testObjectInputs =
AutoMlImageClassificationInputs.newBuilder()
.setModelType(ModelType.CLOUD)
.setBudgetMilliNodeHours(8000)
.setMultiLabel(true)
.setDisableEarlyStopping(false)
.build();
}

@Test
public void testValueConverterToValue() throws InvalidProtocolBufferException {
Value actualConvertedValue = ValueConverter.toValue(testObjectInputs);

Struct actualStruct = actualConvertedValue.getStructValue();
Assert.assertEquals(3, actualStruct.getFieldsCount());

Collection<Object> innerFields = actualStruct.getAllFields().values();
Collection<MapEntry> fieldEntries = (Collection<MapEntry>) innerFields.toArray()[0];

MapEntry actualBoolValueEntry = null;
MapEntry actualStringValueEntry = null;
MapEntry actualNumberValueEntry = null;

for (MapEntry entry : fieldEntries) {
String key = entry.getKey().toString();
if (key.contains("multiLabel")) {
actualBoolValueEntry = entry;
} else if (key.contains("modelType")) {
actualStringValueEntry = entry;
} else if (key.contains("budgetMilliNodeHours")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these checks for key name use equals?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

actualNumberValueEntry = entry;
}
}


Value actualBoolValue = (Value) actualBoolValueEntry.getValue();
Assert.assertEquals(testObjectInputs.getMultiLabel(), actualBoolValue.getBoolValue());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: usually Assert.assertEquals is statically imported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


Value actualStringValue = (Value) actualStringValueEntry.getValue();
Assert.assertEquals("CLOUD", actualStringValue.getStringValue());

Value actualNumberValue = (Value) actualNumberValueEntry.getValue();
// protobuf stores int64 values as strings rather than numbers
long actualNumber = Long.parseLong(actualNumberValue.getStringValue());
Assert.assertEquals(testObjectInputs.getBudgetMilliNodeHours(), actualNumber);
}

@Test
public void testValueConverterFromValue() throws InvalidProtocolBufferException {
AutoMlImageClassificationInputs actualInputs =
(AutoMlImageClassificationInputs)
ValueConverter.fromValue(AutoMlImageClassificationInputs.newBuilder(), testValueInputs);

Assert.assertEquals(8000, actualInputs.getBudgetMilliNodeHours());
Assert.assertEquals(true, actualInputs.getMultiLabel());
Assert.assertEquals(ModelType.CLOUD, actualInputs.getModelType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.cloud.aiplatform.utility.ValueConverter;
import com.google.rpc.Status;
import java.io.IOException;

Expand Down Expand Up @@ -74,11 +75,13 @@ static void createTrainingPipelineImageClassificationSample(
+ "automl_image_classification_1.0.0.yaml";
LocationName locationName = LocationName.of(project, location);

String jsonString =
"{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000,"
+ " \"disableEarlyStopping\": false}";
Value.Builder trainingTaskInputs = Value.newBuilder();
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
AutoMlImageClassificationInputs autoMlImageClassificationInputs =
AutoMlImageClassificationInputs.newBuilder()
.setModelType(ModelType.CLOUD)
.setMultiLabel(false)
.setBudgetMilliNodeHours(8000)
.setDisableEarlyStopping(false)
.build();

InputDataConfig trainingInputDataConfig =
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
Expand All @@ -87,7 +90,7 @@ static void createTrainingPipelineImageClassificationSample(
TrainingPipeline.newBuilder()
.setDisplayName(trainingPipelineDisplayName)
.setTrainingTaskDefinition(trainingTaskDefinition)
.setTrainingTaskInputs(trainingTaskInputs)
.setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
.setInputDataConfig(trainingInputDataConfig)
.setModelToUpload(model)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.ImageClassificationPredictionInstance;
import com.google.cloud.aiplatform.v1beta1.schema.predict.params.ImageClassificationPredictionParams;
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
import com.google.cloud.aiplatform.utility.ValueConverter;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
Expand Down Expand Up @@ -60,23 +63,40 @@ static void predictImageClassification(String project, String fileName, String e
byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName)));
String content = new String(contents, StandardCharsets.UTF_8);

Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();

String contentDict = "{\"content\": \"" + content + "\"}";
Value.Builder instance = Value.newBuilder();
JsonFormat.parser().merge(contentDict, instance);
ImageClassificationPredictionInstance predictionInstance =
ImageClassificationPredictionInstance.newBuilder()
.setContent(content)
.build();

List<Value> instances = new ArrayList<>();
instances.add(instance.build());
instances.add(ValueConverter.toValue(predictionInstance));

ImageClassificationPredictionParams predictionParams =
ImageClassificationPredictionParams.newBuilder()
.setConfidenceThreshold((float) 0.5)
.setMaxPredictions(5)
.build();

PredictResponse predictResponse =
predictionServiceClient.predict(endpointName, instances, parameter);
predictionServiceClient.predict(endpointName, instances, ValueConverter.toValue(predictionParams));
System.out.println("Predict Image Classification Response");
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());

System.out.println("Predictions");
for (Value prediction : predictResponse.getPredictionsList()) {
System.out.format("\tPrediction: %s\n", prediction);

ClassificationPredictionResult.Builder resultBuilder = ClassificationPredictionResult.newBuilder();
// Display names and confidences values correspond to
// IDs in the ID list.
ClassificationPredictionResult result =
(ClassificationPredictionResult)ValueConverter.fromValue(resultBuilder, prediction);
int counter = 0;
for (Long id : result.getIdsList()) {
System.out.printf("Label ID: %d\n", id);
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
counter++;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.TextClassificationPredictionInstance;
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
import com.google.cloud.aiplatform.utility.ValueConverter;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -52,25 +54,37 @@ static void predictTextClassificationSingleLabel(
try (PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings)) {
String location = "us-central1";
String jsonString = "{\"content\": \"" + content + "\"}";

EndpointName endpointName = EndpointName.of(project, location, endpointId);

Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
Value.Builder instance = Value.newBuilder();
JsonFormat.parser().merge(jsonString, instance);
TextClassificationPredictionInstance predictionInstance = TextClassificationPredictionInstance
.newBuilder()
.setContent(content)
.build();

List<Value> instances = new ArrayList<>();
instances.add(instance.build());
instances.add(ValueConverter.toValue(predictionInstance));

PredictResponse predictResponse =
predictionServiceClient.predict(endpointName, instances, parameter);
predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE);
System.out.println("Predict Text Classification Response");
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());

System.out.println("Predictions");
System.out.println("Predictions:\n\n");
for (Value prediction : predictResponse.getPredictionsList()) {
System.out.format("\tPrediction: %s\n", prediction);

ClassificationPredictionResult.Builder resultBuilder = ClassificationPredictionResult.newBuilder();

// Display names and confidences values correspond to
// IDs in the ID list.
ClassificationPredictionResult result =
(ClassificationPredictionResult)ValueConverter.fromValue(resultBuilder, prediction);
int counter = 0;
for (Long id : result.getIdsList()) {
System.out.printf("Label ID: %d\n", id);
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
counter++;
}
}
}
}
Expand Down