Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/117199.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117199
summary: Speed up bit compared with floats or bytes script operations
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,7 @@ public static int ipByteBit(byte[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
int result = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = Byte.SIZE - 1; j >= 0; j--) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
}
}
}
return result;
return IMPL.ipByteBit(q, d);
}

/**
Expand All @@ -87,16 +77,7 @@ public static float ipFloatBit(float[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
float result = 0;
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = Byte.SIZE - 1; j >= 0; j--) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
}
}
}
return result;
return IMPL.ipFloatBit(q, d);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,81 @@
package org.elasticsearch.simdvec.internal.vectorization;

import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;

final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {

private static float fma(float a, float b, float c) {
if (Constants.HAS_FAST_SCALAR_FMA) {
return Math.fma(a, b, c);
} else {
return a * b + c;
}
}

DefaultESVectorUtilSupport() {}

@Override
public long ipByteBinByte(byte[] q, byte[] d) {
return ipByteBinByteImpl(q, d);
}

@Override
public int ipByteBit(byte[] q, byte[] d) {
return ipByteBitImpl(q, d);
}

@Override
public float ipFloatBit(float[] q, byte[] d) {
return ipFloatBitImpl(q, d);
}

public static int ipByteBitImpl(byte[] q, byte[] d) {
assert q.length == d.length * Byte.SIZE;
int acc0 = 0;
int acc1 = 0;
int acc2 = 0;
int acc3 = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a drive-by question here (free to disregard): is this intended to allow vectorization?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@svilen-mihaylov-db it allows some vectorization via the unrolling, but it definitely isn't as fast as a custom vectorized version that we could provide with the Panama API. This solution isn't as fast as it could be, for sure.

Mainly, I discovered its much faster than the previous if block and so its a step in the right direction :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

byte mask = d[i];
// Make sure its just 1 or 0

acc0 += q[i * Byte.SIZE + 0] * ((mask >> 7) & 1);
acc1 += q[i * Byte.SIZE + 1] * ((mask >> 6) & 1);
acc2 += q[i * Byte.SIZE + 2] * ((mask >> 5) & 1);
acc3 += q[i * Byte.SIZE + 3] * ((mask >> 4) & 1);

acc0 += q[i * Byte.SIZE + 4] * ((mask >> 3) & 1);
acc1 += q[i * Byte.SIZE + 5] * ((mask >> 2) & 1);
acc2 += q[i * Byte.SIZE + 6] * ((mask >> 1) & 1);
acc3 += q[i * Byte.SIZE + 7] * ((mask >> 0) & 1);
}
return acc0 + acc1 + acc2 + acc3;
}

public static float ipFloatBitImpl(float[] q, byte[] d) {
assert q.length == d.length * Byte.SIZE;
float acc0 = 0;
float acc1 = 0;
float acc2 = 0;
float acc3 = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0);
acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1);
acc2 = fma(q[i * Byte.SIZE + 2], (mask >> 5) & 1, acc2);
acc3 = fma(q[i * Byte.SIZE + 3], (mask >> 4) & 1, acc3);

acc0 = fma(q[i * Byte.SIZE + 4], (mask >> 3) & 1, acc0);
acc1 = fma(q[i * Byte.SIZE + 5], (mask >> 2) & 1, acc1);
acc2 = fma(q[i * Byte.SIZE + 6], (mask >> 1) & 1, acc2);
acc3 = fma(q[i * Byte.SIZE + 7], (mask >> 0) & 1, acc3);
}
return acc0 + acc1 + acc2 + acc3;
}

public static long ipByteBinByteImpl(byte[] q, byte[] d) {
long ret = 0;
int size = d.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ public interface ESVectorUtilSupport {
short B_QUERY = 4;

long ipByteBinByte(byte[] q, byte[] d);

int ipByteBit(byte[] q, byte[] d);

float ipFloatBit(float[] q, byte[] d);
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ public long ipByteBinByte(byte[] q, byte[] d) {
return DefaultESVectorUtilSupport.ipByteBinByteImpl(q, d);
}

@Override
public int ipByteBit(byte[] q, byte[] d) {
return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
}

@Override
public float ipFloatBit(float[] q, byte[] d) {
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
}

private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;

Expand Down