Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,20 @@ private KnnFloatBenchmarkFunction(int dims, boolean normalize) {
private abstract static class BinaryFloatBenchmarkFunction extends BenchmarkFunction {

final BytesRef docVector;
final float[] docFloatVector;
final float[] queryVector;

private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {
super(dims);

float[] docVector = new float[dims];
docFloatVector = new float[dims];
queryVector = new float[dims];

float docMagnitude = 0f;
float queryMagnitude = 0f;

for (int i = 0; i < dims; ++i) {
docVector[i] = (float) (dims - i);
docFloatVector[i] = (float) (dims - i);
queryVector[i] = (float) i;

docMagnitude += (float) (dims - i);
Expand All @@ -136,11 +137,11 @@ private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {

for (int i = 0; i < dims; ++i) {
if (normalize) {
docVector[i] /= docMagnitude;
docFloatVector[i] /= docMagnitude;
queryVector[i] /= queryMagnitude;
}

byteBuffer.putFloat(docVector[i]);
byteBuffer.putFloat(docFloatVector[i]);
}

byteBuffer.putFloat(docMagnitude);
Expand Down Expand Up @@ -238,7 +239,7 @@ private DotBinaryFloatBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docVector, dims, Version.CURRENT).dotProduct(queryVector);
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).dotProduct(queryVector);
}
}

Expand Down Expand Up @@ -286,7 +287,7 @@ private CosineBinaryFloatBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
}
}

Expand Down Expand Up @@ -334,7 +335,7 @@ private L1BinaryFloatBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
}
}

Expand Down Expand Up @@ -382,7 +383,7 @@ private L2BinaryFloatBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
}
}

