Skip to content
Merged
6 changes: 6 additions & 0 deletions docs/changelog/135209.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 135209
summary: Fix expiration time in ES|QL async
area: ES|QL
type: bug
issues:
- 135169
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ public AsyncTaskManagementService(
public void asyncExecute(
Request request,
TimeValue waitForCompletionTimeout,
TimeValue keepAlive,
boolean keepOnCompletion,
ActionListener<Response> listener
) {
Expand All @@ -190,7 +189,7 @@ public void asyncExecute(
operation.execute(
request,
searchTask,
wrapStoringListener(searchTask, waitForCompletionTimeout, keepAlive, keepOnCompletion, listener)
wrapStoringListener(searchTask, waitForCompletionTimeout, keepOnCompletion, listener)
);
operationStarted = true;
} finally {
Expand All @@ -205,7 +204,6 @@ public void asyncExecute(
private ActionListener<Response> wrapStoringListener(
T searchTask,
TimeValue waitForCompletionTimeout,
TimeValue keepAlive,
boolean keepOnCompletion,
ActionListener<Response> listener
) {
Expand All @@ -227,7 +225,7 @@ private ActionListener<Response> wrapStoringListener(
if (keepOnCompletion) {
storeResults(
searchTask,
new StoredAsyncResponse<>(response, threadPool.absoluteTimeInMillis() + keepAlive.getMillis()),
new StoredAsyncResponse<>(response, searchTask.getExpirationTimeMillis()),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are the main fix.

ActionListener.running(() -> acquiredListener.onResponse(response))
);
} else {
Expand All @@ -239,7 +237,7 @@ private ActionListener<Response> wrapStoringListener(
// We finished after timeout - saving results
storeResults(
searchTask,
new StoredAsyncResponse<>(response, threadPool.absoluteTimeInMillis() + keepAlive.getMillis()),
new StoredAsyncResponse<>(response, searchTask.getExpirationTimeMillis()),
ActionListener.running(response::decRef)
);
}
Expand All @@ -251,7 +249,7 @@ private ActionListener<Response> wrapStoringListener(
if (keepOnCompletion) {
storeResults(
searchTask,
new StoredAsyncResponse<>(e, threadPool.absoluteTimeInMillis() + keepAlive.getMillis()),
new StoredAsyncResponse<>(e, searchTask.getExpirationTimeMillis()),
ActionListener.running(() -> acquiredListener.onFailure(e))
);
} else {
Expand All @@ -261,7 +259,7 @@ private ActionListener<Response> wrapStoringListener(
}
} else {
// We finished after timeout - saving exception
storeResults(searchTask, new StoredAsyncResponse<>(e, threadPool.absoluteTimeInMillis() + keepAlive.getMillis()));
storeResults(searchTask, new StoredAsyncResponse<>(e, searchTask.getExpirationTimeMillis()));
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.LegacyActionRequest;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -40,6 +41,7 @@

import static org.elasticsearch.xpack.esql.core.async.AsyncTaskManagementService.addCompletionListener;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;

Expand All @@ -52,9 +54,11 @@ public class AsyncTaskManagementServiceTests extends ESSingleNodeTestCase {

public static class TestRequest extends LegacyActionRequest {
private final String string;
private final TimeValue keepAlive;

public TestRequest(String string) {
public TestRequest(String string, TimeValue keepAlive) {
this.string = string;
this.keepAlive = keepAlive;
}

@Override
Expand Down Expand Up @@ -129,7 +133,7 @@ public TestTask createTask(
headers,
originHeaders,
asyncExecutionId,
TimeValue.timeValueDays(5)
request.keepAlive
);
}

Expand Down Expand Up @@ -172,7 +176,7 @@ public void setup() {
);
results = new AsyncResultsService<>(
store,
true,
false,
TestTask.class,
(task, listener, timeout) -> addCompletionListener(transportService.getThreadPool(), task, listener, timeout),
transportService.getTaskManager(),
Expand Down Expand Up @@ -212,23 +216,17 @@ public void testReturnBeforeTimeout() throws Exception {
boolean success = randomBoolean();
boolean keepOnCompletion = randomBoolean();
CountDownLatch latch = new CountDownLatch(1);
TestRequest request = new TestRequest(success ? randomAlphaOfLength(10) : "die");
service.asyncExecute(
request,
TimeValue.timeValueMinutes(1),
TimeValue.timeValueMinutes(10),
keepOnCompletion,
ActionListener.wrap(r -> {
assertThat(success, equalTo(true));
assertThat(r.string, equalTo("response for [" + request.string + "]"));
assertThat(r.id, notNullValue());
latch.countDown();
}, e -> {
assertThat(success, equalTo(false));
assertThat(e.getMessage(), equalTo("test exception"));
latch.countDown();
})
);
TestRequest request = new TestRequest(success ? randomAlphaOfLength(10) : "die", TimeValue.timeValueDays(1));
service.asyncExecute(request, TimeValue.timeValueMinutes(1), keepOnCompletion, ActionListener.wrap(r -> {
assertThat(success, equalTo(true));
assertThat(r.string, equalTo("response for [" + request.string + "]"));
assertThat(r.id, notNullValue());
latch.countDown();
}, e -> {
assertThat(success, equalTo(false));
assertThat(e.getMessage(), equalTo("test exception"));
latch.countDown();
}));
assertThat(latch.await(10, TimeUnit.SECONDS), equalTo(true));
}

Expand All @@ -252,20 +250,14 @@ public void execute(TestRequest request, TestTask task, ActionListener<TestRespo
boolean timeoutOnFirstAttempt = randomBoolean();
boolean waitForCompletion = randomBoolean();
CountDownLatch latch = new CountDownLatch(1);
TestRequest request = new TestRequest(success ? randomAlphaOfLength(10) : "die");
TestRequest request = new TestRequest(success ? randomAlphaOfLength(10) : "die", TimeValue.timeValueDays(1));
AtomicReference<TestResponse> responseHolder = new AtomicReference<>();
service.asyncExecute(
request,
TimeValue.timeValueMillis(1),
TimeValue.timeValueMinutes(10),
keepOnCompletion,
ActionTestUtils.assertNoFailureListener(r -> {
assertThat(r.string, nullValue());
assertThat(r.id, notNullValue());
assertThat(responseHolder.getAndSet(r), nullValue());
latch.countDown();
})
);
service.asyncExecute(request, TimeValue.timeValueMillis(1), keepOnCompletion, ActionTestUtils.assertNoFailureListener(r -> {
assertThat(r.string, nullValue());
assertThat(r.id, notNullValue());
assertThat(responseHolder.getAndSet(r), nullValue());
latch.countDown();
}));
assertThat(latch.await(20, TimeUnit.SECONDS), equalTo(true));

if (timeoutOnFirstAttempt) {
Expand All @@ -281,17 +273,11 @@ public void execute(TestRequest request, TestTask task, ActionListener<TestRespo
if (waitForCompletion) {
// now we are waiting for the task to finish
logger.trace("Waiting for response to complete");
AtomicReference<StoredAsyncResponse<TestResponse>> responseRef = new AtomicReference<>();
CountDownLatch getResponseCountDown = getResponse(
responseHolder.get().id,
TimeValue.timeValueSeconds(5),
ActionTestUtils.assertNoFailureListener(responseRef::set)
);
var getFuture = getResponse(responseHolder.get().id, TimeValue.timeValueSeconds(5), TimeValue.MINUS_ONE);

executionLatch.countDown();
assertThat(getResponseCountDown.await(10, TimeUnit.SECONDS), equalTo(true));
var response = safeGet(getFuture);

StoredAsyncResponse<TestResponse> response = responseRef.get();
if (success) {
assertThat(response.getException(), nullValue());
assertThat(response.getResponse(), notNullValue());
Expand Down Expand Up @@ -326,26 +312,46 @@ public void execute(TestRequest request, TestTask task, ActionListener<TestRespo
}
}

public void testUpdateKeepAliveToTask() throws Exception {
long now = System.currentTimeMillis();
CountDownLatch executionLatch = new CountDownLatch(1);
AsyncTaskManagementService<TestRequest, TestResponse, TestTask> service = createManagementService(new TestOperation() {
@Override
public void execute(TestRequest request, TestTask task, ActionListener<TestResponse> listener) {
executorService.submit(() -> {
try {
assertThat(executionLatch.await(10, TimeUnit.SECONDS), equalTo(true));
} catch (InterruptedException ex) {
throw new AssertionError(ex);
}
super.execute(request, task, listener);
});
}
});
TestRequest request = new TestRequest(randomAlphaOfLength(10), TimeValue.timeValueHours(1));
PlainActionFuture<TestResponse> submitResp = new PlainActionFuture<>();
try {
service.asyncExecute(request, TimeValue.timeValueMillis(1), true, submitResp);
String id = submitResp.get().id;
assertThat(id, notNullValue());
TimeValue keepAlive = TimeValue.timeValueDays(between(1, 10));
var resp1 = safeGet(getResponse(id, TimeValue.ZERO, keepAlive));
assertThat(resp1.getExpirationTime(), greaterThanOrEqualTo(now + keepAlive.millis()));
} finally {
executionLatch.countDown();
}
}

private StoredAsyncResponse<TestResponse> getResponse(String id, TimeValue timeout) throws InterruptedException {
AtomicReference<StoredAsyncResponse<TestResponse>> response = new AtomicReference<>();
assertThat(
getResponse(id, timeout, ActionTestUtils.assertNoFailureListener(response::set)).await(10, TimeUnit.SECONDS),
equalTo(true)
);
return response.get();
return safeGet(getResponse(id, timeout, TimeValue.MINUS_ONE));
}

private CountDownLatch getResponse(String id, TimeValue timeout, ActionListener<StoredAsyncResponse<TestResponse>> listener) {
CountDownLatch responseLatch = new CountDownLatch(1);
private PlainActionFuture<StoredAsyncResponse<TestResponse>> getResponse(String id, TimeValue timeout, TimeValue keepAlive) {
PlainActionFuture<StoredAsyncResponse<TestResponse>> future = new PlainActionFuture<>();
GetAsyncResultRequest getResultsRequest = new GetAsyncResultRequest(id).setWaitForCompletionTimeout(timeout);
results.retrieveResult(getResultsRequest, ActionListener.wrap(r -> {
listener.onResponse(r);
responseLatch.countDown();
}, e -> {
listener.onFailure(e);
responseLatch.countDown();
}));
return responseLatch;
getResultsRequest.setKeepAlive(keepAlive);
results.retrieveResult(getResultsRequest, future);
return future;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ public void testUpdateKeepAlive() throws Exception {
assertThat(resp.isRunning(), is(false));
}
});
assertThat(getExpirationFromDoc(asyncId), greaterThanOrEqualTo(nowInMillis + keepAlive.getMillis()));
// update the keepAlive after the query has completed
int iters = between(1, 5);
for (int i = 0; i < iters; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,7 @@ protected void doExecute(Task task, EsqlQueryRequest request, ActionListener<Esq
private void doExecuteForked(Task task, EsqlQueryRequest request, ActionListener<EsqlQueryResponse> listener) {
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH);
if (requestIsAsync(request)) {
asyncTaskManagementService.asyncExecute(
request,
request.waitForCompletionTimeout(),
request.keepAlive(),
request.keepOnCompletion(),
listener
);
asyncTaskManagementService.asyncExecute(request, request.waitForCompletionTimeout(), request.keepOnCompletion(), listener);
} else {
innerExecute(task, request, listener);
}
Expand Down