2626import java .nio .file .FileAlreadyExistsException ;
2727import java .nio .file .Files ;
2828import java .util .Comparator ;
29- import java .util .HashMap ;
3029import java .util .List ;
3130import java .util .Map ;
32- import java .util .Objects ;
3331import java .util .Optional ;
3432import java .util .concurrent .ConcurrentHashMap ;
3533
3634import com .fasterxml .jackson .core .JsonProcessingException ;
3735import com .fasterxml .jackson .core .type .TypeReference ;
3836import com .fasterxml .jackson .databind .ObjectMapper ;
39- import com .fasterxml .jackson .databind .ObjectWriter ;
4037import com .fasterxml .jackson .databind .json .JsonMapper ;
4138import io .micrometer .observation .ObservationRegistry ;
4239import org .slf4j .Logger ;
5047import org .springframework .ai .vectorstore .observation .AbstractObservationVectorStore ;
5148import org .springframework .ai .vectorstore .observation .VectorStoreObservationContext ;
5249import org .springframework .ai .vectorstore .observation .VectorStoreObservationConvention ;
50+ import org .springframework .core .io .FileSystemResource ;
5351import org .springframework .core .io .Resource ;
52+ import org .springframework .util .Assert ;
5453
5554/**
56- * SimpleVectorStore is a simple implementation of the VectorStore interface.
57- *
55+ * Simple, in-memory implementation of the {@link VectorStore} interface.
56+ * <p/>
5857 * It also provides methods to save the current state of the vectors to a file, and to
5958 * load vectors from a file.
60- *
59+ * <p/>
6160 * For a deeper understanding of the mathematical concepts and computations involved in
6261 * calculating similarity scores among vectors, refer to this
6362 * [resource](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_understanding_vectors).
6766 * @author Mark Pollack
6867 * @author Christian Tzolov
6968 * @author Sebastien Deleuze
69+ * @author John Blum
70+ * @see VectorStore
7071 */
7172public class SimpleVectorStore extends AbstractObservationVectorStore {
7273
@@ -87,54 +88,72 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse
8788
8889super (observationRegistry , customObservationConvention );
8990
90- Objects .requireNonNull (embeddingModel , "EmbeddingModel must not be null" );
91+ Assert .notNull (embeddingModel , "EmbeddingModel must not be null" );
92+
9193this .embeddingModel = embeddingModel ;
9294this .objectMapper = JsonMapper .builder ().addModules (JacksonUtils .instantiateAvailableModules ()).build ();
9395}
9496
9597@ Override
9698public void doAdd (List <Document > documents ) {
9799for (Document document : documents ) {
98- logger .info ("Calling EmbeddingModel for document id = {}" , document .getId ());
99- float [] embedding = this .embeddingModel .embed (document );
100- document .setEmbedding (embedding );
100+ logger .info ("Calling EmbeddingModel for Document id = {}" , document .getId ());
101+ document = embed (document );
101102this .store .put (document .getId (), document );
102103}
103104}
104105
106+ protected Document embed (Document document ) {
107+ float [] documentEmbedding = this .embeddingModel .embed (document );
108+ document .setEmbedding (documentEmbedding );
109+ return document ;
110+ }
111+
105112@ Override
106113public Optional <Boolean > doDelete (List <String > idList ) {
107- for (String id : idList ) {
108- this .store .remove (id );
109- }
114+ idList .forEach (this .store ::remove );
110115return Optional .of (true );
111116}
112117
113118@ Override
114119public List <Document > doSimilaritySearch (SearchRequest request ) {
120+
115121if (request .getFilterExpression () != null ) {
116122throw new UnsupportedOperationException (
117- "The [" + this . getClass () + " ] doesn't support metadata filtering!" );
123+ "[%s ] doesn't support metadata filtering" . formatted ( getClass (). getName ()) );
118124}
119125
120- float [] userQueryEmbedding = getUserQueryEmbedding (request .getQuery ());
121- return this .store .values ()
122- .stream ()
123- .map (entry -> new Similarity (entry .getId (),
124- EmbeddingMath .cosineSimilarity (userQueryEmbedding , entry .getEmbedding ())))
125- .filter (s -> s .score >= request .getSimilarityThreshold ())
126- .sorted (Comparator .<Similarity >comparingDouble (s -> s .score ).reversed ())
126+ // @formatter:off
127+ return this .store .values ().stream ()
128+ .map (document -> computeSimilarity (request , document ))
129+ .filter (similarity -> similarity .score >= request .getSimilarityThreshold ())
130+ .sorted (Comparator .<Similarity >comparingDouble (similarity -> similarity .score ).reversed ())
127131.limit (request .getTopK ())
128- .map (s -> this .store .get (s .key ))
132+ .map (similarity -> this .store .get (similarity .key ))
129133.toList ();
134+ // @formatter:on
135+ }
136+
137+ protected Similarity computeSimilarity (SearchRequest request , Document document ) {
138+
139+ float [] userQueryEmbedding = getUserQueryEmbedding (request );
140+ float [] documentEmbedding = document .getEmbedding ();
141+
142+ double score = computeCosineSimilarity (userQueryEmbedding , documentEmbedding );
143+
144+ return new Similarity (document .getId (), score );
145+ }
146+
147+ protected double computeCosineSimilarity (float [] userQueryEmbedding , float [] storedDocumentEmbedding ) {
148+ return EmbeddingMath .cosineSimilarity (userQueryEmbedding , storedDocumentEmbedding );
130149}
131150
132151/**
133152 * Serialize the vector store content into a file in JSON format.
134153 * @param file the file to save the vector store content
135154 */
136155public void save (File file ) {
137- String json = getVectorDbAsJson ();
156+
138157try {
139158if (!file .exists ()) {
140159logger .info ("Creating new vector store file: {}" , file );
@@ -145,28 +164,30 @@ public void save(File file) {
145164throw new RuntimeException ("File already exists: " + file , e );
146165}
147166catch (IOException e ) {
148- throw new RuntimeException ("Failed to create new file: " + file + ". Reason: " + e .getMessage (), e );
167+ throw new RuntimeException ("Failed to create new file: " + file + "; Reason: " + e .getMessage (), e );
149168}
150169}
151170else {
152171logger .info ("Overwriting existing vector store file: {}" , file );
153172}
173+
154174try (OutputStream stream = new FileOutputStream (file );
155175Writer writer = new OutputStreamWriter (stream , StandardCharsets .UTF_8 )) {
176+ String json = getVectorDbAsJson ();
156177writer .write (json );
157178writer .flush ();
158179}
159180}
160181catch (IOException ex ) {
161- logger .error ("IOException occurred while saving vector store file. " , ex );
182+ logger .error ("IOException occurred while saving vector store file" , ex );
162183throw new RuntimeException (ex );
163184}
164185catch (SecurityException ex ) {
165- logger .error ("SecurityException occurred while saving vector store file. " , ex );
186+ logger .error ("SecurityException occurred while saving vector store file" , ex );
166187throw new RuntimeException (ex );
167188}
168189catch (NullPointerException ex ) {
169- logger .error ("NullPointerException occurred while saving vector store file. " , ex );
190+ logger .error ("NullPointerException occurred while saving vector store file" , ex );
170191throw new RuntimeException (ex );
171192}
172193}
@@ -176,45 +197,40 @@ public void save(File file) {
176197 * @param file the file to load the vector store content
177198 */
178199public void load (File file ) {
179- TypeReference <HashMap <String , Document >> typeRef = new TypeReference <>() {
180-
181- };
182- try {
183- Map <String , Document > deserializedMap = this .objectMapper .readValue (file , typeRef );
184- this .store = deserializedMap ;
185- }
186- catch (IOException ex ) {
187- throw new RuntimeException (ex );
188- }
200+ load (new FileSystemResource (file ));
189201}
190202
191203/**
192204 * Deserialize the vector store content from a resource in JSON format into memory.
193205 * @param resource the resource to load the vector store content
194206 */
195207public void load (Resource resource ) {
196- TypeReference <HashMap <String , Document >> typeRef = new TypeReference <>() {
197208
198- };
199209try {
200- Map <String , Document > deserializedMap = this .objectMapper .readValue (resource .getInputStream (), typeRef );
201- this .store = deserializedMap ;
210+ this .store = this .objectMapper .readValue (resource .getInputStream (), documentMapTypeRef ());
202211}
203212catch (IOException ex ) {
204213throw new RuntimeException (ex );
205214}
206215}
207216
217+ private TypeReference <Map <String , Document >> documentMapTypeRef () {
218+ return new TypeReference <>() {
219+ };
220+ }
221+
208222private String getVectorDbAsJson () {
209- ObjectWriter objectWriter = this .objectMapper .writerWithDefaultPrettyPrinter ();
210- String json ;
223+
211224try {
212- json = objectWriter .writeValueAsString (this .store );
225+ return this . objectMapper . writerWithDefaultPrettyPrinter () .writeValueAsString (this .store );
213226}
214227catch (JsonProcessingException e ) {
215- throw new RuntimeException ("Error serializing documentMap to JSON. " , e );
228+ throw new RuntimeException ("Error serializing Map of Documents to JSON" , e );
216229}
217- return json ;
230+ }
231+
232+ private float [] getUserQueryEmbedding (SearchRequest request ) {
233+ return getUserQueryEmbedding (request .getQuery ());
218234}
219235
220236private float [] getUserQueryEmbedding (String query ) {
@@ -232,9 +248,9 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
232248
233249public static class Similarity {
234250
235- private String key ;
251+ private final String key ;
236252
237- private double score ;
253+ private final double score ;
238254
239255public Similarity (String key , double score ) {
240256this .key = key ;
@@ -243,16 +259,18 @@ public Similarity(String key, double score) {
243259
244260}
245261
246- public final class EmbeddingMath {
262+ public static final class EmbeddingMath {
247263
248264private EmbeddingMath () {
249265throw new UnsupportedOperationException ("This is a utility class and cannot be instantiated" );
250266}
251267
252268public static double cosineSimilarity (float [] vectorX , float [] vectorY ) {
269+
253270if (vectorX == null || vectorY == null ) {
254- throw new RuntimeException ("Vectors must not be null" );
271+ throw new IllegalArgumentException ("Vectors must not be null" );
255272}
273+
256274if (vectorX .length != vectorY .length ) {
257275throw new IllegalArgumentException ("Vectors lengths must be equal" );
258276}
@@ -268,20 +286,22 @@ public static double cosineSimilarity(float[] vectorX, float[] vectorY) {
268286return dotProduct / (Math .sqrt (normX ) * Math .sqrt (normY ));
269287}
270288
271- public static float dotProduct (float [] vectorX , float [] vectorY ) {
289+ private static float dotProduct (float [] vectorX , float [] vectorY ) {
290+
272291if (vectorX .length != vectorY .length ) {
273292throw new IllegalArgumentException ("Vectors lengths must be equal" );
274293}
275294
276295float result = 0 ;
277- for (int i = 0 ; i < vectorX .length ; ++i ) {
278- result += vectorX [i ] * vectorY [i ];
296+
297+ for (int index = 0 ; index < vectorX .length ; ++index ) {
298+ result += vectorX [index ] * vectorY [index ];
279299}
280300
281301return result ;
282302}
283303
284- public static float norm (float [] vector ) {
304+ private static float norm (float [] vector ) {
285305return dotProduct (vector , vector );
286306}
287307
0 commit comments