Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/96279.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 96279
summary: Improve cancellability in `TransportTasksAction`
area: Task Management
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,36 @@

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.NoSuchNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.CancellableFanOut;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;

import static java.util.Collections.emptyList;

/**
* The base class for transport actions that are interacting with currently running tasks.
Expand Down Expand Up @@ -85,67 +81,113 @@ protected TransportTasksAction(

@Override
protected void doExecute(Task task, TasksRequest request, ActionListener<TasksResponse> listener) {
new AsyncAction(task, request, listener).start();
}
final var discoveryNodes = clusterService.state().nodes();
final String[] nodeIds = resolveNodes(request, discoveryNodes);

new CancellableFanOut<String, NodeTasksResponse, TasksResponse>() {
final ArrayList<TaskResponse> taskResponses = new ArrayList<>();
final ArrayList<TaskOperationFailure> taskOperationFailures = new ArrayList<>();
final ArrayList<FailedNodeException> failedNodeExceptions = new ArrayList<>();
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout());

@Override
protected void sendItemRequest(String nodeId, ActionListener<NodeTasksResponse> listener) {
final var discoveryNode = discoveryNodes.get(nodeId);
if (discoveryNode == null) {
listener.onFailure(new NoSuchNodeException(nodeId));
return;
}

transportService.sendChildRequest(
discoveryNode,
transportNodeAction,
new NodeTaskRequest(request),
task,
transportRequestOptions,
new ActionListenerResponseHandler<>(listener, nodeResponseReader)
);
}

@Override
protected void onItemResponse(String nodeId, NodeTasksResponse nodeTasksResponse) {
addAllSynchronized(taskResponses, nodeTasksResponse.results);
addAllSynchronized(taskOperationFailures, nodeTasksResponse.exceptions);
}

@SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
private static <T> void addAllSynchronized(List<T> allResults, Collection<T> response) {
if (response.isEmpty() == false) {
synchronized (allResults) {
allResults.addAll(response);
}
}
}

@Override
protected void onItemFailure(String nodeId, Exception e) {
logger.debug(() -> Strings.format("failed to execute on node [{}]", nodeId), e);
synchronized (failedNodeExceptions) {
failedNodeExceptions.add(new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", e));
}
}

@Override
protected TasksResponse onCompletion() {
// ref releases all happen-before here so no need to be synchronized
return newResponse(request, taskResponses, taskOperationFailures, failedNodeExceptions);
}

private void nodeOperation(CancellableTask task, NodeTaskRequest nodeTaskRequest, ActionListener<NodeTasksResponse> listener) {
TasksRequest request = nodeTaskRequest.tasksRequest;
processTasks(request, ActionListener.wrap(tasks -> nodeOperation(task, listener, request, tasks), listener::onFailure));
@Override
public String toString() {
return actionName;
}
}.run(task, Iterators.forArray(nodeIds), listener);
}

// not an inline method reference to avoid capturing CancellableFanOut.this.
private final Writeable.Reader<NodeTasksResponse> nodeResponseReader = NodeTasksResponse::new;

private void nodeOperation(
CancellableTask task,
CancellableTask nodeTask,
ActionListener<NodeTasksResponse> listener,
TasksRequest request,
List<OperationTask> tasks
List<OperationTask> operationTasks
) {
if (tasks.isEmpty()) {
listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), emptyList(), emptyList()));
return;
}
AtomicArray<Tuple<TaskResponse, Exception>> responses = new AtomicArray<>(tasks.size());
final AtomicInteger counter = new AtomicInteger(tasks.size());
for (int i = 0; i < tasks.size(); i++) {
final int taskIndex = i;
ActionListener<TaskResponse> taskListener = new ActionListener<TaskResponse>() {
@Override
public void onResponse(TaskResponse response) {
responses.setOnce(taskIndex, response == null ? null : new Tuple<>(response, null));
respondIfFinished();
}
new CancellableFanOut<OperationTask, TaskResponse, NodeTasksResponse>() {

@Override
public void onFailure(Exception e) {
responses.setOnce(taskIndex, new Tuple<>(null, e));
respondIfFinished();
final ArrayList<TaskResponse> results = new ArrayList<>(operationTasks.size());
final ArrayList<TaskOperationFailure> exceptions = new ArrayList<>();

@Override
protected void sendItemRequest(OperationTask operationTask, ActionListener<TaskResponse> listener) {
ActionListener.run(listener, l -> taskOperation(nodeTask, request, operationTask, l));
}

@Override
protected void onItemResponse(OperationTask operationTask, TaskResponse taskResponse) {
synchronized (results) {
results.add(taskResponse);
}
}

private void respondIfFinished() {
if (counter.decrementAndGet() != 0) {
return;
}
List<TaskResponse> results = new ArrayList<>();
List<TaskOperationFailure> exceptions = new ArrayList<>();
for (Tuple<TaskResponse, Exception> response : responses.asList()) {
if (response.v1() == null) {
assert response.v2() != null;
exceptions.add(
new TaskOperationFailure(clusterService.localNode().getId(), tasks.get(taskIndex).getId(), response.v2())
);
} else {
assert response.v2() == null;
results.add(response.v1());
}
}
listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions));
@Override
protected void onItemFailure(OperationTask operationTask, Exception e) {
synchronized (exceptions) {
exceptions.add(new TaskOperationFailure(clusterService.localNode().getId(), operationTask.getId(), e));
}
};
try {
taskOperation(task, request, tasks.get(taskIndex), taskListener);
} catch (Exception e) {
taskListener.onFailure(e);
}
}

@Override
protected NodeTasksResponse onCompletion() {
// ref releases all happen-before here so no need to be synchronized
return new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions);
}

@Override
public String toString() {
return transportNodeAction;
}
}.run(nodeTask, operationTasks.iterator(), listener);
}

protected String[] resolveNodes(TasksRequest request, DiscoveryNodes discoveryNodes) {
Expand Down Expand Up @@ -192,28 +234,6 @@ protected abstract TasksResponse newResponse(
List<FailedNodeException> failedNodeExceptions
);

@SuppressWarnings("unchecked")
protected TasksResponse newResponse(TasksRequest request, AtomicReferenceArray<?> responses) {
List<TaskResponse> tasks = new ArrayList<>();
List<FailedNodeException> failedNodeExceptions = new ArrayList<>();
List<TaskOperationFailure> taskOperationFailures = new ArrayList<>();
for (int i = 0; i < responses.length(); i++) {
Object response = responses.get(i);
if (response instanceof FailedNodeException) {
failedNodeExceptions.add((FailedNodeException) response);
} else {
NodeTasksResponse tasksResponse = (NodeTasksResponse) response;
if (tasksResponse.results != null) {
tasks.addAll(tasksResponse.results);
}
if (tasksResponse.exceptions != null) {
taskOperationFailures.addAll(tasksResponse.exceptions);
}
}
}
return newResponse(request, tasks, taskOperationFailures, failedNodeExceptions);
}

/**
* Perform the required operation on the task. It is OK start an asynchronous operation or to throw an exception but not both.
* @param actionTask The related transport action task. Can be used to create a task ID to handle upstream transport cancellations.
Expand All @@ -228,120 +248,18 @@ protected abstract void taskOperation(
ActionListener<TaskResponse> listener
);

private class AsyncAction {

private final TasksRequest request;
private final String[] nodesIds;
private final DiscoveryNode[] nodes;
private final ActionListener<TasksResponse> listener;
private final AtomicReferenceArray<Object> responses;
private final AtomicInteger counter = new AtomicInteger();
private final Task task;

private AsyncAction(Task task, TasksRequest request, ActionListener<TasksResponse> listener) {
this.task = task;
this.request = request;
this.listener = listener;
final DiscoveryNodes discoveryNodes = clusterService.state().nodes();
this.nodesIds = resolveNodes(request, discoveryNodes);
Map<String, DiscoveryNode> nodes = discoveryNodes.getNodes();
this.nodes = new DiscoveryNode[nodesIds.length];
for (int i = 0; i < this.nodesIds.length; i++) {
this.nodes[i] = nodes.get(this.nodesIds[i]);
}
this.responses = new AtomicReferenceArray<>(this.nodesIds.length);
}

private void start() {
if (nodesIds.length == 0) {
// nothing to do
try {
listener.onResponse(newResponse(request, responses));
} catch (Exception e) {
logger.debug("failed to generate empty response", e);
listener.onFailure(e);
}
} else {
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout());
for (int i = 0; i < nodesIds.length; i++) {
final String nodeId = nodesIds[i];
final int idx = i;
final DiscoveryNode node = nodes[i];
try {
if (node == null) {
onFailure(idx, nodeId, new NoSuchNodeException(nodeId));
} else {
NodeTaskRequest nodeRequest = new NodeTaskRequest(request);
nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId());
transportService.sendRequest(
node,
transportNodeAction,
nodeRequest,
transportRequestOptions,
new TransportResponseHandler<NodeTasksResponse>() {
@Override
public NodeTasksResponse read(StreamInput in) throws IOException {
return new NodeTasksResponse(in);
}

@Override
public void handleResponse(NodeTasksResponse response) {
onOperation(idx, response);
}

@Override
public void handleException(TransportException exp) {
onFailure(idx, node.getId(), exp);
}
}
);
}
} catch (Exception e) {
onFailure(idx, nodeId, e);
}
}
}
}

private void onOperation(int idx, NodeTasksResponse nodeResponse) {
responses.set(idx, nodeResponse);
if (counter.incrementAndGet() == responses.length()) {
finishHim();
}
}

private void onFailure(int idx, String nodeId, Throwable t) {
logger.debug(() -> "failed to execute on node [" + nodeId + "]", t);

responses.set(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t));

if (counter.incrementAndGet() == responses.length()) {
finishHim();
}
}

private void finishHim() {
if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) {
return;
}
TasksResponse finalResponse;
try {
finalResponse = newResponse(request, responses);
} catch (Exception e) {
logger.debug("failed to combine responses from nodes", e);
listener.onFailure(e);
return;
}
listener.onResponse(finalResponse);
}
}

class NodeTransportHandler implements TransportRequestHandler<NodeTaskRequest> {

@Override
public void messageReceived(final NodeTaskRequest request, final TransportChannel channel, Task task) throws Exception {
assert task instanceof CancellableTask;
nodeOperation((CancellableTask) task, request, new ChannelActionListener<>(channel));
TasksRequest tasksRequest = request.tasksRequest;
processTasks(
tasksRequest,
new ChannelActionListener<NodeTasksResponse>(channel).delegateFailure(
(l, tasks) -> nodeOperation((CancellableTask) task, l, tasksRequest, tasks)
)
);
}
}

Expand Down
Loading