Skip to content

Commit 5c3d93e

Browse files
feat: [vertexai] add fluent API in ChatSession (#10597)
PiperOrigin-RevId: 617901539 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent a8aa591 commit 5c3d93e

File tree

5 files changed

+417
-370
lines changed

5 files changed

+417
-370
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java

Lines changed: 150 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616

1717
package com.google.cloud.vertexai;
1818

19-
import static com.google.common.base.Preconditions.checkArgument;
20-
import static com.google.common.base.Preconditions.checkNotNull;
21-
2219
import com.google.api.core.InternalApi;
2320
import com.google.api.gax.core.CredentialsProvider;
2421
import com.google.api.gax.core.FixedCredentialsProvider;
@@ -31,10 +28,8 @@
3128
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
3229
import com.google.cloud.vertexai.api.PredictionServiceClient;
3330
import com.google.cloud.vertexai.api.PredictionServiceSettings;
34-
import com.google.common.base.Strings;
3531
import java.io.IOException;
3632
import java.util.List;
37-
import java.util.concurrent.locks.ReentrantLock;
3833
import java.util.logging.Level;
3934
import java.util.logging.Logger;
4035

@@ -61,8 +56,9 @@ public class VertexAI implements AutoCloseable {
6156
private Transport transport = Transport.GRPC;
6257
// The clients will be instantiated lazily
6358
private PredictionServiceClient predictionServiceClient = null;
59+
private PredictionServiceClient predictionServiceRestClient = null;
6460
private LlmUtilityServiceClient llmUtilityClient = null;
65-
private final ReentrantLock lock = new ReentrantLock();
61+
private LlmUtilityServiceClient llmUtilityRestClient = null;
6662

6763
/**
6864
* Construct a VertexAI instance.
@@ -197,35 +193,32 @@ public Credentials getCredentials() throws IOException {
197193

198194
/** Sets the value for {@link #getTransport()}. */
199195
public void setTransport(Transport transport) {
200-
checkNotNull(transport, "Transport can't be null.");
201-
if (this.transport == transport) {
202-
return;
203-
}
204-
205196
this.transport = transport;
206-
resetClients();
207197
}
208198

209199
/** Sets the value for {@link #getApiEndpoint()}. */
210200
public void setApiEndpoint(String apiEndpoint) {
211-
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "Api endpoint can't be null or empty.");
212-
if (this.apiEndpoint == apiEndpoint) {
213-
return;
214-
}
215201
this.apiEndpoint = apiEndpoint;
216-
resetClients();
217-
}
218202

219-
private void resetClients() {
220203
if (this.predictionServiceClient != null) {
221204
this.predictionServiceClient.close();
222205
this.predictionServiceClient = null;
223206
}
224207

208+
if (this.predictionServiceRestClient != null) {
209+
this.predictionServiceRestClient.close();
210+
this.predictionServiceRestClient = null;
211+
}
212+
225213
if (this.llmUtilityClient != null) {
226214
this.llmUtilityClient.close();
227215
this.llmUtilityClient = null;
228216
}
217+
218+
if (this.llmUtilityRestClient != null) {
219+
this.llmUtilityRestClient.close();
220+
this.llmUtilityRestClient = null;
221+
}
229222
}
230223

231224
/**
@@ -237,47 +230,78 @@ private void resetClients() {
237230
*/
238231
@InternalApi
239232
public PredictionServiceClient getPredictionServiceClient() throws IOException {
240-
if (predictionServiceClient != null) {
241-
return predictionServiceClient;
233+
if (this.transport == Transport.GRPC) {
234+
return getPredictionServiceGrpcClient();
235+
} else {
236+
return getPredictionServiceRestClient();
242237
}
243-
lock.lock();
244-
try {
245-
if (predictionServiceClient == null) {
246-
PredictionServiceSettings settings = getPredictionServiceSettings();
247-
// Disable the warning message logged in getApplicationDefault
248-
Logger defaultCredentialsProviderLogger =
249-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
250-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
251-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
252-
predictionServiceClient = PredictionServiceClient.create(settings);
253-
defaultCredentialsProviderLogger.setLevel(previousLevel);
238+
}
239+
240+
/**
241+
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
242+
* first prediction API call is made.
243+
*
244+
* @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
245+
* method calls that map to the API methods.
246+
*/
247+
private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException {
248+
if (predictionServiceClient == null) {
249+
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
250+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
251+
if (this.credentialsProvider != null) {
252+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
254253
}
255-
return predictionServiceClient;
256-
} finally {
257-
lock.unlock();
254+
HeaderProvider headerProvider =
255+
FixedHeaderProvider.create(
256+
"user-agent",
257+
String.format(
258+
"%s/%s",
259+
Constants.USER_AGENT_HEADER,
260+
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
261+
settingsBuilder.setHeaderProvider(headerProvider);
262+
// Disable the warning message logged in getApplicationDefault
263+
Logger defaultCredentialsProviderLogger =
264+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
265+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
266+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
267+
predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build());
268+
defaultCredentialsProviderLogger.setLevel(previousLevel);
258269
}
270+
return predictionServiceClient;
259271
}
260272

261-
private PredictionServiceSettings getPredictionServiceSettings() throws IOException {
262-
PredictionServiceSettings.Builder builder;
263-
if (transport == Transport.REST) {
264-
builder = PredictionServiceSettings.newHttpJsonBuilder();
265-
} else {
266-
builder = PredictionServiceSettings.newBuilder();
267-
}
268-
builder.setEndpoint(String.format("%s:443", this.apiEndpoint));
269-
if (this.credentialsProvider != null) {
270-
builder.setCredentialsProvider(this.credentialsProvider);
273+
/**
274+
* Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
275+
* first prediction API call is made.
276+
*
277+
* @return {@link PredictionServiceClient} that send REST requests to the backing service through
278+
* method calls that map to the API methods.
279+
*/
280+
private PredictionServiceClient getPredictionServiceRestClient() throws IOException {
281+
if (predictionServiceRestClient == null) {
282+
PredictionServiceSettings.Builder settingsBuilder =
283+
PredictionServiceSettings.newHttpJsonBuilder();
284+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
285+
if (this.credentialsProvider != null) {
286+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
287+
}
288+
HeaderProvider headerProvider =
289+
FixedHeaderProvider.create(
290+
"user-agent",
291+
String.format(
292+
"%s/%s",
293+
Constants.USER_AGENT_HEADER,
294+
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
295+
settingsBuilder.setHeaderProvider(headerProvider);
296+
// Disable the warning message logged in getApplicationDefault
297+
Logger defaultCredentialsProviderLogger =
298+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
299+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
300+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
301+
predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build());
302+
defaultCredentialsProviderLogger.setLevel(previousLevel);
271303
}
272-
HeaderProvider headerProvider =
273-
FixedHeaderProvider.create(
274-
"user-agent",
275-
String.format(
276-
"%s/%s",
277-
Constants.USER_AGENT_HEADER,
278-
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
279-
builder.setHeaderProvider(headerProvider);
280-
return builder.build();
304+
return predictionServiceRestClient;
281305
}
282306

283307
/**
@@ -289,47 +313,78 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
289313
*/
290314
@InternalApi
291315
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
292-
if (llmUtilityClient != null) {
293-
return llmUtilityClient;
316+
if (this.transport == Transport.GRPC) {
317+
return getLlmUtilityGrpcClient();
318+
} else {
319+
return getLlmUtilityRestClient();
294320
}
295-
lock.lock();
296-
try {
297-
if (llmUtilityClient == null) {
298-
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
299-
// Disable the warning message logged in getApplicationDefault
300-
Logger defaultCredentialsProviderLogger =
301-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
302-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
303-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
304-
llmUtilityClient = LlmUtilityServiceClient.create(settings);
305-
defaultCredentialsProviderLogger.setLevel(previousLevel);
321+
}
322+
323+
/**
324+
* Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
325+
* first API call is made.
326+
*
327+
* @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
328+
* method calls that map to the API methods.
329+
*/
330+
private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
331+
if (llmUtilityClient == null) {
332+
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
333+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
334+
if (this.credentialsProvider != null) {
335+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
306336
}
307-
return llmUtilityClient;
308-
} finally {
309-
lock.unlock();
337+
HeaderProvider headerProvider =
338+
FixedHeaderProvider.create(
339+
"user-agent",
340+
String.format(
341+
"%s/%s",
342+
Constants.USER_AGENT_HEADER,
343+
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
344+
settingsBuilder.setHeaderProvider(headerProvider);
345+
// Disable the warning message logged in getApplicationDefault
346+
Logger defaultCredentialsProviderLogger =
347+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
348+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
349+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
350+
llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build());
351+
defaultCredentialsProviderLogger.setLevel(previousLevel);
310352
}
353+
return llmUtilityClient;
311354
}
312355

313-
private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException {
314-
LlmUtilityServiceSettings.Builder settingsBuilder;
315-
if (transport == Transport.REST) {
316-
settingsBuilder = LlmUtilityServiceSettings.newHttpJsonBuilder();
317-
} else {
318-
settingsBuilder = LlmUtilityServiceSettings.newBuilder();
319-
}
320-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
321-
if (this.credentialsProvider != null) {
322-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
356+
/**
357+
* Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
358+
* first API call is made.
359+
*
360+
* @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
361+
* method calls that map to the API methods.
362+
*/
363+
private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
364+
if (llmUtilityRestClient == null) {
365+
LlmUtilityServiceSettings.Builder settingsBuilder =
366+
LlmUtilityServiceSettings.newHttpJsonBuilder();
367+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
368+
if (this.credentialsProvider != null) {
369+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
370+
}
371+
HeaderProvider headerProvider =
372+
FixedHeaderProvider.create(
373+
"user-agent",
374+
String.format(
375+
"%s/%s",
376+
Constants.USER_AGENT_HEADER,
377+
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
378+
settingsBuilder.setHeaderProvider(headerProvider);
379+
// Disable the warning message logged in getApplicationDefault
380+
Logger defaultCredentialsProviderLogger =
381+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
382+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
383+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
384+
llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build());
385+
defaultCredentialsProviderLogger.setLevel(previousLevel);
323386
}
324-
HeaderProvider headerProvider =
325-
FixedHeaderProvider.create(
326-
"user-agent",
327-
String.format(
328-
"%s/%s",
329-
Constants.USER_AGENT_HEADER,
330-
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
331-
settingsBuilder.setHeaderProvider(headerProvider);
332-
return settingsBuilder.build();
387+
return llmUtilityRestClient;
333388
}
334389

335390
/** Closes the VertexAI instance together with all its instantiated clients. */
@@ -338,8 +393,14 @@ public void close() {
338393
if (predictionServiceClient != null) {
339394
predictionServiceClient.close();
340395
}
396+
if (predictionServiceRestClient != null) {
397+
predictionServiceRestClient.close();
398+
}
341399
if (llmUtilityClient != null) {
342400
llmUtilityClient.close();
343401
}
402+
if (llmUtilityRestClient != null) {
403+
llmUtilityRestClient.close();
404+
}
344405
}
345406
}

0 commit comments

Comments
 (0)