- Notifications
You must be signed in to change notification settings - Fork 25.7k
Support querying multiple indices with the simplified linear retriever #133720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
f6b26ad c97afae 6f7bcef f3fa897 22c4618 5184425 7e4a231 207fc24 1e961ce da3a3a8 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| pr: 133720 | ||
| summary: Support querying multiple indices with the simplified linear retriever | ||
| area: Relevance | ||
| type: enhancement | ||
| issues: [] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -13,8 +13,14 @@ | |
| import org.elasticsearch.common.regex.Regex; | ||
| import org.elasticsearch.common.settings.Settings; | ||
| import org.elasticsearch.core.Nullable; | ||
| import org.elasticsearch.core.Tuple; | ||
| import org.elasticsearch.index.mapper.IndexFieldMapper; | ||
| import org.elasticsearch.index.query.BoolQueryBuilder; | ||
| import org.elasticsearch.index.query.MatchQueryBuilder; | ||
| import org.elasticsearch.index.query.MultiMatchQueryBuilder; | ||
| import org.elasticsearch.index.query.QueryBuilder; | ||
| import org.elasticsearch.index.query.TermQueryBuilder; | ||
| import org.elasticsearch.index.query.TermsQueryBuilder; | ||
| import org.elasticsearch.index.search.QueryParserHelper; | ||
| import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; | ||
| import org.elasticsearch.search.retriever.RetrieverBuilder; | ||
| | @@ -114,20 +120,39 @@ public static ActionRequestValidationException validateParams( | |
| * Generate the inner retriever tree for the given fields, weights, and query. The tree follows this structure: | ||
| * | ||
| * <pre> | ||
| * multi_match query on all lexical fields | ||
| * standard retriever for querying lexical fields using multi_match. | ||
| * normalizer retriever | ||
| * match query on semantic_text field A | ||
| * match query on semantic_text field B | ||
| * match query on semantic_text field A with inference ID id1 | ||
| * match query on semantic_text field A with inference ID id2 | ||
| * match query on semantic_text field B with inference ID id1 | ||
| * ... | ||
| * match query on semantic_text field Z | ||
| * match query on semantic_text field Z with inference ID idN | ||
| * </pre> | ||
| * | ||
| * <p> | ||
| * Where the normalizer retriever is constructed by the {@code innerNormalizerGenerator} function. | ||
| * </p> | ||
| * | ||
| * <p> | ||
| * When the same lexical fields are queried for all indices, we use a single multi_match query to query them. | ||
| * Otherwise, we create a boolean query with the following structure: | ||
| * </p> | ||
| * | ||
| * <pre> | ||
| * bool | ||
| * should | ||
| * bool | ||
| * match query on lexical fields for index A | ||
| * filter on indexA | ||
| * bool | ||
| * match query on lexical fields for index B | ||
| * filter on indexB | ||
| * ... | ||
| * </pre> | ||
| * | ||
| * <p> | ||
| * This tree structure is repeated for each index in {@code indicesMetadata}. That is to say, that for each index in | ||
| * {@code indicesMetadata}, (up to) a pair of retrievers will be added to the returned {@code RetrieverBuilder} list. | ||
| * The semantic_text fields are grouped by inference ID. For each (fieldName, inferenceID) pair we generate a match query. | ||
| * Since we have no way to effectively filter on inference IDs, we filter on index names instead. | ||
| * </p> | ||
| * | ||
| * @param fieldsAndWeights The fields to query and their respective weights, in "field^weight" format | ||
| | @@ -150,32 +175,104 @@ public static List<RetrieverBuilder> generateInnerRetrievers( | |
| if (weightValidator != null) { | ||
| parsedFieldsAndWeights.values().forEach(weightValidator); | ||
| } | ||
| | ||
| // We expect up to 2 inner retrievers to be generated for each index queried | ||
| List<RetrieverBuilder> innerRetrievers = new ArrayList<>(indicesMetadata.size() * 2); | ||
| for (IndexMetadata indexMetadata : indicesMetadata) { | ||
| innerRetrievers.addAll( | ||
| generateInnerRetrieversForIndex(parsedFieldsAndWeights, query, indexMetadata, innerNormalizerGenerator, weightValidator) | ||
| ); | ||
| List<RetrieverBuilder> innerRetrievers = new ArrayList<>(2); | ||
| // add lexical retriever | ||
| RetrieverBuilder lexicalRetriever = generateLexicalRetriever(parsedFieldsAndWeights, indicesMetadata, query, weightValidator); | ||
| if (lexicalRetriever != null) { | ||
| innerRetrievers.add(lexicalRetriever); | ||
| } | ||
| // add semantic retriever | ||
| RetrieverBuilder semanticRetriever = generateSemanticRetriever( | ||
| parsedFieldsAndWeights, | ||
| indicesMetadata, | ||
| query, | ||
| innerNormalizerGenerator, | ||
| weightValidator | ||
| ); | ||
| if (semanticRetriever != null) { | ||
| innerRetrievers.add(semanticRetriever); | ||
| } | ||
| | ||
| return innerRetrievers; | ||
| } | ||
| | ||
| private static List<RetrieverBuilder> generateInnerRetrieversForIndex( | ||
| private static RetrieverBuilder generateSemanticRetriever( | ||
| Map<String, Float> parsedFieldsAndWeights, | ||
| Collection<IndexMetadata> indicesMetadata, | ||
| String query, | ||
| IndexMetadata indexMetadata, | ||
| Function<List<WeightedRetrieverSource>, CompoundRetrieverBuilder<?>> innerNormalizerGenerator, | ||
| @Nullable Consumer<Float> weightValidator | ||
| ) { | ||
| // Form groups of (fieldName, inferenceID) that need to be queried. | ||
| // For each (fieldName, inferenceID) pair determine the weight that needs to be applied and the indices that need to be queried. | ||
| Map<Tuple<String, String>, List<String>> groupedIndices = new HashMap<>(); | ||
| Map<Tuple<String, String>, Float> groupedWeights = new HashMap<>(); | ||
| for (IndexMetadata indexMetadata : indicesMetadata) { | ||
| inferenceFieldsAndWeightsForIndex(parsedFieldsAndWeights, indexMetadata, weightValidator).forEach((fieldName, weight) -> { | ||
| String indexName = indexMetadata.getIndex().getName(); | ||
| Tuple<String, String> fieldAndInferenceId = new Tuple<>( | ||
| fieldName, | ||
| indexMetadata.getInferenceFields().get(fieldName).getInferenceId() | ||
| ); | ||
| | ||
| if (groupedWeights.containsKey(fieldAndInferenceId) && groupedWeights.get(fieldAndInferenceId).equals(weight) == false) { | ||
| String conflictingIndexName = groupedIndices.get(fieldAndInferenceId).getFirst(); | ||
| throw new IllegalArgumentException( | ||
| "field [" + fieldName + "] has different weights in indices [" + conflictingIndexName + "] and [" + indexName + "]" | ||
| ); | ||
Mikep86 marked this conversation as resolved. Show resolved Hide resolved | ||
| } | ||
| | ||
| groupedWeights.put(fieldAndInferenceId, weight); | ||
| groupedIndices.computeIfAbsent(fieldAndInferenceId, k -> new ArrayList<>()).add(indexName); | ||
| }); | ||
| } | ||
| | ||
| // there are no semantic_text fields that need to be queried, no need to create a retriever | ||
| if (groupedIndices.isEmpty()) { | ||
| return null; | ||
| } | ||
| | ||
| // for each (fieldName, inferenceID) pair generate a standard retriever with a semantic query | ||
| List<WeightedRetrieverSource> semanticRetrievers = new ArrayList<>(groupedIndices.size()); | ||
| groupedIndices.forEach((fieldAndInferenceId, indexNames) -> { | ||
| String fieldName = fieldAndInferenceId.v1(); | ||
| Float weight = groupedWeights.get(fieldAndInferenceId); | ||
| | ||
| QueryBuilder queryBuilder = new MatchQueryBuilder(fieldName, query); | ||
| | ||
| // when we query more than one index, we need to filter on indexNames | ||
| if (indicesMetadata.size() > 1) { | ||
| queryBuilder = new BoolQueryBuilder().should(queryBuilder).filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indexNames)); | ||
| } | ||
| | ||
| RetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder(queryBuilder); | ||
| semanticRetrievers.add(new WeightedRetrieverSource(CompoundRetrieverBuilder.RetrieverSource.from(retrieverBuilder), weight)); | ||
| }); | ||
| | ||
| return innerNormalizerGenerator.apply(semanticRetrievers); | ||
| } | ||
| | ||
| private static Map<String, Float> defaultFieldsAndWeightsForIndex( | ||
| IndexMetadata indexMetadata, | ||
| @Nullable Consumer<Float> weightValidator | ||
| ) { | ||
| Settings settings = indexMetadata.getSettings(); | ||
| List<String> defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings)); | ||
| Map<String, Float> fieldsAndWeights = QueryParserHelper.parseFieldsAndWeights(defaultFields); | ||
| if (weightValidator != null) { | ||
| fieldsAndWeights.values().forEach(weightValidator); | ||
| } | ||
| return fieldsAndWeights; | ||
| } | ||
| | ||
| private static Map<String, Float> inferenceFieldsAndWeightsForIndex( | ||
| Map<String, Float> parsedFieldsAndWeights, | ||
| IndexMetadata indexMetadata, | ||
| @Nullable Consumer<Float> weightValidator | ||
| ) { | ||
| Map<String, Float> fieldsAndWeightsToQuery = parsedFieldsAndWeights; | ||
| if (fieldsAndWeightsToQuery.isEmpty()) { | ||
| Settings settings = indexMetadata.getSettings(); | ||
| List<String> defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings)); | ||
| fieldsAndWeightsToQuery = QueryParserHelper.parseFieldsAndWeights(defaultFields); | ||
| if (weightValidator != null) { | ||
| fieldsAndWeightsToQuery.values().forEach(weightValidator); | ||
| } | ||
| fieldsAndWeightsToQuery = defaultFieldsAndWeightsForIndex(indexMetadata, weightValidator); | ||
| } | ||
| | ||
| Map<String, Float> inferenceFields = new HashMap<>(); | ||
| | @@ -198,30 +295,77 @@ private static List<RetrieverBuilder> generateInnerRetrieversForIndex( | |
| } | ||
| } | ||
| } | ||
| return inferenceFields; | ||
| } | ||
| | ||
| private static Map<String, Float> nonInferenceFieldsAndWeightsForIndex( | ||
| Map<String, Float> fieldsAndWeightsToQuery, | ||
| IndexMetadata indexMetadata, | ||
| @Nullable Consumer<Float> weightValidator | ||
| ) { | ||
| Map<String, Float> nonInferenceFields = new HashMap<>(fieldsAndWeightsToQuery); | ||
| nonInferenceFields.keySet().removeAll(inferenceFields.keySet()); // Remove all inference fields from non-inference fields map | ||
| | ||
| // TODO: Set index pre-filters on returned retrievers when we want to implement multi-index support | ||
| List<RetrieverBuilder> innerRetrievers = new ArrayList<>(2); | ||
| if (nonInferenceFields.isEmpty() == false) { | ||
| MultiMatchQueryBuilder nonInferenceFieldQueryBuilder = new MultiMatchQueryBuilder(query).type( | ||
| MultiMatchQueryBuilder.Type.MOST_FIELDS | ||
| ).fields(nonInferenceFields); | ||
| innerRetrievers.add(new StandardRetrieverBuilder(nonInferenceFieldQueryBuilder)); | ||
| if (nonInferenceFields.isEmpty()) { | ||
| nonInferenceFields = defaultFieldsAndWeightsForIndex(indexMetadata, weightValidator); | ||
| } | ||
| if (inferenceFields.isEmpty() == false) { | ||
| List<WeightedRetrieverSource> inferenceFieldRetrievers = new ArrayList<>(inferenceFields.size()); | ||
| inferenceFields.forEach((f, w) -> { | ||
| RetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder(new MatchQueryBuilder(f, query)); | ||
| inferenceFieldRetrievers.add( | ||
| new WeightedRetrieverSource(CompoundRetrieverBuilder.RetrieverSource.from(retrieverBuilder), w) | ||
| ); | ||
| }); | ||
| | ||
| innerRetrievers.add(innerNormalizerGenerator.apply(inferenceFieldRetrievers)); | ||
| nonInferenceFields.keySet().removeAll(indexMetadata.getInferenceFields().keySet()); | ||
| return nonInferenceFields; | ||
| } | ||
| | ||
| private static RetrieverBuilder generateLexicalRetriever( | ||
Mikep86 marked this conversation as resolved. Show resolved Hide resolved | ||
| Map<String, Float> fieldsAndWeightsToQuery, | ||
| Collection<IndexMetadata> indicesMetadata, | ||
| String query, | ||
| @Nullable Consumer<Float> weightValidator | ||
| ) { | ||
| List<QueryBuilder> lexicalQueryBuilders = new ArrayList<>(); | ||
| Map<String, Float> nonInferenceFields = null; | ||
| Boolean differentNonInferenceFields = false; | ||
| | ||
| for (IndexMetadata indexMetadata : indicesMetadata) { | ||
| Map<String, Float> nonInferenceFieldsForIndex = nonInferenceFieldsAndWeightsForIndex( | ||
| fieldsAndWeightsToQuery, | ||
| indexMetadata, | ||
| weightValidator | ||
| ); | ||
| | ||
| if (nonInferenceFields == null) { | ||
| nonInferenceFields = nonInferenceFieldsForIndex; | ||
| } else if (nonInferenceFields.equals(nonInferenceFieldsForIndex) == false) { | ||
| differentNonInferenceFields = true; | ||
| } | ||
| | ||
| if (nonInferenceFieldsForIndex.isEmpty()) { | ||
| continue; | ||
| } | ||
| | ||
| lexicalQueryBuilders.add( | ||
| new BoolQueryBuilder().should( | ||
| new MultiMatchQueryBuilder(query).type(MultiMatchQueryBuilder.Type.MOST_FIELDS).fields(nonInferenceFieldsForIndex) | ||
| ).filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexMetadata.getIndex().getName())) | ||
| ); | ||
| } | ||
| return innerRetrievers; | ||
| // there are no lexical fields that need to be queried, no need to create a retriever | ||
| if (lexicalQueryBuilders.isEmpty()) { | ||
| ||
| return null; | ||
| } | ||
| | ||
| // all indices query the same non inference fields, we can return a single multi_match query | ||
| if (differentNonInferenceFields == false) { | ||
| return new StandardRetrieverBuilder( | ||
| new MultiMatchQueryBuilder(query).type(MultiMatchQueryBuilder.Type.MOST_FIELDS).fields(nonInferenceFields) | ||
| ); | ||
| } | ||
| | ||
| // only a single lexical query, no need to wrap in a boolean query | ||
| if (lexicalQueryBuilders.size() == 1) { | ||
| return new StandardRetrieverBuilder(lexicalQueryBuilders.getFirst()); | ||
| } | ||
| | ||
| BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); | ||
| lexicalQueryBuilders.forEach(boolQueryBuilder::should); | ||
| return new StandardRetrieverBuilder(boolQueryBuilder); | ||
| } | ||
| | ||
| private static void addToInferenceFieldsMap(Map<String, Float> inferenceFields, String field, Float weight) { | ||
| | ||
Uh oh!
There was an error while loading. Please reload this page.