Skip to content

Commit 7983b44

Browse files
authored
fix: predict image samples params (#150)
1 parent 69fc7fd commit 7983b44

File tree

4 files changed

+17
-20
lines changed

4 files changed

+17
-20
lines changed

.sample_configs/param_handlers/predict_image_classification_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def make_instances(filename: str) -> typing.Sequence[google.protobuf.struct_pb2.
3232
def make_parameters() -> google.protobuf.struct_pb2.Value:
3333
# See gs://google-cloud-aiplatform/schema/predict/params/image_classification_1.0.0.yaml for the format of the parameters.
3434
parameters_dict = {
35-
"confidence_threshold": 0.5,
36-
"max_predictions": 5
35+
"confidenceThreshold": 0.5,
36+
"maxPredictions": 5
3737
}
3838
parameters = to_protobuf_value(parameters_dict)
3939

.sample_configs/param_handlers/predict_image_object_detection_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def make_instances(filename: str) -> typing.Sequence[google.protobuf.struct_pb2.
3232
def make_parameters() -> google.protobuf.struct_pb2.Value:
3333
# See gs://google-cloud-aiplatform/schema/predict/params/image_object_detection_1.0.0.yaml for the format of the parameters.
3434
parameters_dict = {
35-
"confidence_threshold": 0.5,
36-
"max_predictions": 5
35+
"confidenceThreshold": 0.5,
36+
"maxPredictions": 5
3737
}
3838
parameters = to_protobuf_value(parameters_dict)
3939

samples/snippets/predict_image_classification_sample.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import base64
1717

1818
from google.cloud import aiplatform
19-
from google.cloud.aiplatform.schema import predict
19+
from google.protobuf import json_format
20+
from google.protobuf.struct_pb2 import Value
2021

2122

2223
def predict_image_classification_sample(
@@ -36,29 +37,25 @@ def predict_image_classification_sample(
3637

3738
# The format of each instance should conform to the deployed model's prediction input schema.
3839
encoded_content = base64.b64encode(file_content).decode("utf-8")
40+
instance_dict = {"content": encoded_content}
3941

40-
instance_obj = predict.instance.ImageClassificationPredictionInstance(
41-
content=encoded_content)
42-
43-
instance_val = instance_obj.to_value()
44-
instances = [instance_val]
45-
46-
params_obj = predict.params.ImageClassificationPredictionParams(
47-
confidence_threshold=0.5, max_predictions=5)
48-
42+
instance = json_format.ParseDict(instance_dict, Value())
43+
instances = [instance]
44+
# See gs://google-cloud-aiplatform/schema/predict/params/image_classification_1.0.0.yaml for the format of the parameters.
45+
parameters_dict = {"confidenceThreshold": 0.5, "maxPredictions": 5}
46+
parameters = json_format.ParseDict(parameters_dict, Value())
4947
endpoint = client.endpoint_path(
5048
project=project, location=location, endpoint=endpoint_id
5149
)
5250
response = client.predict(
53-
endpoint=endpoint, instances=instances, parameters=params_obj
51+
endpoint=endpoint, instances=instances, parameters=parameters
5452
)
5553
print("response")
56-
print("\tdeployed_model_id:", response.deployed_model_id)
54+
print(" deployed_model_id:", response.deployed_model_id)
5755
# See gs://google-cloud-aiplatform/schema/predict/prediction/classification.yaml for the format of the predictions.
5856
predictions = response.predictions
59-
for prediction_ in predictions:
60-
prediction_obj = predict.prediction.ClassificationPredictionResult.from_map(prediction_)
61-
print(prediction_obj)
57+
for prediction in predictions:
58+
print(" prediction:", dict(prediction))
6259

6360

6461
# [END aiplatform_predict_image_classification_sample]

samples/snippets/predict_image_object_detection_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def predict_image_object_detection_sample(
4242
instance = json_format.ParseDict(instance_dict, Value())
4343
instances = [instance]
4444
# See gs://google-cloud-aiplatform/schema/predict/params/image_object_detection_1.0.0.yaml for the format of the parameters.
45-
parameters_dict = {"confidence_threshold": 0.5, "max_predictions": 5}
45+
parameters_dict = {"confidenceThreshold": 0.5, "maxPredictions": 5}
4646
parameters = json_format.ParseDict(parameters_dict, Value())
4747
endpoint = client.endpoint_path(
4848
project=project, location=location, endpoint=endpoint_id

0 commit comments

Comments
 (0)