Skip to content

Commit fb6d7db

Browse files
authored
Semantic search with query builder rewrite (#118676)
* Semantic search with query builder rewrite * Address review feedback * Add feature behind snapshot * Use after/before instead of afterClass/beforeClass * Call onFailure instead of throwing exception * Fix KqlFunctionIT by requiring KqlPlugin * Update scoring tests now that they are enabled * Drop the score column for now
1 parent 724e9be commit fb6d7db

File tree

18 files changed

+629
-7
lines changed

18 files changed

+629
-7
lines changed

server/src/main/java/org/elasticsearch/action/ResolvedIndices.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,26 @@ public static ResolvedIndices resolveWithIndicesRequest(
150150
RemoteClusterService remoteClusterService,
151151
long startTimeInMillis
152152
) {
153-
final Map<String, OriginalIndices> remoteClusterIndices = remoteClusterService.groupIndices(
153+
return resolveWithIndexNamesAndOptions(
154+
request.indices(),
154155
request.indicesOptions(),
155-
request.indices()
156+
clusterState,
157+
indexNameExpressionResolver,
158+
remoteClusterService,
159+
startTimeInMillis
156160
);
161+
}
162+
163+
public static ResolvedIndices resolveWithIndexNamesAndOptions(
164+
String[] indexNames,
165+
IndicesOptions indicesOptions,
166+
ClusterState clusterState,
167+
IndexNameExpressionResolver indexNameExpressionResolver,
168+
RemoteClusterService remoteClusterService,
169+
long startTimeInMillis
170+
) {
171+
final Map<String, OriginalIndices> remoteClusterIndices = remoteClusterService.groupIndices(indicesOptions, indexNames);
172+
157173
final OriginalIndices localIndices = remoteClusterIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
158174

159175
Index[] concreteLocalIndices = localIndices == null
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.qa.multi_node;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
11+
12+
import org.elasticsearch.test.TestClustersThreadFilter;
13+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
14+
import org.elasticsearch.xpack.esql.qa.rest.SemanticMatchTestCase;
15+
import org.junit.ClassRule;
16+
17+
@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
18+
public class SemanticMatchIT extends SemanticMatchTestCase {
19+
@ClassRule
20+
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));
21+
22+
@Override
23+
protected String getTestRestCluster() {
24+
return cluster.getHttpAddresses();
25+
}
26+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.qa.single_node;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
11+
12+
import org.elasticsearch.test.TestClustersThreadFilter;
13+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
14+
import org.elasticsearch.xpack.esql.qa.rest.SemanticMatchTestCase;
15+
import org.junit.ClassRule;
16+
17+
@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
18+
public class SemanticMatchIT extends SemanticMatchTestCase {
19+
@ClassRule
20+
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));
21+
22+
@Override
23+
protected String getTestRestCluster() {
24+
return cluster.getHttpAddresses();
25+
}
26+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.qa.rest;
9+
10+
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.client.ResponseException;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.test.rest.ESRestTestCase;
14+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
15+
import org.junit.After;
16+
import org.junit.Before;
17+
18+
import java.io.IOException;
19+
import java.util.Map;
20+
21+
import static org.hamcrest.core.StringContains.containsString;
22+
23+
public abstract class SemanticMatchTestCase extends ESRestTestCase {
24+
public void testWithMultipleInferenceIds() throws IOException {
25+
String query = """
26+
from test-semantic1,test-semantic2
27+
| where match(semantic_text_field, "something")
28+
""";
29+
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query));
30+
31+
assertThat(re.getMessage(), containsString("Field [semantic_text_field] has multiple inference IDs associated with it"));
32+
33+
assertEquals(400, re.getResponse().getStatusLine().getStatusCode());
34+
}
35+
36+
public void testWithInferenceNotConfigured() {
37+
String query = """
38+
from test-semantic3
39+
| where match(semantic_text_field, "something")
40+
""";
41+
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query));
42+
43+
assertThat(re.getMessage(), containsString("Inference endpoint not found"));
44+
assertEquals(404, re.getResponse().getStatusLine().getStatusCode());
45+
}
46+
47+
@Before
48+
public void setUpIndices() throws IOException {
49+
assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.SEMANTIC_TEXT_TYPE.isEnabled());
50+
51+
var settings = Settings.builder().build();
52+
53+
String mapping1 = """
54+
"properties": {
55+
"semantic_text_field": {
56+
"type": "semantic_text",
57+
"inference_id": "test_sparse_inference"
58+
}
59+
}
60+
""";
61+
createIndex(adminClient(), "test-semantic1", settings, mapping1);
62+
63+
String mapping2 = """
64+
"properties": {
65+
"semantic_text_field": {
66+
"type": "semantic_text",
67+
"inference_id": "test_dense_inference"
68+
}
69+
}
70+
""";
71+
createIndex(adminClient(), "test-semantic2", settings, mapping2);
72+
73+
String mapping3 = """
74+
"properties": {
75+
"semantic_text_field": {
76+
"type": "semantic_text",
77+
"inference_id": "inexistent"
78+
}
79+
}
80+
""";
81+
createIndex(adminClient(), "test-semantic3", settings, mapping3);
82+
}
83+
84+
@Before
85+
public void setUpTextEmbeddingInferenceEndpoint() throws IOException {
86+
assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.SEMANTIC_TEXT_TYPE.isEnabled());
87+
Request request = new Request("PUT", "_inference/text_embedding/test_dense_inference");
88+
request.setJsonEntity("""
89+
{
90+
"service": "test_service",
91+
"service_settings": {
92+
"model": "my_model",
93+
"api_key": "abc64"
94+
},
95+
"task_settings": {
96+
}
97+
}
98+
""");
99+
adminClient().performRequest(request);
100+
}
101+
102+
@After
103+
public void wipeData() throws IOException {
104+
assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.SEMANTIC_TEXT_TYPE.isEnabled());
105+
adminClient().performRequest(new Request("DELETE", "*"));
106+
107+
try {
108+
adminClient().performRequest(new Request("DELETE", "_inference/test_dense_inference"));
109+
} catch (ResponseException e) {
110+
// 404 here means the endpoint was not created
111+
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
112+
throw e;
113+
}
114+
}
115+
}
116+
117+
private Map<String, Object> runEsqlQuery(String query) throws IOException {
118+
RestEsqlTestCase.RequestObjectBuilder builder = RestEsqlTestCase.requestObjectBuilder().query(query);
119+
return RestEsqlTestCase.runEsqlSync(builder);
120+
}
121+
}

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
7171
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
7272
import org.elasticsearch.xpack.esql.session.Configuration;
73+
import org.elasticsearch.xpack.esql.session.QueryBuilderResolver;
7374
import org.elasticsearch.xpack.esql.stats.Metrics;
7475
import org.elasticsearch.xpack.esql.stats.SearchStats;
7576
import org.elasticsearch.xpack.versionfield.Version;
@@ -351,6 +352,8 @@ public String toString() {
351352

352353
public static final Verifier TEST_VERIFIER = new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L));
353354

