Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Commit 4c40747

Browse files
committed
support GPT-4 Turbo API
1 parent fdd7761 commit 4c40747

File tree

9 files changed

+199
-71
lines changed

9 files changed

+199
-71
lines changed

src/main/java/dev/ai4j/openai4j/FunctionCallUtil.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.lang.reflect.Type;
88
import java.util.Map;
99

10+
@Deprecated
1011
public class FunctionCallUtil {
1112

1213
public static final Gson GSON = new Gson();

src/main/java/dev/ai4j/openai4j/MessageTypeAdapter.java

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -85,54 +85,6 @@ public void write(JsonWriter out, Message message) throws IOException {
8585

8686
@Override
8787
public Message read(JsonReader in) throws IOException {
88-
// return delegate.read(in);
89-
in.beginObject();
90-
91-
Message.Builder builder = Message.builder();
92-
93-
while (in.hasNext()) {
94-
String name = in.nextName();
95-
96-
switch (name) {
97-
case "role":
98-
System.out.println("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"+in.nextString());
99-
if (in.peek() == JsonToken.STRING) {
100-
builder.role(Role.valueOf(in.nextString()));
101-
} else {
102-
// 错误情况:非字符串的 JSON 令牌。可以抛出异常或者跳过这个字段。
103-
in.skipValue();
104-
// 如果这个字段是必须的,你可能需要抛出异常,例如:
105-
// throw new JsonSyntaxException("Expected a string for role field but got: " + in.peek());
106-
}
107-
// builder.role(Role.valueOf(in.nextString())); // 假设 Role 是一个枚举
108-
break;
109-
case "content":
110-
if (in.peek() == JsonToken.STRING) {
111-
// 如果 content 是一个字符串,将其转化成 Content 对象
112-
String contentString = in.nextString();
113-
// Content content = //... 创建 Content 对象;
114-
Content content = Content.builder().text(contentString).type(ContentType.TEXT.stringValue()).build();
115-
builder.content(Collections.singletonList(content));
116-
} else if (in.peek() == JsonToken.BEGIN_ARRAY) {
117-
// 如果 content 是一个数组,使用标准方法来解析
118-
Type listOfContent = new TypeToken<List<Content>>(){}.getType();
119-
List<Content> contentList = GSON.fromJson(in, listOfContent);
120-
builder.content(contentList);
121-
}
122-
break;
123-
case "name":
124-
builder.name(in.nextString());
125-
break;
126-
case "function_call":
127-
// 解析 function_call
128-
FunctionCall functionCall = GSON.fromJson(in, FunctionCall.class);
129-
builder.functionCall(functionCall);
130-
break;
131-
// 处理其他字段...
132-
}
133-
}
134-
135-
in.endObject();
136-
return builder.build();
88+
return delegate.read(in);
13789
}
13890
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package dev.ai4j.openai4j;
2+
3+
import com.google.gson.Gson;
4+
import com.google.gson.reflect.TypeToken;
5+
import dev.ai4j.openai4j.chat.Function;
6+
import dev.ai4j.openai4j.chat.FunctionCall;
7+
8+
import java.lang.reflect.Type;
9+
import java.util.Map;
10+
11+
public class ToolCallsUtil {
12+
13+
public static final Gson GSON = new Gson();
14+
public static final Type MAP_TYPE = new TypeToken<Map<String, Object>>() {
15+
}.getType();
16+
17+
public static <T> T argument(Function function, String name) {
18+
Map<String, Object> arguments = argumentsAsMap(function.arguments()); // TODO cache
19+
return (T) arguments.get(name);
20+
}
21+
22+
public static Map<String, Object> argumentsAsMap(String arguments) {
23+
return GSON.fromJson(arguments, MAP_TYPE);
24+
}
25+
}

src/main/java/dev/ai4j/openai4j/chat/ChatCompletionChoice.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
public final class ChatCompletionChoice {
66

77
private final Integer index;
8-
private final Message message;
8+
private final MessageResponse message;
99
private final Delta delta;
1010
private final String finishReason;
1111

@@ -20,7 +20,7 @@ public Integer index() {
2020
return index;
2121
}
2222

23-
public Message message() {
23+
public MessageResponse message() {
2424
return message;
2525
}
2626

@@ -73,7 +73,7 @@ public static Builder builder() {
7373
public static final class Builder {
7474

7575
private Integer index;
76-
private Message message;
76+
private MessageResponse message;
7777
private Delta delta;
7878
private String finishReason;
7979

@@ -85,7 +85,7 @@ public Builder index(Integer index) {
8585
return this;
8686
}
8787

88-
public Builder message(Message message) {
88+
public Builder message(MessageResponse message) {
8989
this.message = message;
9090
return this;
9191
}

src/main/java/dev/ai4j/openai4j/chat/ChatCompletionResponse.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public String systemFingerprint() {
5555
*/
5656
@Experimental
5757
public String content() {
58-
return choices().get(0).message().content().get(0).text();
58+
return choices().get(0).message().content();
5959
}
6060

6161
@Override
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package dev.ai4j.openai4j.chat;
2+
3+
import dev.ai4j.openai4j.Experimental;
4+
5+
import java.util.Arrays;
6+
import java.util.List;
7+
import java.util.Objects;
8+
9+
import static dev.ai4j.openai4j.chat.Role.*;
10+
11+
public final class MessageResponse {
12+
13+
private final Role role;
14+
private final String content;
15+
private final String name;
16+
17+
@Deprecated
18+
private final FunctionCall functionCall;
19+
private final List<ToolCalls> toolCalls;
20+
21+
22+
private MessageResponse(Builder builder) {
23+
this.role = builder.role;
24+
this.content = builder.content;
25+
this.name = builder.name;
26+
this.functionCall = builder.functionCall;
27+
this.toolCalls = builder.toolCalls;
28+
}
29+
30+
public Role role() {
31+
return role;
32+
}
33+
34+
public String content() {
35+
return content;
36+
}
37+
38+
public String name() {
39+
return name;
40+
}
41+
42+
public FunctionCall functionCall() {
43+
return functionCall;
44+
}
45+
46+
public List<ToolCalls> toolCalls() {
47+
return toolCalls;
48+
}
49+
50+
@Override
51+
public boolean equals(Object another) {
52+
if (this == another) return true;
53+
return another instanceof MessageResponse
54+
&& equalTo((MessageResponse) another);
55+
}
56+
57+
private boolean equalTo(MessageResponse another) {
58+
return Objects.equals(role, another.role)
59+
&& Objects.equals(content, another.content)
60+
&& Objects.equals(name, another.name)
61+
&& Objects.equals(functionCall, another.functionCall)
62+
&& Objects.equals(toolCalls, another.toolCalls);
63+
}
64+
65+
@Override
66+
public int hashCode() {
67+
int h = 5381;
68+
h += (h << 5) + Objects.hashCode(role);
69+
h += (h << 5) + Objects.hashCode(content);
70+
h += (h << 5) + Objects.hashCode(name);
71+
h += (h << 5) + Objects.hashCode(functionCall);
72+
h += (h << 5) + Objects.hashCode(toolCalls);
73+
return h;
74+
}
75+
76+
@Override
77+
public String toString() {
78+
return "Message{"
79+
+ "role=" + role
80+
+ ", content=" + content
81+
+ ", name=" + name
82+
+ ", functionCall=" + functionCall
83+
+ ", toolCalls=" + toolCalls
84+
+ "}";
85+
}
86+
87+
@Experimental
88+
public static Message assistantMessage(String content) {
89+
Content userContent = Content.builder().type(ContentType.TEXT.stringValue()).text(content).build();
90+
return Message.builder()
91+
.role(ASSISTANT)
92+
.content(Arrays.asList(userContent))
93+
.build();
94+
}
95+
96+
public static Builder builder() {
97+
return new Builder();
98+
}
99+
100+
public static final class Builder {
101+
102+
private Role role;
103+
private String content;
104+
private String name;
105+
@Deprecated
106+
private FunctionCall functionCall;
107+
private List<ToolCalls> toolCalls;
108+
109+
private Builder() {
110+
}
111+
112+
public Builder role(Role role) {
113+
this.role = role;
114+
return this;
115+
}
116+
117+
@Experimental
118+
public Builder role(String role) {
119+
return role(Role.from(role));
120+
}
121+
122+
public Builder content(String content) {
123+
this.content = content;
124+
return this;
125+
}
126+
127+
public Builder name(String name) {
128+
this.name = name;
129+
return this;
130+
}
131+
132+
@Deprecated
133+
public MessageResponse.Builder functionCall(FunctionCall functionCall) {
134+
this.functionCall = functionCall;
135+
return this;
136+
}
137+
138+
public Builder toolCalls(List<ToolCalls> toolCalls) {
139+
this.toolCalls = toolCalls;
140+
return this;
141+
}
142+
143+
public MessageResponse build() {
144+
return new MessageResponse(this);
145+
}
146+
}
147+
}

src/test/java/dev/ai4j/openai4j/chat/ChatCompletionAsyncTest.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import dev.ai4j.openai4j.FunctionCallUtil;
44
import dev.ai4j.openai4j.OpenAiClient;
55
import dev.ai4j.openai4j.RateLimitAwareTest;
6+
import dev.ai4j.openai4j.ToolCallsUtil;
67
import org.junit.jupiter.api.Test;
78
import org.junit.jupiter.params.ParameterizedTest;
89
import org.junit.jupiter.params.provider.Arguments;
@@ -29,7 +30,6 @@ class ChatCompletionAsyncTest extends RateLimitAwareTest {
2930

3031
private final OpenAiClient client = OpenAiClient.builder()
3132
.openAiApiKey(System.getenv("OPENAI_API_KEY"))
32-
.proxy(new Proxy(HTTP, new InetSocketAddress("127.0.0.1",7890)))
3333
.logRequests()
3434
.logResponses()
3535
.build();
@@ -67,7 +67,7 @@ void testCustomizableApi(ChatCompletionRequest request) throws Exception {
6767

6868
assertThat(response.choices()).hasSize(1);
6969
assertThat(response.choices().get(0).message().role()).isEqualTo(ASSISTANT);
70-
assertThat(response.choices().get(0).message().content().get(0).text()).containsIgnoringCase("hello world");
70+
assertThat(response.choices().get(0).message().content()).containsIgnoringCase("hello world");
7171

7272
assertThat(response.content()).containsIgnoringCase("hello world");
7373
}
@@ -120,17 +120,16 @@ void testFunctions() throws Exception {
120120

121121
ChatCompletionResponse response = future.get(30, SECONDS);
122122

123-
Message assistantMessage = response.choices().get(0).message();
123+
MessageResponse assistantMessage = response.choices().get(0).message();
124124
assertThat(assistantMessage.role()).isEqualTo(ASSISTANT);
125125
assertThat(assistantMessage.content()).isNull();
126126

127127
Function function = assistantMessage.toolCalls().get(0).function();
128-
129-
// FunctionCall functionCall = assistantMessage.functionCall();
130128
assertThat(function.name()).isEqualTo("get_current_weather");
131129
assertThat(function.arguments()).isNotBlank();
132130

133-
Map<String, Object> arguments = FunctionCallUtil.argumentsAsMap(function.arguments());
131+
Map<String, Object> arguments = ToolCallsUtil.argumentsAsMap(function.arguments());
132+
134133
assertThat(arguments).hasSize(1);
135134
assertThat(arguments.get("location").toString()).contains("Boston");
136135
}

src/test/java/dev/ai4j/openai4j/chat/ChatCompletionStreamingTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class ChatCompletionStreamingTest extends RateLimitAwareTest {
3030

3131
private final OpenAiClient client = OpenAiClient.builder()
3232
.openAiApiKey(System.getenv("OPENAI_API_KEY"))
33-
.proxy(new Proxy(HTTP, new InetSocketAddress("127.0.0.1",7890)))
3433
.logRequests()
3534
.logResponses()
3635
.logStreamingResponses()

0 commit comments

Comments
 (0)