Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Commit 092c470

Browse files
author
deep-learning-dynamo
committed
Added an option to specify "api-version" query parameter for Azure OpenAI.
1 parent 87f2898 commit 092c470

File tree

3 files changed

+68
-16
lines changed

3 files changed

+68
-16
lines changed

src/main/java/dev/ai4j/openai4j/OpenAiApi.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,23 @@
1212
import retrofit2.http.Body;
1313
import retrofit2.http.Headers;
1414
import retrofit2.http.POST;
15+
import retrofit2.http.Query;
1516

1617
interface OpenAiApi {
1718

1819
@POST("completions")
1920
@Headers("Content-Type: application/json")
20-
Call<CompletionResponse> completions(@Body CompletionRequest request);
21+
Call<CompletionResponse> completions(@Body CompletionRequest request, @Query("api-version") String apiVersion);
2122

2223
@POST("chat/completions")
2324
@Headers("Content-Type: application/json")
24-
Call<ChatCompletionResponse> chatCompletions(@Body ChatCompletionRequest request);
25+
Call<ChatCompletionResponse> chatCompletions(@Body ChatCompletionRequest request, @Query("api-version") String apiVersion);
2526

2627
@POST("embeddings")
2728
@Headers("Content-Type: application/json")
28-
Call<EmbeddingResponse> embeddings(@Body EmbeddingRequest request);
29+
Call<EmbeddingResponse> embeddings(@Body EmbeddingRequest request, @Query("api-version") String apiVersion);
2930

3031
@POST("moderations")
3132
@Headers("Content-Type: application/json")
32-
Call<ModerationResponse> moderations(@Body ModerationRequest request);
33+
Call<ModerationResponse> moderations(@Body ModerationRequest request, @Query("api-version") String apiVersion);
3334
}

