Skip to content

Commit a9d6c12

Browse files
authored
Fix: avoid further dispatch when ThreadContext population fails (#121665)
1 parent d7db1f5 commit a9d6c12

File tree

4 files changed

+140
-20
lines changed

4 files changed

+140
-20
lines changed

server/src/internalClusterTest/java/org/elasticsearch/rest/RestControllerIT.java

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.rest;
1111

1212
import org.elasticsearch.client.Request;
13+
import org.elasticsearch.client.RequestOptions;
1314
import org.elasticsearch.client.ResponseException;
1415
import org.elasticsearch.client.internal.node.NodeClient;
1516
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
@@ -20,6 +21,7 @@
2021
import org.elasticsearch.common.settings.IndexScopedSettings;
2122
import org.elasticsearch.common.settings.Settings;
2223
import org.elasticsearch.common.settings.SettingsFilter;
24+
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
2325
import org.elasticsearch.features.NodeFeature;
2426
import org.elasticsearch.logging.LogManager;
2527
import org.elasticsearch.logging.Logger;
@@ -30,15 +32,21 @@
3032
import org.elasticsearch.telemetry.Measurement;
3133
import org.elasticsearch.telemetry.TestTelemetryPlugin;
3234
import org.elasticsearch.test.ESIntegTestCase;
35+
import org.elasticsearch.xcontent.XContentParser;
3336

3437
import java.io.IOException;
3538
import java.util.ArrayList;
3639
import java.util.Collection;
40+
import java.util.HashMap;
3741
import java.util.List;
3842
import java.util.function.Consumer;
3943
import java.util.function.Predicate;
4044
import java.util.function.Supplier;
4145

46+
import static org.elasticsearch.test.rest.ESRestTestCase.responseAsParser;
47+
import static org.hamcrest.Matchers.containsInAnyOrder;
48+
import static org.hamcrest.Matchers.containsString;
49+
import static org.hamcrest.Matchers.equalTo;
4250
import static org.hamcrest.Matchers.hasEntry;
4351
import static org.hamcrest.Matchers.hasSize;
4452
import static org.hamcrest.Matchers.instanceOf;
@@ -58,6 +66,49 @@ public void testHeadersEmittedWithChunkedResponses() throws IOException {
5866
assertEquals(ChunkedResponseWithHeadersPlugin.HEADER_VALUE, response.getHeader(ChunkedResponseWithHeadersPlugin.HEADER_NAME));
5967
}
6068

69+
public void testHeadersAreCollapsed() throws IOException {
70+
final var client = getRestClient();
71+
final var request = new Request("GET", TestEchoHeadersPlugin.ROUTE);
72+
request.setOptions(RequestOptions.DEFAULT.toBuilder().addHeader("X-Foo", "1").addHeader("X-Foo", "2").build());
73+
final var response = client.performRequest(request);
74+
var responseMap = responseAsParser(response).map(HashMap::new, XContentParser::list);
75+
assertThat(responseMap, hasEntry(equalTo("X-Foo"), containsInAnyOrder("1", "2")));
76+
}
77+
78+
public void testHeadersTreatedCaseInsensitive() throws IOException {
79+
final var client = getRestClient();
80+
final var request = new Request("GET", TestEchoHeadersPlugin.ROUTE);
81+
request.setOptions(RequestOptions.DEFAULT.toBuilder().addHeader("X-Foo", "1").addHeader("x-foo", "2").build());
82+
final var response = client.performRequest(request);
83+
var responseMap = responseAsParser(response).map(HashMap::new, XContentParser::list);
84+
assertThat(responseMap, hasEntry(equalTo("x-foo"), containsInAnyOrder("1", "2")));
85+
assertThat(responseMap, hasEntry(equalTo("X-Foo"), containsInAnyOrder("1", "2")));
86+
}
87+
88+
public void testThreadContextPopulationFromMultipleHeadersFailsWithCorrectError() {
89+
final var client = getRestClient();
90+
final var sameCaseRequest = new Request("GET", TestEchoHeadersPlugin.ROUTE);
91+
sameCaseRequest.setOptions(
92+
RequestOptions.DEFAULT.toBuilder()
93+
.addHeader("x-elastic-product-origin", "elastic")
94+
.addHeader("x-elastic-product-origin", "other")
95+
);
96+
var exception1 = expectThrows(ResponseException.class, () -> client.performRequest(sameCaseRequest));
97+
assertThat(exception1.getMessage(), containsString("multiple values for single-valued header [X-elastic-product-origin]"));
98+
}
99+
100+
public void testMultipleProductOriginHeadersWithDifferentCaseFailsWithCorrectError() {
101+
final var client = getRestClient();
102+
final var differentCaseRequest = new Request("GET", TestEchoHeadersPlugin.ROUTE);
103+
differentCaseRequest.setOptions(
104+
RequestOptions.DEFAULT.toBuilder()
105+
.addHeader("X-elastic-product-origin", "elastic")
106+
.addHeader("x-elastic-product-origin", "other")
107+
);
108+
var exception2 = expectThrows(ResponseException.class, () -> client.performRequest(differentCaseRequest));
109+
assertThat(exception2.getMessage(), containsString("multiple values for single-valued header [X-elastic-product-origin]"));
110+
}
111+
61112
public void testMetricsEmittedOnSuccess() throws Exception {
62113
final var client = getRestClient();
63114
final var request = new Request("GET", TestEchoStatusCodePlugin.ROUTE);
@@ -125,7 +176,12 @@ private void assertMeasurement(Consumer<Measurement> measurementConsumer) throws
125176

126177
@Override
127178
protected Collection<Class<? extends Plugin>> nodePlugins() {
128-
return List.of(ChunkedResponseWithHeadersPlugin.class, TestEchoStatusCodePlugin.class, TestTelemetryPlugin.class);
179+
return List.of(
180+
ChunkedResponseWithHeadersPlugin.class,
181+
TestEchoStatusCodePlugin.class,
182+
TestEchoHeadersPlugin.class,
183+
TestTelemetryPlugin.class
184+
);
129185
}
130186

131187
public static class TestEchoStatusCodePlugin extends Plugin implements ActionPlugin {
@@ -181,6 +237,62 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
181237
}
182238
}
183239

240+
public static class TestEchoHeadersPlugin extends Plugin implements ActionPlugin {
241+
static final String ROUTE = "/_test/echo_headers";
242+
static final String NAME = "test_echo_headers";
243+
244+
private static final Logger logger = LogManager.getLogger(TestEchoStatusCodePlugin.class);
245+
246+
@Override
247+
public Collection<RestHandler> getRestHandlers(
248+
Settings settings,
249+
NamedWriteableRegistry namedWriteableRegistry,
250+
RestController restController,
251+
ClusterSettings clusterSettings,
252+
IndexScopedSettings indexScopedSettings,
253+
SettingsFilter settingsFilter,
254+
IndexNameExpressionResolver indexNameExpressionResolver,
255+
Supplier<DiscoveryNodes> nodesInCluster,
256+
Predicate<NodeFeature> clusterSupportsFeature
257+
) {
258+
return List.of(new BaseRestHandler() {
259+
@Override
260+
public String getName() {
261+
return NAME;
262+
}
263+
264+
@Override
265+
public List<Route> routes() {
266+
return List.of(new Route(RestRequest.Method.GET, ROUTE), new Route(RestRequest.Method.POST, ROUTE));
267+
}
268+
269+
@Override
270+
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
271+
var headers = request.getHeaders();
272+
logger.info("received header echo request for [{}]", String.join(",", headers.keySet()));
273+
274+
return channel -> {
275+
final var response = RestResponse.chunked(
276+
RestStatus.OK,
277+
ChunkedRestResponseBodyPart.fromXContent(
278+
params -> Iterators.concat(
279+
ChunkedToXContentHelper.startObject(),
280+
Iterators.map(headers.entrySet().iterator(), e -> (b, p) -> b.field(e.getKey(), e.getValue())),
281+
ChunkedToXContentHelper.endObject()
282+
),
283+
request,
284+
channel
285+
),
286+
null
287+
);
288+
channel.sendResponse(response);
289+
logger.info("sent response");
290+
};
291+
}
292+
});
293+
}
294+
}
295+
184296
public static class ChunkedResponseWithHeadersPlugin extends Plugin implements ActionPlugin {
185297

186298
static final String ROUTE = "/_test/chunked_response_with_headers";

server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import org.elasticsearch.core.RefCounted;
3939
import org.elasticsearch.rest.RestChannel;
4040
import org.elasticsearch.rest.RestRequest;
41-
import org.elasticsearch.rest.RestResponse;
4241
import org.elasticsearch.tasks.Task;
4342
import org.elasticsearch.telemetry.tracing.Tracer;
4443
import org.elasticsearch.threadpool.ThreadPool;
@@ -484,25 +483,22 @@ void dispatchRequest(final RestRequest restRequest, final RestChannel channel, f
484483
if (badRequestCause != null) {
485484
dispatcher.dispatchBadRequest(channel, threadContext, badRequestCause);
486485
} else {
487-
populatePerRequestThreadContext0(restRequest, channel, threadContext);
486+
try {
487+
populatePerRequestThreadContext(restRequest, threadContext);
488+
} catch (Exception e) {
489+
try {
490+
dispatcher.dispatchBadRequest(channel, threadContext, e);
491+
} catch (Exception inner) {
492+
inner.addSuppressed(e);
493+
logger.error(() -> "failed to send failure response for uri [" + restRequest.uri() + "]", inner);
494+
}
495+
return;
496+
}
488497
dispatcher.dispatchRequest(restRequest, channel, threadContext);
489498
}
490499
}
491500
}
492501