355+
public static final QueryBuilderResolver MOCK_QUERY_BUILDER_RESOLVER = new MockQueryBuilderResolver();
356+
354357
private EsqlTestUtils() {}
355358

356359
public static Configuration configuration(QueryPragmas pragmas, String query) {
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
12+
import org.elasticsearch.xpack.esql.session.QueryBuilderResolver;
13+
import org.elasticsearch.xpack.esql.session.Result;
14+
15+
import java.util.function.BiConsumer;
16+
17+
public class MockQueryBuilderResolver extends QueryBuilderResolver {
18+
public MockQueryBuilderResolver() {
19+
super(null, null, null, null);
20+
}
21+
22+
@Override
23+
public void resolveQueryBuilders(
24+
LogicalPlan plan,
25+
ActionListener<Result> listener,
26+
BiConsumer<LogicalPlan, ActionListener<Result>> callback
27+
) {
28+
callback.accept(plan, listener);
29+
}
30+
}

x-pack/plugin/esql/qa/testFixtures/src/main/resources/match-function.csv-spec

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,74 @@ from employees,employees_incompatible
597597

598598
emp_no_bool:boolean
599599
;
600+
601+
testMatchWithSemanticText
602+
required_capability: match_function
603+
required_capability: semantic_text_type
604+
605+
from semantic_text
606+
| where match(semantic_text_field, "something")
607+
| keep semantic_text_field
608+
| sort semantic_text_field asc
609+
;
610+
611+
semantic_text_field:semantic_text
612+
all we have to decide is what to do with the time that is given to us
613+
be excellent to each other
614+
live long and prosper
615+
;
616+
617+
testMatchWithSemanticTextAndKeyword
618+
required_capability: match_function
619+
required_capability: semantic_text_type
620+
621+
from semantic_text
622+
| where match(semantic_text_field, "something") AND match(host, "host1")
623+
| keep semantic_text_field, host
624+
;
625+
626+
semantic_text_field:semantic_text | host:keyword
627+
live long and prosper | host1
628+
;
629+
630+
testMatchWithSemanticTextMultiValueField
631+
required_capability: match_function
632+
required_capability: semantic_text_type
633+
634+
from semantic_text metadata _id
635+
| where match(st_multi_value, "something") AND match(host, "host1")
636+
| keep _id, st_multi_value
637+
;
638+
639+
_id: keyword | st_multi_value:semantic_text
640+
1 | ["Hello there!", "This is a random value", "for testing purposes"]
641+
;
642+
643+
testMatchWithSemanticTextWithEvalsAndOtherFunctionsAndStats
644+
required_capability: match_function
645+
required_capability: semantic_text_type
646+
647+
from semantic_text
648+
| where qstr("description:some*")
649+
| eval size = mv_count(st_multi_value)
650+
| where match(semantic_text_field, "something") AND size > 1 AND match(host, "host1")
651+
| STATS result = count(*)
652+
;
653+
654+
result:long
655+
1
656+
;
657+
658+
testMatchWithSemanticTextAndKql
659+
required_capability: match_function
660+
required_capability: semantic_text_type
661+
required_capability: kql_function
662+
663+
from semantic_text
664+
| where kql("host:host1") AND match(semantic_text_field, "something")
665+
| KEEP host, semantic_text_field
666+
;
667+
668+
host:keyword | semantic_text_field:semantic_text
669+
"host1" | live long and prosper
670+
;

0 commit comments

Comments
 (0)