Skip to content

Commit 92773f1

Browse files
tzolovmarkpollack
authored andcommitted
Add embedding model dimensions retrieval
- Add dimension method to the EmbeddingClient interface. Default dimension implementation uses the embed method to produce results and counts the result dimensions. - Add EmbeddingUtil#dimensions utilities that look up the model dimensions from a pre-defined (static) file. If the requested model is unknown, fallback to the default behaviour. - Override the dimensions method in the OpenAiEmbeddingClient and AzureOpenAiEmbeddingClient to implement local caching. - Add unit and IT tests. Resolves #28
1 parent 141f503 commit 92773f1

File tree

7 files changed

+213
-8
lines changed

7 files changed

+213
-8
lines changed

spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/embedding/AzureOpenAiEmbeddingClient.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
import org.springframework.ai.embedding.Embedding;
1212
import org.springframework.ai.embedding.EmbeddingClient;
1313
import org.springframework.ai.embedding.EmbeddingResponse;
14+
import org.springframework.ai.embedding.EmbeddingUtil;
1415
import org.springframework.util.Assert;
1516

1617
import java.util.ArrayList;
1718
import java.util.HashMap;
1819
import java.util.List;
1920
import java.util.Map;
21+
import java.util.concurrent.atomic.AtomicInteger;
2022
import java.util.stream.Collectors;
2123

2224
public class AzureOpenAiEmbeddingClient implements EmbeddingClient {
@@ -27,6 +29,8 @@ public class AzureOpenAiEmbeddingClient implements EmbeddingClient {
2729

2830
private final String model;
2931

32+
private final AtomicInteger embeddingDimensions = new AtomicInteger(-1);
33+
3034
public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient) {
3135
this(azureOpenAiClient, "text-embedding-ada-002");
3236
}
@@ -89,8 +93,6 @@ private Map<String, Object> generateMetadata(String model, EmbeddingsUsage embed
8993
Map<String, Object> metadata = new HashMap<>();
9094
metadata.put("model", model);
9195
metadata.put("prompt-tokens", embeddingsUsage.getPromptTokens());
92-
// NOTE, not in API of AzureAI - metadata.put("completion-tokens",
93-
// embeddingsUsage.getCompletionTokens());
9496
metadata.put("total-tokens", embeddingsUsage.getTotalTokens());
9597
return metadata;
9698
}
@@ -106,4 +108,12 @@ private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
106108
return data;
107109
}
108110

