Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ This is an unofficial Java client library that helps to connect your Java applic
- [synchronous](https://github.com/ai-for-java/openai4j#synchronously-3)
- [asynchronous](https://github.com/ai-for-java/openai4j#asynchronously-3)
- [Functions](https://github.com/ai-for-java/openai4j/blob/main/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java)
- [Audio](https://platform.openai.com/docs/api-reference/audio)
- [Speech](https://platform.openai.com/docs/api-reference/audio/createSpeech)

## Coming soon:

Expand Down Expand Up @@ -380,6 +382,57 @@ Customizable way:
String localImage = response.data().get(0).url();
```


## Audio Generations
### Create speech

Simple way:

```
OpenAiClient client = OpenAiClient
.builder()
.openAiApiKey(System.getenv("OPENAI_API_KEY"))
.build();

SpeechRequest request = SpeechRequest
.builder()
.model(TTS_1)
.input("The quick brown fox jumped over the lazy dog.")
.voice(SpeechModel.Voice.ALLOY)
.build();

SpeechResponse response = client.speechGenerations(request).execute();

// Byte array audio speech generated
String speechData = response.data();
```

Customizable way:

```
OpenAiClient client = OpenAiClient
.builder()
.openAiApiKey(System.getenv("OPENAI_API_KEY"))
.logRequests()
.logResponses()
.withPersisting()
.build();

SpeechRequest request = SpeechRequest
.builder()
.model(TTS_1)
.input("The quick brown fox jumped over the lazy dog.")
.voice(SpeechModel.Voice.ALLOY)
.responseFormat(SpeechModel.ResponseFormat.WAV)
.speed(2)
.build();

AudioResponse response = client.speechGenerations(request).execute();

// your generated audio speech is here locally:
String speechUrl = response.url();
```

# Useful materials

- How to get best results form AI: https://www.deeplearning.ai/short-courses/chatgpt-prompt-engineering-for-developers/
Expand Down
48 changes: 48 additions & 0 deletions src/main/java/dev/ai4j/openai4j/BytesConverterFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package dev.ai4j.openai4j;

import dev.ai4j.openai4j.audio.GenerateSpeechResponse;
import okhttp3.ResponseBody;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import retrofit2.Converter;
import retrofit2.Retrofit;

import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Type;

/**
* A converter factory to handle the conversion of Retrofit response bodies into
* ByteArrayObjectWrapper instances.
*/
public class BytesConverterFactory extends Converter.Factory {

private static final Logger logger = LoggerFactory.getLogger(BytesConverterFactory.class);

public BytesConverterFactory() {
// Constructor can be utilized for initializing if needed
}

@Override
public Converter<ResponseBody, ?> responseBodyConverter(Type type, Annotation[] annotations, Retrofit retrofit) {

logger.debug("Requesting conversion for type: {}", type.getTypeName());
if (GenerateSpeechResponse.class.equals(type)) {
return responseBody -> {
try {
logger.debug("Converting response body to GenerateSpeechResponse");
return GenerateSpeechResponse.builder()
.data(responseBody.bytes())
.build();
} catch (IOException e) {
logger.error("Failed to read bytes from response body", e);
throw new RuntimeException("Error reading response body", e);
} finally {
responseBody.close();
}
};
}
logger.debug("No converter found for type: {}", type.getTypeName());
return null;
}
}
9 changes: 9 additions & 0 deletions src/main/java/dev/ai4j/openai4j/DefaultOpenAiClient.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.ai4j.openai4j;

import dev.ai4j.openai4j.audio.GenerateSpeechRequest;
import dev.ai4j.openai4j.audio.GenerateSpeechResponse;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.completion.CompletionRequest;
Expand Down Expand Up @@ -96,6 +98,8 @@ private DefaultOpenAiClient(Builder serviceBuilder) {
retrofitBuilder.addConverterFactory(new PersistorConverterFactory(serviceBuilder.persistTo));
}

retrofitBuilder.addConverterFactory(new BytesConverterFactory());

retrofitBuilder.addConverterFactory(GsonConverterFactory.create(GSON));

this.openAiApi = retrofitBuilder.build().create(OpenAiApi.class);
Expand Down Expand Up @@ -224,6 +228,11 @@ public SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesReques
return new RequestExecutor<>(openAiApi.imagesGenerations(request, apiVersion), r -> r);
}

@Override
public SyncOrAsync<GenerateSpeechResponse> speechGeneration(GenerateSpeechRequest request) {
return new RequestExecutor<>(openAiApi.speechGenerations(request, apiVersion), r -> r);
}

private String formatUrl(String endpoint) {
return baseUrl + endpoint + apiVersionQueryParam();
}
Expand Down
7 changes: 5 additions & 2 deletions src/main/java/dev/ai4j/openai4j/FilePersistor.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ static Path persistFromUri(URI uri, Path destinationFolder) {

public static Path persistFromBase64String(String base64EncodedString, Path destinationFolder) throws IOException {
byte[] decodedBytes = Base64.getDecoder().decode(base64EncodedString);
Path destinationFile = destinationFolder.resolve(randomFileName());
return persistFromByteArray(decodedBytes, destinationFolder);
}

Files.write(destinationFile, decodedBytes, StandardOpenOption.CREATE);
public static Path persistFromByteArray(byte[] bytes, Path destinationFolder) throws IOException {
Path destinationFile = destinationFolder.resolve(randomFileName());
Files.write(destinationFile, bytes, StandardOpenOption.CREATE);

return destinationFile;
}
Expand Down
9 changes: 9 additions & 0 deletions src/main/java/dev/ai4j/openai4j/OpenAiApi.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.ai4j.openai4j;

import dev.ai4j.openai4j.audio.GenerateSpeechRequest;
import dev.ai4j.openai4j.audio.GenerateSpeechResponse;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.completion.CompletionRequest;
Expand Down Expand Up @@ -42,4 +44,11 @@ Call<GenerateImagesResponse> imagesGenerations(
@Body GenerateImagesRequest request,
@Query("api-version") String apiVersion
);

@POST("audio/speech")
@Headers({ "Content-Type: application/json" })
Call<GenerateSpeechResponse> speechGenerations(
@Body GenerateSpeechRequest request,
@Query("api-version") String apiVersion
);
}
4 changes: 4 additions & 0 deletions src/main/java/dev/ai4j/openai4j/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import java.time.Duration;
import java.util.List;

import dev.ai4j.openai4j.audio.GenerateSpeechRequest;
import dev.ai4j.openai4j.audio.GenerateSpeechResponse;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.completion.CompletionRequest;
Expand Down Expand Up @@ -44,6 +46,8 @@ public abstract class OpenAiClient {

public abstract SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request);

public abstract SyncOrAsync<GenerateSpeechResponse> speechGeneration(GenerateSpeechRequest request);

public abstract void shutdown();

@SuppressWarnings("rawtypes")
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/dev/ai4j/openai4j/PersistorConverterFactory.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.ai4j.openai4j;

import dev.ai4j.openai4j.audio.GenerateSpeechResponse;
import dev.ai4j.openai4j.image.GenerateImagesResponse;
import java.io.IOException;
import java.lang.annotation.Annotation;
Expand Down Expand Up @@ -49,6 +50,17 @@ public T convert(ResponseBody value) throws IOException {
});
}

if (response instanceof GenerateSpeechResponse) {
try {
GenerateSpeechResponse generateSpeechResponse = (GenerateSpeechResponse) response;
generateSpeechResponse.url(
FilePersistor.persistFromByteArray(generateSpeechResponse.data(), persistTo).toUri()
);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

return response;
}
}
Expand Down
109 changes: 109 additions & 0 deletions src/main/java/dev/ai4j/openai4j/audio/GenerateSpeechRequest.java
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about the builder approach. Why not use lombok?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding another dependency? It's not much that would be saved here in terms of lines of code or complexity. I'd keep as is. @langchain4j @LizeRaes I'd love to use it here: langchain4j/langchain4j#255

Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package dev.ai4j.openai4j.audio;

import java.util.Objects;

/**
* This class represents a request to generate audio speech using specific parameters.
*/
public class GenerateSpeechRequest {
private final String model;
private final String input;
private final String voice;
private final String responseFormat;
private final double speed;

private GenerateSpeechRequest(Builder builder) {
this.model = Objects.requireNonNull(builder.model, "Model cannot be null");
this.input = Objects.requireNonNull(builder.input, "Input cannot be null");
this.voice = Objects.requireNonNull(builder.voice, "Voice cannot be null");
this.responseFormat = builder.responseFormat;
this.speed = builder.speed;
}

// Implementing equals method to ensure correct behavior in collections and other use cases.
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
GenerateSpeechRequest that = (GenerateSpeechRequest) o;
return Double.compare(that.speed, speed) == 0 &&
Objects.equals(model, that.model) &&
Objects.equals(input, that.input) &&
Objects.equals(voice, that.voice) &&
Objects.equals(responseFormat, that.responseFormat);
}

@Override
public int hashCode() {
return Objects.hash(model, input, voice, responseFormat, speed);
}

@Override
public String toString() {
return String.format(
"GenerateAudioRequest{model='%s', input='%s', voice='%s', responseFormat='%s', speed=%.1f}",
model, input, voice, responseFormat, speed
);
}

// Static factory method for the builder, improving code readability.
public static Builder builder() {
return new Builder();
}

// Builder class for GenerateAudioRequest.
public static class Builder {
private String model;
private String input;
private String voice;
private String responseFormat = SpeechModel.ResponseFormat.MP3.toString(); // Default response format
private double speed = 1.0; // Default speed

public Builder model(String model) {
this.model = model;
return this;
}

public Builder model(SpeechModel model) {
this.model = model.toString();
return this;
}

public Builder input(String input) {
this.input = input;
return this;
}

public Builder voice(String voice) {
this.voice = voice;
return this;
}

public Builder voice(SpeechModel.Voice voice) {
this.voice = voice.toString();
return this;
}

public Builder responseFormat(String responseFormat) {
this.responseFormat = responseFormat;
return this;
}

public Builder responseFormat(SpeechModel.ResponseFormat responseFormat) {
this.responseFormat = responseFormat.toString();
return this;
}

public Builder speed(double speed) {
if (speed <= 0) {
throw new IllegalArgumentException("Speed must be positive");
}
this.speed = speed;
return this;
}

public GenerateSpeechRequest build() {
return new GenerateSpeechRequest(this);
}
}
}
Loading