Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/96716.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 96716
summary: Feature/speed up binary vector decoding
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.time.ZoneId;
import java.util.Locale;
import java.util.Map;
Expand All @@ -64,6 +65,8 @@
* A {@link FieldMapper} for indexing a dense vector of floats.
*/
public class DenseVectorFieldMapper extends FieldMapper {
public static final Version MAGNITUDE_STORED_INDEX_VERSION = Version.V_7_5_0;
public static final Version LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = Version.V_8_9_0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be the last version prior to your change's version using the new TransportVersion constants?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jdconrad this is a transport/wire serialization thing. its an index version thing. From my understanding index versioning is different. I will see what I can find.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked @thecoop and me just using Version here is OK. We may need to update it to IndexVersion depending on which commits make it in first :)


public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 2048; // maximum allowed number of dimensions
Expand Down Expand Up @@ -353,6 +356,11 @@ public Field parseKnnVector(DocumentParserContext context, DenseVectorFieldMappe
fieldMapper.checkDimensionMatches(index, context);
return dotProduct;
}

@Override
ByteBuffer createByteBuffer(Version indexVersion, int numBytes) {
return ByteBuffer.wrap(new byte[numBytes]);
}
},

FLOAT(4) {
Expand Down Expand Up @@ -460,6 +468,13 @@ public Field parseKnnVector(DocumentParserContext context, DenseVectorFieldMappe
checkVectorBounds(vector);
return dotProduct;
}

@Override
ByteBuffer createByteBuffer(Version indexVersion, int numBytes) {
return indexVersion.onOrAfter(LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION)
? ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN)
: ByteBuffer.wrap(new byte[numBytes]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

};

final int elementBytes;
Expand All @@ -483,6 +498,8 @@ public Field parseKnnVector(DocumentParserContext context, DenseVectorFieldMappe
abstract double parseKnnVectorToByteBuffer(DocumentParserContext context, DenseVectorFieldMapper fieldMapper, ByteBuffer byteBuffer)
throws IOException;

abstract ByteBuffer createByteBuffer(Version indexVersion, int numBytes);

public abstract void checkVectorBounds(float[] vector);

abstract void checkVectorMagnitude(
Expand Down Expand Up @@ -890,18 +907,18 @@ private Field parseKnnVector(DocumentParserContext context) throws IOException {
private Field parseBinaryDocValuesVector(DocumentParserContext context) throws IOException {
// encode array of floats as array of integers and store into buf
// this code is here and not int the VectorEncoderDecoder so not to create extra arrays
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to your change, but would you mind fixing the not int -> not in?

byte[] bytes = indexCreatedVersion.onOrAfter(Version.V_7_5_0)
? new byte[dims * elementType.elementBytes + MAGNITUDE_BYTES]
: new byte[dims * elementType.elementBytes];
int numBytes = indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)
? dims * elementType.elementBytes + MAGNITUDE_BYTES
: dims * elementType.elementBytes;

ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
ByteBuffer byteBuffer = elementType.createByteBuffer(indexCreatedVersion, numBytes);
double dotProduct = elementType.parseKnnVectorToByteBuffer(context, this, byteBuffer);
if (indexCreatedVersion.onOrAfter(Version.V_7_5_0)) {
if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}
return new BinaryDocValuesField(fieldType().name(), new BytesRef(bytes));
return new BinaryDocValuesField(fieldType().name(), new BytesRef(byteBuffer.array()));
}

private void checkDimensionExceeded(int index, DocumentParserContext context) {
Expand Down Expand Up @@ -1000,7 +1017,7 @@ public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() {
if (indexed) {
return new IndexedSyntheticFieldLoader();
}
return new DocValuesSyntheticFieldLoader();
return new DocValuesSyntheticFieldLoader(indexCreatedVersion);
}

private class IndexedSyntheticFieldLoader implements SourceLoader.SyntheticFieldLoader {
Expand Down Expand Up @@ -1060,6 +1077,11 @@ public void write(XContentBuilder b) throws IOException {
private class DocValuesSyntheticFieldLoader implements SourceLoader.SyntheticFieldLoader {
private BinaryDocValues values;
private boolean hasValue;
private final Version indexCreatedVersion;

private DocValuesSyntheticFieldLoader(Version indexCreatedVersion) {
this.indexCreatedVersion = indexCreatedVersion;
}

@Override
public Stream<Map.Entry<String, StoredFieldLoader>> storedFieldLoaders() {
Expand Down Expand Up @@ -1091,6 +1113,9 @@ public void write(XContentBuilder b) throws IOException {
b.startArray(simpleName());
BytesRef ref = values.binaryValue();
ByteBuffer byteBuffer = ByteBuffer.wrap(ref.bytes, ref.offset, ref.length);
if (indexCreatedVersion.onOrAfter(LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION)) {
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
}
for (int dim = 0; dim < dims; dim++) {
elementType.readAndWriteValue(byteBuffer, b);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
import org.elasticsearch.Version;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION;

public final class VectorEncoderDecoder {
public static final byte INT_BYTES = 4;

private VectorEncoderDecoder() {}

public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) {
return indexVersion.onOrAfter(Version.V_7_5_0) ? (vectorBR.length - INT_BYTES) / INT_BYTES : vectorBR.length / INT_BYTES;
return indexVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)
? (vectorBR.length - INT_BYTES) / INT_BYTES
: vectorBR.length / INT_BYTES;
}

/**
Expand All @@ -28,8 +35,10 @@ public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) {
* equal to 7.5.0, since vectors created prior to that do not store the magnitude.
*/
public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) {
assert indexVersion.onOrAfter(Version.V_7_5_0);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
assert indexVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION);
ByteBuffer byteBuffer = indexVersion.onOrAfter(LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION)
? ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN)
: ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - INT_BYTES);
}

Expand All @@ -49,7 +58,7 @@ public static float getMagnitude(Version indexVersion, BytesRef vectorBR, float[
if (vectorBR == null) {
throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE);
}
if (indexVersion.onOrAfter(Version.V_7_5_0)) {
if (indexVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
return decodeMagnitude(indexVersion, vectorBR);
} else {
return calculateMagnitude(decodedVector);
Expand All @@ -61,13 +70,20 @@ public static float getMagnitude(Version indexVersion, BytesRef vectorBR, float[
* @param vectorBR - dense vector encoded in BytesRef
* @param vector - array of floats where the decoded vector should be stored
*/
public static void decodeDenseVector(BytesRef vectorBR, float[] vector) {
public static void decodeDenseVector(Version indexVersion, BytesRef vectorBR, float[] vector) {
if (vectorBR == null) {
throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE);
}
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
for (int dim = 0; dim < vector.length; dim++) {
vector[dim] = byteBuffer.getFloat((dim * Float.BYTES) + vectorBR.offset);
if (indexVersion.onOrAfter(LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION)) {
FloatBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
fb.get(vector);
} else {
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
for (int dim = 0; dim < vector.length; dim++) {
vector[dim] = byteBuffer.getFloat((dim * Float.BYTES) + vectorBR.offset);
}
Comment on lines +78 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could .asFloatBuffer() not be used for both little and big endian?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.asFloatBuffer() is marginally slower for BE. These implementations are the fastest I could get them.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public DenseVector getInternal() {

private void decodeVectorIfNecessary() {
if (decoded == false && value != null) {
VectorEncoderDecoder.decodeDenseVector(value, vectorValue);
VectorEncoderDecoder.decodeDenseVector(indexVersion, value, vectorValue);
decoded = true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,11 @@ public long cost() {
}

public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, Version indexVersion) {
byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0)
? new byte[elementType.elementBytes * values.length + DenseVectorFieldMapper.MAGNITUDE_BYTES]
: new byte[elementType.elementBytes * values.length];
int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION)
? elementType.elementBytes * values.length + DenseVectorFieldMapper.MAGNITUDE_BYTES
: elementType.elementBytes * values.length;
double dotProduct = 0f;

ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes);
for (float value : values) {
if (elementType == ElementType.FLOAT) {
byteBuffer.putFloat(value);
Expand All @@ -239,12 +238,12 @@ public static BytesRef mockEncodeDenseVector(float[] values, ElementType element
dotProduct += value * value;
}

if (indexVersion.onOrAfter(Version.V_7_5_0)) {
if (indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}
return new BytesRef(bytes);
return new BytesRef(byteBuffer.array());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.junit.AssumptionViolatedException;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;

Expand Down Expand Up @@ -465,11 +464,7 @@ public void testAddDocumentsToIndexBefore_V_7_5_0() throws Exception {
private static float[] decodeDenseVector(Version indexVersion, BytesRef encodedVector) {
int dimCount = VectorEncoderDecoder.denseVectorLength(indexVersion, encodedVector);
float[] vector = new float[dimCount];

ByteBuffer byteBuffer = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length);
for (int dim = 0; dim < dimCount; dim++) {
vector[dim] = byteBuffer.getFloat();
}
VectorEncoderDecoder.decodeDenseVector(indexVersion, encodedVector, vector);
return vector;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,48 @@
package org.elasticsearch.index.mapper.vectors;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.VersionUtils;

import java.nio.ByteBuffer;
import java.util.List;

import static org.hamcrest.Matchers.equalTo;

public class VectorEncoderDecoderTests extends ESTestCase {

public void testVectorDecodingWithOffset() {
float[] inputFloats = new float[] { 1f, 2f, 3f, 4f };
ByteBuffer byteBuffer = ByteBuffer.allocate(20);
double magnitude = 0.0;
for (float f : inputFloats) {
byteBuffer.putFloat(f);
magnitude += f * f;
float[] expected = new float[] { 2f, 3f, 4f };
int dims = 3;
for (Version version : List.of(
VersionUtils.randomVersionBetween(
random(),
DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION,
VersionUtils.getPreviousVersion(DenseVectorFieldMapper.LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION)
),
DenseVectorFieldMapper.LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION
)) {
ByteBuffer byteBuffer = DenseVectorFieldMapper.ElementType.FLOAT.createByteBuffer(version, 20);
double magnitude = 0.0;
for (float f : inputFloats) {
byteBuffer.putFloat(f);
magnitude += f * f;
}
// Binary documents store magnitude in a float at the end of the buffer array
magnitude /= 4;
byteBuffer.putFloat((float) magnitude);
BytesRef floatBytes = new BytesRef(byteBuffer.array());
// adjust so that we have an offset ignoring the first float
floatBytes.length = 16;
floatBytes.offset = 4;
// since we are ignoring the first float to mock an offset, our dimensions can be assumed to be 3
float[] outputFloats = new float[dims];
VectorEncoderDecoder.decodeDenseVector(version, floatBytes, outputFloats);
assertArrayEquals(outputFloats, expected, 0f);
assertThat(VectorEncoderDecoder.decodeMagnitude(version, floatBytes), equalTo((float) magnitude));
}
// Binary documents store magnitude in a float at the end of the buffer array
magnitude /= 4;
byteBuffer.putFloat((float) magnitude);
BytesRef floatBytes = new BytesRef(byteBuffer.array());
// adjust so that we have an offset ignoring the first float
floatBytes.length = 16;
floatBytes.offset = 4;
// since we are ignoring the first float to mock an offset, our dimensions can be assumed to be 3
float[] outputFloats = new float[3];
VectorEncoderDecoder.decodeDenseVector(floatBytes, outputFloats);
assertArrayEquals(outputFloats, new float[] { 2f, 3f, 4f }, 0f);
}

}