Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateObserver;
import org.elasticsearch.cluster.ClusterStateTaskConfig;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ClusterStateTaskListener;
import org.elasticsearch.cluster.MasterNodeChangePredicate;
import org.elasticsearch.cluster.NotMasterException;
Expand All @@ -48,6 +48,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.node.NodeClosedException;
Expand All @@ -68,7 +69,9 @@
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Predicate;

public class ShardStateAction extends AbstractComponent {
Expand All @@ -80,6 +83,10 @@ public class ShardStateAction extends AbstractComponent {
private final ClusterService clusterService;
private final ThreadPool threadPool;

// a list of shards that failed during replication
// we keep track of these shards in order to avoid sending duplicate failed shard requests for a single failing shard.
private final ConcurrentMap<FailedShardEntry, CompositeListener> remoteFailedShardsCache = ConcurrentCollections.newConcurrentMap();

@Inject
public ShardStateAction(Settings settings, ClusterService clusterService, TransportService transportService,
AllocationService allocationService, RoutingService routingService, ThreadPool threadPool) {
Expand Down Expand Up @@ -146,8 +153,35 @@ private static boolean isMasterChannelException(TransportException exp) {
*/
public void remoteShardFailed(final ShardId shardId, String allocationId, long primaryTerm, boolean markAsStale, final String message, @Nullable final Exception failure, Listener listener) {
assert primaryTerm > 0L : "primary term should be strictly positive";
FailedShardEntry shardEntry = new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale);
sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), shardEntry, listener);
final FailedShardEntry shardEntry = new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale);
final CompositeListener compositeListener = new CompositeListener(listener);
final CompositeListener existingListener = remoteFailedShardsCache.putIfAbsent(shardEntry, compositeListener);
if (existingListener == null) {
sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), shardEntry, new Listener() {
@Override
public void onSuccess() {
try {
compositeListener.onSuccess();
} finally {
remoteFailedShardsCache.remove(shardEntry);
}
}
@Override
public void onFailure(Exception e) {
try {
compositeListener.onFailure(e);
} finally {
remoteFailedShardsCache.remove(shardEntry);
}
}
});
} else {
existingListener.addListener(listener);
}
}

int remoteShardFailedCacheSize() {
return remoteFailedShardsCache.size();
}

/**
Expand Down Expand Up @@ -414,6 +448,23 @@ public String toString() {
components.add("markAsStale [" + markAsStale + "]");
return String.join(", ", components);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FailedShardEntry that = (FailedShardEntry) o;
// Exclude message and exception from equals and hashCode
return Objects.equals(this.shardId, that.shardId) &&
Objects.equals(this.allocationId, that.allocationId) &&
primaryTerm == that.primaryTerm &&
markAsStale == that.markAsStale;
}

@Override
public int hashCode() {
return Objects.hash(shardId, allocationId, primaryTerm, markAsStale);
}
}

public void shardStarted(final ShardRouting shardRouting, final String message, Listener listener) {
Expand Down Expand Up @@ -585,6 +636,72 @@ default void onFailure(final Exception e) {

}

/**
* A composite listener that allows registering multiple listeners dynamically.
*/
static final class CompositeListener implements Listener {
private boolean isNotified = false;
private Exception failure = null;
private final List<Listener> listeners = new ArrayList<>();

CompositeListener(Listener listener) {
listeners.add(listener);
}

void addListener(Listener listener) {
final boolean ready;
synchronized (this) {
ready = this.isNotified;
if (ready == false) {
listeners.add(listener);
}
}
if (ready) {
if (failure != null) {
listener.onFailure(failure);
} else {
listener.onSuccess();
}
}
}

private void onCompleted(Exception failure) {
synchronized (this) {
this.failure = failure;
this.isNotified = true;
}
RuntimeException firstException = null;
for (Listener listener : listeners) {
try {
if (failure != null) {
listener.onFailure(failure);
} else {
listener.onSuccess();
}
} catch (RuntimeException innerEx) {
if (firstException == null) {
firstException = innerEx;
} else {
firstException.addSuppressed(innerEx);
}
}
}
if (firstException != null) {
throw firstException;
}
}

@Override
public void onSuccess() {
onCompleted(null);
}

@Override
public void onFailure(Exception failure) {
onCompleted(failure);
}
}

