Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Commit fdd7761

Browse files
committed
support GPT-4-Turbo API Test unit
1 parent edd5096 commit fdd7761

File tree

7 files changed

+139
-34
lines changed

7 files changed

+139
-34
lines changed

src/main/java/dev/ai4j/openai4j/MessageTypeAdapter.java

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,23 @@
55
import com.google.gson.TypeAdapterFactory;
66
import com.google.gson.reflect.TypeToken;
77
import com.google.gson.stream.JsonReader;
8+
import com.google.gson.stream.JsonToken;
89
import com.google.gson.stream.JsonWriter;
9-
import dev.ai4j.openai4j.chat.Content;
10-
import dev.ai4j.openai4j.chat.FunctionCall;
11-
import dev.ai4j.openai4j.chat.Message;
12-
import dev.ai4j.openai4j.chat.ToolCalls;
10+
import dev.ai4j.openai4j.chat.*;
1311

1412
import java.io.IOException;
13+
import java.lang.reflect.Type;
14+
import java.util.ArrayList;
15+
import java.util.Arrays;
16+
import java.util.Collections;
1517
import java.util.List;
1618

19+
import static dev.ai4j.openai4j.Json.GSON;
20+
1721
class MessageTypeAdapter extends TypeAdapter<Message> {
1822

23+
24+
1925
static final TypeAdapterFactory MESSAGE_TYPE_ADAPTER_FACTORY = new TypeAdapterFactory() {
2026

2127
@Override
@@ -50,7 +56,7 @@ public void write(JsonWriter out, Message message) throws IOException {
5056
out.setSerializeNulls(serializeNulls);
5157
} else {
5258
if (message.content().get(0).type() != null){
53-
TypeAdapter<List> contentTypeAdapter = Json.GSON.getAdapter(List.class);
59+
TypeAdapter<List> contentTypeAdapter = GSON.getAdapter(List.class);
5460
contentTypeAdapter.write(out,message.content());
5561
}else {
5662
out.value(message.content().get(0).text());
@@ -64,13 +70,13 @@ public void write(JsonWriter out, Message message) throws IOException {
6470

6571
if (message.functionCall() != null) {
6672
out.name("function_call");
67-
TypeAdapter<FunctionCall> functionCallTypeAdapter = Json.GSON.getAdapter(FunctionCall.class);
73+
TypeAdapter<FunctionCall> functionCallTypeAdapter = GSON.getAdapter(FunctionCall.class);
6874
functionCallTypeAdapter.write(out, message.functionCall());
6975
}
7076

7177
if (message.toolCalls() != null){
7278
out.name("tool_calls");
73-
TypeAdapter<List> toolCallsTypeAdapter = Json.GSON.getAdapter(List.class);
79+
TypeAdapter<List> toolCallsTypeAdapter = GSON.getAdapter(List.class);
7480
toolCallsTypeAdapter.write(out, message.toolCalls());
7581
}
7682

@@ -79,6 +85,54 @@ public void write(JsonWriter out, Message message) throws IOException {
7985

8086
@Override
8187
public Message read(JsonReader in) throws IOException {
82-
return delegate.read(in);
88+
// return delegate.read(in);
89+
in.beginObject();
90+
91+
Message.Builder builder = Message.builder();
92+
93+
while (in.hasNext()) {
94+
String name = in.nextName();
95+
96+
switch (name) {
97+
case "role":
98+
System.out.println("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"+in.nextString());
99+
if (in.peek() == JsonToken.STRING) {
100+
builder.role(Role.valueOf(in.nextString()));
101+
} else {
102+
// 错误情况:非字符串的 JSON 令牌。可以抛出异常或者跳过这个字段。
103+
in.skipValue();
104+
// 如果这个字段是必须的,你可能需要抛出异常,例如:
105+
// throw new JsonSyntaxException("Expected a string for role field but got: " + in.peek());
106+
}
107+
// builder.role(Role.valueOf(in.nextString())); // 假设 Role 是一个枚举
108+
break;
109+
case "content":
110+
if (in.peek() == JsonToken.STRING) {
111+
// 如果 content 是一个字符串,将其转化成 Content 对象
112+
String contentString = in.nextString();
113+
// Content content = //... 创建 Content 对象;
114+
Content content = Content.builder().text(contentString).type(ContentType.TEXT.stringValue()).build();
115+
builder.content(Collections.singletonList(content));
116+
} else if (in.peek() == JsonToken.BEGIN_ARRAY) {
117+
// 如果 content 是一个数组,使用标准方法来解析
118+
Type listOfContent = new TypeToken<List<Content>>(){}.getType();
119+
List<Content> contentList = GSON.fromJson(in, listOfContent);
120+
builder.content(contentList);
121+
}
122+
break;
123+
case "name":
124+
builder.name(in.nextString());
125+
break;
126+
case "function_call":
127+
// 解析 function_call
128+
FunctionCall functionCall = GSON.fromJson(in, FunctionCall.class);
129+
builder.functionCall(functionCall);
130+
break;
131+
// 处理其他字段...
132+
}
133+
}
134+
135+
in.endObject();
136+
return builder.build();
83137
}
84138
}

src/main/java/dev/ai4j/openai4j/chat/Function.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ public class Function {
1212
private final String description;
1313
private final Parameters parameters;
1414

15+
private final String arguments;
16+
1517
private Function(Builder builder) {
1618
this.name = builder.name;
1719
this.description = builder.description;
1820
this.parameters = builder.parameters;
21+
this.arguments = builder.arguments;
1922
}
2023

2124
public String name() {
@@ -30,6 +33,10 @@ public Parameters parameters() {
3033
return parameters;
3134
}
3235

36+
public String arguments() {
37+
return arguments;
38+
}
39+
3340
@Override
3441
public boolean equals(Object another) {
3542
if (this == another) return true;
@@ -40,7 +47,8 @@ public boolean equals(Object another) {
4047
private boolean equalTo(Function another) {
4148
return Objects.equals(name, another.name)
4249
&& Objects.equals(description, another.description)
43-
&& Objects.equals(parameters, another.parameters);
50+
&& Objects.equals(parameters, another.parameters)
51+
&& Objects.equals(arguments, another.arguments);
4452
}
4553

4654
@Override
@@ -49,6 +57,7 @@ public int hashCode() {
4957
h += (h << 5) + Objects.hashCode(name);
5058
h += (h << 5) + Objects.hashCode(description);
5159
h += (h << 5) + Objects.hashCode(parameters);
60+
h += (h << 5) + Objects.hashCode(arguments);
5261
return h;
5362
}
5463

@@ -58,6 +67,7 @@ public String toString() {
5867
+ "name=" + name
5968
+ ", description=" + description
6069
+ ", parameters=" + parameters
70+
+ ", arguments=" + arguments
6171
+ "}";
6272
}
6373

@@ -71,6 +81,8 @@ public static final class Builder {
7181
private String description;
7282
private Parameters parameters;
7383

84+
private String arguments;
85+
7486
private Builder() {
7587
}
7688

@@ -89,6 +101,11 @@ public Builder parameters(Parameters parameters) {
89101
return this;
90102
}
91103

104+
public Builder arguments(String arguments) {
105+
this.arguments = arguments;
106+
return this;
107+
}
108+
92109
@Experimental
93110
public Builder addParameter(String name, JsonSchemaProperty... jsonSchemaProperties) {
94111
addOptionalParameter(name, jsonSchemaProperties);

src/main/java/dev/ai4j/openai4j/chat/Message.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import dev.ai4j.openai4j.Experimental;
44

5-
import java.util.ArrayList;
65
import java.util.Arrays;
76
import java.util.List;
87
import java.util.Objects;

src/main/java/dev/ai4j/openai4j/chat/ToolCalls.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@ public class ToolCalls {
88
private final String name;
99
private final String arguments;
1010

11+
private final Integer index;
12+
13+
private final Function function;
14+
1115
private ToolCalls(ToolCalls.Builder builder) {
1216
this.name = builder.name;
1317
this.arguments = builder.arguments;
18+
this.index = builder.index;
19+
this.function = builder.function;
1420
}
1521

1622
public String name() {
@@ -21,6 +27,14 @@ public String arguments() {
2127
return arguments;
2228
}
2329

30+
public Integer index(){
31+
return index;
32+
}
33+
34+
public Function function(){
35+
return function;
36+
}
37+
2438
@Override
2539
public boolean equals(Object another) {
2640
if (this == another) return true;
@@ -30,14 +44,18 @@ public boolean equals(Object another) {
3044

3145
private boolean equalTo(ToolCalls another) {
3246
return Objects.equals(name, another.name)
33-
&& Objects.equals(arguments, another.arguments);
47+
&& Objects.equals(arguments, another.arguments)
48+
&& Objects.equals(index, another.index)
49+
&& Objects.equals(function, another.function);
3450
}
3551

3652
@Override
3753
public int hashCode() {
3854
int h = 5381;
3955
h += (h << 5) + Objects.hashCode(name);
4056
h += (h << 5) + Objects.hashCode(arguments);
57+
h += (h << 5) + Objects.hashCode(index);
58+
h += (h << 5) + Objects.hashCode(function);
4159
return h;
4260
}
4361

@@ -46,6 +64,8 @@ public String toString() {
4664
return "Function{"
4765
+ "name=" + name
4866
+ ", arguments=" + arguments
67+
+ ", index=" + index
68+
+ ", function=" + function
4969
+ "}";
5070
}
5171

@@ -57,6 +77,9 @@ public static final class Builder {
5777

5878
private String name;
5979
private String arguments;
80+
private Integer index;
81+
82+
private Function function;
6083

6184
private Builder() {
6285
}
@@ -71,6 +94,16 @@ public ToolCalls.Builder arguments(String arguments) {
7194
return this;
7295
}
7396

97+
public ToolCalls.Builder index(Integer index) {
98+
this.index = index;
99+
return this;
100+
}
101+
102+
public ToolCalls.Builder function(Function function) {
103+
this.function = function;
104+
return this;
105+
}
106+
74107
public ToolCalls build() {
75108
return new ToolCalls(this);
76109
}

src/test/java/dev/ai4j/openai4j/chat/ChatCompletionAsyncTest.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import org.junit.jupiter.params.provider.Arguments;
99
import org.junit.jupiter.params.provider.MethodSource;
1010

11+
import java.net.InetSocketAddress;
12+
import java.net.Proxy;
1113
import java.util.Map;
1214
import java.util.concurrent.CompletableFuture;
1315
import java.util.stream.Stream;
@@ -16,6 +18,7 @@
1618
import static dev.ai4j.openai4j.chat.JsonSchemaProperty.*;
1719
import static dev.ai4j.openai4j.chat.Message.userMessage;
1820
import static dev.ai4j.openai4j.chat.Role.ASSISTANT;
21+
import static java.net.Proxy.Type.HTTP;
1922
import static java.util.Collections.singletonList;
2023
import static java.util.concurrent.TimeUnit.SECONDS;
2124
import static org.assertj.core.api.Assertions.assertThat;
@@ -26,6 +29,7 @@ class ChatCompletionAsyncTest extends RateLimitAwareTest {
2629

2730
private final OpenAiClient client = OpenAiClient.builder()
2831
.openAiApiKey(System.getenv("OPENAI_API_KEY"))
32+
.proxy(new Proxy(HTTP, new InetSocketAddress("127.0.0.1",7890)))
2933
.logRequests()
3034
.logResponses()
3135
.build();
@@ -120,11 +124,13 @@ void testFunctions() throws Exception {
120124
assertThat(assistantMessage.role()).isEqualTo(ASSISTANT);
121125
assertThat(assistantMessage.content()).isNull();
122126

123-
FunctionCall functionCall = assistantMessage.functionCall();
124-
assertThat(functionCall.name()).isEqualTo("get_current_weather");
125-
assertThat(functionCall.arguments()).isNotBlank();
127+
Function function = assistantMessage.toolCalls().get(0).function();
126128

127-
Map<String, Object> arguments = FunctionCallUtil.argumentsAsMap(functionCall.arguments());
129+
// FunctionCall functionCall = assistantMessage.functionCall();
130+
assertThat(function.name()).isEqualTo("get_current_weather");
131+
assertThat(function.arguments()).isNotBlank();
132+
133+
Map<String, Object> arguments = FunctionCallUtil.argumentsAsMap(function.arguments());
128134
assertThat(arguments).hasSize(1);
129135
assertThat(arguments.get("location").toString()).contains("Boston");
130136
}

src/test/java/dev/ai4j/openai4j/chat/ChatCompletionStreamingTest.java

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,19 @@ void testSimpleApi() throws Exception {
5050

5151

5252
client.chatCompletion(request)
53-
.onPartialResponse(responseBuilder::append)
53+
.onPartialResponse(partialResponse -> {
54+
String content = partialResponse.choices().get(0).delta().content();
55+
if (content != null) {
56+
responseBuilder.append(content);
57+
}
58+
})
5459
.onComplete(() -> future.complete(responseBuilder.toString()))
5560
.onError(future::completeExceptionally)
5661
.execute();
5762

5863

5964
String response = future.get(30, SECONDS);
60-
System.out.println("-----------------:"+response);
61-
// assertThat(response).containsIgnoringCase("hello world");
65+
assertThat(response).containsIgnoringCase("hello world");
6266
}
6367

6468
@MethodSource
@@ -130,36 +134,24 @@ void testFunctions() throws Exception {
130134
client.chatCompletion(request)
131135
.onPartialResponse(partialResponse -> {
132136
Delta delta = partialResponse.choices().get(0).delta();
133-
System.out.println("@@@@@@@@@@@@@@@"+delta.toString());
134137

135138
assertThat(delta.content()).isNull();
136139

137140
List<ToolCalls> toolCalls = delta.toolCalls();
138141

139142
if (partialResponse.choices().get(0).finishReason() == null) {
140143
toolCalls.stream().forEach(toolCall ->{
141-
System.out.println("-----------------"+toolCall.name()+"||"+toolCall.arguments());
142-
if (toolCall.name() != null) {
143-
responseBuilder.append(toolCall.name());
144-
} else if (toolCall.arguments() != null) {
145-
responseBuilder.append(toolCall.arguments());
144+
if (toolCall.function().name() != null) {
145+
responseBuilder.append(toolCall.function().name());
146+
} else if (toolCall.function().arguments() != null) {
147+
responseBuilder.append(toolCall.function().arguments());
146148
}
147149
});
148150
}
149-
150-
// FunctionCall functionCall = delta.functionCall();
151-
// if (partialResponse.choices().get(0).finishReason() == null) {
152-
// if (functionCall.name() != null) {
153-
// responseBuilder.append(functionCall.name());
154-
// } else if (functionCall.arguments() != null) {
155-
// responseBuilder.append(functionCall.arguments());
156-
// }
157-
// }
158151
})
159152
.onComplete(() -> future.complete(responseBuilder.toString()))
160153
.onError(future::completeExceptionally)
161154
.execute();
162-
163155
String response = future.get(30, SECONDS);
164156

165157
assertThat(response).contains("get_current_weather");

0 commit comments

Comments
 (0)