Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Commit 4af45de

Browse files
[Feature] Added per-call headers capabilities to OpenAiClient (#25)
# Description In order to better facilitate integration with OpenAI proxies, we want to be able to inject headers per request so that these headers could be used for auxiliary handling of actual OpenAI calls e.g use case being usage tracking and rate limits. # Approach The approach is to retain much of the OpenAiClient interface and calls but extending the capabilities to send extra headers per request. This requires us to open up the `abstract` methods to take in a new inner class in `OpenAiClient.OpenAiClientContext` as a method argument. We can extend this context so that other implementations of `OpenAiClient` abstract class could use the same context object to add in other metadata if needed for executing OpenAi API methods. For now, I've only added `headers` in the `OpenAiClientContext` but I could imagine this being used for other cases as well. The default signature is retained for the `OpenAiClient` so we can still call `chatCompletion(ChatCompletionRequest)` as it is and it will simply pass along a default `OpenAiClientContext` cc: @langchain4j
1 parent e4d41a0 commit 4af45de

File tree

8 files changed

+205
-62
lines changed

8 files changed

+205
-62
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
<groupId>dev.ai4j</groupId>
66
<artifactId>openai4j</artifactId>
7-
<version>0.17.0</version>
7+
<version>0.18.0</version>
88

99
<name>Java Client for OpenAI (ChatGPT)</name>
1010
<description>Java Client for OpenAI (ChatGPT)</description>

src/main/java/dev/ai4j/openai4j/DefaultOpenAiClient.java

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,12 @@ public DefaultOpenAiClient build() {
128128
}
129129

130130
@Override
131-
public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request) {
131+
public SyncOrAsyncOrStreaming<CompletionResponse> completion(OpenAiClientContext context,
132+
CompletionRequest request) {
132133
CompletionRequest syncRequest = CompletionRequest.builder().from(request).stream(null).build();
133134

134135
return new RequestExecutor<>(
135-
openAiApi.completions(syncRequest, apiVersion),
136+
openAiApi.completions(context.headers(), syncRequest, apiVersion),
136137
r -> r,
137138
okHttpClient,
138139
formatUrl("completions"),
@@ -144,13 +145,13 @@ public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest r
144145
}
145146

146147
@Override
147-
public SyncOrAsyncOrStreaming<String> completion(String prompt) {
148+
public SyncOrAsyncOrStreaming<String> completion(OpenAiClientContext context, String prompt) {
148149
CompletionRequest request = CompletionRequest.builder().prompt(prompt).build();
149150

150151
CompletionRequest syncRequest = CompletionRequest.builder().from(request).stream(null).build();
151152

152153
return new RequestExecutor<>(
153-
openAiApi.completions(syncRequest, apiVersion),
154+
openAiApi.completions(context.headers(), syncRequest, apiVersion),
154155
CompletionResponse::text,
155156
okHttpClient,
156157
formatUrl("completions"),
@@ -162,11 +163,12 @@ public SyncOrAsyncOrStreaming<String> completion(String prompt) {
162163
}
163164

164165
@Override
165-
public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatCompletionRequest request) {
166+
public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(OpenAiClientContext context,
167+
ChatCompletionRequest request) {
166168
ChatCompletionRequest syncRequest = ChatCompletionRequest.builder().from(request).stream(null).build();
167169

168170
return new RequestExecutor<>(
169-
openAiApi.chatCompletions(syncRequest, apiVersion),
171+
openAiApi.chatCompletions(context.headers(), syncRequest, apiVersion),
170172
r -> r,
171173
okHttpClient,
172174
formatUrl("chat/completions"),
@@ -178,13 +180,13 @@ public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatComplet
178180
}
179181

180182
@Override
181-
public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
183+
public SyncOrAsyncOrStreaming<String> chatCompletion(OpenAiClientContext context, String userMessage) {
182184
ChatCompletionRequest request = ChatCompletionRequest.builder().addUserMessage(userMessage).build();
183185

184186
ChatCompletionRequest syncRequest = ChatCompletionRequest.builder().from(request).stream(null).build();
185187

186188
return new RequestExecutor<>(
187-
openAiApi.chatCompletions(syncRequest, apiVersion),
189+
openAiApi.chatCompletions(context.headers(), syncRequest, apiVersion),
188190
ChatCompletionResponse::content,
189191
okHttpClient,
190192
formatUrl("chat/completions"),
@@ -196,32 +198,38 @@ public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
196198
}
197199

198200
@Override
199-
public SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request) {
200-
return new RequestExecutor<>(openAiApi.embeddings(request, apiVersion), r -> r);
201+
public SyncOrAsync<EmbeddingResponse> embedding(OpenAiClientContext context, EmbeddingRequest request) {
202+
return new RequestExecutor<>(openAiApi.embeddings(context.headers(), request, apiVersion), r -> r);
201203
}
202204

203205
@Override
204-
public SyncOrAsync<List<Float>> embedding(String input) {
206+
public SyncOrAsync<List<Float>> embedding(OpenAiClientContext context, String input) {
205207
EmbeddingRequest request = EmbeddingRequest.builder().input(input).build();
206208

207-
return new RequestExecutor<>(openAiApi.embeddings(request, apiVersion), EmbeddingResponse::embedding);
209+
return new RequestExecutor<>(openAiApi.embeddings(context.headers(), request, apiVersion),
210+
EmbeddingResponse::embedding);
208211
}
209212

210213
@Override
211-
public SyncOrAsync<ModerationResponse> moderation(ModerationRequest request) {
212-
return new RequestExecutor<>(openAiApi.moderations(request, apiVersion), r -> r);
214+
public SyncOrAsync<ModerationResponse> moderation(OpenAiClientContext context,
215+
ModerationRequest request) {
216+
return new RequestExecutor<>(openAiApi.moderations(context.headers(), request, apiVersion),
217+
r -> r);
213218
}
214219

215220
@Override
216-
public SyncOrAsync<ModerationResult> moderation(String input) {
221+
public SyncOrAsync<ModerationResult> moderation(OpenAiClientContext context, String input) {
217222
ModerationRequest request = ModerationRequest.builder().input(input).build();
218223

219-
return new RequestExecutor<>(openAiApi.moderations(request, apiVersion), r -> r.results().get(0));
224+
return new RequestExecutor<>(openAiApi.moderations(context.headers(), request, apiVersion),
225+
r -> r.results().get(0));
220226
}
221227

222228
@Override
223-
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request) {
224-
return new RequestExecutor<>(openAiApi.imagesGenerations(request, apiVersion), r -> r);
229+
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(OpenAiClientContext context,
230+
GenerateImagesRequest request) {
231+
return new RequestExecutor<>(openAiApi.imagesGenerations(context.headers(), request, apiVersion),
232+
r -> r);
225233
}
226234

227235
private String formatUrl(String endpoint) {

src/main/java/dev/ai4j/openai4j/OpenAiApi.java

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,26 @@
1010
import dev.ai4j.openai4j.image.GenerateImagesResponse;
1111
import dev.ai4j.openai4j.moderation.ModerationRequest;
1212
import dev.ai4j.openai4j.moderation.ModerationResponse;
13+
import java.util.Map;
1314
import retrofit2.Call;
1415
import retrofit2.http.Body;
16+
import retrofit2.http.HeaderMap;
1517
import retrofit2.http.Headers;
1618
import retrofit2.http.POST;
1719
import retrofit2.http.Query;
1820

1921
interface OpenAiApi {
2022
@POST("completions")
2123
@Headers("Content-Type: application/json")
22-
Call<CompletionResponse> completions(@Body CompletionRequest request, @Query("api-version") String apiVersion);
24+
Call<CompletionResponse> completions(@Body CompletionRequest request,
25+
@Query("api-version") String apiVersion);
26+
27+
@POST("completions")
28+
@Headers("Content-Type: application/json")
29+
Call<CompletionResponse> completions(
30+
@HeaderMap Map<String, String> headers,
31+
@Body CompletionRequest request,
32+
@Query("api-version") String apiVersion);
2333

2434
@POST("chat/completions")
2535
@Headers("Content-Type: application/json")
@@ -28,17 +38,51 @@ Call<ChatCompletionResponse> chatCompletions(
2838
@Query("api-version") String apiVersion
2939
);
3040

41+
@POST("chat/completions")
42+
@Headers("Content-Type: application/json")
43+
Call<ChatCompletionResponse> chatCompletions(
44+
@HeaderMap Map<String, String> headers,
45+
@Body ChatCompletionRequest request,
46+
@Query("api-version") String apiVersion
47+
);
48+
49+
@POST("embeddings")
50+
@Headers("Content-Type: application/json")
51+
Call<EmbeddingResponse> embeddings(
52+
@Body EmbeddingRequest request,
53+
@Query("api-version") String apiVersion);
54+
3155
@POST("embeddings")
3256
@Headers("Content-Type: application/json")
33-
Call<EmbeddingResponse> embeddings(@Body EmbeddingRequest request, @Query("api-version") String apiVersion);
57+
Call<EmbeddingResponse> embeddings(
58+
@HeaderMap Map<String, String> headers,
59+
@Body EmbeddingRequest request,
60+
@Query("api-version") String apiVersion);
61+
62+
@POST("moderations")
63+
@Headers("Content-Type: application/json")
64+
Call<ModerationResponse> moderations(
65+
@Body ModerationRequest request,
66+
@Query("api-version") String apiVersion);
3467

3568
@POST("moderations")
3669
@Headers("Content-Type: application/json")
37-
Call<ModerationResponse> moderations(@Body ModerationRequest request, @Query("api-version") String apiVersion);
70+
Call<ModerationResponse> moderations(
71+
@HeaderMap Map<String, String> headers,
72+
@Body ModerationRequest request,
73+
@Query("api-version") String apiVersion);
74+
75+
@POST("images/generations")
76+
@Headers({"Content-Type: application/json"})
77+
Call<GenerateImagesResponse> imagesGenerations(
78+
@Body GenerateImagesRequest request,
79+
@Query("api-version") String apiVersion
80+
);
3881

3982
@POST("images/generations")
40-
@Headers({ "Content-Type: application/json" })
83+
@Headers({"Content-Type: application/json"})
4184
Call<GenerateImagesResponse> imagesGenerations(
85+
@HeaderMap Map<String, String> headers,
4286
@Body GenerateImagesRequest request,
4387
@Query("api-version") String apiVersion
4488
);

src/main/java/dev/ai4j/openai4j/OpenAiClient.java

Lines changed: 106 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
package dev.ai4j.openai4j;
22

3-
import java.net.InetSocketAddress;
4-
import java.net.Proxy;
5-
import java.nio.file.Path;
6-
import java.nio.file.Paths;
7-
import java.time.Duration;
8-
import java.util.List;
3+
import static dev.ai4j.openai4j.LogLevel.DEBUG;
94

105
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
116
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
@@ -20,29 +15,101 @@
2015
import dev.ai4j.openai4j.moderation.ModerationResult;
2116
import dev.ai4j.openai4j.spi.OpenAiClientBuilderFactory;
2217
import dev.ai4j.openai4j.spi.ServiceHelper;
18+
import java.net.InetSocketAddress;
19+
import java.net.Proxy;
20+
import java.nio.file.Path;
21+
import java.nio.file.Paths;
22+
import java.time.Duration;
23+
import java.util.HashMap;
24+
import java.util.List;
2325
import java.util.Map;
2426

25-
import static dev.ai4j.openai4j.LogLevel.DEBUG;
26-
2727
public abstract class OpenAiClient {
2828

29-
public abstract SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request);
29+
public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request) {
30+
return completion(new OpenAiClientContext(), request);
31+
}
32+
33+
public SyncOrAsyncOrStreaming<CompletionResponse> completion(
34+
OpenAiClientContext clientContext, CompletionRequest request) {
35+
throw new UnsupportedOperationException();
36+
}
37+
38+
public SyncOrAsyncOrStreaming<String> completion(String prompt) {
39+
return completion(new OpenAiClientContext(), prompt);
40+
}
41+
42+
public SyncOrAsyncOrStreaming<String> completion(OpenAiClientContext clientContext,
43+
String prompt) {
44+
throw new UnsupportedOperationException();
45+
}
46+
47+
public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(
48+
ChatCompletionRequest request) {
49+
return chatCompletion(new OpenAiClientContext(), request);
50+
}
51+
52+
public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(
53+
OpenAiClientContext clientContext,
54+
ChatCompletionRequest request) {
55+
throw new UnsupportedOperationException();
56+
}
57+
58+
public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
59+
return chatCompletion(new OpenAiClientContext(), userMessage);
60+
}
3061

31-
public abstract SyncOrAsyncOrStreaming<String> completion(String prompt);
62+
public SyncOrAsyncOrStreaming<String> chatCompletion(
63+
OpenAiClientContext clientContext,
64+
String userMessage) {
65+
throw new UnsupportedOperationException();
66+
}
3267

33-
public abstract SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatCompletionRequest request);
68+
public SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request) {
69+
return embedding(new OpenAiClientContext(), request);
70+
}
3471

35-
public abstract SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage);
72+
public SyncOrAsync<EmbeddingResponse> embedding(OpenAiClientContext clientContext,
73+
EmbeddingRequest request) {
74+
throw new UnsupportedOperationException();
75+
}
3676

37-
public abstract SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request);
77+
public SyncOrAsync<List<Float>> embedding(String input) {
78+
return embedding(new OpenAiClientContext(), input);
79+
}
3880

39-
public abstract SyncOrAsync<List<Float>> embedding(String input);
81+
public SyncOrAsync<List<Float>> embedding(OpenAiClientContext clientContext,
82+
String input) {
83+
throw new UnsupportedOperationException();
84+
}
85+
86+
public SyncOrAsync<ModerationResponse> moderation(ModerationRequest request) {
87+
return moderation(new OpenAiClientContext(), request);
88+
}
89+
90+
public SyncOrAsync<ModerationResponse> moderation(OpenAiClientContext clientContext,
91+
ModerationRequest request) {
92+
throw new UnsupportedOperationException();
93+
}
94+
95+
public SyncOrAsync<ModerationResult> moderation(String input) {
96+
return moderation(new OpenAiClientContext(), input);
97+
}
4098

41-
public abstract SyncOrAsync<ModerationResponse> moderation(ModerationRequest request);
99+
public SyncOrAsync<ModerationResult> moderation(OpenAiClientContext clientContext,
100+
String input) {
101+
throw new UnsupportedOperationException();
102+
}
42103

43-
public abstract SyncOrAsync<ModerationResult> moderation(String input);
104+
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request) {
105+
return imagesGeneration(new OpenAiClientContext(), request);
106+
}
44107

45-
public abstract SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request);
108+
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(
109+
OpenAiClientContext clientContext,
110+
GenerateImagesRequest request) {
111+
throw new UnsupportedOperationException();
112+
}
46113

47114
public abstract void shutdown();
48115

@@ -55,6 +122,28 @@ public static OpenAiClient.Builder builder() {
55122
return DefaultOpenAiClient.builder();
56123
}
57124

125+
public static class OpenAiClientContext {
126+
private final Map<String, String> headers = new HashMap<>();
127+
128+
public OpenAiClientContext addHeaders(Map<String, String> headers) {
129+
this.headers.putAll(headers);
130+
return this;
131+
}
132+
133+
public OpenAiClientContext addHeader(String key, String value) {
134+
headers.put(key, value);
135+
return this;
136+
}
137+
138+
public Map<String, String> headers() {
139+
return headers;
140+
}
141+
142+
public static OpenAiClientContext create() {
143+
return new OpenAiClientContext();
144+
}
145+
}
146+
58147
@SuppressWarnings("unchecked")
59148
public abstract static class Builder<T extends OpenAiClient, B extends Builder<T, B>> {
60149

src/main/java/dev/ai4j/openai4j/chat/ChatCompletionModel.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,15 @@
33
public enum ChatCompletionModel {
44

55
GPT_3_5_TURBO("gpt-3.5-turbo"), // alias
6-
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"),
76
GPT_3_5_TURBO_1106("gpt-3.5-turbo-1106"),
87
GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),
98

10-
GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k"), // alias
11-
GPT_3_5_TURBO_16K_0613("gpt-3.5-turbo-16k-0613"),
12-
139
GPT_4("gpt-4"), // alias
1410
GPT_4_0314("gpt-4-0314"),
1511
GPT_4_0613("gpt-4-0613"),
1612

13+
GPT_4_TURBO("gpt-4-turbo"), // alias
14+
GPT_4_TURBO_2024_04_09("gpt-4-turbo-2024-04-09"), // With vision support
1715
GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview"), // alias
1816
GPT_4_1106_PREVIEW("gpt-4-1106-preview"),
1917
GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
@@ -22,7 +20,10 @@ public enum ChatCompletionModel {
2220
GPT_4_32K_0314("gpt-4-32k-0314"),
2321
GPT_4_32K_0613("gpt-4-32k-0613"),
2422

25-
GPT_4_VISION_PREVIEW("gpt-4-vision-preview");
23+
@Deprecated
24+
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
25+
GPT_4O("gpt-4o"),
26+
GPT_4O_2024_05_13("gpt-4o-2024-05-13");
2627

2728
private final String value;
2829

0 commit comments

Comments
 (0)