Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import com.datastax.oss.driver.api.core.ConsistencyLevel;
import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.RequestRoutingType;
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
import com.datastax.oss.driver.api.core.context.DriverContext;
Expand Down Expand Up @@ -63,6 +64,9 @@
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -71,8 +75,10 @@
import java.util.Queue;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntUnaryOperator;
import java.util.stream.Collectors;
import net.jcip.annotations.ThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -113,6 +119,11 @@
@ThreadSafe
public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {

public enum RequestRoutingMethod {
REGULAR,
PRESERVE_REPLICA_ORDER
}

private static final Logger LOG = LoggerFactory.getLogger(BasicLoadBalancingPolicy.class);

protected static final IntUnaryOperator INCREMENT = i -> (i == Integer.MAX_VALUE) ? 0 : i + 1;
Expand All @@ -127,6 +138,7 @@ public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
private final int maxNodesPerRemoteDc;
private final boolean allowDcFailoverForLocalCl;
private final ConsistencyLevel defaultConsistencyLevel;
private final RequestRoutingMethod lwtRequestRoutingMethod;

// private because they should be set in init() and never be modified after
private volatile DistanceReporter distanceReporter;
Expand Down Expand Up @@ -154,6 +166,34 @@ public BasicLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String
new LinkedHashSet<>(
profile.getStringList(
DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS));
this.lwtRequestRoutingMethod = parseLwtRequestRoutingMethod();
}

@NonNull
private RequestRoutingMethod parseLwtRequestRoutingMethod() {
String methodString =
profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD);
try {
return RequestRoutingMethod.valueOf(methodString.toUpperCase());
} catch (IllegalArgumentException e) {
LOG.warn(
"[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER",
logPrefix,
methodString);
return RequestRoutingMethod.PRESERVE_REPLICA_ORDER;
}
}

@NonNull
public RequestRoutingMethod getRequestRoutingMethod(@Nullable Request request) {
if (request == null) {
return RequestRoutingMethod.REGULAR;
}
if (request.getRequestRoutingType() == RequestRoutingType.LWT) {
return lwtRequestRoutingMethod;
} else {
return RequestRoutingMethod.REGULAR;
}
}

/**
Expand Down Expand Up @@ -260,6 +300,17 @@ protected NodeDistanceEvaluator createNodeDistanceEvaluator(
@NonNull
@Override
public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session session) {
switch (getRequestRoutingMethod(request)) {
case PRESERVE_REPLICA_ORDER:
return newQueryPlanPreserveReplicas(request, session);
case REGULAR:
default:
return newQueryPlanRegular(request, session);
}
}

@NonNull
protected Queue<Node> newQueryPlanRegular(@Nullable Request request, @Nullable Session session) {
// Take a snapshot since the set is concurrent:
Object[] currentNodes = liveNodes.dc(localDc).toArray();

Expand Down Expand Up @@ -294,6 +345,101 @@ public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session ses
return maybeAddDcFailover(request, plan);
}

/**
* Builds a query plan that preserves replica order: local replicas, remote replicas, local
* non-replicas (rotated), remote non-replicas (rotated).
*/
@NonNull
protected Queue<Node> newQueryPlanPreserveReplicas(
@Nullable Request request, @Nullable Session session) {
List<Node> replicas = getReplicas(request, session);
String localDc = getLocalDatacenter();
List<Node> queryPlan = new ArrayList<>();

if (localDc == null) {
// No local DC: all replicas first, then rotated non-replicas
List<Node> allNodes = new ArrayList<>();
for (Object obj : getLiveNodes().dc(null).toArray()) {
allNodes.add((Node) obj);
}
queryPlan.addAll(replicas);
addRotatedNonReplicas(queryPlan, allNodes, replicas, request);
} else {
// With local DC: prioritize local, then remote
Map<String, List<Node>> nodesByDc = getAllNodesByDc();
addReplicasByDc(queryPlan, replicas, localDc);
addNonReplicasByDc(queryPlan, nodesByDc, replicas, localDc, request);
}

return new SimpleQueryPlan(queryPlan.toArray());
}

/** Collect all live nodes grouped by DC. */
private Map<String, List<Node>> getAllNodesByDc() {
Map<String, List<Node>> nodesByDc = new HashMap<>();
for (String dc : getLiveNodes().dcs()) {
List<Node> dcNodes = new ArrayList<>();
for (Object obj : getLiveNodes().dc(dc).toArray()) {
dcNodes.add((Node) obj);
}
nodesByDc.put(dc, dcNodes);
}
return nodesByDc;
}

