Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,20 @@ private KnnFloatBenchmarkFunction(int dims, boolean normalize) {
private abstract static class BinaryFloatBenchmarkFunction extends BenchmarkFunction {

final BytesRef docVector;
final float[] docFloatVector;
final float[] queryVector;

private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {
super(dims);

float[] docVector = new float[dims];
docFloatVector = new float[dims];
queryVector = new float[dims];

float docMagnitude = 0f;
float queryMagnitude = 0f;

for (int i = 0; i < dims; ++i) {
docVector[i] = (float) (dims - i);
docFloatVector[i] = (float) (dims - i);
queryVector[i] = (float) i;

docMagnitude += (float) (dims - i);
Expand All @@ -136,11 +137,11 @@ private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {

for (int i = 0; i < dims; ++i) {
if (normalize) {
docVector[i] /= docMagnitude;
docFloatVector[i] /= docMagnitude;
queryVector[i] /= queryMagnitude;
}

byteBuffer.putFloat(docVector[i]);
byteBuffer.putFloat(docFloatVector[i]);
}

byteBuffer.putFloat(docMagnitude);
Expand Down Expand Up @@ -178,6 +179,7 @@ private KnnByteBenchmarkFunction(int dims) {
private abstract static class BinaryByteBenchmarkFunction extends BenchmarkFunction {

final BytesRef docVector;
final byte[] vectorValue;
final byte[] queryVector;

final float queryMagnitude;
Expand All @@ -187,12 +189,14 @@ private BinaryByteBenchmarkFunction(int dims) {

ByteBuffer docVector = ByteBuffer.allocate(dims + 4);
queryVector = new byte[dims];
vectorValue = new byte[dims];

float docMagnitude = 0f;
float queryMagnitude = 0f;

for (int i = 0; i < dims; ++i) {
docVector.put((byte) (dims - i));
vectorValue[i] = (byte) (dims - i);
queryVector[i] = (byte) i;

docMagnitude += (float) (dims - i);
Expand Down Expand Up @@ -238,7 +242,7 @@ private DotBinaryFloatBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docVector, dims, Version.CURRENT).dotProduct(queryVector);
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).dotProduct(queryVector);
}
}

Expand All @@ -250,7 +254,7 @@ private DotBinaryByteBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new ByteBinaryDenseVector(docVector, dims).dotProduct(queryVector);
new ByteBinaryDenseVector(vectorValue, docVector, dims).dotProduct(queryVector);
}
}

Expand Down Expand Up @@ -286,7 +290,7 @@ private CosineBinaryFloatBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
}
}

Expand All @@ -298,7 +302,7 @@ private CosineBinaryByteBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new ByteBinaryDenseVector(docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
new ByteBinaryDenseVector(vectorValue, docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
}
}

Expand Down Expand Up @@ -334,7 +338,7 @@ private L1BinaryFloatBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
}
}

Expand All @@ -346,7 +350,7 @@ private L1BinaryByteBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new ByteBinaryDenseVector(docVector, dims).l1Norm(queryVector);
new ByteBinaryDenseVector(vectorValue, docVector, dims).l1Norm(queryVector);
}
}

Expand Down Expand Up @@ -382,7 +386,7 @@ private L2BinaryFloatBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
}
}

Expand All @@ -394,7 +398,7 @@ private L2BinaryByteBenchmarkFunction(int dims) {

@Override
public void execute(Consumer<Object> consumer) {
consumer.accept(new ByteBinaryDenseVector(docVector, dims).l2Norm(queryVector));
consumer.accept(new ByteBinaryDenseVector(vectorValue, docVector, dims).l2Norm(queryVector));
}
}

