Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<groupId>dev.ai4j</groupId>
<artifactId>openai4j</artifactId>
<version>0.17.0</version>
<version>0.18.0</version>

<name>Java Client for OpenAI (ChatGPT)</name>
<description>Java Client for OpenAI (ChatGPT)</description>
Expand Down
44 changes: 26 additions & 18 deletions src/main/java/dev/ai4j/openai4j/DefaultOpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,12 @@ public DefaultOpenAiClient build() {
}

@Override
public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request) {
public SyncOrAsyncOrStreaming<CompletionResponse> completion(OpenAiClientContext context,
CompletionRequest request) {
CompletionRequest syncRequest = CompletionRequest.builder().from(request).stream(null).build();

return new RequestExecutor<>(
openAiApi.completions(syncRequest, apiVersion),
openAiApi.completions(context.headers(), syncRequest, apiVersion),
r -> r,
okHttpClient,
formatUrl("completions"),
Expand All @@ -144,13 +145,13 @@ public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest r
}

@Override
public SyncOrAsyncOrStreaming<String> completion(String prompt) {
public SyncOrAsyncOrStreaming<String> completion(OpenAiClientContext context, String prompt) {
CompletionRequest request = CompletionRequest.builder().prompt(prompt).build();

CompletionRequest syncRequest = CompletionRequest.builder().from(request).stream(null).build();

return new RequestExecutor<>(
openAiApi.completions(syncRequest, apiVersion),
openAiApi.completions(context.headers(), syncRequest, apiVersion),
CompletionResponse::text,
okHttpClient,
formatUrl("completions"),
Expand All @@ -162,11 +163,12 @@ public SyncOrAsyncOrStreaming<String> completion(String prompt) {
}

@Override
public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatCompletionRequest request) {
public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(OpenAiClientContext context,
ChatCompletionRequest request) {
ChatCompletionRequest syncRequest = ChatCompletionRequest.builder().from(request).stream(null).build();

return new RequestExecutor<>(
openAiApi.chatCompletions(syncRequest, apiVersion),
openAiApi.chatCompletions(context.headers(), syncRequest, apiVersion),
r -> r,
okHttpClient,
formatUrl("chat/completions"),
Expand All @@ -178,13 +180,13 @@ public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatComplet
}

@Override
public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
public SyncOrAsyncOrStreaming<String> chatCompletion(OpenAiClientContext context, String userMessage) {
ChatCompletionRequest request = ChatCompletionRequest.builder().addUserMessage(userMessage).build();

ChatCompletionRequest syncRequest = ChatCompletionRequest.builder().from(request).stream(null).build();

return new RequestExecutor<>(
openAiApi.chatCompletions(syncRequest, apiVersion),
openAiApi.chatCompletions(context.headers(), syncRequest, apiVersion),
ChatCompletionResponse::content,
okHttpClient,
formatUrl("chat/completions"),
Expand All @@ -196,32 +198,38 @@ public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
}

@Override
public SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request) {
return new RequestExecutor<>(openAiApi.embeddings(request, apiVersion), r -> r);
public SyncOrAsync<EmbeddingResponse> embedding(OpenAiClientContext context, EmbeddingRequest request) {
return new RequestExecutor<>(openAiApi.embeddings(context.headers(), request, apiVersion), r -> r);
}

@Override
public SyncOrAsync<List<Float>> embedding(String input) {
public SyncOrAsync<List<Float>> embedding(OpenAiClientContext context, String input) {
EmbeddingRequest request = EmbeddingRequest.builder().input(input).build();

return new RequestExecutor<>(openAiApi.embeddings(request, apiVersion), EmbeddingResponse::embedding);
return new RequestExecutor<>(openAiApi.embeddings(context.headers(), request, apiVersion),
EmbeddingResponse::embedding);
}

@Override
public SyncOrAsync<ModerationResponse> moderation(ModerationRequest request) {
return new RequestExecutor<>(openAiApi.moderations(request, apiVersion), r -> r);
public SyncOrAsync<ModerationResponse> moderation(OpenAiClientContext context,
ModerationRequest request) {
return new RequestExecutor<>(openAiApi.moderations(context.headers(), request, apiVersion),
r -> r);
}

@Override
public SyncOrAsync<ModerationResult> moderation(String input) {
public SyncOrAsync<ModerationResult> moderation(OpenAiClientContext context, String input) {
ModerationRequest request = ModerationRequest.builder().input(input).build();

return new RequestExecutor<>(openAiApi.moderations(request, apiVersion), r -> r.results().get(0));
return new RequestExecutor<>(openAiApi.moderations(context.headers(), request, apiVersion),
r -> r.results().get(0));
}

@Override
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request) {
return new RequestExecutor<>(openAiApi.imagesGenerations(request, apiVersion), r -> r);
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(OpenAiClientContext context,
GenerateImagesRequest request) {
return new RequestExecutor<>(openAiApi.imagesGenerations(context.headers(), request, apiVersion),
r -> r);
}

private String formatUrl(String endpoint) {
Expand Down
52 changes: 48 additions & 4 deletions src/main/java/dev/ai4j/openai4j/OpenAiApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,26 @@
import dev.ai4j.openai4j.image.GenerateImagesResponse;
import dev.ai4j.openai4j.moderation.ModerationRequest;
import dev.ai4j.openai4j.moderation.ModerationResponse;
import java.util.Map;
import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.HeaderMap;
import retrofit2.http.Headers;
import retrofit2.http.POST;
import retrofit2.http.Query;

interface OpenAiApi {
@POST("completions")
@Headers("Content-Type: application/json")
Call<CompletionResponse> completions(@Body CompletionRequest request, @Query("api-version") String apiVersion);
Call<CompletionResponse> completions(@Body CompletionRequest request,
@Query("api-version") String apiVersion);

@POST("completions")
@Headers("Content-Type: application/json")
Call<CompletionResponse> completions(
@HeaderMap Map<String, String> headers,
@Body CompletionRequest request,
@Query("api-version") String apiVersion);

@POST("chat/completions")
@Headers("Content-Type: application/json")
Expand All @@ -28,17 +38,51 @@ Call<ChatCompletionResponse> chatCompletions(
@Query("api-version") String apiVersion
);

@POST("chat/completions")
@Headers("Content-Type: application/json")
Call<ChatCompletionResponse> chatCompletions(
@HeaderMap Map<String, String> headers,
@Body ChatCompletionRequest request,
@Query("api-version") String apiVersion
);

@POST("embeddings")
@Headers("Content-Type: application/json")
Call<EmbeddingResponse> embeddings(
@Body EmbeddingRequest request,
@Query("api-version") String apiVersion);

@POST("embeddings")
@Headers("Content-Type: application/json")
Call<EmbeddingResponse> embeddings(@Body EmbeddingRequest request, @Query("api-version") String apiVersion);
Call<EmbeddingResponse> embeddings(
@HeaderMap Map<String, String> headers,
@Body EmbeddingRequest request,
@Query("api-version") String apiVersion);

@POST("moderations")
@Headers("Content-Type: application/json")
Call<ModerationResponse> moderations(
@Body ModerationRequest request,
@Query("api-version") String apiVersion);

@POST("moderations")
@Headers("Content-Type: application/json")
Call<ModerationResponse> moderations(@Body ModerationRequest request, @Query("api-version") String apiVersion);
Call<ModerationResponse> moderations(
@HeaderMap Map<String, String> headers,
@Body ModerationRequest request,
@Query("api-version") String apiVersion);

@POST("images/generations")
@Headers({"Content-Type: application/json"})
Call<GenerateImagesResponse> imagesGenerations(
@Body GenerateImagesRequest request,
@Query("api-version") String apiVersion
);

@POST("images/generations")
@Headers({ "Content-Type: application/json" })
@Headers({"Content-Type: application/json"})
Call<GenerateImagesResponse> imagesGenerations(
@HeaderMap Map<String, String> headers,
@Body GenerateImagesRequest request,
@Query("api-version") String apiVersion
);
Expand Down
123 changes: 106 additions & 17 deletions src/main/java/dev/ai4j/openai4j/OpenAiClient.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
package dev.ai4j.openai4j;

import java.net.InetSocketAddress;
import java.net.Proxy;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.List;
import static dev.ai4j.openai4j.LogLevel.DEBUG;

import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
Expand All @@ -20,29 +15,101 @@
import dev.ai4j.openai4j.moderation.ModerationResult;
import dev.ai4j.openai4j.spi.OpenAiClientBuilderFactory;
import dev.ai4j.openai4j.spi.ServiceHelper;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static dev.ai4j.openai4j.LogLevel.DEBUG;

public abstract class OpenAiClient {

public abstract SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request);
public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request) {
return completion(new OpenAiClientContext(), request);
}

public SyncOrAsyncOrStreaming<CompletionResponse> completion(
OpenAiClientContext clientContext, CompletionRequest request) {
throw new UnsupportedOperationException();
}

public SyncOrAsyncOrStreaming<String> completion(String prompt) {
return completion(new OpenAiClientContext(), prompt);
}

public SyncOrAsyncOrStreaming<String> completion(OpenAiClientContext clientContext,
String prompt) {
throw new UnsupportedOperationException();
}

public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(
ChatCompletionRequest request) {
return chatCompletion(new OpenAiClientContext(), request);
}

public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(
OpenAiClientContext clientContext,
ChatCompletionRequest request) {
throw new UnsupportedOperationException();
}

public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
return chatCompletion(new OpenAiClientContext(), userMessage);
}

public abstract SyncOrAsyncOrStreaming<String> completion(String prompt);
public SyncOrAsyncOrStreaming<String> chatCompletion(
OpenAiClientContext clientContext,
String userMessage) {
throw new UnsupportedOperationException();
}

public abstract SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatCompletionRequest request);
public SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request) {
return embedding(new OpenAiClientContext(), request);
}

public abstract SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage);
public SyncOrAsync<EmbeddingResponse> embedding(OpenAiClientContext clientContext,
EmbeddingRequest request) {
throw new UnsupportedOperationException();
}

public abstract SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request);
public SyncOrAsync<List<Float>> embedding(String input) {
return embedding(new OpenAiClientContext(), input);
}

public abstract SyncOrAsync<List<Float>> embedding(String input);
public SyncOrAsync<List<Float>> embedding(OpenAiClientContext clientContext,
String input) {
throw new UnsupportedOperationException();
}

public SyncOrAsync<ModerationResponse> moderation(ModerationRequest request) {
return moderation(new OpenAiClientContext(), request);
}

public SyncOrAsync<ModerationResponse> moderation(OpenAiClientContext clientContext,
ModerationRequest request) {
throw new UnsupportedOperationException();
}

public SyncOrAsync<ModerationResult> moderation(String input) {
return moderation(new OpenAiClientContext(), input);
}

public abstract SyncOrAsync<ModerationResponse> moderation(ModerationRequest request);
public SyncOrAsync<ModerationResult> moderation(OpenAiClientContext clientContext,
String input) {
throw new UnsupportedOperationException();
}

public abstract SyncOrAsync<ModerationResult> moderation(String input);
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request) {
return imagesGeneration(new OpenAiClientContext(), request);
}

public abstract SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request);
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(
OpenAiClientContext clientContext,
GenerateImagesRequest request) {
throw new UnsupportedOperationException();
}

public abstract void shutdown();

Expand All @@ -55,6 +122,28 @@ public static OpenAiClient.Builder builder() {
return DefaultOpenAiClient.builder();
}

public static class OpenAiClientContext {
private final Map<String, String> headers = new HashMap<>();

public OpenAiClientContext addHeaders(Map<String, String> headers) {
this.headers.putAll(headers);
return this;
}

public OpenAiClientContext addHeader(String key, String value) {
headers.put(key, value);
return this;
}

public Map<String, String> headers() {
return headers;
}

public static OpenAiClientContext create() {
return new OpenAiClientContext();
}
}

@SuppressWarnings("unchecked")
public abstract static class Builder<T extends OpenAiClient, B extends Builder<T, B>> {

Expand Down
11 changes: 6 additions & 5 deletions src/main/java/dev/ai4j/openai4j/chat/ChatCompletionModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
public enum ChatCompletionModel {

GPT_3_5_TURBO("gpt-3.5-turbo"), // alias
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"),
GPT_3_5_TURBO_1106("gpt-3.5-turbo-1106"),
GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),

GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k"), // alias
GPT_3_5_TURBO_16K_0613("gpt-3.5-turbo-16k-0613"),

GPT_4("gpt-4"), // alias
GPT_4_0314("gpt-4-0314"),
GPT_4_0613("gpt-4-0613"),

GPT_4_TURBO("gpt-4-turbo"), // alias
GPT_4_TURBO_2024_04_09("gpt-4-turbo-2024-04-09"), // With vision support
GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview"), // alias
GPT_4_1106_PREVIEW("gpt-4-1106-preview"),
GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
Expand All @@ -22,7 +20,10 @@ public enum ChatCompletionModel {
GPT_4_32K_0314("gpt-4-32k-0314"),
GPT_4_32K_0613("gpt-4-32k-0613"),

GPT_4_VISION_PREVIEW("gpt-4-vision-preview");
@Deprecated
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
GPT_4O("gpt-4o"),
GPT_4O_2024_05_13("gpt-4o-2024-05-13");

private final String value;

Expand Down
Loading