Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/main/java/dev/ai4j/openai4j/chat/Content.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ public final class Content {
private final String text;
@JsonProperty
private final ImageUrl imageUrl;
@JsonProperty
private final InputAudio inputAudio;

public Content(Builder builder) {
this.type = builder.type;
this.text = builder.text;
this.imageUrl = builder.imageUrl;
this.inputAudio = builder.inputAudio;
}

public ContentType type() {
Expand All @@ -40,6 +43,10 @@ public ImageUrl imageUrl() {
return imageUrl;
}

public InputAudio inputAudio() {
return inputAudio;
}

@Override
public boolean equals(Object another) {
if (this == another) return true;
Expand All @@ -50,7 +57,8 @@ public boolean equals(Object another) {
private boolean equalTo(Content another) {
return Objects.equals(type, another.type)
&& Objects.equals(text, another.text)
&& Objects.equals(imageUrl, another.imageUrl);
&& Objects.equals(imageUrl, another.imageUrl)
&& Objects.equals(inputAudio, another.inputAudio);
}

@Override
Expand All @@ -59,6 +67,7 @@ public int hashCode() {
h += (h << 5) + Objects.hashCode(type);
h += (h << 5) + Objects.hashCode(text);
h += (h << 5) + Objects.hashCode(imageUrl);
h += (h << 5) + Objects.hashCode(inputAudio);
return h;
}

Expand All @@ -68,6 +77,7 @@ public String toString() {
"type=" + type +
", text=" + text +
", imageUrl=" + imageUrl +
", inputAudio=" + inputAudio +
"}";
}

Expand All @@ -83,6 +93,7 @@ public static final class Builder {
private ContentType type;
private String text;
private ImageUrl imageUrl;
private InputAudio inputAudio;

public Builder type(ContentType type) {
this.type = type;
Expand All @@ -99,6 +110,11 @@ public Builder imageUrl(ImageUrl imageUrl) {
return this;
}

public Builder inputAudio(InputAudio inputAudio) {
this.inputAudio = inputAudio;
return this;
}

public Content build() {
return new Content(this);
}
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/dev/ai4j/openai4j/chat/ContentType.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ public enum ContentType {
@JsonProperty("text")
TEXT,
@JsonProperty("image_url")
IMAGE_URL
IMAGE_URL,
@JsonProperty("input_audio")
AUDIO
}
87 changes: 87 additions & 0 deletions src/main/java/dev/ai4j/openai4j/chat/InputAudio.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package dev.ai4j.openai4j.chat;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonNaming;

import java.util.Objects;

@JsonDeserialize(builder = InputAudio.Builder.class)
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class InputAudio {

private final String data;
private final String format;

private InputAudio(Builder builder) {
data = builder.data;
format = builder.format;
}

public String getData() {
return data;
}

public String getFormat() {
return format;
}

@Override
public boolean equals(Object another) {
if (this == another) return true;
return another instanceof InputAudio
&& equalTo((InputAudio) another);
}

private boolean equalTo(InputAudio another) {
return Objects.equals(data, another.data)
&& Objects.equals(format, another.format);
}

@Override
public int hashCode() {
int h = 5381;
h += (h << 5) + Objects.hashCode(data);
h += (h << 5) + Objects.hashCode(format);
return h;
}

@Override
public String toString() {
return "InputAudio{" +
"data=" + data +
", format=" + format +
"}";
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {

private String data;
private String format;

public Builder data(String data) {
this.data = data;
return this;
}

public Builder format(String format) {
this.format = format;
return this;
}

public static Builder builder() {
return new Builder();
}

public InputAudio build() {
return new InputAudio(this);
}

}
}
26 changes: 20 additions & 6 deletions src/main/java/dev/ai4j/openai4j/chat/UserMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ private Builder() {
}

public Builder addText(String text) {
if (this.content == null) {
this.content = new ArrayList<>();
}
initializeContent();
Content content = Content.builder()
.type(TEXT)
.text(text)
Expand All @@ -123,9 +121,7 @@ public Builder addImageUrl(String imageUrl) {
}

public Builder addImageUrl(String imageUrl, ImageDetail imageDetail) {
if (this.content == null) {
this.content = new ArrayList<>();
}
initializeContent();
Content content = Content.builder()
.type(IMAGE_URL)
.imageUrl(ImageUrl.builder()
Expand All @@ -143,6 +139,18 @@ public Builder addImageUrls(String... imageUrls) {
}
return this;
}

public Builder addInputAudio(InputAudio inputAudio) {
initializeContent();
this.content.add(
Content.builder()
.type(ContentType.AUDIO)
.inputAudio(inputAudio)
.build()
);

return this;
}

public Builder content(List<Content> content) {
if (content != null) {
Expand All @@ -164,5 +172,11 @@ public Builder name(String name) {
public UserMessage build() {
return new UserMessage(this);
}

private void initializeContent() {
if (this.content == null) {
this.content = new ArrayList<>();
}
}
}
}
37 changes: 37 additions & 0 deletions src/test/java/dev/ai4j/openai4j/chat/ChatCompletionAsyncTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

Expand Down Expand Up @@ -479,4 +482,38 @@ void testGpt4Vision() throws Exception {
// then
assertThat(response.content()).containsIgnoringCase("cat");
}

@Test
void testGpt4Audio() throws Exception {

// given
URL resource = getClass().getClassLoader().getResource("sample.b64");
final byte[] bytes = Files.readAllBytes(Paths.get(resource.toURI()));

ChatCompletionRequest request = ChatCompletionRequest.builder()
.model("gpt-4o-audio-preview")
.messages(UserMessage.builder()
.addText("What is on the audio?")
.addInputAudio(InputAudio.builder()
.format("wav")
.data(new String(bytes))
.build())
.build())
.maxCompletionTokens(100)
.temperature(0.0)
.build();

CompletableFuture<ChatCompletionResponse> future = new CompletableFuture<>();

// when
client.chatCompletion(request)
.onResponse(future::complete)
.onError(future::completeExceptionally)
.execute();

ChatCompletionResponse response = future.get(30, SECONDS);

// then
assertThat(response.content()).containsIgnoringCase("hello");
}
}
1 change: 1 addition & 0 deletions src/test/resources/sample.b64

Large diffs are not rendered by default.