Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/125103.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125103
summary: Fix LTR query feature with phrases (and two-phase) queries
area: Ranking
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

Expand All @@ -25,52 +24,52 @@
* respective feature name.
*/
public class QueryFeatureExtractor implements FeatureExtractor {

private final List<String> featureNames;
private final List<Weight> weights;
private final List<Scorer> scorers;
private DisjunctionDISIApproximation rankerIterator;

private final DisiPriorityQueue subScorers;
private DisjunctionDISIApproximation approximation;

public QueryFeatureExtractor(List<String> featureNames, List<Weight> weights) {
if (featureNames.size() != weights.size()) {
throw new IllegalArgumentException("[featureNames] and [weights] must be the same size.");
}
this.featureNames = featureNames;
this.weights = weights;
this.scorers = new ArrayList<>(weights.size());
this.subScorers = new DisiPriorityQueue(weights.size());
}

@Override
public void setNextReader(LeafReaderContext segmentContext) throws IOException {
DisiPriorityQueue disiPriorityQueue = new DisiPriorityQueue(weights.size());
scorers.clear();
for (Weight weight : weights) {
subScorers.clear();
for (int i = 0; i < weights.size(); i++) {
var weight = weights.get(i);
if (weight == null) {
scorers.add(null);
continue;
}
Scorer scorer = weight.scorer(segmentContext);
if (scorer != null) {
disiPriorityQueue.add(new DisiWrapper(scorer, false));
subScorers.add(new FeatureDisiWrapper(scorer, featureNames.get(i)));
}
scorers.add(scorer);
}

rankerIterator = disiPriorityQueue.size() > 0 ? new DisjunctionDISIApproximation(disiPriorityQueue) : null;
approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(subScorers) : null;
}

@Override
public void addFeatures(Map<String, Object> featureMap, int docId) throws IOException {
if (rankerIterator == null) {
if (approximation == null || approximation.docID() > docId) {
return;
}

rankerIterator.advance(docId);
for (int i = 0; i < featureNames.size(); i++) {
Scorer scorer = scorers.get(i);
// Do we have a scorer, and does it match the provided document?
if (scorer != null && scorer.docID() == docId) {
featureMap.put(featureNames.get(i), scorer.score());
if (approximation.docID() < docId) {
approximation.advance(docId);
}
if (approximation.docID() != docId) {
return;
}
var w = (FeatureDisiWrapper) subScorers.topList();
for (; w != null; w = (FeatureDisiWrapper) w.next) {
if (w.twoPhaseView == null || w.twoPhaseView.matches()) {
featureMap.put(w.featureName, w.scorable.score());
}
}
}
Expand All @@ -80,4 +79,12 @@ public List<String> featureNames() {
return featureNames;
}

private static class FeatureDisiWrapper extends DisiWrapper {
final String featureName;

FeatureDisiWrapper(Scorer scorer, String featureName) {
super(scorer, false);
this.featureName = featureName;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.apache.lucene.document.IntField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
Expand Down Expand Up @@ -43,13 +43,11 @@

public class QueryFeatureExtractorTests extends AbstractBuilderTestCase {

private Directory dir;
private IndexReader reader;
private IndexSearcher searcher;

private void addDocs(String[] textValues, int[] numberValues) throws IOException {
dir = newDirectory();
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir)) {
private IndexReader addDocs(Directory dir, String[] textValues, int[] numberValues) throws IOException {
var config = newIndexWriterConfig();
// override the merge policy to ensure that docs remain in the same ingestion order
config.setMergePolicy(newLogMergePolicy(random()));
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir, config)) {
for (int i = 0; i < textValues.length; i++) {
Document doc = new Document();
doc.add(newTextField(TEXT_FIELD_NAME, textValues[i], Field.Store.NO));
Expand All @@ -59,98 +57,119 @@ private void addDocs(String[] textValues, int[] numberValues) throws IOException
indexWriter.flush();
}
}
reader = indexWriter.getReader();
return indexWriter.getReader();
}
searcher = newSearcher(reader);
searcher.setSimilarity(new ClassicSimilarity());
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98127")
public void testQueryExtractor() throws IOException {
addDocs(
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
new int[] { 5, 10, 12, 11 }
);
QueryRewriteContext ctx = createQueryRewriteContext();
List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
new QueryExtractorBuilder("text_score", QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox")))
.rewrite(ctx),
new QueryExtractorBuilder(
"number_score",
QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
).rewrite(ctx),
new QueryExtractorBuilder(
"matching_none",
QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
).rewrite(ctx),
new QueryExtractorBuilder(
"matching_missing_field",
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
).rewrite(ctx)
);
SearchExecutionContext dummySEC = createSearchExecutionContext();
List<Weight> weights = new ArrayList<>();
List<String> featureNames = new ArrayList<>();
for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
weights.add(weight);
featureNames.add(qeb.featureName());
}
QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
List<Map<String, Object>> extractedFeatures = new ArrayList<>();
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
int maxDoc = leafReaderContext.reader().maxDoc();
queryFeatureExtractor.setNextReader(leafReaderContext);
for (int i = 0; i < maxDoc; i++) {
Map<String, Object> featureMap = new HashMap<>();
queryFeatureExtractor.addFeatures(featureMap, i);
extractedFeatures.add(featureMap);
try (var dir = newDirectory()) {
try (
var reader = addDocs(
dir,
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
new int[] { 5, 10, 12, 11 }
)
) {
var searcher = newSearcher(reader);
searcher.setSimilarity(new ClassicSimilarity());
QueryRewriteContext ctx = createQueryRewriteContext();
List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
new QueryExtractorBuilder(
"text_score",
QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox"))
).rewrite(ctx),
new QueryExtractorBuilder(
"number_score",
QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
).rewrite(ctx),
new QueryExtractorBuilder(
"matching_none",
QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
).rewrite(ctx),
new QueryExtractorBuilder(
"matching_missing_field",
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
).rewrite(ctx),
new QueryExtractorBuilder(
"phrase_score",
QueryProvider.fromParsedQuery(QueryBuilders.matchPhraseQuery(TEXT_FIELD_NAME, "slow brown fox"))
).rewrite(ctx)
);
SearchExecutionContext dummySEC = createSearchExecutionContext();
List<Weight> weights = new ArrayList<>();
List<String> featureNames = new ArrayList<>();
for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
weights.add(weight);
featureNames.add(qeb.featureName());
}
QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
List<Map<String, Object>> extractedFeatures = new ArrayList<>();
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
int maxDoc = leafReaderContext.reader().maxDoc();
queryFeatureExtractor.setNextReader(leafReaderContext);
for (int i = 0; i < maxDoc; i++) {
Map<String, Object> featureMap = new HashMap<>();
queryFeatureExtractor.addFeatures(featureMap, i);
extractedFeatures.add(featureMap);
}
}
assertThat(extractedFeatures, hasSize(4));
// Should never add features for queries that don't match a document or on documents where the field is missing
for (Map<String, Object> features : extractedFeatures) {
assertThat(features, not(hasKey("matching_none")));
assertThat(features, not(hasKey("matching_missing_field")));
}
// First two only match the text field
assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
assertThat(extractedFeatures.get(0), not(hasKey("phrase_score")));
assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
assertThat(extractedFeatures.get(1), hasEntry("phrase_score", 2.468971f));

// Only matches the range query
assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
assertThat(extractedFeatures.get(2), not(hasKey("phrase_score")));

// No query matches
assertThat(extractedFeatures.get(3), anEmptyMap());
}
}
assertThat(extractedFeatures, hasSize(4));
// Should never add features for queries that don't match a document or on documents where the field is missing
for (Map<String, Object> features : extractedFeatures) {
assertThat(features, not(hasKey("matching_none")));
assertThat(features, not(hasKey("matching_missing_field")));
}
// First two only match the text field
assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
// Only matches the range query
assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
// No query matches
assertThat(extractedFeatures.get(3), anEmptyMap());
reader.close();
dir.close();
}

public void testEmptyDisiPriorityQueue() throws IOException {
addDocs(
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
new int[] { 5, 10, 12, 11 }
);
try (var dir = newDirectory()) {
var config = newIndexWriterConfig();
config.setMergePolicy(NoMergePolicy.INSTANCE);
try (
var reader = addDocs(
dir,
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
new int[] { 5, 10, 12, 11 }
)
) {

// Scorers returned by weights are null
List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();
var searcher = newSearcher(reader);
searcher.setSimilarity(new ClassicSimilarity());

QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
// Scorers returned by weights are null
List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();

for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
int maxDoc = leafReaderContext.reader().maxDoc();
featureExtractor.setNextReader(leafReaderContext);
for (int i = 0; i < maxDoc; i++) {
Map<String, Object> featureMap = new HashMap<>();
featureExtractor.addFeatures(featureMap, i);
assertThat(featureMap, anEmptyMap());
QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
int maxDoc = leafReaderContext.reader().maxDoc();
featureExtractor.setNextReader(leafReaderContext);
for (int i = 0; i < maxDoc; i++) {
Map<String, Object> featureMap = new HashMap<>();
featureExtractor.addFeatures(featureMap, i);
assertThat(featureMap, anEmptyMap());
}
}
}
}

reader.close();
dir.close();
}
}
Loading