Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
import com.mongodb.connection.ClusterId;
import com.mongodb.connection.ClusterSettings;
import com.mongodb.connection.ClusterType;
import com.mongodb.connection.ServerDescription;
import com.mongodb.lang.Nullable;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;

import static com.mongodb.assertions.Assertions.assertNotNull;

Expand All @@ -38,7 +40,6 @@ public final class DnsMultiServerCluster extends AbstractMultiServerCluster {
private final DnsSrvRecordMonitor dnsSrvRecordMonitor;
private volatile MongoException srvResolutionException;


public DnsMultiServerCluster(final ClusterId clusterId, final ClusterSettings settings, final ClusterableServerFactory serverFactory,
final DnsSrvRecordMonitorFactory dnsSrvRecordMonitorFactory) {
super(clusterId, settings, serverFactory);
Expand All @@ -57,17 +58,33 @@ public void initialize(final Collection<ServerAddress> hosts) {
}
}

private Collection<ServerAddress> applySrvMaxHosts(final Collection<ServerAddress> hosts) {
Collection<ServerAddress> newHosts = hosts;
private Collection<ServerAddress> applySrvMaxHosts(final Collection<ServerAddress> latestSrvHosts) {
Integer srvMaxHosts = getSettings().getSrvMaxHosts();
if (srvMaxHosts != null && srvMaxHosts > 0) {
if (srvMaxHosts < hosts.size()) {
List<ServerAddress> newHostsList = new ArrayList<>(hosts);
Collections.shuffle(newHostsList, ThreadLocalRandom.current());
newHosts = newHostsList.subList(0, srvMaxHosts);
}
if (srvMaxHosts == null || srvMaxHosts <= 0 || latestSrvHosts.size() <= srvMaxHosts) {
return new ArrayList<>(latestSrvHosts);
}
return newHosts;
List<ServerAddress> activeHosts = getActivePriorHosts(latestSrvHosts);
int numNewHostsToAdd = srvMaxHosts - activeHosts.size();
activeHosts.addAll(addShuffledHosts(latestSrvHosts, activeHosts, numNewHostsToAdd));

return activeHosts;
}

private List<ServerAddress> getActivePriorHosts(final Collection<ServerAddress> latestSrvHosts) {
List<ServerAddress> priorHosts = DnsMultiServerCluster.this.getCurrentDescription().getServerDescriptions().stream()
.map(ServerDescription::getAddress).collect(Collectors.toList());
priorHosts.removeIf(host -> !latestSrvHosts.contains(host));

return priorHosts;
}

private List<ServerAddress> addShuffledHosts(final Collection<ServerAddress> latestSrvHosts,
final List<ServerAddress> activePriorHosts, final int numNewHostsToAdd) {
List<ServerAddress> addedHosts = new ArrayList<>(latestSrvHosts);
addedHosts.removeAll(activePriorHosts);
Collections.shuffle(addedHosts, ThreadLocalRandom.current());

return addedHosts.subList(0, numNewHostsToAdd);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,10 @@ public void shouldUseAllRecordsWhenSrvMaxHostsIsGreaterThanOrEqualToNumSrvRecord
public void shouldUseSrvMaxHostsWhenSrvMaxHostsIsLessThanNumSrvRecords() {
int srvMaxHosts = 2;
List<String> updatedHosts = asList(firstHost, thirdHost, fourthHost);

initCluster(updatedHosts, srvMaxHosts);

assertEquals(srvMaxHosts, clusterHostsSet().size());
assertTrue(updatedHosts.contains(firstHost));
assertTrue(updatedHosts.containsAll(clusterHostsSet()));
}

Expand Down