Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.retriever;

import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.MockSearchService;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
import org.junit.Before;

import java.io.IOException;
import java.util.Collection;
import java.util.List;

import static org.hamcrest.Matchers.equalTo;

public class RetrieverRewriteIT extends ESIntegTestCase {
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(MockSearchService.TestPlugin.class);
}

private static String INDEX_DOCS = "docs";
private static String INDEX_QUERIES = "queries";
private static final String ID_FIELD = "_id";
private static final String QUERY_FIELD = "query";

@Before
public void setup() throws Exception {
createIndex(INDEX_DOCS);
index(INDEX_DOCS, "doc_0", "{}");
index(INDEX_DOCS, "doc_1", "{}");
index(INDEX_DOCS, "doc_2", "{}");
refresh(INDEX_DOCS);

createIndex(INDEX_QUERIES);
index(INDEX_QUERIES, "query_0", "{ \"" + QUERY_FIELD + "\": \"doc_2\"}");
index(INDEX_QUERIES, "query_1", "{ \"" + QUERY_FIELD + "\": \"doc_1\"}");
index(INDEX_QUERIES, "query_2", "{ \"" + QUERY_FIELD + "\": \"doc_0\"}");
refresh(INDEX_QUERIES);
}

public void testRewrite() {
SearchSourceBuilder source = new SearchSourceBuilder();
StandardRetrieverBuilder standard = new StandardRetrieverBuilder();
standard.queryBuilder = QueryBuilders.termQuery(ID_FIELD, "doc_0");
source.retriever(new AssertingRetrieverBuilder(standard));
SearchRequestBuilder req = client().prepareSearch(INDEX_DOCS, INDEX_QUERIES).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertThat(resp.getHits().getTotalHits().value, equalTo(1L));
assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_0"));
});
}

public void testRewriteCompound() {
SearchSourceBuilder source = new SearchSourceBuilder();
source.retriever(new AssertingCompoundRetrieverBuilder("query_0"));
SearchRequestBuilder req = client().prepareSearch(INDEX_DOCS, INDEX_QUERIES).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertThat(resp.getHits().getTotalHits().value, equalTo(1L));
assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
});
}

private static class AssertingRetrieverBuilder extends RetrieverBuilder {
private final RetrieverBuilder innerRetriever;

private AssertingRetrieverBuilder(RetrieverBuilder innerRetriever) {
this.innerRetriever = innerRetriever;
}

@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
assertNull(ctx.getPointInTimeBuilder());
assertNull(ctx.convertToInnerHitsRewriteContext());
assertNull(ctx.convertToCoordinatorRewriteContext());
assertNull(ctx.convertToIndexMetadataContext());
assertNull(ctx.convertToSearchExecutionContext());
assertNull(ctx.convertToDataRewriteContext());
var newRetriever = innerRetriever.rewrite(ctx);
if (newRetriever != innerRetriever) {
return new AssertingRetrieverBuilder(newRetriever);
}
return this;
}

@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder sourceBuilder, boolean compoundUsed) {
assertNull(sourceBuilder.retriever());
innerRetriever.extractToSearchSourceBuilder(sourceBuilder, compoundUsed);
}

@Override
public String getName() {
return "asserting";
}

@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {}

@Override
protected boolean doEquals(Object o) {
return false;
}

@Override
protected int doHashCode() {
return innerRetriever.doHashCode();
}
}

private static class AssertingCompoundRetrieverBuilder extends RetrieverBuilder {
private final String id;
private final SetOnce<RetrieverBuilder> innerRetriever;

private AssertingCompoundRetrieverBuilder(String id) {
this.id = id;
this.innerRetriever = new SetOnce<>(null);
}

private AssertingCompoundRetrieverBuilder(String id, SetOnce<RetrieverBuilder> innerRetriever) {
this.id = id;
this.innerRetriever = innerRetriever;
}

@Override
public boolean isCompound() {
return true;
}

@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
assertNotNull(ctx.getPointInTimeBuilder());
assertNull(ctx.convertToInnerHitsRewriteContext());
assertNull(ctx.convertToCoordinatorRewriteContext());
assertNull(ctx.convertToIndexMetadataContext());
assertNull(ctx.convertToSearchExecutionContext());
assertNull(ctx.convertToDataRewriteContext());
if (innerRetriever.get() != null) {
return this;
}
SetOnce<RetrieverBuilder> innerRetriever = new SetOnce<>();
ctx.registerAsyncAction((client, actionListener) -> {
SearchSourceBuilder source = new SearchSourceBuilder().pointInTimeBuilder(ctx.getPointInTimeBuilder())
.query(QueryBuilders.termQuery(ID_FIELD, id))
.fetchField(QUERY_FIELD);
client.search(new SearchRequest().source(source), new ActionListener<>() {
@Override
public void onResponse(SearchResponse response) {
String query = response.getHits().getAt(0).field(QUERY_FIELD).getValue();
StandardRetrieverBuilder standard = new StandardRetrieverBuilder();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume in the main code we want retrievers to be rewritten into some kind of ScoreDocQuery, but not to another retriever?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retrievers will always rewrite to a retriever, that's enforced by the design. Using some kind of ScoreDocQuery will be internally done by the rewritten retriever when extractToSourceBuilder is called (when the retriever is guaranteed to be fully rewritten).

standard.queryBuilder = QueryBuilders.termQuery(ID_FIELD, query);
innerRetriever.set(standard);
actionListener.onResponse(null);
}

@Override
public void onFailure(Exception e) {
actionListener.onFailure(e);
}
});
});
return new AssertingCompoundRetrieverBuilder(id, innerRetriever);
}