111+
@Override
112+
public int dimensions() {
113+
if (this.embeddingDimensions.get() < 0) {
114+
this.embeddingDimensions.set(EmbeddingUtil.dimensions(this, this.model));
115+
}
116+
return this.embeddingDimensions.get();
117+
}
118+
109119
}

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,8 @@ public interface EmbeddingClient {
1414

1515
EmbeddingResponse embedForResponse(List<String> texts);
1616

17+
default int dimensions() {
18+
return embed("Test String").size();
19+
}
20+
1721
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.embedding;
18+
19+
import java.io.IOException;
20+
import java.util.Map;
21+
import java.util.Properties;
22+
import java.util.stream.Collectors;
23+
24+
import org.springframework.core.io.DefaultResourceLoader;
25+
26+
/**
27+
* @author Christian Tzolov
28+
*/
29+
public class EmbeddingUtil {
30+
31+
private static Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions();
32+
33+
/**
34+
* Return the dimension of the requested embedding model name. If the model name is
35+
* unknown uses the EmbeddingClient to perform a dummy EmbeddingClient#embed and count
36+
* the response dimensions.
37+
* @param embeddingClient Fall-back client to determine, empirically the dimensions.
38+
* @param modelName Embedding model name to retrieve the dimensions for.
39+
* @return Returns the embedding dimensions for the modelName.
40+
*/
41+
public static int dimensions(EmbeddingClient embeddingClient, String modelName) {
42+
43+
if (KNOWN_EMBEDDING_DIMENSIONS.containsKey(modelName)) {
44+
// Retrieve the dimension from a pre-configured file.
45+
return KNOWN_EMBEDDING_DIMENSIONS.get(modelName);
46+
}
47+
else {
48+
// Determine the dimensions empirically.
49+
// Generate an embedding and count the dimension size;
50+
return embeddingClient.embed("Test String").size();
51+
}
52+
}
53+
54+
private static Map<String, Integer> loadKnownModelDimensions() {
55+
try {
56+
Properties properties = new Properties();
57+
properties.load(new DefaultResourceLoader()
58+
.getResource("classpath:/embedding/embedding-model-dimensions.properties")
59+
.getInputStream());
60+
return properties.entrySet()
61+
.stream()
62+
.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));
63+
}
64+
catch (IOException e) {
65+
throw new RuntimeException(e);
66+
}
67+
}
68+
69+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Map of embedding model names and their dimesions
2+
# OpenAI
3+
text-embedding-ada-002=1536
4+
text-similarity-ada-001=1024
5+
text-similarity-babbage-001=2048
6+
text-similarity-curie-001=4096
7+
text-similarity-davinci-001=12288
8+
text-search-ada-doc-001=1024
9+
text-search-ada-query-001=1024
10+
text-search-babbage-doc-001=2048
11+
text-search-babbage-query-001=2048
12+
text-search-curie-doc-001=4096
13+
text-search-curie-query-001=4096
14+
text-search-davinci-doc-001=12288
15+
text-search-davinci-query-001=12288
16+
code-search-ada-code-001=1024
17+
code-search-ada-text-001=1024
18+
code-search-babbage-code-001=2048
19+
code-search-babbage-text-001=2048
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.embedding;
18+
19+
import java.util.List;
20+
21+
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.extension.ExtendWith;
23+
import org.junit.jupiter.params.ParameterizedTest;
24+
import org.junit.jupiter.params.provider.CsvFileSource;
25+
import org.mockito.Mock;
26+
import org.mockito.junit.jupiter.MockitoExtension;
27+
28+
import org.springframework.ai.document.Document;
29+
30+
import static org.assertj.core.api.Assertions.assertThat;
31+
import static org.mockito.ArgumentMatchers.any;
32+
import static org.mockito.ArgumentMatchers.eq;
33+
import static org.mockito.Mockito.never;
34+
import static org.mockito.Mockito.verify;
35+
import static org.mockito.Mockito.when;
36+
37+
/**
38+
* @author Christian Tzolov
39+
*/
40+
@ExtendWith(MockitoExtension.class)
41+
public class EmbeddingUtilTest {
42+
43+
@Mock
44+
private EmbeddingClient embeddingClient;
45+
46+
@Test
47+
public void testDefaultMethodImplementation() {
48+
49+
EmbeddingClient dummy = new EmbeddingClient() {
50+
51+
@Override
52+
public List<Double> embed(String text) {
53+
return List.of(0.1, 0.1, 0.1);
54+
}
55+
56+
@Override
57+
public List<Double> embed(Document document) {
58+
throw new UnsupportedOperationException("Unimplemented method 'embed'");
59+
}
60+
61+
@Override
62+
public List<List<Double>> embed(List<String> texts) {
63+
throw new UnsupportedOperationException("Unimplemented method 'embed'");
64+
}
65+
66+
@Override
67+
public EmbeddingResponse embedForResponse(List<String> texts) {
68+
throw new UnsupportedOperationException("Unimplemented method 'embedForResponse'");
69+
}
70+
};
71+
72+
assertThat(dummy.dimensions()).isEqualTo(3);
73+
}
74+
75+
@ParameterizedTest
76+
@CsvFileSource(resources = "/embedding/embedding-model-dimensions.properties", numLinesToSkip = 1, delimiter = '=')
77+
public void testKnownEmbeddingModelDimensions(String model, String dimension) {
78+
assertThat(EmbeddingUtil.dimensions(embeddingClient, model)).isEqualTo(Integer.valueOf(dimension));
79+
verify(embeddingClient, never()).embed(any(String.class));
80+
verify(embeddingClient, never()).embed(any(Document.class));
81+
}
82+
83+
@Test
84+
public void testUnknownModelDimension() {
85+
when(embeddingClient.embed(eq("Test String"))).thenReturn(List.of(0.1, 0.1, 0.1));
86+
assertThat(EmbeddingUtil.dimensions(embeddingClient, "unknown_model")).isEqualTo(3);
87+
}
88+
89+
}

spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingClient.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
package org.springframework.ai.openai.embedding;
22

3+
import java.util.ArrayList;
4+
import java.util.HashMap;
5+
import java.util.List;
6+
import java.util.Map;
7+
import java.util.concurrent.atomic.AtomicInteger;
8+
import java.util.stream.Collectors;
9+
310
import com.theokanning.openai.Usage;
411
import com.theokanning.openai.embedding.EmbeddingRequest;
512
import com.theokanning.openai.service.OpenAiService;
613
import org.slf4j.Logger;
714
import org.slf4j.LoggerFactory;
15+
816
import org.springframework.ai.document.Document;
917
import org.springframework.ai.embedding.Embedding;
1018
import org.springframework.ai.embedding.EmbeddingClient;
1119
import org.springframework.ai.embedding.EmbeddingResponse;
20+
import org.springframework.ai.embedding.EmbeddingUtil;
1221
import org.springframework.util.Assert;
1322

14-
import java.util.ArrayList;
15-
import java.util.HashMap;
16-
import java.util.List;
17-
import java.util.Map;
18-
import java.util.stream.Collectors;
19-
2023
public class OpenAiEmbeddingClient implements EmbeddingClient {
2124

2225
private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingClient.class);
@@ -25,6 +28,8 @@ public class OpenAiEmbeddingClient implements EmbeddingClient {
2528

2629
private final String model;
2730

31+
private final AtomicInteger embeddingDimensions = new AtomicInteger(-1);
32+
2833
public OpenAiEmbeddingClient(OpenAiService openAiService) {
2934
this(openAiService, "text-embedding-ada-002");
3035
}
@@ -95,4 +100,12 @@ private Map<String, Object> generateMetadata(String model, Usage usage) {
95100
return metadata;
96101
}
97102

103+
@Override
104+
public int dimensions() {
105+
if (this.embeddingDimensions.get() < 0) {
106+
this.embeddingDimensions.set(EmbeddingUtil.dimensions(this, this.model));
107+
}
108+
return this.embeddingDimensions.get();
109+
}
110+
98111
}

spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ void simpleEmbedding() {
2828
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 2L);
2929
assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 2L);
3030

31+
assertThat(embeddingClient.dimensions()).isEqualTo(1536);
3132
}
3233

3334
}

0 commit comments

Comments
 (0)