1515 */
1616package org .springframework .data .jpa .repository ;
1717
18- import static org .assertj .core .api .Assertions .* ;
18+ import static org .assertj .core .api .Assertions .assertThat ;
1919
2020import jakarta .persistence .Column ;
2121import jakarta .persistence .Entity ;
3636import org .junit .jupiter .api .Test ;
3737import org .junit .jupiter .params .ParameterizedTest ;
3838import org .junit .jupiter .params .provider .MethodSource ;
39-
4039import org .springframework .beans .factory .annotation .Autowired ;
4140import org .springframework .data .domain .Range ;
4241import org .springframework .data .domain .Score ;
5352 * Testcase to verify Vector Search work with Hibernate.
5453 *
5554 * @author Mark Paluch
55+ * @author Christoph Strobl
5656 */
5757@ Transactional
5858@ Rollback (value = false )
@@ -65,10 +65,11 @@ abstract class AbstractVectorIntegrationTests {
6565@ BeforeEach
6666void setUp () {
6767
68- WithVector w1 = new WithVector ("de" , "one" , new float [] { 0.1001f , 0.22345f , 0.33456f , 0.44567f , 0.55678f });
69- WithVector w2 = new WithVector ("de" , "two" , new float [] { 0.2001f , 0.32345f , 0.43456f , 0.54567f , 0.65678f });
70- WithVector w3 = new WithVector ("en" , "three" , new float [] { 0.9001f , 0.82345f , 0.73456f , 0.64567f , 0.55678f });
71- WithVector w4 = new WithVector ("de" , "four" , new float [] { 0.9001f , 0.92345f , 0.93456f , 0.94567f , 0.95678f });
68+ WithVector w1 = new WithVector ("de" , "one" , "d1" , new float [] { 0.1001f , 0.22345f , 0.33456f , 0.44567f , 0.55678f });
69+ WithVector w2 = new WithVector ("de" , "two" , "d2" , new float [] { 0.2001f , 0.32345f , 0.43456f , 0.54567f , 0.65678f });
70+ WithVector w3 = new WithVector ("en" , "three" , "d3" ,
71+ new float [] { 0.9001f , 0.82345f , 0.73456f , 0.64567f , 0.55678f });
72+ WithVector w4 = new WithVector ("de" , "four" , "d4" , new float [] { 0.9001f , 0.92345f , 0.93456f , 0.94567f , 0.95678f });
7273
7374repository .deleteAllInBatch ();
7475repository .saveAllAndFlush (Arrays .asList (w1 , w2 , w3 , w4 ));
@@ -93,7 +94,7 @@ static Set<VectorScoringFunctions> scoringFunctions() {
9394VectorScoringFunctions .EUCLIDEAN );
9495}
9596
96- @ Test
97+ @ Test // GH-3868
9798void shouldNormalizeEuclideanSimilarity () {
9899
99100SearchResults <WithVector > results = repository .searchTop5ByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -108,7 +109,16 @@ void shouldNormalizeEuclideanSimilarity() {
108109assertThat (two .getScore ().getValue ()).isGreaterThan (0.99 );
109110}
110111
111- @ Test
112+ @ Test // GH-3868
113+ void orderTargetsProperty () {
114+
115+ SearchResults <WithVector > results = repository .searchTop5ByCountryAndEmbeddingWithinOrderByDistance ("de" , VECTOR ,
116+ Similarity .of (0 , VectorScoringFunctions .EUCLIDEAN ));
117+
118+ assertThat (results .getContent ()).extracting (it -> it .getContent ().getDistance ()).containsExactly ("d1" , "d2" , "d4" );
119+ }
120+
121+ @ Test // GH-3868
112122void shouldNormalizeCosineSimilarity () {
113123
114124SearchResults <WithVector > results = repository .searchTop5ByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -123,7 +133,7 @@ void shouldNormalizeCosineSimilarity() {
123133assertThat (two .getScore ().getValue ()).isGreaterThan (0.99 );
124134}
125135
126- @ Test
136+ @ Test // GH-3868
127137void shouldRunStringQuery () {
128138
129139List <WithVector > results = repository .findAnnotatedByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -133,7 +143,7 @@ void shouldRunStringQuery() {
133143assertThat (results ).extracting (WithVector ::getDescription ).containsSequence ("two" , "one" , "four" );
134144}
135145
136- @ Test
146+ @ Test // GH-3868
137147void shouldRunStringQueryWithDistance () {
138148
139149SearchResults <WithVector > results = repository .searchAnnotatedByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -149,7 +159,7 @@ void shouldRunStringQueryWithDistance() {
149159assertThat (result .getScore ().getFunction ()).isEqualTo (VectorScoringFunctions .COSINE );
150160}
151161
152- @ Test
162+ @ Test // GH-3868
153163void shouldRunStringQueryWithFloatDistance () {
154164
155165SearchResults <WithVector > results = repository .searchAnnotatedByCountryAndEmbeddingWithin ("de" , VECTOR , 2 );
@@ -164,7 +174,7 @@ void shouldRunStringQueryWithFloatDistance() {
164174assertThat (result .getScore ().getFunction ()).isEqualTo (ScoringFunction .unspecified ());
165175}
166176
167- @ Test
177+ @ Test // GH-3868
168178void shouldApplyVectorSearchWithRange () {
169179
170180SearchResults <WithVector > results = repository .searchAllByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -176,7 +186,7 @@ void shouldApplyVectorSearchWithRange() {
176186.containsSequence ("two" , "one" , "four" );
177187}
178188
179- @ Test
189+ @ Test // GH-3868
180190void shouldApplyVectorSearchAndReturnList () {
181191
182192List <WithVector > results = repository .findAllByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -186,7 +196,7 @@ void shouldApplyVectorSearchAndReturnList() {
186196assertThat (results ).extracting (WithVector ::getDescription ).containsSequence ("one" , "two" , "four" );
187197}
188198
189- @ Test
199+ @ Test // GH-3868
190200void shouldProjectVectorSearchAsInterface () {
191201
192202SearchResults <WithDescription > results = repository .searchInterfaceProjectionByCountryAndEmbeddingWithin ("de" ,
@@ -196,7 +206,7 @@ void shouldProjectVectorSearchAsInterface() {
196206.containsSequence ("two" , "one" , "four" );
197207}
198208
199- @ Test
209+ @ Test // GH-3868
200210void shouldProjectVectorSearchAsDto () {
201211
202212SearchResults <DescriptionDto > results = repository .searchDtoByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -206,7 +216,7 @@ void shouldProjectVectorSearchAsDto() {
206216.containsSequence ("two" , "one" , "four" );
207217}
208218
209- @ Test
219+ @ Test // GH-3868
210220void shouldProjectVectorSearchDynamically () {
211221
212222SearchResults <DescriptionDto > dtos = repository .searchDynamicByCountryAndEmbeddingWithin ("de" , VECTOR ,
@@ -233,16 +243,19 @@ public static class WithVector {
233243private String country ;
234244private String description ;
235245
246+ private String distance ;
247+
236248@ Column (name = "the_embedding" )
237249@ JdbcTypeCode (SqlTypes .VECTOR )
238250@ Array (length = 5 ) private float [] embedding ;
239251
240252public WithVector () {}
241253
242- public WithVector (String country , String description , float [] embedding ) {
254+ public WithVector (String country , String description , String distance , float [] embedding ) {
243255this .country = country ;
244256this .description = description ;
245257this .embedding = embedding ;
258+ this .distance = distance ;
246259}
247260
248261public Integer getId () {
@@ -273,9 +286,22 @@ public void setEmbedding(float[] embedding) {
273286this .embedding = embedding ;
274287}
275288
289+ public void setDescription (String description ) {
290+ this .description = description ;
291+ }
292+
293+ public String getDistance () {
294+ return distance ;
295+ }
296+
297+ public void setDistance (String distance ) {
298+ this .distance = distance ;
299+ }
300+
276301@ Override
277302public String toString () {
278- return "WithVector{" + "country='" + country + '\'' + ", description='" + description + '\'' + '}' ;
303+ return "WithVector{" + "id=" + id + ", country='" + country + '\'' + ", description='" + description + '\''
304+ + ", distance='" + distance + '\'' + ", embedding=" + Arrays .toString (embedding ) + '}' ;
279305}
280306}
281307
@@ -328,6 +354,9 @@ SearchResults<WithVector> searchAllByCountryAndEmbeddingWithin(String country, V
328354
329355SearchResults <WithVector > searchTop5ByCountryAndEmbeddingWithin (String country , Vector embedding , Score distance );
330356
357+ SearchResults <WithVector > searchTop5ByCountryAndEmbeddingWithinOrderByDistance (String country , Vector embedding ,
358+ Score distance );
359+
331360SearchResults <WithDescription > searchInterfaceProjectionByCountryAndEmbeddingWithin (String country ,
332361Vector embedding , Score distance );
333362
0 commit comments