Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a0be9df
Adds base and classic retrievers (#100731)
jdconrad Oct 13, 2023
1407559
Merge branch 'main' into retrievers
jdconrad Oct 16, 2023
987a4ac
Merge branch 'main' into retrievers
jdconrad Oct 18, 2023
addda00
Merge remote-tracking branch 'upstream/main' into retrievers
mayya-sharipova Oct 20, 2023
69cfca8
Introduce kNN retriever (#101166)
mayya-sharipova Oct 20, 2023
b89b582
Add an RRF Retriever (#101090)
jdconrad Oct 24, 2023
c62d4ad
Merge branch 'main' into retrievers
jdconrad Nov 7, 2023
5b4a214
Merge branch 'main' into retrievers
jdconrad Nov 13, 2023
cb6e5f6
Merge branch 'main' into retrievers
jdconrad Nov 14, 2023
cb487ac
Add basic linear combination retriever (#102100)
jdconrad Nov 16, 2023
32f97a4
Merge branch 'main' into retrievers
jdconrad Nov 16, 2023
3b0c859
Merge branch 'main' into retrievers
jdconrad Nov 27, 2023
0ceb78a
Merge branch 'main' into retrievers
jdconrad Nov 28, 2023
5553191
Merge branch 'main' into retrievers
jdconrad Dec 5, 2023
9426c04
Merge branch 'main' into retrievers
jdconrad Dec 6, 2023
6a36899
Merge branch 'main' into retrievers
jdconrad Dec 7, 2023
349a966
Merge branch 'main' into retrievers
jdconrad Dec 10, 2023
c95cb61
Merge branch 'main' into retrievers
jdconrad Dec 14, 2023
e6f35e1
Merge branch 'main' into retrievers
jdconrad Jan 8, 2024
8270e05
Merge branch 'main' into retrievers
jdconrad Jan 16, 2024
fcbd7f8
Merge branch 'main' into retrievers
jdconrad Jan 17, 2024
bf746d9
Merge branch 'main' into retrievers
jdconrad Jan 17, 2024
90e2d36
Yaml tests for standard retriever (#104482)
jdconrad Jan 19, 2024
706e3d4
Merge branch 'main' into retrievers
jdconrad Jan 22, 2024
945caba
Merge remote-tracking branch 'upstream/retrievers' into retrievers
jdconrad Jan 22, 2024
74ab5e4
Merge branch 'main' into retrievers
jdconrad Jan 29, 2024
7ee27cf
Remove extraneous code from retrievers (#104888)
jdconrad Jan 30, 2024
948ce3d
Merge branch 'main' into retrievers
jdconrad Jan 30, 2024
4bb0feb
Merge branch 'main' into retrievers
jdconrad Jan 31, 2024
710d98d
Add yaml tests for rrf retreiver (#104992)
jdconrad Jan 31, 2024
60faec0
Merge branch 'main' into retrievers
jdconrad Feb 1, 2024
3526a8c
Merge branch 'main' into retrievers
jdconrad Feb 1, 2024
d8f9af6
add feature versioning to retrievers
jdconrad Feb 1, 2024
c9090d8
fix bindings
jdconrad Feb 2, 2024
fc8a987
update error messages
jdconrad Feb 2, 2024
ecc3a16
set number of shards for test consistency
jdconrad Feb 2, 2024
a6d5bc9
add rrf dependency for testing
jdconrad Feb 2, 2024
9d16732
Merge branch 'main' into retver
jdconrad Feb 12, 2024
a4b57d2
remove retriever specific code
jdconrad Feb 12, 2024
6851460
spotless
jdconrad Feb 12, 2024
17fc95b
remove additional retrievers code
jdconrad Feb 12, 2024
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public List<RestHandler> getRestHandlers(
Predicate<NodeFeature> clusterSupportsFeature
) {
return Arrays.asList(
new RestSearchTemplateAction(namedWriteableRegistry),
new RestSearchTemplateAction(namedWriteableRegistry, clusterSupportsFeature),
new RestMultiSearchTemplateAction(settings),
new RestRenderSearchTemplateAction()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.Scope;
Expand All @@ -23,6 +24,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;

import static org.elasticsearch.rest.RestRequest.Method.GET;
import static org.elasticsearch.rest.RestRequest.Method.POST;
Expand All @@ -35,9 +37,11 @@ public class RestSearchTemplateAction extends BaseRestHandler {
private static final Set<String> RESPONSE_PARAMS = Set.of(TYPED_KEYS_PARAM, RestSearchAction.TOTAL_HITS_AS_INT_PARAM);

private final NamedWriteableRegistry namedWriteableRegistry;
private final Predicate<NodeFeature> clusterSupportsFeature;

public RestSearchTemplateAction(NamedWriteableRegistry namedWriteableRegistry) {
public RestSearchTemplateAction(NamedWriteableRegistry namedWriteableRegistry, Predicate<NodeFeature> clusterSupportsFeature) {
this.namedWriteableRegistry = namedWriteableRegistry;
this.clusterSupportsFeature = clusterSupportsFeature;
}

@Override
Expand Down Expand Up @@ -70,6 +74,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
request,
null,
namedWriteableRegistry,
clusterSupportsFeature,
size -> searchRequest.source().size(size)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.features.FeatureService;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
Expand All @@ -29,6 +33,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.function.Predicate;

import static org.elasticsearch.script.mustache.TransportSearchTemplateAction.convert;

Expand All @@ -38,6 +43,7 @@ public class TransportMultiSearchTemplateAction extends HandledTransportAction<M

private final ScriptService scriptService;
private final NamedXContentRegistry xContentRegistry;
private final Predicate<NodeFeature> clusterSupportsFeature;
private final NodeClient client;
private final SearchUsageHolder searchUsageHolder;

Expand All @@ -48,7 +54,9 @@ public TransportMultiSearchTemplateAction(
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
NodeClient client,
UsageService usageService
UsageService usageService,
ClusterService clusterService,
FeatureService featureService
) {
super(
MustachePlugin.MULTI_SEARCH_TEMPLATE_ACTION.name(),
Expand All @@ -59,6 +67,10 @@ public TransportMultiSearchTemplateAction(
);
this.scriptService = scriptService;
this.xContentRegistry = xContentRegistry;
this.clusterSupportsFeature = f -> {
ClusterState state = clusterService.state();
return state.clusterRecovered() && featureService.clusterHasFeature(state, f);
};
this.client = client;
this.searchUsageHolder = usageService.getSearchUsageHolder();
}
Expand All @@ -78,7 +90,14 @@ protected void doExecute(Task task, MultiSearchTemplateRequest request, ActionLi
SearchTemplateResponse searchTemplateResponse = new SearchTemplateResponse();
SearchRequest searchRequest;
try {
searchRequest = convert(searchTemplateRequest, searchTemplateResponse, scriptService, xContentRegistry, searchUsageHolder);
searchRequest = convert(
searchTemplateRequest,
searchTemplateResponse,
scriptService,
xContentRegistry,
clusterSupportsFeature,
searchUsageHolder
);
} catch (Exception e) {
searchTemplateResponse.decRef();
items[i] = new MultiSearchTemplateResponse.Item(null, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.features.FeatureService;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.rest.action.search.RestSearchAction;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptService;
Expand All @@ -36,13 +40,15 @@

import java.io.IOException;
import java.util.Collections;
import java.util.function.Predicate;

public class TransportSearchTemplateAction extends HandledTransportAction<SearchTemplateRequest, SearchTemplateResponse> {

private static final String TEMPLATE_LANG = MustacheScriptEngine.NAME;

private final ScriptService scriptService;
private final NamedXContentRegistry xContentRegistry;
private final Predicate<NodeFeature> clusterSupportsFeature;
private final NodeClient client;
private final SearchUsageHolder searchUsageHolder;

Expand All @@ -53,7 +59,9 @@ public TransportSearchTemplateAction(
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
NodeClient client,
UsageService usageService
UsageService usageService,
ClusterService clusterService,
FeatureService featureService
) {
super(
MustachePlugin.SEARCH_TEMPLATE_ACTION.name(),
Expand All @@ -64,6 +72,10 @@ public TransportSearchTemplateAction(
);
this.scriptService = scriptService;
this.xContentRegistry = xContentRegistry;
this.clusterSupportsFeature = f -> {
ClusterState state = clusterService.state();
return state.clusterRecovered() && featureService.clusterHasFeature(state, f);
};
this.client = client;
this.searchUsageHolder = usageService.getSearchUsageHolder();
}
Expand All @@ -73,7 +85,14 @@ protected void doExecute(Task task, SearchTemplateRequest request, ActionListene
final SearchTemplateResponse response = new SearchTemplateResponse();
boolean success = false;
try {
SearchRequest searchRequest = convert(request, response, scriptService, xContentRegistry, searchUsageHolder);
SearchRequest searchRequest = convert(
request,
response,
scriptService,
xContentRegistry,
clusterSupportsFeature,
searchUsageHolder
);
if (searchRequest != null) {
client.search(searchRequest, listener.delegateResponse((l, e) -> {
response.decRef();
Expand Down Expand Up @@ -102,6 +121,7 @@ static SearchRequest convert(
SearchTemplateResponse response,
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
Predicate<NodeFeature> clusterSupportsFeature,
SearchUsageHolder searchUsageHolder
) throws IOException {
Script script = new Script(
Expand All @@ -121,7 +141,7 @@ static SearchRequest convert(
XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withRegistry(xContentRegistry)
.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, source)) {
builder.parseXContent(parser, false, searchUsageHolder);
builder.parseXContent(parser, false, searchUsageHolder, clusterSupportsFeature);
}

if (searchTemplateRequest.isSimulate()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public final class RestSearchTemplateActionTests extends RestActionTestCase {

@Before
public void setUpAction() {
controller().registerHandler(new RestSearchTemplateAction(mock(NamedWriteableRegistry.class)));
controller().registerHandler(new RestSearchTemplateAction(mock(NamedWriteableRegistry.class), nf -> false));
verifyingClient.setExecuteVerifier((actionType, request) -> mock(SearchTemplateResponse.class));
verifyingClient.setExecuteLocallyVerifier((actionType, request) -> mock(SearchTemplateResponse.class));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public List<RestHandler> getRestHandlers(
Supplier<DiscoveryNodes> nodesInCluster,
Predicate<NodeFeature> clusterSupportsFeature
) {
return Collections.singletonList(new RestRankEvalAction());
return Collections.singletonList(new RestRankEvalAction(clusterSupportsFeature));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.script.Script;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
Expand All @@ -30,6 +31,7 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.function.Predicate;

/**
* Specification of the ranking evaluation request.<br>
Expand Down Expand Up @@ -132,13 +134,13 @@ public void setMaxConcurrentSearches(int maxConcurrentSearches) {
private static final ParseField REQUESTS_FIELD = new ParseField("requests");
private static final ParseField MAX_CONCURRENT_SEARCHES_FIELD = new ParseField("max_concurrent_searches");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<RankEvalSpec, Void> PARSER = new ConstructingObjectParser<>(
private static final ConstructingObjectParser<RankEvalSpec, Predicate<NodeFeature>> PARSER = new ConstructingObjectParser<>(
"rank_eval",
a -> new RankEvalSpec((List<RatedRequest>) a[0], (EvaluationMetric) a[1], (Collection<ScriptWithId>) a[2])
);

static {
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> RatedRequest.fromXContent(p), REQUESTS_FIELD);
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> RatedRequest.fromXContent(p, c), REQUESTS_FIELD);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseMetric(p), METRIC_FIELD);
PARSER.declareObjectArray(
ConstructingObjectParser.optionalConstructorArg(),
Expand All @@ -156,8 +158,8 @@ private static EvaluationMetric parseMetric(XContentParser parser) throws IOExce
return metric;
}

public static RankEvalSpec parse(XContentParser parser) {
return PARSER.apply(parser, null);
public static RankEvalSpec parse(XContentParser parser, Predicate<NodeFeature> clusterSupportsFeature) {
return PARSER.apply(parser, clusterSupportsFeature);
}

static class ScriptWithId {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
Expand All @@ -31,6 +32,7 @@
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;

/**
* Definition of a particular query in the ranking evaluation request.<br>
Expand Down Expand Up @@ -246,7 +248,7 @@ public void addSummaryFields(List<String> summaryFieldsToAdd) {
private static final ParseField TEMPLATE_ID_FIELD = new ParseField("template_id");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<RatedRequest, Void> PARSER = new ConstructingObjectParser<>(
private static final ConstructingObjectParser<RatedRequest, Predicate<NodeFeature>> PARSER = new ConstructingObjectParser<>(
"request",
a -> new RatedRequest(
(String) a[0],
Expand All @@ -262,7 +264,7 @@ public void addSummaryFields(List<String> summaryFieldsToAdd) {
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> RatedDocument.fromXContent(p), RATINGS_FIELD);
PARSER.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> new SearchSourceBuilder().parseXContent(p, false),
(p, c) -> new SearchSourceBuilder().parseXContent(p, false, c),
REQUEST_FIELD
);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> p.map(), PARAMS_FIELD);
Expand All @@ -273,8 +275,8 @@ public void addSummaryFields(List<String> summaryFieldsToAdd) {
/**
* parse from rest representation
*/
public static RatedRequest fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
public static RatedRequest fromXContent(XContentParser parser, Predicate<NodeFeature> clusterSupportsFeature) {
return PARSER.apply(parser, clusterSupportsFeature);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.Strings;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.Scope;
Expand All @@ -21,6 +22,7 @@

import java.io.IOException;
import java.util.List;
import java.util.function.Predicate;

import static org.elasticsearch.rest.RestRequest.Method.GET;
import static org.elasticsearch.rest.RestRequest.Method.POST;
Expand Down Expand Up @@ -82,6 +84,12 @@ public class RestRankEvalAction extends BaseRestHandler {

public static final String ENDPOINT = "_rank_eval";

private Predicate<NodeFeature> clusterSupportsFeature;

public RestRankEvalAction(Predicate<NodeFeature> clusterSupportsFeature) {
this.clusterSupportsFeature = clusterSupportsFeature;
}

@Override
public List<Route> routes() {
return List.of(
Expand All @@ -96,7 +104,7 @@ public List<Route> routes() {
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
RankEvalRequest rankEvalRequest = new RankEvalRequest();
try (XContentParser parser = request.contentOrSourceParamParser()) {
parseRankEvalRequest(rankEvalRequest, request, parser);
parseRankEvalRequest(rankEvalRequest, request, parser, clusterSupportsFeature);
}
return channel -> client.executeLocally(
RankEvalPlugin.ACTION,
Expand All @@ -105,13 +113,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
);
}

private static void parseRankEvalRequest(RankEvalRequest rankEvalRequest, RestRequest request, XContentParser parser) {
private static void parseRankEvalRequest(
RankEvalRequest rankEvalRequest,
RestRequest request,
XContentParser parser,
Predicate<NodeFeature> clusterSupportsFeature
) {
rankEvalRequest.indices(Strings.splitStringByCommaToArray(request.param("index")));
rankEvalRequest.indicesOptions(IndicesOptions.fromRequest(request, rankEvalRequest.indicesOptions()));
if (request.hasParam("search_type")) {
rankEvalRequest.searchType(SearchType.fromString(request.param("search_type")));
}
RankEvalSpec spec = RankEvalSpec.parse(parser);
RankEvalSpec spec = RankEvalSpec.parse(parser, clusterSupportsFeature);
rankEvalRequest.setRankEvalSpec(spec);
}

Expand Down
Loading