Expand Down
5 changes: 5 additions & 0 deletions docs/changelog/96617.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 96617
summary: Improve brute force vector search speed by using Lucene functions
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,23 @@ public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) {
/**
* Calculates vector magnitude
*/
private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) {
final int length = denseVectorLength(indexVersion, vectorBR);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
private static float calculateMagnitude(float[] decodedVector) {
double magnitude = 0.0f;
for (int i = 0; i < length; i++) {
float value = byteBuffer.getFloat();
magnitude += value * value;
for (int i = 0; i < decodedVector.length; i++) {
magnitude += decodedVector[i] * decodedVector[i];
}
magnitude = Math.sqrt(magnitude);
return (float) magnitude;
}

public static float getMagnitude(Version indexVersion, BytesRef vectorBR) {
public static float getMagnitude(Version indexVersion, BytesRef vectorBR, float[] decodedVector) {
if (vectorBR == null) {
throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE);
}
if (indexVersion.onOrAfter(Version.V_7_5_0)) {
return decodeMagnitude(indexVersion, vectorBR);
} else {
return calculateMagnitude(indexVersion, vectorBR);
return calculateMagnitude(decodedVector);
}
}

Expand All @@ -70,7 +67,7 @@ public static void decodeDenseVector(BytesRef vectorBR, float[] vector) {
}
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
for (int dim = 0; dim < vector.length; dim++) {
vector[dim] = byteBuffer.getFloat();
vector[dim] = byteBuffer.getFloat((dim + vectorBR.offset) * Float.BYTES);
Copy link
Contributor

@ChrisHegarty ChrisHegarty Jun 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect that vectorBR.offset is not needed here. In fact it would be a problem if it is anything other than 0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisHegarty I guess I misunderstand. I thought getFloat was the absolute position within the underlying bytes, not relative to the starting position (dictated by the wrapping of bytes with offset).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My equation is wrong for sure, I think it should be (dim * Float.BYTES) + vectorBR.offset?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test passes with my fixed equation:

 public void testDecoding() { float[] inputFloats = new float[]{1f, 2f, 3f, 4f}; ByteBuffer byteBuffer = ByteBuffer.allocate(16); for (float f : inputFloats) { byteBuffer.putFloat(f); } BytesRef floatBytes = new BytesRef(byteBuffer.array()); floatBytes.length = 12; floatBytes.offset = 4; float[] outputFloats = new float[3]; VectorEncoderDecoder.decodeDenseVector(floatBytes, outputFloats); assertArrayEquals(outputFloats, new float[]{2f, 3f, 4f}, 0f); } 
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisHegarty what do you think? I fixed my math. Is this still an issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benwtrent Your changes are good. The wrap with an offset and length is not relevant any more - it could be a plain wrap(byte[]) - since we are now using absolute addressing of the buffer - the complete byte[] is available to us even with wrap(byte[],int,int), which just sets up an initial position and limit.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,36 @@
package org.elasticsearch.script.field.vectors;

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;

import java.nio.ByteBuffer;
import java.util.List;

public class BinaryDenseVector implements DenseVector {

protected final BytesRef docVector;
protected final int dims;
protected final Version indexVersion;
private final BytesRef docVector;

private final int dims;
private final Version indexVersion;

protected float[] decodedDocVector;

public BinaryDenseVector(BytesRef docVector, int dims, Version indexVersion) {
public BinaryDenseVector(float[] decodedDocVector, BytesRef docVector, int dims, Version indexVersion) {
this.decodedDocVector = decodedDocVector;
this.docVector = docVector;
this.indexVersion = indexVersion;
this.dims = dims;
}

@Override
public float[] getVector() {
if (decodedDocVector == null) {
decodedDocVector = new float[dims];
VectorEncoderDecoder.decodeDenseVector(docVector, decodedDocVector);
}
return decodedDocVector;
}

@Override
public float getMagnitude() {
return VectorEncoderDecoder.getMagnitude(indexVersion, docVector);
return VectorEncoderDecoder.getMagnitude(indexVersion, docVector, decodedDocVector);
}

@Override
Expand All @@ -50,22 +48,14 @@ public int dotProduct(byte[] queryVector) {

@Override
public double dotProduct(float[] queryVector) {
ByteBuffer byteBuffer = wrap(docVector);

double dotProduct = 0;
for (float v : queryVector) {
dotProduct += byteBuffer.getFloat() * v;
}
return dotProduct;
return VectorUtil.dotProduct(decodedDocVector, queryVector);
}

@Override
public double dotProduct(List<Number> queryVector) {
ByteBuffer byteBuffer = wrap(docVector);

double dotProduct = 0;
for (int i = 0; i < queryVector.size(); i++) {
dotProduct += byteBuffer.getFloat() * queryVector.get(i).floatValue();
dotProduct += decodedDocVector[i] * queryVector.get(i).floatValue();
}
return dotProduct;
}
Expand All @@ -77,22 +67,18 @@ public int l1Norm(byte[] queryVector) {

@Override
public double l1Norm(float[] queryVector) {
ByteBuffer byteBuffer = wrap(docVector);

double l1norm = 0;
for (float v : queryVector) {
l1norm += Math.abs(v - byteBuffer.getFloat());
for (int i = 0; i < queryVector.length; i++) {
l1norm += Math.abs(queryVector[i] - decodedDocVector[i]);
}
return l1norm;
}

@Override
public double l1Norm(List<Number> queryVector) {
ByteBuffer byteBuffer = wrap(docVector);

double l1norm = 0;
for (int i = 0; i < queryVector.size(); i++) {
l1norm += Math.abs(queryVector.get(i).floatValue() - byteBuffer.getFloat());
l1norm += Math.abs(queryVector.get(i).floatValue() - decodedDocVector[i]);
}
return l1norm;
}
Expand All @@ -104,21 +90,14 @@ public double l2Norm(byte[] queryVector) {

@Override
public double l2Norm(float[] queryVector) {
ByteBuffer byteBuffer = wrap(docVector);
double l2norm = 0;
for (float queryValue : queryVector) {
double diff = byteBuffer.getFloat() - queryValue;
l2norm += diff * diff;
}
return Math.sqrt(l2norm);
return Math.sqrt(VectorUtil.squareDistance(queryVector, decodedDocVector));
}

@Override
public double l2Norm(List<Number> queryVector) {
ByteBuffer byteBuffer = wrap(docVector);
double l2norm = 0;
for (Number number : queryVector) {
double diff = byteBuffer.getFloat() - number.floatValue();
for (int i = 0; i < queryVector.size(); i++) {
double diff = decodedDocVector[i] - queryVector.get(i).floatValue();
l2norm += diff * diff;
}
return Math.sqrt(l2norm);
Expand Down Expand Up @@ -156,8 +135,4 @@ public boolean isEmpty() {
public int getDims() {
return dims;
}

private static ByteBuffer wrap(BytesRef dv) {
return ByteBuffer.wrap(dv.bytes, dv.offset, dv.length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,24 @@
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues;
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;

import java.io.IOException;

public class BinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {

protected final BinaryDocValues input;
protected final Version indexVersion;
protected final int dims;
protected BytesRef value;
private final BinaryDocValues input;
private final float[] vectorValue;
private final Version indexVersion;
private final int dims;
private BytesRef value;

public BinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims, Version indexVersion) {
super(name, elementType);
this.input = input;
this.indexVersion = indexVersion;
this.dims = dims;
this.vectorValue = new float[dims];
}

@Override
Expand All @@ -54,16 +57,17 @@ public DenseVector get() {
if (isEmpty()) {
return DenseVector.EMPTY;
}

return new BinaryDenseVector(value, dims, indexVersion);
VectorEncoderDecoder.decodeDenseVector(value, vectorValue);
return new BinaryDenseVector(vectorValue, value, dims, indexVersion);
}

@Override
public DenseVector get(DenseVector defaultValue) {
if (isEmpty()) {
return defaultValue;
}
return new BinaryDenseVector(value, dims, indexVersion);
VectorEncoderDecoder.decodeDenseVector(value, vectorValue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible get is used w/ the same value twice. Wonder if it makes sense to cache it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can cache it for sure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added code to cache

return new BinaryDenseVector(vectorValue, value, dims, indexVersion);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public void testFloatVsListQueryVector() {

for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
BytesRef value = BinaryDenseVectorScriptDocValuesTests.mockEncodeDenseVector(docVector, ElementType.FLOAT, indexVersion);
BinaryDenseVector bdv = new BinaryDenseVector(value, dims, indexVersion);
BinaryDenseVector bdv = new BinaryDenseVector(docVector, value, dims, indexVersion);

assertEquals(bdv.dotProduct(arrayQV), bdv.dotProduct(listQV), 0.001f);
assertEquals(bdv.dotProduct((Object) listQV), bdv.dotProduct((Object) arrayQV), 0.001f);
Expand Down Expand Up @@ -219,7 +219,7 @@ public void testFloatUnsupported() {
e = expectThrows(UnsupportedOperationException.class, () -> knn.cosineSimilarity((Object) queryVector));
assertEquals(e.getMessage(), "use [double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector)] instead");

BinaryDenseVector binary = new BinaryDenseVector(new BytesRef(docBuffer.array()), dims, Version.CURRENT);
BinaryDenseVector binary = new BinaryDenseVector(docVector, new BytesRef(docBuffer.array()), dims, Version.CURRENT);

e = expectThrows(UnsupportedOperationException.class, () -> binary.dotProduct(queryVector));
assertEquals(e.getMessage(), "use [double dotProduct(float[] queryVector)] instead");
Expand Down