Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/122381.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 122381
summary: Adds implementations of dotProduct and cosineSimilarity painless methods to operate on float vectors for byte fields
area: Vector Search
type: enhancement
issues:
- 117274
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ public static float ipFloatBit(float[] q, byte[] d) {
return IMPL.ipFloatBit(q, d);
}

/**
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a byte vector.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static float ipFloatByte(float[] q, byte[] d) {
if (q.length != d.length) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + d.length);
}
return IMPL.ipFloatByte(q, d);
}

/**
* AND bit count computed over signed bytes.
* Copied from Lucene's XOR implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ public float ipFloatBit(float[] q, byte[] d) {
return ipFloatBitImpl(q, d);
}

@Override
public float ipFloatByte(float[] q, byte[] d) {
return ipFloatByteImpl(q, d);
}

public static int ipByteBitImpl(byte[] q, byte[] d) {
assert q.length == d.length * Byte.SIZE;
int acc0 = 0;
Expand Down Expand Up @@ -101,4 +106,12 @@ public static long ipByteBinByteImpl(byte[] q, byte[] d) {
}
return ret;
}

public static float ipFloatByteImpl(float[] q, byte[] d) {
float ret = 0;
for (int i = 0; i < q.length; i++) {
ret += q[i] * d[i];
}
return ret;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ public interface ESVectorUtilSupport {
int ipByteBit(byte[] q, byte[] d);

float ipFloatBit(float[] q, byte[] d);

float ipFloatByte(float[] q, byte[] d);
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ public float ipFloatBit(float[] q, byte[] d) {
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
}

@Override
public float ipFloatByte(float[] q, byte[] d) {
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Panama implementation can be added as a separate PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #123270

}

private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,28 @@ public void testIpByteBit() {
public void testIpFloatBit() {
float[] q = new float[16];
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
random().nextFloat();
for (int i = 0; i < q.length; i++) {
q[i] = random().nextFloat();
}
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
}

public void testIpFloatByte() {
float[] q = new float[16];
byte[] d = new byte[16];
for (int i = 0; i < q.length; i++) {
q[i] = random().nextFloat();
}
random().nextBytes(d);

float expected = 0;
for (int i = 0; i < q.length; i++) {
expected += q[i] * d[i];
}
assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6);
}

public void testBitAndCount() {
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,38 @@ setup:
- match: {hits.hits.2._id: "1"}
- match: {hits.hits.2._score: 1632.0}
---
"Dot Product float":
- requires:
capabilities:
- path: /_search
capabilities: [byte_float_dot_product_capability]
test_runner_features: [capabilities]
reason: "float vector queries capability added"
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]

- match: {hits.total: 3}

- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 32865.2}

- match: {hits.hits.1._id: "3"}
- match: {hits.hits.1._score: 21413.4}

- match: {hits.hits.2._id: "1"}
- match: {hits.hits.2._score: 1862.3}
---
"Cosine Similarity":
- do:
headers:
Expand Down Expand Up @@ -198,3 +230,39 @@ setup:
- match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.509}
- lte: {hits.hits.2._score: 0.512}

---
"Cosine Similarity float":
- requires:
capabilities:
- path: /_search
capabilities: [byte_float_dot_product_capability]
test_runner_features: [capabilities]
reason: "float vector queries capability added"
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, 'vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]

- match: {hits.total: 3}

- match: {hits.hits.0._id: "2"}
- gte: {hits.hits.0._score: 0.989}
- lte: {hits.hits.0._score: 0.992}

- match: {hits.hits.1._id: "3"}
- gte: {hits.hits.1._score: 0.885}
- lte: {hits.hits.1._score: 0.888}

- match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.505}
- lte: {hits.hits.2._score: 0.508}
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,17 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp
}

@Override
public void checkVectorBounds(float[] vector) {
checkNanAndInfinite(vector);

StringBuilder errorBuilder = null;
StringBuilder checkVectorErrors(float[] vector) {
StringBuilder errors = checkNanAndInfinite(vector);
if (errors != null) {
return errors;
}

for (int index = 0; index < vector.length; ++index) {
float value = vector[index];

if (value % 1.0f != 0.0f) {
errorBuilder = new StringBuilder(
errors = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support non-decimal values but found decimal value ["
Expand All @@ -368,7 +369,7 @@ public void checkVectorBounds(float[] vector) {
}

if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
errorBuilder = new StringBuilder(
errors = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support integers between ["
Expand All @@ -385,9 +386,7 @@ public void checkVectorBounds(float[] vector) {
}
}

if (errorBuilder != null) {
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
}
return errors;
}

@Override
Expand Down Expand Up @@ -614,8 +613,8 @@ public FloatVectorValues getFloatVectorValues(String fieldName) throws IOExcepti
}

@Override
public void checkVectorBounds(float[] vector) {
checkNanAndInfinite(vector);
StringBuilder checkVectorErrors(float[] vector) {
return checkNanAndInfinite(vector);
}

@Override
Expand Down Expand Up @@ -768,16 +767,17 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp
}

@Override
public void checkVectorBounds(float[] vector) {
checkNanAndInfinite(vector);

StringBuilder errorBuilder = null;
StringBuilder checkVectorErrors(float[] vector) {
StringBuilder errors = checkNanAndInfinite(vector);
if (errors != null) {
return errors;
}

for (int index = 0; index < vector.length; ++index) {
float value = vector[index];

if (value % 1.0f != 0.0f) {
errorBuilder = new StringBuilder(
errors = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support non-decimal values but found decimal value ["
Expand All @@ -790,7 +790,7 @@ public void checkVectorBounds(float[] vector) {
}

if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
errorBuilder = new StringBuilder(
errors = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support integers between ["
Expand All @@ -807,9 +807,7 @@ public void checkVectorBounds(float[] vector) {
}
}

if (errorBuilder != null) {
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
}
return errors;
}

@Override
Expand Down Expand Up @@ -993,7 +991,44 @@ public abstract VectorData parseKnnVector(

public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);

public abstract void checkVectorBounds(float[] vector);
/**
* Checks the input {@code vector} is one of the {@code possibleTypes},
* and returns the first type that it matches
*/
public static ElementType checkValidVector(float[] vector, ElementType... possibleTypes) {
assert possibleTypes.length != 0;
// we're looking for one valid allowed type
// assume the types are in order of specificity
StringBuilder[] errors = new StringBuilder[possibleTypes.length];
for (int i = 0; i < possibleTypes.length; i++) {
StringBuilder error = possibleTypes[i].checkVectorErrors(vector);
if (error == null) {
// this one works - use it
return possibleTypes[i];
} else {
errors[i] = error;
}
}

// oh dear, none of the possible types work with this vector. Generate the error message and throw.
StringBuilder message = new StringBuilder();
for (int i = 0; i < possibleTypes.length; i++) {
if (i > 0) {
message.append(" ");
}
message.append("Vector is not a ").append(possibleTypes[i]).append(" vector: ").append(errors[i]);
}
throw new IllegalArgumentException(appendErrorElements(message, vector).toString());
}

