Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public float score(
*
* <p>The results are stored in the provided scores array.
*/
public void scoreBulk(
public float scoreBulk(
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
Expand All @@ -158,6 +158,7 @@ public void scoreBulk(
targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
}
in.readFloats(additionalCorrections, 0, BULK_SIZE);
float maxScore = Float.NEGATIVE_INFINITY;
for (int i = 0; i < BULK_SIZE; i++) {
scores[i] = score(
queryLowerInterval,
Expand All @@ -172,6 +173,10 @@ public void scoreBulk(
additionalCorrections[i],
scores[i]
);
if (scores[i] > maxScore) {
maxScore = scores[i];
}
}
return maxScore;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO
}

@Override
public void scoreBulk(
public float scoreBulk(
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
Expand All @@ -366,7 +366,7 @@ public void scoreBulk(
// 128 / 8 == 16
if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
score256Bulk(
return score256Bulk(
q,
queryLowerInterval,
queryUpperInterval,
Expand All @@ -376,9 +376,8 @@ public void scoreBulk(
centroidDp,
scores
);
return;
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
score128Bulk(
return score128Bulk(
q,
queryLowerInterval,
queryUpperInterval,
Expand All @@ -388,10 +387,9 @@ public void scoreBulk(
centroidDp,
scores
);
return;
}
}
super.scoreBulk(
return super.scoreBulk(
q,
queryLowerInterval,
queryUpperInterval,
Expand All @@ -403,7 +401,7 @@ public void scoreBulk(
);
}

private void score128Bulk(
private float score128Bulk(
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
Expand All @@ -420,6 +418,7 @@ private void score128Bulk(
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
float maxScore = Float.NEGATIVE_INFINITY;
for (; i < limit; i += FLOAT_SPECIES_128.length()) {
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
Expand Down Expand Up @@ -453,6 +452,7 @@ private void score128Bulk(
if (similarityFunction == EUCLIDEAN) {
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
maxScore = res.reduceLanes(VectorOperators.MAX);
res.intoArray(scores, i);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
Expand All @@ -463,17 +463,20 @@ private void score128Bulk(
// not sure how to do it better
for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) {
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
maxScore = Math.max(maxScore, scores[i + j]);
}
} else {
res = res.add(1f).mul(0.5f).max(0);
res.intoArray(scores, i);
maxScore = res.reduceLanes(VectorOperators.MAX);
}
}
}
in.seek(offset + 14L * BULK_SIZE);
return maxScore;
}

private void score256Bulk(
private float score256Bulk(
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
Expand All @@ -490,6 +493,7 @@ private void score256Bulk(
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
float maxScore = Float.NEGATIVE_INFINITY;
for (; i < limit; i += FLOAT_SPECIES_256.length()) {
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var lx = FloatVector.fromMemorySegment(
Expand Down Expand Up @@ -523,6 +527,7 @@ private void score256Bulk(
if (similarityFunction == EUCLIDEAN) {
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
maxScore = res.reduceLanes(VectorOperators.MAX);
res.intoArray(scores, i);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
Expand All @@ -533,13 +538,16 @@ private void score256Bulk(
// not sure how to do it better
for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) {
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
maxScore = Math.max(maxScore, scores[i + j]);
}
} else {
res = res.add(1f).mul(0.5f).max(0);
maxScore = res.reduceLanes(VectorOperators.MAX);
res.intoArray(scores, i);
}
}
}
in.seek(offset + 14L * BULK_SIZE);
return maxScore;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ public int resetPostingsScorer(long offset) throws IOException {
return vectors;
}

void scoreIndividually(int offset) throws IOException {
float scoreIndividually(int offset) throws IOException {
float maxScore = Float.NEGATIVE_INFINITY;
// score individually, first the quantized byte chunk
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[j + offset];
Expand Down Expand Up @@ -407,8 +408,35 @@ void scoreIndividually(int offset) throws IOException {
correctionsAdd[j],
scores[j]
);
if (scores[j] > maxScore) {
maxScore = scores[j];
}
}
}
return maxScore;
}

private static int filterDocs(int[] docIds, int offset, IntPredicate needsScoring) {
int filtered = 0;
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
if (needsScoring.test(docIds[offset + i]) == false) {
docIds[offset + i] = -1;
filtered++;
}
}
return filtered;
}

private static int collect(int[] docIds, int offset, KnnCollector knnCollector, float[] scores) {
int scoredDocs = 0;
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
int doc = docIds[offset + i];
if (doc != -1) {
scoredDocs++;
knnCollector.collect(doc, scores[i]);
}
}
return scoredDocs;
}

@Override
Expand All @@ -418,23 +446,17 @@ public int visit(KnnCollector knnCollector) throws IOException {
int limit = vectors - BULK_SIZE + 1;
int i = 0;
for (; i < limit; i += BULK_SIZE) {
int docsToScore = BULK_SIZE;
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[i + j];
if (needsScoring.test(doc) == false) {
docIdsScratch[i + j] = -1;
docsToScore--;
}
}
int docsToScore = BULK_SIZE - filterDocs(docIdsScratch, i, needsScoring);
if (docsToScore == 0) {
continue;
}
quantizeQueryIfNecessary();
indexInput.seek(slicePos + i * quantizedByteLength);
float maxScore = Float.NEGATIVE_INFINITY;
if (docsToScore < BULK_SIZE / 2) {
scoreIndividually(i);
maxScore = scoreIndividually(i);
} else {
osqVectorsScorer.scoreBulk(
maxScore = osqVectorsScorer.scoreBulk(
quantizedQueryScratch,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
Expand All @@ -445,12 +467,8 @@ public int visit(KnnCollector knnCollector) throws IOException {
scores
);
}
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[i + j];
if (doc != -1) {
scoredDocs++;
knnCollector.collect(doc, scores[j]);
}
if (knnCollector.minCompetitiveSimilarity() < maxScore) {
scoredDocs += collect(docIdsScratch, i, knnCollector, scores);
}
}
// process tail
Expand Down