/** Add replicas with local DC first, then remote DCs. */
private void addReplicasByDc(List<Node> queryPlan, List<Node> replicas, String localDc) {
replicas.stream()
.filter(r -> Objects.equals(r.getDatacenter(), localDc))
.forEach(queryPlan::add);
replicas.stream()
.filter(r -> !Objects.equals(r.getDatacenter(), localDc))
.forEach(queryPlan::add);
}

/** Add non-replicas with local DC first, then remote DCs (all rotated). */
private void addNonReplicasByDc(
List<Node> queryPlan,
Map<String, List<Node>> nodesByDc,
List<Node> replicas,
String localDc,
Request request) {
// Local DC non-replicas first
addRotatedNonReplicas(
queryPlan, nodesByDc.getOrDefault(localDc, new ArrayList<>()), replicas, request);
// Remote DC non-replicas
for (Map.Entry<String, List<Node>> entry : nodesByDc.entrySet()) {
if (!Objects.equals(entry.getKey(), localDc)) {
addRotatedNonReplicas(queryPlan, entry.getValue(), replicas, request);
}
}
}

/** Add non-replica nodes from given list with rotation. */
private void addRotatedNonReplicas(
List<Node> queryPlan, List<Node> nodes, List<Node> replicas, Request request) {
List<Node> nonReplicas =
nodes.stream().filter(n -> !replicas.contains(n)).collect(Collectors.toList());
if (!nonReplicas.isEmpty()) {
rotateNonReplicas(nonReplicas, request);
queryPlan.addAll(nonReplicas);
}
}

/** Rotates nodes based on routing key (consistent) or randomly. */
private void rotateNonReplicas(List<Node> nodes, @Nullable Request request) {
if (nodes.size() <= 1) return;

int rotationAmount =
(request != null && request.getRoutingKey() != null)
? Math.abs(request.getRoutingKey().hashCode()) % nodes.size()
: randomNextInt(nodes.size());

if (rotationAmount > 0) {
Collections.rotate(nodes, -rotationAmount);
}
}

