1616
1717package com .google .cloud .vertexai ;
1818
19- import static com .google .common .base .Preconditions .checkArgument ;
20- import static com .google .common .base .Preconditions .checkNotNull ;
21-
2219import com .google .api .core .InternalApi ;
2320import com .google .api .gax .core .CredentialsProvider ;
2421import com .google .api .gax .core .FixedCredentialsProvider ;
3128import com .google .cloud .vertexai .api .LlmUtilityServiceSettings ;
3229import com .google .cloud .vertexai .api .PredictionServiceClient ;
3330import com .google .cloud .vertexai .api .PredictionServiceSettings ;
34- import com .google .common .base .Strings ;
3531import java .io .IOException ;
3632import java .util .List ;
37- import java .util .concurrent .locks .ReentrantLock ;
3833import java .util .logging .Level ;
3934import java .util .logging .Logger ;
4035
@@ -61,8 +56,9 @@ public class VertexAI implements AutoCloseable {
6156 private Transport transport = Transport .GRPC ;
6257 // The clients will be instantiated lazily
6358 private PredictionServiceClient predictionServiceClient = null ;
59+ private PredictionServiceClient predictionServiceRestClient = null ;
6460 private LlmUtilityServiceClient llmUtilityClient = null ;
65- private final ReentrantLock lock = new ReentrantLock () ;
61+ private LlmUtilityServiceClient llmUtilityRestClient = null ;
6662
6763 /**
6864 * Construct a VertexAI instance.
@@ -197,35 +193,32 @@ public Credentials getCredentials() throws IOException {
197193
198194 /** Sets the value for {@link #getTransport()}. */
199195 public void setTransport (Transport transport ) {
200- checkNotNull (transport , "Transport can't be null." );
201- if (this .transport == transport ) {
202- return ;
203- }
204-
205196 this .transport = transport ;
206- resetClients ();
207197 }
208198
209199 /** Sets the value for {@link #getApiEndpoint()}. */
210200 public void setApiEndpoint (String apiEndpoint ) {
211- checkArgument (!Strings .isNullOrEmpty (apiEndpoint ), "Api endpoint can't be null or empty." );
212- if (this .apiEndpoint == apiEndpoint ) {
213- return ;
214- }
215201 this .apiEndpoint = apiEndpoint ;
216- resetClients ();
217- }
218202
219- private void resetClients () {
220203 if (this .predictionServiceClient != null ) {
221204 this .predictionServiceClient .close ();
222205 this .predictionServiceClient = null ;
223206 }
224207
208+ if (this .predictionServiceRestClient != null ) {
209+ this .predictionServiceRestClient .close ();
210+ this .predictionServiceRestClient = null ;
211+ }
212+
225213 if (this .llmUtilityClient != null ) {
226214 this .llmUtilityClient .close ();
227215 this .llmUtilityClient = null ;
228216 }
217+
218+ if (this .llmUtilityRestClient != null ) {
219+ this .llmUtilityRestClient .close ();
220+ this .llmUtilityRestClient = null ;
221+ }
229222 }
230223
231224 /**
@@ -237,47 +230,78 @@ private void resetClients() {
237230 */
238231 @ InternalApi
239232 public PredictionServiceClient getPredictionServiceClient () throws IOException {
240- if (predictionServiceClient != null ) {
241- return predictionServiceClient ;
233+ if (this .transport == Transport .GRPC ) {
234+ return getPredictionServiceGrpcClient ();
235+ } else {
236+ return getPredictionServiceRestClient ();
242237 }
243- lock .lock ();
244- try {
245- if (predictionServiceClient == null ) {
246- PredictionServiceSettings settings = getPredictionServiceSettings ();
247- // Disable the warning message logged in getApplicationDefault
248- Logger defaultCredentialsProviderLogger =
249- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
250- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
251- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
252- predictionServiceClient = PredictionServiceClient .create (settings );
253- defaultCredentialsProviderLogger .setLevel (previousLevel );
238+ }
239+
240+ /**
241+ * Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
242+ * first prediction API call is made.
243+ *
244+ * @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
245+ * method calls that map to the API methods.
246+ */
247+ private PredictionServiceClient getPredictionServiceGrpcClient () throws IOException {
248+ if (predictionServiceClient == null ) {
249+ PredictionServiceSettings .Builder settingsBuilder = PredictionServiceSettings .newBuilder ();
250+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
251+ if (this .credentialsProvider != null ) {
252+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
254253 }
255- return predictionServiceClient ;
256- } finally {
257- lock .unlock ();
254+ HeaderProvider headerProvider =
255+ FixedHeaderProvider .create (
256+ "user-agent" ,
257+ String .format (
258+ "%s/%s" ,
259+ Constants .USER_AGENT_HEADER ,
260+ GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
261+ settingsBuilder .setHeaderProvider (headerProvider );
262+ // Disable the warning message logged in getApplicationDefault
263+ Logger defaultCredentialsProviderLogger =
264+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
265+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
266+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
267+ predictionServiceClient = PredictionServiceClient .create (settingsBuilder .build ());
268+ defaultCredentialsProviderLogger .setLevel (previousLevel );
258269 }
270+ return predictionServiceClient ;
259271 }
260272
261- private PredictionServiceSettings getPredictionServiceSettings () throws IOException {
262- PredictionServiceSettings .Builder builder ;
263- if (transport == Transport .REST ) {
264- builder = PredictionServiceSettings .newHttpJsonBuilder ();
265- } else {
266- builder = PredictionServiceSettings .newBuilder ();
267- }
268- builder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
269- if (this .credentialsProvider != null ) {
270- builder .setCredentialsProvider (this .credentialsProvider );
273+ /**
274+ * Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
275+ * first prediction API call is made.
276+ *
277+ * @return {@link PredictionServiceClient} that send REST requests to the backing service through
278+ * method calls that map to the API methods.
279+ */
280+ private PredictionServiceClient getPredictionServiceRestClient () throws IOException {
281+ if (predictionServiceRestClient == null ) {
282+ PredictionServiceSettings .Builder settingsBuilder =
283+ PredictionServiceSettings .newHttpJsonBuilder ();
284+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
285+ if (this .credentialsProvider != null ) {
286+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
287+ }
288+ HeaderProvider headerProvider =
289+ FixedHeaderProvider .create (
290+ "user-agent" ,
291+ String .format (
292+ "%s/%s" ,
293+ Constants .USER_AGENT_HEADER ,
294+ GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
295+ settingsBuilder .setHeaderProvider (headerProvider );
296+ // Disable the warning message logged in getApplicationDefault
297+ Logger defaultCredentialsProviderLogger =
298+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
299+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
300+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
301+ predictionServiceRestClient = PredictionServiceClient .create (settingsBuilder .build ());
302+ defaultCredentialsProviderLogger .setLevel (previousLevel );
271303 }
272- HeaderProvider headerProvider =
273- FixedHeaderProvider .create (
274- "user-agent" ,
275- String .format (
276- "%s/%s" ,
277- Constants .USER_AGENT_HEADER ,
278- GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
279- builder .setHeaderProvider (headerProvider );
280- return builder .build ();
304+ return predictionServiceRestClient ;
281305 }
282306
283307 /**
@@ -289,47 +313,78 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
289313 */
290314 @ InternalApi
291315 public LlmUtilityServiceClient getLlmUtilityClient () throws IOException {
292- if (llmUtilityClient != null ) {
293- return llmUtilityClient ;
316+ if (this .transport == Transport .GRPC ) {
317+ return getLlmUtilityGrpcClient ();
318+ } else {
319+ return getLlmUtilityRestClient ();
294320 }
295- lock .lock ();
296- try {
297- if (llmUtilityClient == null ) {
298- LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings ();
299- // Disable the warning message logged in getApplicationDefault
300- Logger defaultCredentialsProviderLogger =
301- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
302- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
303- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
304- llmUtilityClient = LlmUtilityServiceClient .create (settings );
305- defaultCredentialsProviderLogger .setLevel (previousLevel );
321+ }
322+
323+ /**
324+ * Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
325+ * first API call is made.
326+ *
327+ * @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
328+ * method calls that map to the API methods.
329+ */
330+ private LlmUtilityServiceClient getLlmUtilityGrpcClient () throws IOException {
331+ if (llmUtilityClient == null ) {
332+ LlmUtilityServiceSettings .Builder settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
333+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
334+ if (this .credentialsProvider != null ) {
335+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
306336 }
307- return llmUtilityClient ;
308- } finally {
309- lock .unlock ();
337+ HeaderProvider headerProvider =
338+ FixedHeaderProvider .create (
339+ "user-agent" ,
340+ String .format (
341+ "%s/%s" ,
342+ Constants .USER_AGENT_HEADER ,
343+ GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
344+ settingsBuilder .setHeaderProvider (headerProvider );
345+ // Disable the warning message logged in getApplicationDefault
346+ Logger defaultCredentialsProviderLogger =
347+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
348+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
349+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
350+ llmUtilityClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
351+ defaultCredentialsProviderLogger .setLevel (previousLevel );
310352 }
353+ return llmUtilityClient ;
311354 }
312355
313- private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings () throws IOException {
314- LlmUtilityServiceSettings .Builder settingsBuilder ;
315- if (transport == Transport .REST ) {
316- settingsBuilder = LlmUtilityServiceSettings .newHttpJsonBuilder ();
317- } else {
318- settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
319- }
320- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
321- if (this .credentialsProvider != null ) {
322- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
356+ /**
357+ * Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
358+ * first API call is made.
359+ *
360+ * @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
361+ * method calls that map to the API methods.
362+ */
363+ private LlmUtilityServiceClient getLlmUtilityRestClient () throws IOException {
364+ if (llmUtilityRestClient == null ) {
365+ LlmUtilityServiceSettings .Builder settingsBuilder =
366+ LlmUtilityServiceSettings .newHttpJsonBuilder ();
367+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
368+ if (this .credentialsProvider != null ) {
369+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
370+ }
371+ HeaderProvider headerProvider =
372+ FixedHeaderProvider .create (
373+ "user-agent" ,
374+ String .format (
375+ "%s/%s" ,
376+ Constants .USER_AGENT_HEADER ,
377+ GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
378+ settingsBuilder .setHeaderProvider (headerProvider );
379+ // Disable the warning message logged in getApplicationDefault
380+ Logger defaultCredentialsProviderLogger =
381+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
382+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
383+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
384+ llmUtilityRestClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
385+ defaultCredentialsProviderLogger .setLevel (previousLevel );
323386 }
324- HeaderProvider headerProvider =
325- FixedHeaderProvider .create (
326- "user-agent" ,
327- String .format (
328- "%s/%s" ,
329- Constants .USER_AGENT_HEADER ,
330- GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
331- settingsBuilder .setHeaderProvider (headerProvider );
332- return settingsBuilder .build ();
387+ return llmUtilityRestClient ;
333388 }
334389
335390 /** Closes the VertexAI instance together with all its instantiated clients. */
@@ -338,8 +393,14 @@ public void close() {
338393 if (predictionServiceClient != null ) {
339394 predictionServiceClient .close ();
340395 }
396+ if (predictionServiceRestClient != null ) {
397+ predictionServiceRestClient .close ();
398+ }
341399 if (llmUtilityClient != null ) {
342400 llmUtilityClient .close ();
343401 }
402+ if (llmUtilityRestClient != null ) {
403+ llmUtilityRestClient .close ();
404+ }
344405 }
345406}
0 commit comments