22
33import static dev .ai4j .openai4j .Json .GSON ;
44
5- import java .io .IOException ;
6- import java .util .Collections ;
7- import java .util .List ;
8-
95import org .slf4j .Logger ;
106import org .slf4j .LoggerFactory ;
117
1511import dev .ai4j .openai4j .completion .CompletionResponse ;
1612import dev .ai4j .openai4j .embedding .EmbeddingRequest ;
1713import dev .ai4j .openai4j .embedding .EmbeddingResponse ;
14+ import dev .ai4j .openai4j .image .GenerateImagesRequest ;
15+ import dev .ai4j .openai4j .image .GenerateImagesResponse ;
1816import dev .ai4j .openai4j .moderation .ModerationRequest ;
1917import dev .ai4j .openai4j .moderation .ModerationResponse ;
2018import dev .ai4j .openai4j .moderation .ModerationResult ;
19+ import java .io .IOException ;
20+ import java .util .Collections ;
21+ import java .util .List ;
2122import okhttp3 .Cache ;
2223import okhttp3 .OkHttpClient ;
2324import retrofit2 .Retrofit ;
@@ -38,15 +39,14 @@ public DefaultOpenAiClient(String apiKey) {
3839 }
3940
4041 private DefaultOpenAiClient (Builder serviceBuilder ) {
41-
4242 this .baseUrl = serviceBuilder .baseUrl ;
4343 this .apiVersion = serviceBuilder .apiVersion ;
4444
4545 OkHttpClient .Builder okHttpClientBuilder = new OkHttpClient .Builder ()
46- .callTimeout (serviceBuilder .callTimeout )
47- .connectTimeout (serviceBuilder .connectTimeout )
48- .readTimeout (serviceBuilder .readTimeout )
49- .writeTimeout (serviceBuilder .writeTimeout );
46+ .callTimeout (serviceBuilder .callTimeout )
47+ .connectTimeout (serviceBuilder .connectTimeout )
48+ .readTimeout (serviceBuilder .readTimeout )
49+ .writeTimeout (serviceBuilder .writeTimeout );
5050
5151 if (serviceBuilder .openAiApiKey == null && serviceBuilder .azureApiKey == null ) {
5252 throw new IllegalArgumentException ("openAiApiKey OR azureApiKey must be defined" );
@@ -78,17 +78,18 @@ private DefaultOpenAiClient(Builder serviceBuilder) {
7878
7979 this .okHttpClient = okHttpClientBuilder .build ();
8080
81- Retrofit retrofit = new Retrofit .Builder ()
82- .baseUrl (serviceBuilder .baseUrl )
83- .client (okHttpClient )
84- .addConverterFactory (GsonConverterFactory .create (GSON ))
85- .build ();
81+ Retrofit .Builder retrofitBuilder = new Retrofit .Builder ().baseUrl (serviceBuilder .baseUrl ).client (okHttpClient );
82+
83+ if (serviceBuilder .persistTo != null ) {
84+ retrofitBuilder .addConverterFactory (new PersistorConverterFactory (serviceBuilder .persistTo ));
85+ }
86+
87+ retrofitBuilder .addConverterFactory (GsonConverterFactory .create (GSON ));
8688
87- this .openAiApi = retrofit .create (OpenAiApi .class );
89+ this .openAiApi = retrofitBuilder . build () .create (OpenAiApi .class );
8890 }
8991
9092 public void shutdown () {
91-
9293 okHttpClient .dispatcher ().executorService ().shutdown ();
9394
9495 okHttpClient .connectionPool ().evictAll ();
@@ -116,122 +117,99 @@ public DefaultOpenAiClient build() {
116117
117118 @ Override
118119 public SyncOrAsyncOrStreaming <CompletionResponse > completion (CompletionRequest request ) {
119-
120- CompletionRequest syncRequest = CompletionRequest .builder ()
121- .from (request )
122- .stream (null )
123- .build ();
120+ CompletionRequest syncRequest = CompletionRequest .builder ().from (request ).stream (null ).build ();
124121
125122 return new RequestExecutor <>(
126- openAiApi .completions (syncRequest , apiVersion ),
127- ( r ) -> r ,
128- okHttpClient ,
129- formatUrl ("completions" ),
130- () -> CompletionRequest .builder ().from (request ).stream (true ).build (),
131- CompletionResponse .class ,
132- ( r ) -> r ,
133- logStreamingResponses
123+ openAiApi .completions (syncRequest , apiVersion ),
124+ r -> r ,
125+ okHttpClient ,
126+ formatUrl ("completions" ),
127+ () -> CompletionRequest .builder ().from (request ).stream (true ).build (),
128+ CompletionResponse .class ,
129+ r -> r ,
130+ logStreamingResponses
134131 );
135132 }
136133
137134 @ Override
138135 public SyncOrAsyncOrStreaming <String > completion (String prompt ) {
136+ CompletionRequest request = CompletionRequest .builder ().prompt (prompt ).build ();
139137
140- CompletionRequest request = CompletionRequest .builder ()
141- .prompt (prompt )
142- .build ();
143-
144- CompletionRequest syncRequest = CompletionRequest .builder ()
145- .from (request )
146- .stream (null )
147- .build ();
138+ CompletionRequest syncRequest = CompletionRequest .builder ().from (request ).stream (null ).build ();
148139
149140 return new RequestExecutor <>(
150- openAiApi .completions (syncRequest , apiVersion ),
151- CompletionResponse ::text ,
152- okHttpClient ,
153- formatUrl ("completions" ),
154- () -> CompletionRequest .builder ().from (request ).stream (true ).build (),
155- CompletionResponse .class ,
156- CompletionResponse ::text ,
157- logStreamingResponses
141+ openAiApi .completions (syncRequest , apiVersion ),
142+ CompletionResponse ::text ,
143+ okHttpClient ,
144+ formatUrl ("completions" ),
145+ () -> CompletionRequest .builder ().from (request ).stream (true ).build (),
146+ CompletionResponse .class ,
147+ CompletionResponse ::text ,
148+ logStreamingResponses
158149 );
159150 }
160151
161152 @ Override
162153 public SyncOrAsyncOrStreaming <ChatCompletionResponse > chatCompletion (ChatCompletionRequest request ) {
163-
164- ChatCompletionRequest syncRequest = ChatCompletionRequest .builder ()
165- .from (request )
166- .stream (null )
167- .build ();
154+ ChatCompletionRequest syncRequest = ChatCompletionRequest .builder ().from (request ).stream (null ).build ();
168155
169156 return new RequestExecutor <>(
170- openAiApi .chatCompletions (syncRequest , apiVersion ),
171- ( r ) -> r ,
172- okHttpClient ,
173- formatUrl ("chat/completions" ),
174- () -> ChatCompletionRequest .builder ().from (request ).stream (true ).build (),
175- ChatCompletionResponse .class ,
176- ( r ) -> r ,
177- logStreamingResponses
157+ openAiApi .chatCompletions (syncRequest , apiVersion ),
158+ r -> r ,
159+ okHttpClient ,
160+ formatUrl ("chat/completions" ),
161+ () -> ChatCompletionRequest .builder ().from (request ).stream (true ).build (),
162+ ChatCompletionResponse .class ,
163+ r -> r ,
164+ logStreamingResponses
178165 );
179166 }
180167
181168 @ Override
182169 public SyncOrAsyncOrStreaming <String > chatCompletion (String userMessage ) {
170+ ChatCompletionRequest request = ChatCompletionRequest .builder ().addUserMessage (userMessage ).build ();
183171
184- ChatCompletionRequest request = ChatCompletionRequest .builder ()
185- .addUserMessage (userMessage )
186- .build ();
187-
188- ChatCompletionRequest syncRequest = ChatCompletionRequest .builder ()
189- .from (request )
190- .stream (null )
191- .build ();
172+ ChatCompletionRequest syncRequest = ChatCompletionRequest .builder ().from (request ).stream (null ).build ();
192173
193174 return new RequestExecutor <>(
194- openAiApi .chatCompletions (syncRequest , apiVersion ),
195- ChatCompletionResponse ::content ,
196- okHttpClient ,
197- formatUrl ("chat/completions" ),
198- () -> ChatCompletionRequest .builder ().from (request ).stream (true ).build (),
199- ChatCompletionResponse .class ,
200- ( r ) -> r .choices ().get (0 ).delta ().content (),
201- logStreamingResponses
175+ openAiApi .chatCompletions (syncRequest , apiVersion ),
176+ ChatCompletionResponse ::content ,
177+ okHttpClient ,
178+ formatUrl ("chat/completions" ),
179+ () -> ChatCompletionRequest .builder ().from (request ).stream (true ).build (),
180+ ChatCompletionResponse .class ,
181+ r -> r .choices ().get (0 ).delta ().content (),
182+ logStreamingResponses
202183 );
203184 }
204185
205186 @ Override
206187 public SyncOrAsync <EmbeddingResponse > embedding (EmbeddingRequest request ) {
207-
208- return new RequestExecutor <>(openAiApi .embeddings (request , apiVersion ), (r ) -> r );
188+ return new RequestExecutor <>(openAiApi .embeddings (request , apiVersion ), r -> r );
209189 }
210190
211191 @ Override
212192 public SyncOrAsync <List <Float >> embedding (String input ) {
213-
214- EmbeddingRequest request = EmbeddingRequest .builder ()
215- .input (input )
216- .build ();
193+ EmbeddingRequest request = EmbeddingRequest .builder ().input (input ).build ();
217194
218195 return new RequestExecutor <>(openAiApi .embeddings (request , apiVersion ), EmbeddingResponse ::embedding );
219196 }
220197
221198 @ Override
222199 public SyncOrAsync <ModerationResponse > moderation (ModerationRequest request ) {
223-
224- return new RequestExecutor <>(openAiApi .moderations (request , apiVersion ), (r ) -> r );
200+ return new RequestExecutor <>(openAiApi .moderations (request , apiVersion ), r -> r );
225201 }
226202
227203 @ Override
228204 public SyncOrAsync <ModerationResult > moderation (String input ) {
205+ ModerationRequest request = ModerationRequest .builder ().input (input ).build ();
229206
230- ModerationRequest request = ModerationRequest .builder ()
231- .input (input )
232- .build ();
207+ return new RequestExecutor <>(openAiApi .moderations (request , apiVersion ), r -> r .results ().get (0 ));
208+ }
233209
234- return new RequestExecutor <>(openAiApi .moderations (request , apiVersion ), (r ) -> r .results ().get (0 ));
210+ @ Override
211+ public SyncOrAsync <GenerateImagesResponse > imagesGeneration (GenerateImagesRequest request ) {
212+ return new RequestExecutor <>(openAiApi .imagesGenerations (request , apiVersion ), r -> r );
235213 }
236214
237215 private String formatUrl (String endpoint ) {
0 commit comments