Skip to content
Prev Previous commit
Next Next commit
Add reactive search implementation.
  • Loading branch information
mp911de committed Apr 29, 2025
commit 72aed1ff10ae8d0bacb725111d6e4bcd4a890d87
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import reactor.core.publisher.Mono;

import org.reactivestreams.Publisher;

import org.springframework.core.convert.converter.Converter;
import org.springframework.data.cassandra.ReactiveResultSet;
import org.springframework.data.cassandra.core.CassandraOperations;
Expand All @@ -26,6 +27,7 @@
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.ExistsExecution;
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.ResultProcessingConverter;
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.ResultProcessingExecution;
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.SearchExecution;
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.SingleEntityExecution;
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.SlicedExecution;
import org.springframework.data.cassandra.repository.query.ReactiveCassandraQueryExecution.WindowExecution;
Expand Down Expand Up @@ -126,6 +128,8 @@ private ReactiveCassandraQueryExecution getExecutionToWrap(CassandraParameterAcc
} else if (getQueryMethod().isScrollQuery()) {
return new WindowExecution(getReactiveCassandraOperations(), parameterAccessor.getScrollPosition(),
parameterAccessor.getLimit());
} else if (getQueryMethod().isSearchQuery()) {
return new SearchExecution(getReactiveCassandraOperations(), parameterAccessor);
} else if (getQueryMethod().isCollectionQuery()) {
return new CollectionExecution(getReactiveCassandraOperations());
} else if (isCountQuery()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.SearchResult;
import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Similarity;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.mapping.context.MappingContext;
Expand Down Expand Up @@ -204,8 +205,10 @@ public Object execute(Statement<?> statement, Class<?> type) {
private Score getScore(Row row, String columnName, @Nullable ScoringFunction function) {

Object object = row.getObject(columnName);
return Score.of(((Number) object).doubleValue(), function == null ? ScoringFunction.UNSPECIFIED : function);
return Similarity.raw(((Number) object).doubleValue(),
function == null ? ScoringFunction.unspecified() : function);
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@
class QueryStatementCreator {

private static final Map<ScoringFunction, SimilarityFunction> SIMILARITY_FUNCTIONS = Map.of(
VectorScoringFunctions.COSINE, SimilarityFunction.COSINE, VectorScoringFunctions.EUCLIDEAN,
SimilarityFunction.EUCLIDEAN, VectorScoringFunctions.DOT, SimilarityFunction.DOT_PRODUCT,
VectorScoringFunctions.INNER_PRODUCT, SimilarityFunction.DOT_PRODUCT);
VectorScoringFunctions.COSINE, SimilarityFunction.COSINE, //
VectorScoringFunctions.EUCLIDEAN, SimilarityFunction.EUCLIDEAN, //
VectorScoringFunctions.DOT_PRODUCT, SimilarityFunction.DOT_PRODUCT);

private static final Log LOG = LogFactory.getLog(QueryStatementCreator.class);

Expand Down Expand Up @@ -148,7 +148,7 @@ private SimilarityFunction getSimilarityFunction(@Nullable ScoringFunction funct

if (function == null) {
throw new IllegalStateException(
"Cannot determine ScoringFunction. No Score or bounded Score Range parameters provided.");
"Cannot determine ScoringFunction. No ScoringFunction, Score/Similarity or bounded Score Range parameters provided.");
}

SimilarityFunction similarityFunction = SIMILARITY_FUNCTIONS.get(function);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import java.util.List;

import org.jspecify.annotations.Nullable;
import org.reactivestreams.Publisher;

import org.springframework.core.convert.converter.Converter;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.data.cassandra.core.ReactiveCassandraOperations;
Expand All @@ -31,6 +33,10 @@
import org.springframework.data.convert.DtoInstantiatingConverter;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Score;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.SearchResult;
import org.springframework.data.domain.Similarity;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.mapping.context.MappingContext;
Expand Down Expand Up @@ -152,6 +158,44 @@ public Publisher<? extends Object> execute(Statement<?> statement, Class<?> type

}

final class SearchExecution implements ReactiveCassandraQueryExecution {

private final ReactiveCassandraOperations operations;
private final CassandraParameterAccessor accessor;

public SearchExecution(ReactiveCassandraOperations operations, CassandraParameterAccessor accessor) {

this.operations = operations;
this.accessor = accessor;
}

@Override
public Publisher<? extends Object> execute(Statement<?> statement, Class<?> type) {

ScoringFunction function = accessor.getScoringFunction();

return operations.query(statement).as(type).map((row, reader) -> {

Object entity = reader.get();
if (row.getColumnDefinitions().contains("__score__")) {
return new SearchResult<>(entity, getScore(row, "__score__", function));
}

if (row.getColumnDefinitions().contains("score")) {
return new SearchResult<>(entity, getScore(row, "score", function));
}
return new SearchResult<>(entity, 0);
}).all();
}

private Score getScore(Row row, String columnName, @Nullable ScoringFunction function) {

Object object = row.getObject(columnName);
return Similarity.raw(((Number) object).doubleValue(),
function == null ? ScoringFunction.unspecified() : function);
}
}

/**
* {@link ReactiveCassandraQueryExecution} to return a single entity.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Copyright 2016-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.cassandra.repository;

import static org.assertj.core.api.Assertions.*;

import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;

import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.UUID;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.FilterType;
import org.springframework.data.annotation.Id;
import org.springframework.data.annotation.PersistenceCreator;
import org.springframework.data.cassandra.config.SchemaAction;
import org.springframework.data.cassandra.core.mapping.SaiIndexed;
import org.springframework.data.cassandra.core.mapping.Table;
import org.springframework.data.cassandra.core.mapping.VectorType;
import org.springframework.data.cassandra.repository.config.EnableReactiveCassandraRepositories;
import org.springframework.data.cassandra.repository.support.AbstractSpringDataEmbeddedCassandraIntegrationTest;
import org.springframework.data.cassandra.repository.support.IntegrationTestConfig;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.SearchResult;
import org.springframework.data.domain.Similarity;
import org.springframework.data.domain.Vector;
import org.springframework.data.domain.VectorScoringFunctions;
import org.springframework.data.repository.reactive.ReactiveCrudRepository;
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;

/**
* Integration tests for Vector Search using reactive repositories.
*
* @author Mark Paluch
*/
@SpringJUnitConfig
class ReactiveVectorSearchIntegrationTests extends AbstractSpringDataEmbeddedCassandraIntegrationTest {

Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f);

@Configuration
@EnableReactiveCassandraRepositories(basePackageClasses = ReactiveVectorSearchRepository.class,
considerNestedRepositories = true,
includeFilters = @ComponentScan.Filter(classes = ReactiveVectorSearchRepository.class,
type = FilterType.ASSIGNABLE_TYPE))
public static class Config extends IntegrationTestConfig {

@Override
protected Set<Class<?>> getInitialEntitySet() {
return Collections.singleton(WithVectorFields.class);
}

@Override
public SchemaAction getSchemaAction() {
return SchemaAction.RECREATE_DROP_UNUSED;
}
}

@Autowired ReactiveVectorSearchRepository repository;

@BeforeEach
void setUp() {

repository.deleteAll().as(StepVerifier::create).verifyComplete();

WithVectorFields w1 = new WithVectorFields("de", "one", Vector.of(0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f));
WithVectorFields w2 = new WithVectorFields("de", "two", Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f));
WithVectorFields w3 = new WithVectorFields("en", "three",
Vector.of(0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f));
WithVectorFields w4 = new WithVectorFields("de", "four",
Vector.of(0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f));

repository.saveAll(List.of(w1, w2, w3, w4)).as(StepVerifier::create).expectNextCount(4).verifyComplete();
}

@Test // GH-
void shouldConsiderScoringFunction() {

Vector vector = Vector.of(0.9f, 0.54f, 0.12f, 0.1f, 0.95f);

List<SearchResult<WithVectorFields>> results = repository
.searchByEmbeddingNear(vector, VectorScoringFunctions.COSINE, Limit.of(100)).collectList().block();

assertThat(results).hasSize(4);
for (SearchResult<WithVectorFields> result : results) {
assertThat(result.getScore()).isInstanceOf(Similarity.class);
assertThat(result.getScore().getValue()).isNotCloseTo(0d, offset(0.1d));
}

results = repository.searchByEmbeddingNear(VECTOR, VectorScoringFunctions.EUCLIDEAN, Limit.of(100)).collectList()
.block();

assertThat(results).hasSize(4);
for (SearchResult<WithVectorFields> result : results) {
assertThat(result.getScore()).isInstanceOf(Similarity.class);
assertThat(result.getScore().getValue()).isNotCloseTo(0.3d, offset(0.1d));
}
}

@Test // GH-
void shouldRunAnnotatedSearchByVector() {

List<SearchResult<WithVectorFields>> results = repository.searchAnnotatedByEmbeddingNear(VECTOR, Limit.of(100))
.collectList().block();

assertThat(results).hasSize(4);
for (SearchResult<WithVectorFields> result : results) {
assertThat(result.getScore()).isInstanceOf(Similarity.class);
assertThat(result.getScore().getValue()).isNotCloseTo(0d, offset(0.1d));
}
}

@Test // GH-
void shouldFindByVector() {

List<WithVectorFields> result = repository.findByEmbeddingNear(VECTOR, Limit.of(100)).collectList().block();

assertThat(result).hasSize(4);
}

interface ReactiveVectorSearchRepository extends ReactiveCrudRepository<WithVectorFields, UUID> {

Flux<SearchResult<WithVectorFields>> searchByEmbeddingNear(Vector embedding, ScoringFunction function, Limit limit);

Flux<WithVectorFields> findByEmbeddingNear(Vector embedding, Limit limit);

@Query("SELECT id,description,country,similarity_cosine(embedding,:embedding) AS score FROM withvectorfields ORDER BY embedding ANN OF :embedding LIMIT :limit")
Flux<SearchResult<WithVectorFields>> searchAnnotatedByEmbeddingNear(Vector embedding, Limit limit);

}

@Table
static class WithVectorFields {

@Id String id;
String country;
String description;

@VectorType(dimensions = 5)
@SaiIndexed Vector embedding;

@PersistenceCreator
public WithVectorFields(String id, String country, String description, Vector embedding) {
this.id = id;
this.country = country;
this.description = description;
this.embedding = embedding;
}

public WithVectorFields(String country, String description, Vector embedding) {
this.id = UUID.randomUUID().toString();
this.country = country;
this.description = description;
this.embedding = embedding;
}

public String getId() {
return id;
}

public String getCountry() {
return country;
}

public String getDescription() {
return description;
}

public Vector getEmbedding() {
return embedding;
}

@Override
public String toString() {
return "WithVectorFields{" + "id='" + id + '\'' + ", country='" + country + '\'' + ", description='" + description
+ '\'' + '}';
}
}

}
Loading