493-
private void populatePerRequestThreadContext0(RestRequest restRequest, RestChannel channel, ThreadContext threadContext) {
494-
try {
495-
populatePerRequestThreadContext(restRequest, threadContext);
496-
} catch (Exception e) {
497-
try {
498-
channel.sendResponse(new RestResponse(channel, e));
499-
} catch (Exception inner) {
500-
inner.addSuppressed(e);
501-
logger.error(() -> "failed to send failure response for uri [" + restRequest.uri() + "]", inner);
502-
}
503-
}
504-
}
505-
506502
protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {}
507503

508504
private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) {

server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,6 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th
355355
}
356356

357357
};
358-
// the set of headers to copy
359-
Set<RestHeaderDefinition> headers = Set.of(new RestHeaderDefinition(Task.TRACE_PARENT_HTTP_HEADER, false));
360358
// sample request headers to test with
361359
Map<String, List<String>> restHeaders = new HashMap<>();
362360
restHeaders.put(Task.TRACE_PARENT_HTTP_HEADER, Collections.singletonList(traceParentValue));
@@ -397,7 +395,7 @@ public HttpStats stats() {
397395

398396
@Override
399397
protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {
400-
getFakeActionModule(headers).copyRequestHeadersToThreadContext(restRequest.getHttpRequest(), threadContext);
398+
getFakeActionModule(Set.of()).copyRequestHeadersToThreadContext(restRequest.getHttpRequest(), threadContext);
401399
}
402400
}
403401
) {

x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTests.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,21 @@ public void testHttpHeaderAuthnBypassHeaderValidator() throws Exception {
370370
new NetworkService(List.of()),
371371
testThreadPool,
372372
xContentRegistry(),
373-
new NullDispatcher(),
373+
new HttpServerTransport.Dispatcher() {
374+
@Override
375+
public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) {
376+
fail("Request should not be dispatched");
377+
}
378+
379+
@Override
380+
public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext, Throwable cause) {
381+
try {
382+
channel.sendResponse(new RestResponse(channel, (Exception) cause));
383+
} catch (IOException e) {
384+
fail(e, "Unexpected exception dispatching bad request");
385+
}
386+
}
387+
},
374388
randomClusterSettings(),
375389
new SharedGroupFactory(settings),
376390
Tracer.NOOP,

0 commit comments

Comments
 (0)