Skip to content

Commit 4db0d1d

Browse files
feat: [vertexai] enable AutomaticFunctionCallingResponder in ChatSession (#10913)
PiperOrigin-RevId: 638768526 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent 76f0396 commit 4db0d1d

File tree

65 files changed

+4534
-6553
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+4534
-6553
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/api/PredictionServiceClient.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,7 +1410,6 @@ public final GenerateContentResponse generateContent(String model, List<Content>
14101410
* .addAllContents(new ArrayList<Content>())
14111411
* .setSystemInstruction(Content.newBuilder().build())
14121412
* .addAllTools(new ArrayList<Tool>())
1413-
* .setToolConfig(ToolConfig.newBuilder().build())
14141413
* .addAllSafetySettings(new ArrayList<SafetySetting>())
14151414
* .setGenerationConfig(GenerationConfig.newBuilder().build())
14161415
* .build();
@@ -1444,7 +1443,6 @@ public final GenerateContentResponse generateContent(GenerateContentRequest requ
14441443
* .addAllContents(new ArrayList<Content>())
14451444
* .setSystemInstruction(Content.newBuilder().build())
14461445
* .addAllTools(new ArrayList<Tool>())
1447-
* .setToolConfig(ToolConfig.newBuilder().build())
14481446
* .addAllSafetySettings(new ArrayList<SafetySetting>())
14491447
* .setGenerationConfig(GenerationConfig.newBuilder().build())
14501448
* .build();
@@ -1479,7 +1477,6 @@ public final GenerateContentResponse generateContent(GenerateContentRequest requ
14791477
* .addAllContents(new ArrayList<Content>())
14801478
* .setSystemInstruction(Content.newBuilder().build())
14811479
* .addAllTools(new ArrayList<Tool>())
1482-
* .setToolConfig(ToolConfig.newBuilder().build())
14831480
* .addAllSafetySettings(new ArrayList<SafetySetting>())
14841481
* .setGenerationConfig(GenerationConfig.newBuilder().build())
14851482
* .build();

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/api/stub/EndpointServiceStubSettings.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,15 @@ public EndpointServiceStub createStub() throws IOException {
390390
"Transport not supported: %s", getTransportChannelProvider().getTransportName()));
391391
}
392392

393+
/** Returns the endpoint set by the user or the the service's default endpoint. */
394+
@Override
395+
public String getEndpoint() {
396+
if (super.getEndpoint() != null) {
397+
return super.getEndpoint();
398+
}
399+
return getDefaultEndpoint();
400+
}
401+
393402
/** Returns the default service name. */
394403
@Override
395404
public String getServiceName() {
@@ -989,6 +998,15 @@ public UnaryCallSettings.Builder<GetIamPolicyRequest, Policy> getIamPolicySettin
989998
return testIamPermissionsSettings;
990999
}
9911000

1001+
/** Returns the endpoint set by the user or the the service's default endpoint. */
1002+
@Override
1003+
public String getEndpoint() {
1004+
if (super.getEndpoint() != null) {
1005+
return super.getEndpoint();
1006+
}
1007+
return getDefaultEndpoint();
1008+
}
1009+
9921010
@Override
9931011
public EndpointServiceStubSettings build() throws IOException {
9941012
return new EndpointServiceStubSettings(this);

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/api/stub/GrpcLlmUtilityServiceStub.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@
5454
@Generated("by gapic-generator-java")
5555
public class GrpcLlmUtilityServiceStub extends LlmUtilityServiceStub {
5656
private static final MethodDescriptor<CountTokensRequest, CountTokensResponse>
57+
// TODO(b/317255628): switch back to the google.cloud.aiplatform.v1.LlmUtilityServiceClient
5758
countTokensMethodDescriptor =
58-
MethodDescriptor.<CountTokensRequest, CountTokensResponse>newBuilder()
59-
.setType(MethodDescriptor.MethodType.UNARY)
60-
.setFullMethodName("google.cloud.aiplatform.v1.LlmUtilityService/CountTokens")
61-
.setRequestMarshaller(ProtoUtils.marshaller(CountTokensRequest.getDefaultInstance()))
62-
.setResponseMarshaller(
63-
ProtoUtils.marshaller(CountTokensResponse.getDefaultInstance()))
64-
.build();
59+
MethodDescriptor.<CountTokensRequest, CountTokensResponse>newBuilder()
60+
.setType(MethodDescriptor.MethodType.UNARY)
61+
.setFullMethodName("google.cloud.aiplatform.v1beta1.PredictionService/CountTokens")
62+
.setRequestMarshaller(ProtoUtils.marshaller(CountTokensRequest.getDefaultInstance()))
63+
.setResponseMarshaller(ProtoUtils.marshaller(CountTokensResponse.getDefaultInstance()))
64+
.build();
6565

6666
private static final MethodDescriptor<ComputeTokensRequest, ComputeTokensResponse>
6767
computeTokensMethodDescriptor =

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/api/stub/GrpcPredictionServiceStub.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -385,24 +385,12 @@ protected GrpcPredictionServiceStub(
385385
streamDirectPredictTransportSettings =
386386
GrpcCallSettings.<StreamDirectPredictRequest, StreamDirectPredictResponse>newBuilder()
387387
.setMethodDescriptor(streamDirectPredictMethodDescriptor)
388-
.setParamsExtractor(
389-
request -> {
390-
RequestParamsBuilder builder = RequestParamsBuilder.create();
391-
builder.add("endpoint", String.valueOf(request.getEndpoint()));
392-
return builder.build();
393-
})
394388
.build();
395389
GrpcCallSettings<StreamDirectRawPredictRequest, StreamDirectRawPredictResponse>
396390
streamDirectRawPredictTransportSettings =
397391
GrpcCallSettings
398392
.<StreamDirectRawPredictRequest, StreamDirectRawPredictResponse>newBuilder()
399393
.setMethodDescriptor(streamDirectRawPredictMethodDescriptor)
400-
.setParamsExtractor(
401-
request -> {
402-
RequestParamsBuilder builder = RequestParamsBuilder.create();
403-
builder.add("endpoint", String.valueOf(request.getEndpoint()));
404-
return builder.build();
405-
})
406394
.build();
407395
GrpcCallSettings<StreamingPredictRequest, StreamingPredictResponse>
408396
streamingPredictTransportSettings =

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/api/stub/HttpJsonEndpointServiceStub.java

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -688,16 +688,6 @@ protected HttpJsonEndpointServiceStub(
688688
"google.longrunning.Operations.CancelOperation",
689689
HttpRule.newBuilder()
690690
.setPost("/ui/{name=projects/*/locations/*/operations/*}:cancel")
691-
.addAdditionalBindings(
692-
HttpRule.newBuilder()
693-
.setPost(
694-
"/ui/{name=projects/*/locations/*/agents/*/operations/*}:cancel")
695-
.build())
696-
.addAdditionalBindings(
697-
HttpRule.newBuilder()
698-
.setPost(
699-
"/ui/{name=projects/*/locations/*/apps/*/operations/*}:cancel")
700-
.build())
701691
.addAdditionalBindings(
702692
HttpRule.newBuilder()
703693
.setPost(
@@ -1077,15 +1067,6 @@ protected HttpJsonEndpointServiceStub(
10771067
"google.longrunning.Operations.DeleteOperation",
10781068
HttpRule.newBuilder()
10791069
.setDelete("/ui/{name=projects/*/locations/*/operations/*}")
1080-
.addAdditionalBindings(
1081-
HttpRule.newBuilder()
1082-
.setDelete(
1083-
"/ui/{name=projects/*/locations/*/agents/*/operations/*}")
1084-
.build())
1085-
.addAdditionalBindings(
1086-
HttpRule.newBuilder()
1087-
.setDelete("/ui/{name=projects/*/locations/*/apps/*/operations/*}")
1088-
.build())
10891070
.addAdditionalBindings(
10901071
HttpRule.newBuilder()
10911072
.setDelete(
@@ -1495,14 +1476,6 @@ protected HttpJsonEndpointServiceStub(
14951476
"google.longrunning.Operations.GetOperation",
14961477
HttpRule.newBuilder()
14971478
.setGet("/ui/{name=projects/*/locations/*/operations/*}")
1498-
.addAdditionalBindings(
1499-
HttpRule.newBuilder()
1500-
.setGet("/ui/{name=projects/*/locations/*/agents/*/operations/*}")
1501-
.build())
1502-
.addAdditionalBindings(
1503-
HttpRule.newBuilder()
1504-
.setGet("/ui/{name=projects/*/locations/*/apps/*/operations/*}")
1505-
.build())
15061479
.addAdditionalBindings(
15071480
HttpRule.newBuilder()
15081481
.setGet("/ui/{name=projects/*/locations/*/datasets/*/operations/*}")
@@ -1919,14 +1892,6 @@ protected HttpJsonEndpointServiceStub(
19191892
"google.longrunning.Operations.ListOperations",
19201893
HttpRule.newBuilder()
19211894
.setGet("/ui/{name=projects/*/locations/*}/operations")
1922-
.addAdditionalBindings(
1923-
HttpRule.newBuilder()
1924-
.setGet("/ui/{name=projects/*/locations/*/agents/*}/operations")
1925-
.build())
1926-
.addAdditionalBindings(
1927-
HttpRule.newBuilder()
1928-
.setGet("/ui/{name=projects/*/locations/*/apps/*}/operations")
1929-
.build())
19301895
.addAdditionalBindings(
19311896
HttpRule.newBuilder()
19321897
.setGet("/ui/{name=projects/*/locations/*/datasets/*}/operations")
@@ -2329,16 +2294,6 @@ protected HttpJsonEndpointServiceStub(
23292294
"google.longrunning.Operations.WaitOperation",
23302295
HttpRule.newBuilder()
23312296
.setPost("/ui/{name=projects/*/locations/*/operations/*}:wait")
2332-
.addAdditionalBindings(
2333-
HttpRule.newBuilder()
2334-
.setPost(
2335-
"/ui/{name=projects/*/locations/*/agents/*/operations/*}:wait")
2336-
.build())
2337-
.addAdditionalBindings(
2338-
HttpRule.newBuilder()
2339-
.setPost(
2340-
"/ui/{name=projects/*/locations/*/apps/*/operations/*}:wait")
2341-
.build())
23422297
.addAdditionalBindings(
23432298
HttpRule.newBuilder()
23442299
.setPost(

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/api/stub/HttpJsonLlmUtilityServiceStub.java

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -63,42 +63,43 @@ public class HttpJsonLlmUtilityServiceStub extends LlmUtilityServiceStub {
6363
private static final TypeRegistry typeRegistry = TypeRegistry.newBuilder().build();
6464

6565
private static final ApiMethodDescriptor<CountTokensRequest, CountTokensResponse>
66+
// TODO(b/317255628): switch back to the google.cloud.aiplatform.v1.LlmUtilityServiceClient
6667
countTokensMethodDescriptor =
67-
ApiMethodDescriptor.<CountTokensRequest, CountTokensResponse>newBuilder()
68-
.setFullMethodName("google.cloud.aiplatform.v1.LlmUtilityService/CountTokens")
69-
.setHttpMethod("POST")
70-
.setType(ApiMethodDescriptor.MethodType.UNARY)
71-
.setRequestFormatter(
72-
ProtoMessageRequestFormatter.<CountTokensRequest>newBuilder()
73-
.setPath(
74-
"/v1/{endpoint=projects/*/locations/*/endpoints/*}:countTokens",
75-
request -> {
76-
Map<String, String> fields = new HashMap<>();
77-
ProtoRestSerializer<CountTokensRequest> serializer =
78-
ProtoRestSerializer.create();
79-
serializer.putPathParam(fields, "endpoint", request.getEndpoint());
80-
return fields;
81-
})
82-
.setAdditionalPaths(
83-
"/v1/{endpoint=projects/*/locations/*/publishers/*/models/*}:countTokens")
84-
.setQueryParamsExtractor(
85-
request -> {
86-
Map<String, List<String>> fields = new HashMap<>();
87-
ProtoRestSerializer<CountTokensRequest> serializer =
88-
ProtoRestSerializer.create();
89-
return fields;
90-
})
91-
.setRequestBodyExtractor(
92-
request ->
93-
ProtoRestSerializer.create()
94-
.toBody("*", request.toBuilder().clearEndpoint().build(), false))
95-
.build())
96-
.setResponseParser(
97-
ProtoMessageResponseParser.<CountTokensResponse>newBuilder()
98-
.setDefaultInstance(CountTokensResponse.getDefaultInstance())
99-
.setDefaultTypeRegistry(typeRegistry)
100-
.build())
101-
.build();
68+
ApiMethodDescriptor.<CountTokensRequest, CountTokensResponse>newBuilder()
69+
.setFullMethodName("google.cloud.aiplatform.v1beta1.PredictionService/CountTokens")
70+
.setHttpMethod("POST")
71+
.setType(ApiMethodDescriptor.MethodType.UNARY)
72+
.setRequestFormatter(
73+
ProtoMessageRequestFormatter.<CountTokensRequest>newBuilder()
74+
.setPath(
75+
"/v1/{endpoint=projects/*/locations/*/endpoints/*}:countTokens",
76+
request -> {
77+
Map<String, String> fields = new HashMap<>();
78+
ProtoRestSerializer<CountTokensRequest> serializer =
79+
ProtoRestSerializer.create();
80+
serializer.putPathParam(fields, "endpoint", request.getEndpoint());
81+
return fields;
82+
})
83+
.setAdditionalPaths(
84+
"/v1/{endpoint=projects/*/locations/*/publishers/*/models/*}:countTokens")
85+
.setQueryParamsExtractor(
86+
request -> {
87+
Map<String, List<String>> fields = new HashMap<>();
88+
ProtoRestSerializer<CountTokensRequest> serializer =
89+
ProtoRestSerializer.create();
90+
return fields;
91+
})
92+
.setRequestBodyExtractor(
93+
request ->
94+
ProtoRestSerializer.create()
95+
.toBody("*", request.toBuilder().clearEndpoint().build(), false))
96+
.build())
97+
.setResponseParser(
98+
ProtoMessageResponseParser.<CountTokensResponse>newBuilder()
99+
.setDefaultInstance(CountTokensResponse.getDefaultInstance())
100+
.setDefaultTypeRegistry(typeRegistry)
101+
.build())
102+
.build();
102103

103104
private static final ApiMethodDescriptor<ComputeTokensRequest, ComputeTokensResponse>
104105
computeTokensMethodDescriptor =

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/api/stub/HttpJsonPredictionServiceStub.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,6 @@ public class HttpJsonPredictionServiceStub extends PredictionServiceStub {
211211
serializer.putPathParam(fields, "endpoint", request.getEndpoint());
212212
return fields;
213213
})
214-
.setAdditionalPaths(
215-
"/v1/{endpoint=projects/*/locations/*/publishers/*/models/*}:directPredict")
216214
.setQueryParamsExtractor(
217215
request -> {
218216
Map<String, List<String>> fields = new HashMap<>();
@@ -249,8 +247,6 @@ public class HttpJsonPredictionServiceStub extends PredictionServiceStub {
249247
serializer.putPathParam(fields, "endpoint", request.getEndpoint());
250248
return fields;
251249
})
252-
.setAdditionalPaths(
253-
"/v1/{endpoint=projects/*/locations/*/publishers/*/models/*}:directRawPredict")
254250
.setQueryParamsExtractor(
255251
request -> {
256252
Map<String, List<String>> fields = new HashMap<>();

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/api/stub/LlmUtilityServiceStubSettings.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,15 @@ public LlmUtilityServiceStub createStub() throws IOException {
226226
"Transport not supported: %s", getTransportChannelProvider().getTransportName()));
227227
}
228228

229+
/** Returns the endpoint set by the user or the the service's default endpoint. */
230+
@Override
231+
public String getEndpoint() {
232+
if (super.getEndpoint() != null) {
233+
return super.getEndpoint();
234+
}
235+
return getDefaultEndpoint();
236+
}
237+
229238
/** Returns the default service name. */
230239
@Override
231240
public String getServiceName() {
@@ -531,6 +540,15 @@ public UnaryCallSettings.Builder<GetIamPolicyRequest, Policy> getIamPolicySettin
531540
return testIamPermissionsSettings;
532541
}
533542

543+
/** Returns the endpoint set by the user or the the service's default endpoint. */
544+
@Override
545+
public String getEndpoint() {
546+
if (super.getEndpoint() != null) {
547+
return super.getEndpoint();
548+
}
549+
return getDefaultEndpoint();
550+
}
551+
534552
@Override
535553
public LlmUtilityServiceStubSettings build() throws IOException {
536554
return new LlmUtilityServiceStubSettings(this);

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/api/stub/PredictionServiceStubSettings.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,7 @@
125125
public class PredictionServiceStubSettings extends StubSettings<PredictionServiceStubSettings> {
126126
/** The default scopes of the service. */
127127
private static final ImmutableList<String> DEFAULT_SERVICE_SCOPES =
128-
ImmutableList.<String>builder()
129-
.add("https://www.googleapis.com/auth/cloud-platform")
130-
.add("https://www.googleapis.com/auth/cloud-platform.read-only")
131-
.build();
128+
ImmutableList.<String>builder().add("https://www.googleapis.com/auth/cloud-platform").build();
132129

133130
private final UnaryCallSettings<PredictRequest, PredictResponse> predictSettings;
134131
private final UnaryCallSettings<RawPredictRequest, HttpBody> rawPredictSettings;
@@ -331,6 +328,15 @@ public PredictionServiceStub createStub() throws IOException {
331328
"Transport not supported: %s", getTransportChannelProvider().getTransportName()));
332329
}
333330

331+
/** Returns the endpoint set by the user or the the service's default endpoint. */
332+
@Override
333+
public String getEndpoint() {
334+
if (super.getEndpoint() != null) {
335+
return super.getEndpoint();
336+
}
337+
return getDefaultEndpoint();
338+
}
339+
334340
/** Returns the default service name. */
335341
@Override
336342
public String getServiceName() {
@@ -800,6 +806,15 @@ public UnaryCallSettings.Builder<GetIamPolicyRequest, Policy> getIamPolicySettin
800806
return testIamPermissionsSettings;
801807
}
802808

809+
/** Returns the endpoint set by the user or the the service's default endpoint. */
810+
@Override
811+
public String getEndpoint() {
812+
if (super.getEndpoint() != null) {
813+
return super.getEndpoint();
814+
}
815+
return getDefaultEndpoint();
816+
}
817+
803818
@Override
804819
public PredictionServiceStubSettings build() throws IOException {
805820
return new PredictionServiceStubSettings(this);

0 commit comments

Comments
 (0)