Skip to content

Commit c22faad

Browse files
committed
Resume Driver on cancelled or early finished
1 parent b1f4fa7 commit c22faad

File tree

7 files changed

+127
-48
lines changed

7 files changed

+127
-48
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.action.support.ContextPreservingActionListener;
1212
import org.elasticsearch.action.support.SubscribableListener;
1313
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
14+
import org.elasticsearch.common.util.concurrent.EsExecutors;
1415
import org.elasticsearch.common.util.concurrent.ThreadContext;
1516
import org.elasticsearch.compute.Describable;
1617
import org.elasticsearch.compute.data.Page;
@@ -74,10 +75,12 @@ public class Driver implements Releasable, Describable {
7475
private final long statusNanos;
7576

7677
private final AtomicReference<String> cancelReason = new AtomicReference<>();
77-
private final AtomicReference<SubscribableListener<Void>> blocked = new AtomicReference<>();
7878

7979
private final AtomicBoolean started = new AtomicBoolean();
80+
private final AtomicBoolean earlyFinished = new AtomicBoolean();
8081
private final SubscribableListener<Void> completionListener = new SubscribableListener<>();
82+
private final AtomicReference<SubscribableListener<Void>> blocked = new AtomicReference<>();
83+
private final AtomicReference<Runnable> scheduledTask = new AtomicReference<>();
8184

8285
/**
8386
* Status reported to the tasks API. We write the status at most once every
@@ -215,6 +218,23 @@ SubscribableListener<Void> run(TimeValue maxTime, int maxIterations, LongSupplie
215218
}
216219
}
217220

221+
private void earlyFinish() {
222+
earlyFinished.set(true);
223+
tryResumeDriver();
224+
}
225+
226+
private void tryResumeDriver() {
227+
// Attempt to run the scheduled task on the current thread to finish the driver by closing operators
228+
SubscribableListener<Void> blockingFuture = blocked.get();
229+
if (blockingFuture != null) {
230+
blockingFuture.onResponse(null);
231+
}
232+
Runnable task = scheduledTask.getAndSet(null);
233+
if (task != null) {
234+
task.run();
235+
}
236+
}
237+
218238
/**
219239
* Whether the driver has run the chain of operators to completion.
220240
*/
@@ -245,31 +265,33 @@ private IsBlockedResult runSingleLoopIteration() {
245265
ensureNotCancelled();
246266
boolean movedPage = false;
247267

248-
for (int i = 0; i < activeOperators.size() - 1; i++) {
249-
Operator op = activeOperators.get(i);
250-
Operator nextOp = activeOperators.get(i + 1);
268+
if (earlyFinished.get() == false) {
269+
for (int i = 0; i < activeOperators.size() - 1; i++) {
270+
Operator op = activeOperators.get(i);
271+
Operator nextOp = activeOperators.get(i + 1);
251272

252-
// skip blocked operator
253-
if (op.isBlocked().listener().isDone() == false) {
254-
continue;
255-
}
273+
// skip blocked operator
274+
if (op.isBlocked().listener().isDone() == false) {
275+
continue;
276+
}
256277

257-
if (op.isFinished() == false && nextOp.needsInput()) {
258-
Page page = op.getOutput();
259-
if (page == null) {
260-
// No result, just move to the next iteration
261-
} else if (page.getPositionCount() == 0) {
262-
// Empty result, release any memory it holds immediately and move to the next iteration
263-
page.releaseBlocks();
264-
} else {
265-
// Non-empty result from the previous operation, move it to the next operation
266-
nextOp.addInput(page);
267-
movedPage = true;
278+
if (op.isFinished() == false && nextOp.needsInput()) {
279+
Page page = op.getOutput();
280+
if (page == null) {
281+
// No result, just move to the next iteration
282+
} else if (page.getPositionCount() == 0) {
283+
// Empty result, release any memory it holds immediately and move to the next iteration
284+
page.releaseBlocks();
285+
} else {
286+
// Non-empty result from the previous operation, move it to the next operation
287+
nextOp.addInput(page);
288+
movedPage = true;
289+
}
268290
}
269-
}
270291

271-
if (op.isFinished()) {
272-
nextOp.finish();
292+
if (op.isFinished()) {
293+
nextOp.finish();
294+
}
273295
}
274296
}
275297

@@ -312,19 +334,10 @@ private IsBlockedResult runSingleLoopIteration() {
312334

313335
public void cancel(String reason) {
314336
if (cancelReason.compareAndSet(null, reason)) {
315-
synchronized (this) {
316-
SubscribableListener<Void> fut = this.blocked.get();
317-
if (fut != null) {
318-
fut.onFailure(new TaskCancelledException(reason));
319-
}
320-
}
337+
tryResumeDriver();
321338
}
322339
}
323340

324-
private boolean isCancelled() {
325-
return cancelReason.get() != null;
326-
}
327-
328341
private void ensureNotCancelled() {
329342
String reason = cancelReason.get();
330343
if (reason != null) {
@@ -342,6 +355,12 @@ public static void start(
342355
driver.completionListener.addListener(listener);
343356
if (driver.started.compareAndSet(false, true)) {
344357
driver.updateStatus(0, 0, DriverStatus.Status.STARTING, "driver starting");
358+
if (driver.activeOperators.isEmpty() == false) {
359+
var onFinishedListener = driver.activeOperators.getLast().onFinishedListener();
360+
if (onFinishedListener != null) {
361+
onFinishedListener.addListener(ActionListener.running(driver::earlyFinish));
362+
}
363+
}
345364
schedule(DEFAULT_TIME_BEFORE_YIELDING, maxIterations, threadContext, executor, driver, driver.completionListener);
346365
}
347366
}
@@ -371,7 +390,7 @@ private static void schedule(
371390
Driver driver,
372391
ActionListener<Void> listener
373392
) {
374-
executor.execute(new AbstractRunnable() {
393+
final var task = new AbstractRunnable() {
375394

376395
@Override
377396
protected void doRun() {
@@ -383,16 +402,12 @@ protected void doRun() {
383402
if (fut.isDone()) {
384403
schedule(maxTime, maxIterations, threadContext, executor, driver, listener);
385404
} else {
386-
synchronized (driver) {
387-
if (driver.isCancelled() == false) {
388-
driver.blocked.set(fut);
389-
}
390-
}
391405
ActionListener<Void> readyListener = ActionListener.wrap(
392406
ignored -> schedule(maxTime, maxIterations, threadContext, executor, driver, listener),
393407
this::onFailure
394408
);
395409
fut.addListener(ContextPreservingActionListener.wrapPreservingContext(readyListener, threadContext));
410+
driver.blocked.set(fut);
396411
}
397412
}
398413

@@ -405,6 +420,17 @@ public void onFailure(Exception e) {
405420
void onComplete(ActionListener<Void> listener) {
406421
driver.driverContext.waitForAsyncActions(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext));
407422
}
423+
};
424+
final Runnable existing = driver.scheduledTask.getAndSet(task);
425+
assert existing == null : existing;
426+
final boolean runOnCurrentThread = driver.earlyFinished.get() || driver.cancelReason.get() != null;
427+
final Executor executorToUse = runOnCurrentThread ? EsExecutors.DIRECT_EXECUTOR_SERVICE : executor;
428+
executorToUse.execute(() -> {
429+
final Runnable next = driver.scheduledTask.getAndSet(null);
430+
if (next != null) {
431+
assert next == task;
432+
next.run();
433+
}
408434
});
409435
}
410436

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Operator.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.compute.Describable;
1414
import org.elasticsearch.compute.data.Block;
1515
import org.elasticsearch.compute.data.Page;
16+
import org.elasticsearch.core.Nullable;
1617
import org.elasticsearch.core.Releasable;
1718
import org.elasticsearch.xcontent.ToXContentObject;
1819

@@ -62,6 +63,14 @@ public interface Operator extends Releasable {
6263
*/
6364
boolean isFinished();
6465

66+
/**
67+
* Returns a listener that is notified when the operator is finished. This returns null if the operator doesn't support this feature.
68+
*/
69+
@Nullable
70+
default SubscribableListener<Void> onFinishedListener() {
71+
return null;
72+
}
73+
6574
/**
6675
* returns non-null if output page available. Only called when isFinished() == false
6776
* @throws UnsupportedOperationException if the operator is a {@link SinkOperator}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSink.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.compute.operator.exchange;
99

10+
import org.elasticsearch.action.support.SubscribableListener;
1011
import org.elasticsearch.compute.data.Page;
1112
import org.elasticsearch.compute.operator.IsBlockedResult;
1213

@@ -26,12 +27,12 @@ public interface ExchangeSink {
2627
void finish();
2728

2829
/**
29-
* Whether the sink has received all pages
30+
* Whether the sink is blocked on adding more pages
3031
*/
31-
boolean isFinished();
32+
IsBlockedResult waitForWriting();
3233

3334
/**
34-
* Whether the sink is blocked on adding more pages
35+
* Returns a listener that is notified when the sink is finished
3536
*/
36-
IsBlockedResult waitForWriting();
37+
SubscribableListener<Void> onFinished();
3738
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkHandler.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ public ExchangeSinkHandler(BlockFactory blockFactory, int maxBufferSize, LongSup
5252

5353
private class ExchangeSinkImpl implements ExchangeSink {
5454
boolean finished;
55+
private final SubscribableListener<Void> onFinished = new SubscribableListener<>();
5556

5657
ExchangeSinkImpl() {
5758
onChanged();
59+
buffer.addCompletionListener(onFinished);
5860
outstandingSinks.incrementAndGet();
5961
}
6062

@@ -68,6 +70,7 @@ public void addPage(Page page) {
6870
public void finish() {
6971
if (finished == false) {
7072
finished = true;
73+
onFinished.onResponse(null);
7174
onChanged();
7275
if (outstandingSinks.decrementAndGet() == 0) {
7376
buffer.finish(false);
@@ -77,8 +80,8 @@ public void finish() {
7780
}
7881

7982
@Override
80-
public boolean isFinished() {
81-
return finished || buffer.isFinished();
83+
public SubscribableListener<Void> onFinished() {
84+
return onFinished;
8285
}
8386

8487
@Override

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperator.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.action.support.SubscribableListener;
1213
import org.elasticsearch.common.Strings;
1314
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1415
import org.elasticsearch.common.io.stream.StreamInput;
@@ -56,7 +57,12 @@ public ExchangeSinkOperator(ExchangeSink sink, Function<Page, Page> transformer)
5657

5758
@Override
5859
public boolean isFinished() {
59-
return sink.isFinished();
60+
return sink.onFinished().isDone();
61+
}
62+
63+
@Override
64+
public SubscribableListener<Void> onFinishedListener() {
65+
return sink.onFinished();
6066
}
6167

6268
@Override

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverTests.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.action.ActionRunnable;
12+
import org.elasticsearch.action.support.PlainActionFuture;
1213
import org.elasticsearch.common.breaker.CircuitBreaker;
1314
import org.elasticsearch.common.settings.Settings;
1415
import org.elasticsearch.common.unit.ByteSizeValue;
@@ -21,6 +22,10 @@
2122
import org.elasticsearch.compute.data.BlockFactory;
2223
import org.elasticsearch.compute.data.ElementType;
2324
import org.elasticsearch.compute.data.Page;
25+
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
26+
import org.elasticsearch.compute.operator.exchange.ExchangeSinkOperator;
27+
import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler;
28+
import org.elasticsearch.compute.operator.exchange.ExchangeSourceOperator;
2429
import org.elasticsearch.core.TimeValue;
2530
import org.elasticsearch.test.ESTestCase;
2631
import org.elasticsearch.threadpool.FixedExecutorBuilder;
@@ -35,8 +40,10 @@
3540
import java.util.concurrent.CountDownLatch;
3641
import java.util.concurrent.CyclicBarrier;
3742
import java.util.concurrent.TimeUnit;
43+
import java.util.function.Function;
3844
import java.util.function.LongSupplier;
3945

46+
import static org.hamcrest.Matchers.either;
4047
import static org.hamcrest.Matchers.equalTo;
4148

4249
public class DriverTests extends ESTestCase {
@@ -273,6 +280,33 @@ public Page getOutput() {
273280
}
274281
}
275282

283+
public void testResumeOnEarlyFinish() throws Exception {
284+
DriverContext driverContext = driverContext();
285+
ThreadPool threadPool = threadPool();
286+
try {
287+
PlainActionFuture<Void> sourceFuture = new PlainActionFuture<>();
288+
var sourceHandler = new ExchangeSourceHandler(between(1, 5), threadPool.executor("esql"), sourceFuture);
289+
var sinkHandler = new ExchangeSinkHandler(driverContext.blockFactory(), between(1, 5), System::currentTimeMillis);
290+
var sourceOperator = new ExchangeSourceOperator(sourceHandler.createExchangeSource());
291+
var sinkOperator = new ExchangeSinkOperator(sinkHandler.createExchangeSink(), Function.identity());
292+
Driver driver = new Driver(driverContext, sourceOperator, List.of(), sinkOperator, () -> {});
293+
PlainActionFuture<Void> future = new PlainActionFuture<>();
294+
Driver.start(threadPool.getThreadContext(), threadPool.executor("esql"), driver, between(1, 1000), future);
295+
assertBusy(
296+
() -> assertThat(
297+
driver.status().status(),
298+
either(equalTo(DriverStatus.Status.ASYNC)).or(equalTo(DriverStatus.Status.STARTING))
299+
)
300+
);
301+
sinkHandler.fetchPageAsync(true, ActionListener.noop());
302+
future.actionGet(5, TimeUnit.SECONDS);
303+
assertThat(driver.status().status(), equalTo(DriverStatus.Status.DONE));
304+
sourceFuture.actionGet(5, TimeUnit.SECONDS);
305+
} finally {
306+
terminate(threadPool);
307+
}
308+
}
309+
276310
private static void assertRunningWithRegularUser(ThreadPool threadPool) {
277311
String user = threadPool.getThreadContext().getHeader("user");
278312
assertThat(user, equalTo("user1"));

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ public void testBasic() throws Exception {
129129
// sink buffer is full
130130
assertFalse(randomFrom(sink1, sink2).waitForWriting().listener().isDone());
131131
sink1.finish();
132-
assertTrue(sink1.isFinished());
132+
assertTrue(sink1.onFinished().isDone());
133133
for (int i = 0; i < 5; i++) {
134134
assertBusy(() -> assertTrue(source.waitForReading().listener().isDone()));
135135
assertEquals(pages[2 + i], source.pollPage());
@@ -138,7 +138,7 @@ public void testBasic() throws Exception {
138138
assertFalse(source.waitForReading().listener().isDone());
139139
assertBusy(() -> assertTrue(sink2.waitForWriting().listener().isDone()));
140140
sink2.finish();
141-
assertTrue(sink2.isFinished());
141+
assertTrue(sink2.onFinished().isDone());
142142
assertTrue(source.isFinished());
143143
assertFalse(sourceCompletion.isDone());
144144
source.finish();
@@ -440,7 +440,7 @@ public void testClosingSinks() {
440440
assertTrue(resp.finished());
441441
assertNull(resp.takePage());
442442
assertTrue(sink.waitForWriting().listener().isDone());
443-
assertTrue(sink.isFinished());
443+
assertTrue(sink.onFinished().isDone());
444444
}
445445

446446
public void testFinishEarly() throws Exception {

0 commit comments

Comments
 (0)