Skip to content

Commit dd6c0a9

Browse files
filiphrilayaperumalg
authored andcommitted
Resolve OpenAI ApiKey for every request
- Set ApiKey as late as possible Signed-off-by: Filip Hrisafov <filip.hrisafov@gmail.com>
1 parent f4f2cfd commit dd6c0a9

File tree

8 files changed

+904
-22
lines changed

8 files changed

+904
-22
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
* @author Thomas Vitale
6363
* @author David Frizelle
6464
* @author Alexandros Pappas
65+
* @author Filip Hrisafov
6566
*/
6667
public class OpenAiApi {
6768

@@ -128,10 +129,6 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> he
128129

129130
// @formatter:off
130131
Consumer<HttpHeaders> finalHeaders = h -> {
131-
if (!(apiKey instanceof NoopApiKey)) {
132-
h.setBearerAuth(apiKey.getValue());
133-
}
134-
135132
h.setContentType(MediaType.APPLICATION_JSON);
136133
h.addAll(headers);
137134
};
@@ -179,12 +176,17 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
179176
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
180177
Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null.");
181178

179+
// @formatter:off
182180
return this.restClient.post()
183181
.uri(this.completionsPath)
184-
.headers(headers -> headers.addAll(additionalHttpHeader))
182+
.headers(headers -> {
183+
headers.addAll(additionalHttpHeader);
184+
addDefaultHeadersIfMissing(headers);
185+
})
185186
.body(chatRequest)
186187
.retrieve()
187188
.toEntity(ChatCompletion.class);
189+
// @formatter:on
188190
}
189191

190192
/**
@@ -213,9 +215,13 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
213215

214216
AtomicBoolean isInsideTool = new AtomicBoolean(false);
215217

218+
// @formatter:off
216219
return this.webClient.post()
217220
.uri(this.completionsPath)
218-
.headers(headers -> headers.addAll(additionalHttpHeader))
221+
.headers(headers -> {
222+
headers.addAll(additionalHttpHeader);
223+
addDefaultHeadersIfMissing(headers);
224+
}) // @formatter:on
219225
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
220226
.retrieve()
221227
.bodyToFlux(String.class)
@@ -289,13 +295,20 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
289295

290296
return this.restClient.post()
291297
.uri(this.embeddingsPath)
298+
.headers(this::addDefaultHeadersIfMissing)
292299
.body(embeddingRequest)
293300
.retrieve()
294301
.toEntity(new ParameterizedTypeReference<>() {
295302

296303
});
297304
}
298305

306+
private void addDefaultHeadersIfMissing(HttpHeaders headers) {
307+
if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) {
308+
headers.setBearerAuth(this.apiKey.getValue());
309+
}
310+
}
311+
299312
// Package-private getters for mutate/copy
300313
String getBaseUrl() {
301314
return this.baseUrl;

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
* @author Christian Tzolov
5050
* @author Ilayaperumal Gopinathan
5151
* @author Jonghoon Park
52+
* @author Filip Hrisafov
5253
* @since 0.8.1
5354
*/
5455
public class OpenAiAudioApi {
@@ -71,20 +72,30 @@ public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
7172
ResponseErrorHandler responseErrorHandler) {
7273

7374
Consumer<HttpHeaders> authHeaders = h -> {
74-
if (!(apiKey instanceof NoopApiKey)) {
75-
h.setBearerAuth(apiKey.getValue());
76-
}
7775
h.addAll(headers);
78-
// h.setContentType(MediaType.APPLICATION_JSON);
7976
};
8077

78+
// @formatter:off
8179
this.restClient = restClientBuilder.clone()
8280
.baseUrl(baseUrl)
8381
.defaultHeaders(authHeaders)
8482
.defaultStatusHandler(responseErrorHandler)
83+
.defaultRequest(requestHeadersSpec -> {
84+
if (!(apiKey instanceof NoopApiKey)) {
85+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
86+
}
87+
})
8588
.build();
8689

87-
this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(authHeaders).build();
90+
this.webClient = webClientBuilder.clone()
91+
.baseUrl(baseUrl)
92+
.defaultHeaders(authHeaders)
93+
.defaultRequest(requestHeadersSpec -> {
94+
if (!(apiKey instanceof NoopApiKey)) {
95+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
96+
}
97+
})
98+
.build(); // @formatter:on
8899
}
89100

90101
public static Builder builder() {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.SimpleApiKey;
2828
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
2929
import org.springframework.ai.retry.RetryUtils;
30+
import org.springframework.http.HttpHeaders;
3031
import org.springframework.http.MediaType;
3132
import org.springframework.http.ResponseEntity;
3233
import org.springframework.util.Assert;
@@ -40,6 +41,7 @@
4041
*
4142
* @see <a href= "https://platform.openai.com/docs/api-reference/images">Images</a>
4243
* @author lambochen
44+
* @author Filip Hrisafov
4345
*/
4446
public class OpenAiImageApi {
4547

@@ -62,15 +64,18 @@ public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
6264
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
6365

6466
// @formatter:off
65-
this.restClient = restClientBuilder.baseUrl(baseUrl)
67+
this.restClient = restClientBuilder.clone()
68+
.baseUrl(baseUrl)
6669
.defaultHeaders(h -> {
67-
if (!(apiKey instanceof NoopApiKey)) {
68-
h.setBearerAuth(apiKey.getValue());
69-
}
7070
h.setContentType(MediaType.APPLICATION_JSON);
7171
h.addAll(headers);
7272
})
7373
.defaultStatusHandler(responseErrorHandler)
74+
.defaultRequest(requestHeadersSpec -> {
75+
if (!(apiKey instanceof NoopApiKey)) {
76+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
77+
}
78+
})
7479
.build();
7580
// @formatter:on
7681

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.SimpleApiKey;
2828
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
2929
import org.springframework.ai.retry.RetryUtils;
30+
import org.springframework.http.HttpHeaders;
3031
import org.springframework.http.MediaType;
3132
import org.springframework.http.ResponseEntity;
3233
import org.springframework.util.Assert;
@@ -40,6 +41,7 @@
4041
*
4142
* @author Ahmed Yousri
4243
* @author Ilayaperumal Gopinathan
44+
* @author Filip Hrisafov
4345
* @see <a href=
4446
* "https://platform.openai.com/docs/api-reference/moderations">https://platform.openai.com/docs/api-reference/moderations</a>
4547
*/
@@ -64,13 +66,20 @@ public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap<String,
6466

6567
this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
6668

67-
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> {
68-
if (!(apiKey instanceof NoopApiKey)) {
69-
h.setBearerAuth(apiKey.getValue());
70-
}
71-
h.setContentType(MediaType.APPLICATION_JSON);
72-
h.addAll(headers);
73-
}).defaultStatusHandler(responseErrorHandler).build();
69+
// @formatter:off
70+
this.restClient = restClientBuilder.clone()
71+
.baseUrl(baseUrl)
72+
.defaultHeaders(h -> {
73+
h.setContentType(MediaType.APPLICATION_JSON);
74+
h.addAll(headers);
75+
})
76+
.defaultStatusHandler(responseErrorHandler)
77+
.defaultRequest(requestHeadersSpec -> {
78+
if (!(apiKey instanceof NoopApiKey)) {
79+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
80+
}
81+
})
82+
.build(); // @formatter:on
7483
}
7584

7685
public ResponseEntity<OpenAiModerationResponse> createModeration(OpenAiModerationRequest openAiModerationRequest) {

0 commit comments

Comments
 (0)