@NonNull
protected List<Node> getReplicas(@Nullable Request request, @Nullable Session session) {
if (request == null || session == null) {
Expand Down Expand Up @@ -441,6 +587,11 @@ protected Object[] computeNodes() {
return new CompositeQueryPlan(queryPlans);
}

/** Exposed as a protected method so that it can be accessed by tests */
protected int randomNextInt(int bound) {
return ThreadLocalRandom.current().nextInt(bound);
}

/** Exposed as a protected method so that it can be accessed by tests */
protected void shuffleHead(Object[] currentNodes, int headLength) {
ArrayUtils.shuffleHead(currentNodes, headLength);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.MINUTES;

import com.datastax.oss.driver.api.core.RequestRoutingType;
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
import com.datastax.oss.driver.api.core.context.DriverContext;
Expand Down Expand Up @@ -48,7 +47,6 @@
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLongArray;
import net.jcip.annotations.ThreadSafe;
import org.slf4j.Logger;
Expand Down Expand Up @@ -96,11 +94,6 @@
@ThreadSafe
public class DefaultLoadBalancingPolicy extends BasicLoadBalancingPolicy implements RequestTracker {

public enum RequestRoutingMethod {
REGULAR,
PRESERVE_REPLICA_ORDER
}

private static final Logger LOG = LoggerFactory.getLogger(DefaultLoadBalancingPolicy.class);

private static final long NEWLY_UP_INTERVAL_NANOS = MINUTES.toNanos(1);
Expand All @@ -110,31 +103,14 @@ public enum RequestRoutingMethod {
protected final ConcurrentMap<Node, NodeResponseRateSample> responseTimes;
protected final Map<Node, Long> upTimes = new ConcurrentHashMap<>();
private final boolean avoidSlowReplicas;
private final RequestRoutingMethod lwtRequestRoutingMethod;

public DefaultLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String profileName) {
super(context, profileName);
this.avoidSlowReplicas =
profile.getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true);
this.lwtRequestRoutingMethod = getDefaultLWTRequestRoutingMethod();
this.responseTimes = new MapMaker().weakKeys().makeMap();
}

@NonNull
private RequestRoutingMethod getDefaultLWTRequestRoutingMethod() {
String methodString =
profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD);
try {
return RequestRoutingMethod.valueOf(methodString.toUpperCase());
} catch (IllegalArgumentException e) {
LOG.warn(
"[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER",
logPrefix,
methodString);
return RequestRoutingMethod.PRESERVE_REPLICA_ORDER;
}
}

@NonNull
@Override
public Optional<RequestTracker> getRequestTracker() {
Expand All @@ -151,52 +127,13 @@ protected Optional<String> discoverLocalDc(@NonNull Map<UUID, Node> nodes) {
return new MandatoryLocalDcHelper(context, profile, logPrefix).discoverLocalDc(nodes);
}

@NonNull
public RequestRoutingMethod getDefaultLWTRequestRoutingMethod(@Nullable Request request) {
if (request == null) {
return RequestRoutingMethod.REGULAR;
}
if (request.getRequestRoutingType() == RequestRoutingType.LWT) {
return lwtRequestRoutingMethod;
} else {
return RequestRoutingMethod.REGULAR;
}
}

@NonNull
@Override
public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session session) {
switch (getDefaultLWTRequestRoutingMethod(request)) {
case PRESERVE_REPLICA_ORDER:
return newQueryPlanPreserveReplicas(request, session);
case REGULAR:
default:
return newQueryPlanRegular(request, session);
}
}

/**
* Builds a query plan that preserves the replica order as returned by the load balancing
* strategy, while pushing non-local replicas after local ones.
*/
@NonNull
public Queue<Node> newQueryPlanPreserveReplicas(
@Nullable Request request, @Nullable Session session) {
List<Node> replicas = getReplicas(request, session);
String localDc = getLocalDatacenter();
if (localDc == null || replicas.isEmpty()) {
return new SimpleQueryPlan(replicas.toArray());
}

return new SimpleQueryPlan(moveNonLocalReplicasToTheEnd(replicas, localDc));
}

/**
* Builds a query plan that prioritizes local replicas, shuffles them for balance, and then
* round-robins the remaining local nodes.
*/
@NonNull
public Queue<Node> newQueryPlanRegular(@Nullable Request request, @Nullable Session session) {
@Override
protected Queue<Node> newQueryPlanRegular(@Nullable Request request, @Nullable Session session) {
List<Node> replicas = getReplicas(request, session);
Object[] currentNodes = getLiveNodes().dc(getLocalDatacenter()).toArray();
int replicaCount = 0; // in currentNodes
Expand Down Expand Up @@ -228,26 +165,6 @@ public Queue<Node> newQueryPlanRegular(@Nullable Request request, @Nullable Sess
return maybeAddDcFailover(request, plan);
}

/**
* Returns a replica array with local-datacenter replicas first and remote replicas preserved at
* the end.
*/
private static Object[] moveNonLocalReplicasToTheEnd(List<Node> replicas, String localDc) {
Object[] orderedReplicas = new Object[replicas.size()];
int index = 0;
for (Node replica : replicas) {
if (Objects.equals(replica.getDatacenter(), localDc)) {
orderedReplicas[index++] = replica;
}
}
for (Node replica : replicas) {
if (!Objects.equals(replica.getDatacenter(), localDc)) {
orderedReplicas[index++] = replica;
}
}
return orderedReplicas;
}

private int[] moveReplicasToFront(Object[] currentNodes, List<Node> allReplicas) {
int replicaCount = 0, localRackReplicaCount = 0;
for (int i = 0; i < currentNodes.length; i++) {
Expand Down Expand Up @@ -329,7 +246,7 @@ private void avoidSlowReplicas(
// - the replica in first or second position is the most recent replica marked as UP and
// - dice roll 1d4 != 1
else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1])
&& diceRoll1d4() != 1) {
&& randomNextInt(4) != 1) {

// Send it to the back of the replicas
ArrayUtils.bubbleDown(
Expand Down Expand Up @@ -370,11 +287,6 @@ protected long nanoTime() {
return System.nanoTime();
}

/** Exposed as a protected method so that it can be accessed by tests */
protected int diceRoll1d4() {
return ThreadLocalRandom.current().nextInt(4);
}

protected boolean isUnhealthy(@NonNull Node node, @NonNull Session session, long now) {
return isBusy(node, session) && isResponseRateInsufficient(node, now);
}
Expand Down
Loading