Expand Down
5 changes: 5 additions & 0 deletions docs/changelog/96617.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 96617
summary: Improve brute force vector search speed by using Lucene functions
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,23 @@ public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) {
/**
* Calculates vector magnitude
*/
private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) {
final int length = denseVectorLength(indexVersion, vectorBR);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
private static float calculateMagnitude(float[] decodedVector) {
double magnitude = 0.0f;
for (int i = 0; i < length; i++) {
float value = byteBuffer.getFloat();
magnitude += value * value;
for (int i = 0; i < decodedVector.length; i++) {
magnitude += decodedVector[i] * decodedVector[i];
}
magnitude = Math.sqrt(magnitude);
return (float) magnitude;
}

public static float getMagnitude(Version indexVersion, BytesRef vectorBR) {
public static float getMagnitude(Version indexVersion, BytesRef vectorBR, float[] decodedVector) {
if (vectorBR == null) {
throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE);
}
if (indexVersion.onOrAfter(Version.V_7_5_0)) {
return decodeMagnitude(indexVersion, vectorBR);
} else {
return calculateMagnitude(indexVersion, vectorBR);
return calculateMagnitude(decodedVector);
}
}

Expand All @@ -70,7 +67,7 @@ public static void decodeDenseVector(BytesRef vectorBR, float[] vector) {
}
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
for (int dim = 0; dim < vector.length; dim++) {
vector[dim] = byteBuffer.getFloat();
vector[dim] = byteBuffer.getFloat((dim * Float.BYTES) + vectorBR.offset);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vectorBR.offset is already taken into account to create the ByteBuffer, so we shouldn't add it there again?

Also unrelated to your PR, I would expect ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).asFloatBuffer().get(vector) to run faster than reading a float at a time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jpountz I just asked @mayya-sharipova about this same thing. Using getFloat(int) ignores the current position of the buffer. It may make sense to wrap without the offset and length parameters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also unrelated to your PR, I would expect ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).asFloatBuffer().get(vector) to run faster than reading a float at a time.

@jpountz it doesn't. This is 30% faster than making it a float buffer. For some reason, getting the absolute is faster. @ChrisHegarty might have an intuition as to why.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vectorBR.offset is already taken into account to create the ByteBuffer, so we shouldn't add it there again?

I honestly don't know. I added a test to verify (see: https://github.com/elastic/elasticsearch/pull/96617/files#diff-2d953a23603f9d7ef2f18d9f7bff3960307afc1a763e28fa2c7eca0ee3a65599). That test creates a bytebuffer with a custom length & offset and we get the expected results with my change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jdconrad, you are right indeed! Thanks @benwtrent for adding a test.

@benwtrent Thanks for checking performance. I remember that MikeS found that wrapping as a float buffer was faster when adding DataInput#readFloats, but it was with a direct byte buffer with little-endian byte order, these differences might matter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's the endianness of the bytebuffers we're using here - big in this case. It's most efficient to use the native endianness, otherwise byte swapping will occur. Since these are already in the index, is it possible to switch them to little?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jpountz here are some benchmark results:

So, average time to calculate latency for decoding float[] from ByteBuffers. Here is the benchmark code (if you can spot any wonkiness): https://gist.github.com/benwtrent/29cc0338cd851c345cace5c486095507

Direct Read (via index in the buffer) is always faster. The floatbuffer vs. byte buffer are almost identical for both decoding kinds (iteration vs. direct read).

Benchmark (floatArraySizes) Mode Cnt Score Error Units ByteBufferFloatDecodeLatencyBenchmark.decodeByteBufferDirectRead 96 avgt 5 3348858.476 ± 13314.154 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeByteBufferDirectRead 768 avgt 5 19960837.116 ± 137071.324 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeByteBufferDirectRead 2048 avgt 5 49142580.882 ± 216614.762 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeByteBufferIteration 96 avgt 5 4992161.167 ± 47263.300 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeByteBufferIteration 768 avgt 5 33762525.212 ± 110609.781 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeByteBufferIteration 2048 avgt 5 89409177.001 ± 1122561.441 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeFloatBufferDirectRead 96 avgt 5 3529213.563 ± 24635.601 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeFloatBufferDirectRead 768 avgt 5 18152123.115 ± 49712.077 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeFloatBufferDirectRead 2048 avgt 5 46610729.883 ± 2462771.340 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeFloatBufferIteration 96 avgt 5 5085876.957 ± 75157.360 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeFloatBufferIteration 768 avgt 5 32359519.955 ± 399297.947 ns/op ByteBufferFloatDecodeLatencyBenchmark.decodeFloatBufferIteration 2048 avgt 5 86083215.393 ± 1120197.606 ns/op 
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,36 @@
package org.elasticsearch.script.field.vectors;

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;

import java.nio.ByteBuffer;
import java.util.List;

public class BinaryDenseVector implements DenseVector {

protected final BytesRef docVector;
protected final int dims;
protected final Version indexVersion;
private final BytesRef docVector;

protected float[] decodedDocVector;
private final int dims;
private final Version indexVersion;

public BinaryDenseVector(BytesRef docVector, int dims, Version indexVersion) {
private final float[] decodedDocVector;

public BinaryDenseVector(float[] decodedDocVector, BytesRef docVector, int dims, Version indexVersion) {
this.decodedDocVector = decodedDocVector;
this.docVector = docVector;
this.indexVersion = indexVersion;
this.dims = dims;
}

@Override
public float[] getVector() {
if (decodedDocVector == null) {
decodedDocVector = new float[dims];
VectorEncoderDecoder.decodeDenseVector(docVector, decodedDocVector);
}
return decodedDocVector;
}

@Override
public float getMagnitude() {
return VectorEncoderDecoder.getMagnitude(indexVersion, docVector);
return VectorEncoderDecoder.getMagnitude(indexVersion, docVector, decodedDocVector);
}

@Override
Expand All @@ -50,22 +48,14 @@ public int dotProduct(byte[] queryVector) {

@Override
public double dotProduct(float[] queryVector) {
ByteBuffer byteBuffer = wrap(docVector);

double dotProduct = 0;
for (float v : queryVector) {
dotProduct += byteBuffer.getFloat() * v;
}
return dotProduct;
return VectorUtil.dotProduct(decodedDocVector, queryVector);
}

@Override
public double dotProduct(List<Number> queryVector) {
ByteBuffer byteBuffer = wrap(docVector);

double dotProduct = 0;
for (int i = 0; i < queryVector.size(); i++) {
dotProduct += byteBuffer.getFloat() * queryVector.get(i).floatValue();
dotProduct += decodedDocVector[i] * queryVector.get(i).floatValue();
}
return dotProduct;
}
Expand All @@ -77,22 +67,18 @@ public int l1Norm(byte[] queryVector) {

@Override
public double l1Norm(float[] queryVector) {
ByteBuffer byteBuffer = wrap(docVector);

double l1norm = 0;
for (float v : queryVector) {
l1norm += Math.abs(v - byteBuffer.getFloat());
for (int i = 0; i < queryVector.length; i++) {
l1norm += Math.abs(queryVector[i] - decodedDocVector[i]);
}
return l1norm;
}

@Override
public double l1Norm(List<Number> queryVector) {
ByteBuffer byteBuffer = wrap(docVector);

double l1norm = 0;
for (int i = 0; i < queryVector.size(); i++) {
l1norm += Math.abs(queryVector.get(i).floatValue() - byteBuffer.getFloat());
l1norm += Math.abs(queryVector.get(i).floatValue() - decodedDocVector[i]);
}
return l1norm;
}
Expand All @@ -104,21 +90,14 @@ public double l2Norm(byte[] queryVector) {

@Override
public double l2Norm(float[] queryVector) {
ByteBuffer byteBuffer = wrap(docVector);
double l2norm = 0;
for (float queryValue : queryVector) {
double diff = byteBuffer.getFloat() - queryValue;
l2norm += diff * diff;
}
return Math.sqrt(l2norm);
return Math.sqrt(VectorUtil.squareDistance(queryVector, decodedDocVector));
}

@Override
public double l2Norm(List<Number> queryVector) {
ByteBuffer byteBuffer = wrap(docVector);
double l2norm = 0;
for (Number number : queryVector) {
double diff = byteBuffer.getFloat() - number.floatValue();
for (int i = 0; i < queryVector.size(); i++) {
double diff = decodedDocVector[i] - queryVector.get(i).floatValue();
l2norm += diff * diff;
}
return Math.sqrt(l2norm);
Expand Down Expand Up @@ -156,8 +135,4 @@ public boolean isEmpty() {
public int getDims() {
return dims;
}

private static ByteBuffer wrap(BytesRef dv) {
return ByteBuffer.wrap(dv.bytes, dv.offset, dv.length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,30 @@
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues;
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;

import java.io.IOException;

public class BinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {

protected final BinaryDocValues input;
protected final Version indexVersion;
protected final int dims;
protected BytesRef value;
private final BinaryDocValues input;
private final float[] vectorValue;
private final Version indexVersion;
private boolean decoded;
private final int dims;
private BytesRef value;

public BinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims, Version indexVersion) {
super(name, elementType);
this.input = input;
this.indexVersion = indexVersion;
this.dims = dims;
this.vectorValue = new float[dims];
}

@Override
public void setNextDocId(int docId) throws IOException {
decoded = false;
if (input.advanceExact(docId)) {
value = input.binaryValue();
} else {
Expand All @@ -54,20 +59,28 @@ public DenseVector get() {
if (isEmpty()) {
return DenseVector.EMPTY;
}

return new BinaryDenseVector(value, dims, indexVersion);
decodeVectorIfNecessary();
return new BinaryDenseVector(vectorValue, value, dims, indexVersion);
}

@Override
public DenseVector get(DenseVector defaultValue) {
if (isEmpty()) {
return defaultValue;
}
return new BinaryDenseVector(value, dims, indexVersion);
decodeVectorIfNecessary();
return new BinaryDenseVector(vectorValue, value, dims, indexVersion);
}

@Override
public DenseVector getInternal() {
return get(null);
}

private void decodeVectorIfNecessary() {
if (decoded == false && value != null) {
VectorEncoderDecoder.decodeDenseVector(value, vectorValue);
decoded = true;
}
}
}
Loading