@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder sourceBuilder, boolean compoundUsed) {
assertNull(sourceBuilder.retriever());
innerRetriever.get().extractToSearchSourceBuilder(sourceBuilder, compoundUsed);
}

@Override
public String getName() {
return "asserting";
}

@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
throw new AssertionError("not implemented");
}

@Override
protected boolean doEquals(Object o) {
return false;
}

@Override
protected int doHashCode() {
return id.hashCode();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ protected void doExecute(Task task, ValidateQueryRequest request, ActionListener
if (request.query() == null) {
rewriteListener.onResponse(request.query());
} else {
Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, resolvedIndices), rewriteListener);
Rewriteable.rewriteAndFetch(
request.query(),
searchService.getRewriteContext(timeProvider, resolvedIndices, null),
rewriteListener
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ protected void doExecute(Task task, ExplainRequest request, ActionListener<Expla

assert request.query() != null;
LongSupplier timeProvider = () -> request.nowInMillis;
Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, resolvedIndices), rewriteListener);
Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, resolvedIndices, null), rewriteListener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,9 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At
if (buildPointInTimeFromSearchResults()) {
searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minTransportVersion);
} else {
if (request.source() != null && request.source().pointInTimeBuilder() != null) {
if (request.source() != null
&& request.source().pointInTimeBuilder() != null
&& request.source().pointInTimeBuilder().singleSession() == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to rely on the keepAlive to determine explicit releasing in this case? Could this possible cause any side effects with any user provided requests with -1 as keep_alive ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just about returning the pit id in the response. The logic to release the pit automatically is restricted to the retriever builder case. It might be surprising for existing pit users though since -1 is a valid keep-alive value. We should be able to restrict this behavior to retrievers entirely.

searchContextId = request.source().pointInTimeBuilder().getEncodedId();
} else {
searchContextId = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext;
Expand Down Expand Up @@ -487,11 +488,72 @@ void executeRequest(
}
}
});
final SearchSourceBuilder source = original.source();
if (shouldOpenPIT(source)) {
openPIT(client, original, searchService.getDefaultKeepAliveInMillis(), listener.delegateFailureAndWrap((delegate, resp) -> {
// We set the keep alive to -1 to indicate that we don't need the pit id in the response.
// This is needed since we delete the pit prior to sending the response so the id doesn't exist anymore.
source.pointInTimeBuilder(new PointInTimeBuilder(resp.getPointInTimeId()).setKeepAlive(TimeValue.MINUS_ONE));
executeRequest(task, original, new ActionListener<>() {
@Override
public void onResponse(SearchResponse response) {
// we need to close the PIT first so we delay the release of the response to after the closing
response.incRef();
closePIT(
client,
original.source().pointInTimeBuilder(),
() -> ActionListener.respondAndRelease(listener, response)
);
}

Rewriteable.rewriteAndFetch(
original,
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices),
rewriteListener
@Override
public void onFailure(Exception e) {
closePIT(client, original.source().pointInTimeBuilder(), () -> listener.onFailure(e));
}
}, searchPhaseProvider);
}));
} else {
Rewriteable.rewriteAndFetch(
original,
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, original.pointInTimeBuilder()),
rewriteListener
);
}
}

/**
* Returns true if the provided source needs to open a shared point in time prior to executing the request.
*/
private boolean shouldOpenPIT(SearchSourceBuilder source) {
if (source == null) {
return false;
}
if (source.pointInTimeBuilder() != null) {
return false;
}
var retriever = source.retriever();
return retriever != null && retriever.isCompound();
}

static void openPIT(Client client, SearchRequest request, long keepAliveMillis, ActionListener<OpenPointInTimeResponse> listener) {
OpenPointInTimeRequest pitReq = new OpenPointInTimeRequest(request.indices()).indicesOptions(request.indicesOptions())
.preference(request.preference())
.routing(request.routing())
.keepAlive(TimeValue.timeValueMillis(keepAliveMillis));
client.execute(TransportOpenPointInTimeAction.TYPE, pitReq, listener);
}

static void closePIT(Client client, PointInTimeBuilder pit, Runnable next) {
client.execute(
TransportClosePointInTimeAction.TYPE,
new ClosePointInTimeRequest(pit.getEncodedId()),
ActionListener.runAfter(new ActionListener<>() {
@Override
public void onResponse(ClosePointInTimeResponse closePointInTimeResponse) {}

@Override
public void onFailure(Exception e) {}
}, next)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ protected void doExecute(Task task, SearchShardsRequest searchShardsRequest, Act

Rewriteable.rewriteAndFetch(
original,
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices),
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, null),
listener.delegateFailureAndWrap((delegate, searchRequest) -> {
Index[] concreteIndices = resolvedIndices.getConcreteLocalIndices();
final Set<String> indicesAndAliases = indexNameExpressionResolver.resolveExpressions(clusterState, searchRequest.indices());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ public QueryRewriteContext newQueryRewriteContext(
valuesSourceRegistry,
allowExpensiveQueries,
scriptService,
null,
null
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public CoordinatorRewriteContext(
null,
null,
null,
null,
null
);
this.indexLongFieldRange = indexLongFieldRange;
Expand Down
Loading