public void checkVectorBounds(float[] vector) {
StringBuilder errors = checkVectorErrors(vector);
if (errors != null) {
throw new IllegalArgumentException(appendErrorElements(errors, vector).toString());
}
}

abstract StringBuilder checkVectorErrors(float[] vector);

abstract void checkVectorMagnitude(
VectorSimilarity similarity,
Expand All @@ -1017,7 +1052,7 @@ public int parseDimensionCount(DocumentParserContext context) throws IOException
return index;
}

void checkNanAndInfinite(float[] vector) {
StringBuilder checkNanAndInfinite(float[] vector) {
StringBuilder errorBuilder = null;

for (int index = 0; index < vector.length; ++index) {
Expand All @@ -1044,9 +1079,7 @@ void checkNanAndInfinite(float[] vector) {
}
}

if (errorBuilder != null) {
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
}
return errorBuilder;
}

static StringBuilder appendErrorElements(StringBuilder errorBuilder, float[] vector) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ private SearchCapabilities() {}
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
/** Support Byte and Float with Bit dot product. */
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product_with_bugfix";
/** Support float query vectors on byte vectors */
private static final String BYTE_FLOAT_DOT_PRODUCT_CAPABILITY = "byte_float_dot_product_capability";
/** Support docvalue_fields parameter for `dense_vector` field. */
private static final String DENSE_VECTOR_DOCVALUE_FIELDS = "dense_vector_docvalue_fields";
/** Support transforming rank rrf queries to the corresponding rrf retriever. */
Expand All @@ -50,6 +52,7 @@ private SearchCapabilities() {}
capabilities.add(RANGE_REGEX_INTERVAL_QUERY_CAPABILITY);
capabilities.add(BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY);
capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY);
capabilities.add(BYTE_FLOAT_DOT_PRODUCT_CAPABILITY);
capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER);
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);
Expand Down
Loading
Loading