src/main/java/dev/ai4j/openai4j/OpenAiClient.java

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ public class OpenAiClient {
2929
private static final Logger log = LoggerFactory.getLogger(OpenAiClient.class);
3030

3131
private final String baseUrl;
32+
private final String apiVersion;
3233
private final OkHttpClient okHttpClient;
3334
private final OpenAiApi openAiApi;
3435
private final boolean logStreamingResponses;
@@ -40,6 +41,7 @@ public OpenAiClient(String apiKey) {
4041
private OpenAiClient(Builder serviceBuilder) {
4142

4243
this.baseUrl = serviceBuilder.baseUrl;
44+
this.apiVersion = serviceBuilder.apiVersion;
4345

4446
OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder()
4547
.addInterceptor(new ApiKeyInsertingInterceptor(serviceBuilder.apiKey))
@@ -94,6 +96,7 @@ public static Builder builder() {
9496
public static class Builder {
9597

9698
private String baseUrl = "https://api.openai.com/v1/";
99+
private String apiVersion;
97100
private String apiKey;
98101
private Duration callTimeout = Duration.ofSeconds(60);
99102
private Duration connectTimeout = Duration.ofSeconds(60);
@@ -107,6 +110,12 @@ public static class Builder {
107110
private Builder() {
108111
}
109112

113+
/**
114+
* @param baseUrl Base URL of OpenAI API.
115+
* For OpenAI (default): "https://api.openai.com/v1/"
116+
* For Azure OpenAI: "https://{resource-name}.openai.azure.com/openai/deployments/{deployment-id}/"
117+
* @return builder
118+
*/
110119
public Builder baseUrl(String baseUrl) {
111120
if (baseUrl == null || baseUrl.trim().isEmpty()) {
112121
throw new IllegalArgumentException("baseUrl cannot be null or empty");
@@ -115,6 +124,15 @@ public Builder baseUrl(String baseUrl) {
115124
return this;
116125
}
117126

127+
/**
128+
* @param apiVersion Version of the API in the YYYY-MM-DD format. Applicable only for Azure OpenAI.
129+
* @return builder
130+
*/
131+
public Builder apiVersion(String apiVersion) {
132+
this.apiVersion = apiVersion;
133+
return this;
134+
}
135+
118136
public Builder apiKey(String apiKey) {
119137
if (apiKey == null || apiKey.trim().isEmpty()) {
120138
throw new IllegalArgumentException("API key cannot be null or empty. API keys can be generated here: https://platform.openai.com/account/api-keys");
@@ -208,11 +226,16 @@ public OpenAiClient build() {
208226

209227
public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request) {
210228

229+
CompletionRequest syncRequest = CompletionRequest.builder()
230+
.from(request)
231+
.stream(null)
232+
.build();
233+
211234
return new RequestExecutor<>(
212-
openAiApi.completions(CompletionRequest.builder().from(request).stream(null).build()),
235+
openAiApi.completions(syncRequest, apiVersion),
213236
(r) -> r,
214237
okHttpClient,
215-
baseUrl + "completions",
238+
formatUrl("completions"),
216239
() -> CompletionRequest.builder().from(request).stream(true).build(),
217240
CompletionResponse.class,
218241
(r) -> r,
@@ -227,11 +250,16 @@ public SyncOrAsyncOrStreaming<String> completion(String prompt) {
227250
.prompt(prompt)
228251
.build();
229252

253+
CompletionRequest syncRequest = CompletionRequest.builder()
254+
.from(request)
255+
.stream(null)
256+
.build();
257+
230258
return new RequestExecutor<>(
231-
openAiApi.completions(CompletionRequest.builder().from(request).stream(null).build()),
259+
openAiApi.completions(syncRequest, apiVersion),
232260
CompletionResponse::text,
233261
okHttpClient,
234-
baseUrl + "completions",
262+
formatUrl("completions"),
235263
() -> CompletionRequest.builder().from(request).stream(true).build(),
236264
CompletionResponse.class,
237265
CompletionResponse::text,
@@ -241,11 +269,16 @@ public SyncOrAsyncOrStreaming<String> completion(String prompt) {
241269

242270
public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatCompletionRequest request) {
243271

272+
ChatCompletionRequest syncRequest = ChatCompletionRequest.builder()
273+
.from(request)
274+
.stream(null)
275+
.build();
276+
244277
return new RequestExecutor<>(
245-
openAiApi.chatCompletions(ChatCompletionRequest.builder().from(request).stream(null).build()),
278+
openAiApi.chatCompletions(syncRequest, apiVersion),
246279
(r) -> r,
247280
okHttpClient,
248-
baseUrl + "chat/completions",
281+
formatUrl("chat/completions"),
249282
() -> ChatCompletionRequest.builder().from(request).stream(true).build(),
250283
ChatCompletionResponse.class,
251284
(r) -> r,
@@ -260,11 +293,16 @@ public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
260293
.addUserMessage(userMessage)
261294
.build();
262295

296+
ChatCompletionRequest syncRequest = ChatCompletionRequest.builder()
297+
.from(request)
298+
.stream(null)
299+
.build();
300+
263301
return new RequestExecutor<>(
264-
openAiApi.chatCompletions(ChatCompletionRequest.builder().from(request).stream(null).build()),
302+
openAiApi.chatCompletions(syncRequest, apiVersion),
265303
ChatCompletionResponse::content,
266304
okHttpClient,
267-
baseUrl + "chat/completions",
305+
formatUrl("chat/completions"),
268306
() -> ChatCompletionRequest.builder().from(request).stream(true).build(),
269307
ChatCompletionResponse.class,
270308
(r) -> r.choices().get(0).delta().content(),
@@ -274,7 +312,7 @@ public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
274312

275313
public SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request) {
276314

277-
return new RequestExecutor<>(openAiApi.embeddings(request), (r) -> r);
315+
return new RequestExecutor<>(openAiApi.embeddings(request, apiVersion), (r) -> r);
278316
}
279317

280318
@Experimental
@@ -284,12 +322,12 @@ public SyncOrAsync<List<Float>> embedding(String input) {
284322
.input(input)
285323
.build();
286324

287-
return new RequestExecutor<>(openAiApi.embeddings(request), EmbeddingResponse::embedding);
325+
return new RequestExecutor<>(openAiApi.embeddings(request, apiVersion), EmbeddingResponse::embedding);
288326
}
289327

290328
public SyncOrAsync<ModerationResponse> moderation(ModerationRequest request) {
291329

292-
return new RequestExecutor<>(openAiApi.moderations(request), (r) -> r);
330+
return new RequestExecutor<>(openAiApi.moderations(request, apiVersion), (r) -> r);
293331
}
294332

295333
@Experimental
@@ -299,6 +337,17 @@ public SyncOrAsync<ModerationResult> moderation(String input) {
299337
.input(input)
300338
.build();
301339

302-
return new RequestExecutor<>(openAiApi.moderations(request), (r) -> r.results().get(0));
340+
return new RequestExecutor<>(openAiApi.moderations(request, apiVersion), (r) -> r.results().get(0));
341+
}
342+
343+
private String formatUrl(String endpoint) {
344+
return baseUrl + endpoint + apiVersionQueryParam();
345+
}
346+
347+
private String apiVersionQueryParam() {
348+
if (apiVersion == null || apiVersion.trim().isEmpty()) {
349+
return "";
350+
}
351+
return "?api-version=" + apiVersion;
303352
}
304353
}

src/test/java/dev/ai4j/openai4j/Test.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ public static void main(String[] args) {
1212
String apiKey = System.getenv("OPENAI_API_KEY");
1313

1414
OpenAiClient client = OpenAiClient.builder()
15+
.baseUrl("https://my-resource.openai.azure.com/openai/deployments/my-deployment/")
16+
.apiVersion("2023-06-13")
1517
.apiKey(apiKey)
1618
.callTimeout(ofSeconds(60))
1719
.connectTimeout(ofSeconds(60))

0 commit comments

Comments
 (0)