public static class NoLongerPrimaryShardException extends ElasticsearchException {

public NoLongerPrimaryShardException(ShardId shardId, String msg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import com.carrotsearch.hppc.cursors.ObjectCursor;
import org.apache.lucene.index.CorruptIndexException;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.action.shard.ShardStateAction.FailedShardEntry;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ESAllocationTestCase;
import org.elasticsearch.cluster.action.shard.ShardStateAction.FailedShardEntry;
import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNodes;
Expand All @@ -43,6 +43,7 @@
import org.elasticsearch.cluster.routing.allocation.StaleShard;
import org.elasticsearch.cluster.routing.allocation.decider.ClusterRebalanceAllocationDecider;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.Index;
Expand All @@ -53,9 +54,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -131,10 +130,15 @@ ClusterState applyFailedShards(ClusterState currentState, List<FailedShard> fail
tasks.addAll(failingTasks);
tasks.addAll(nonExistentTasks);
ClusterStateTaskExecutor.ClusterTasksResult<FailedShardEntry> result = failingExecutor.execute(currentState, tasks);
Map<FailedShardEntry, ClusterStateTaskExecutor.TaskResult> taskResultMap =
failingTasks.stream().collect(Collectors.toMap(Function.identity(), task -> ClusterStateTaskExecutor.TaskResult.failure(new RuntimeException("simulated applyFailedShards failure"))));
taskResultMap.putAll(nonExistentTasks.stream().collect(Collectors.toMap(Function.identity(), task -> ClusterStateTaskExecutor.TaskResult.success())));
assertTaskResults(taskResultMap, result, currentState, false);
List<Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult>> taskResultList = new ArrayList<>();
for (FailedShardEntry failingTask : failingTasks) {
taskResultList.add(Tuple.tuple(failingTask,
ClusterStateTaskExecutor.TaskResult.failure(new RuntimeException("simulated applyFailedShards failure"))));
}
for (FailedShardEntry nonExistentTask : nonExistentTasks) {
taskResultList.add(Tuple.tuple(nonExistentTask, ClusterStateTaskExecutor.TaskResult.success()));
}
assertTaskResults(taskResultList, result, currentState, false);
}

public void testIllegalShardFailureRequests() throws Exception {
Expand All @@ -147,14 +151,14 @@ public void testIllegalShardFailureRequests() throws Exception {
tasks.add(new FailedShardEntry(failingTask.shardId, failingTask.allocationId,
randomIntBetween(1, (int) primaryTerm - 1), failingTask.message, failingTask.failure, randomBoolean()));
}
Map<FailedShardEntry, ClusterStateTaskExecutor.TaskResult> taskResultMap =
tasks.stream().collect(Collectors.toMap(
Function.identity(),
task -> ClusterStateTaskExecutor.TaskResult.failure(new ShardStateAction.NoLongerPrimaryShardException(task.shardId,
"primary term [" + task.primaryTerm + "] did not match current primary term [" +
currentState.metaData().index(task.shardId.getIndex()).primaryTerm(task.shardId.id()) + "]"))));
List<Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult>> taskResultList = tasks.stream()
.map(task -> Tuple.tuple(task, ClusterStateTaskExecutor.TaskResult.failure(
new ShardStateAction.NoLongerPrimaryShardException(task.shardId, "primary term ["
+ task.primaryTerm + "] did not match current primary term ["
+ currentState.metaData().index(task.shardId.getIndex()).primaryTerm(task.shardId.id()) + "]"))))
.collect(Collectors.toList());
ClusterStateTaskExecutor.ClusterTasksResult<FailedShardEntry> result = executor.execute(currentState, tasks);
assertTaskResults(taskResultMap, result, currentState, false);
assertTaskResults(taskResultList, result, currentState, false);
}

public void testMarkAsStaleWhenFailingShard() throws Exception {
Expand Down Expand Up @@ -251,44 +255,44 @@ private static void assertTasksSuccessful(
ClusterState clusterState,
boolean clusterStateChanged
) {
Map<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult> taskResultMap =
tasks.stream().collect(Collectors.toMap(Function.identity(), task -> ClusterStateTaskExecutor.TaskResult.success()));
assertTaskResults(taskResultMap, result, clusterState, clusterStateChanged);
List<Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult>> taskResultList = tasks.stream()
.map(t -> Tuple.tuple(t, ClusterStateTaskExecutor.TaskResult.success())).collect(Collectors.toList());
assertTaskResults(taskResultList, result, clusterState, clusterStateChanged);
}

private static void assertTaskResults(
Map<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult> taskResultMap,
List<Tuple<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult>> taskResultList,
ClusterStateTaskExecutor.ClusterTasksResult<ShardStateAction.FailedShardEntry> result,
ClusterState clusterState,
boolean clusterStateChanged
) {
// there should be as many task results as tasks
assertEquals(taskResultMap.size(), result.executionResults.size());
assertEquals(taskResultList.size(), result.executionResults.size());

for (Map.Entry<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult> entry : taskResultMap.entrySet()) {
for (Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult> entry : taskResultList) {
// every task should have a corresponding task result
assertTrue(result.executionResults.containsKey(entry.getKey()));
assertTrue(result.executionResults.containsKey(entry.v1()));

// the task results are as expected
assertEquals(entry.getKey().toString(), entry.getValue().isSuccess(), result.executionResults.get(entry.getKey()).isSuccess());
assertEquals(entry.v1().toString(), entry.v2().isSuccess(), result.executionResults.get(entry.v1()).isSuccess());
}

List<ShardRouting> shards = clusterState.getRoutingTable().allShards();
for (Map.Entry<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult> entry : taskResultMap.entrySet()) {
if (entry.getValue().isSuccess()) {
for (Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult> entry : taskResultList) {
if (entry.v2().isSuccess()) {
// the shard was successfully failed and so should not be in the routing table
for (ShardRouting shard : shards) {
if (shard.assignedToNode()) {
assertFalse("entry key " + entry.getKey() + ", shard routing " + shard,
entry.getKey().getShardId().equals(shard.shardId()) &&
entry.getKey().getAllocationId().equals(shard.allocationId().getId()));
assertFalse("entry key " + entry.v1() + ", shard routing " + shard,
entry.v1().getShardId().equals(shard.shardId()) &&
entry.v1().getAllocationId().equals(shard.allocationId().getId()));
}
}
} else {
// check we saw the expected failure
ClusterStateTaskExecutor.TaskResult actualResult = result.executionResults.get(entry.getKey());
assertThat(actualResult.getFailure(), instanceOf(entry.getValue().getFailure().getClass()));
assertThat(actualResult.getFailure().getMessage(), equalTo(entry.getValue().getFailure().getMessage()));
ClusterStateTaskExecutor.TaskResult actualResult = result.executionResults.get(entry.v1());
assertThat(actualResult.getFailure(), instanceOf(entry.v2().getFailure().getClass()));
assertThat(actualResult.getFailure().getMessage(), equalTo(entry.v2().getFailure().getMessage()));
}
}

Expand Down
Loading