diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicy.java b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicy.java index 3ce0f7d08d2..5f2a1924968 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicy.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicy.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.context.DriverContext; @@ -63,6 +64,9 @@ import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -71,8 +75,10 @@ import java.util.Queue; import java.util.Set; import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.IntUnaryOperator; +import java.util.stream.Collectors; import net.jcip.annotations.ThreadSafe; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -113,6 +119,11 @@ @ThreadSafe public class BasicLoadBalancingPolicy implements LoadBalancingPolicy { + public enum RequestRoutingMethod { + REGULAR, + PRESERVE_REPLICA_ORDER + } + private static final Logger LOG = LoggerFactory.getLogger(BasicLoadBalancingPolicy.class); protected static final IntUnaryOperator INCREMENT = i -> (i == Integer.MAX_VALUE) ? 0 : i + 1; @@ -127,6 +138,7 @@ public class BasicLoadBalancingPolicy implements LoadBalancingPolicy { private final int maxNodesPerRemoteDc; private final boolean allowDcFailoverForLocalCl; private final ConsistencyLevel defaultConsistencyLevel; + private final RequestRoutingMethod lwtRequestRoutingMethod; // private because they should be set in init() and never be modified after private volatile DistanceReporter distanceReporter; @@ -154,6 +166,34 @@ public BasicLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String new LinkedHashSet<>( profile.getStringList( DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS)); + this.lwtRequestRoutingMethod = parseLwtRequestRoutingMethod(); + } + + @NonNull + private RequestRoutingMethod parseLwtRequestRoutingMethod() { + String methodString = + profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD); + try { + return RequestRoutingMethod.valueOf(methodString.toUpperCase()); + } catch (IllegalArgumentException e) { + LOG.warn( + "[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER", + logPrefix, + methodString); + return RequestRoutingMethod.PRESERVE_REPLICA_ORDER; + } + } + + @NonNull + public RequestRoutingMethod getRequestRoutingMethod(@Nullable Request request) { + if (request == null) { + return RequestRoutingMethod.REGULAR; + } + if (request.getRequestRoutingType() == RequestRoutingType.LWT) { + return lwtRequestRoutingMethod; + } else { + return RequestRoutingMethod.REGULAR; + } } /** @@ -260,6 +300,17 @@ protected NodeDistanceEvaluator createNodeDistanceEvaluator( @NonNull @Override public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { + switch (getRequestRoutingMethod(request)) { + case PRESERVE_REPLICA_ORDER: + return newQueryPlanPreserveReplicas(request, session); + case REGULAR: + default: + return newQueryPlanRegular(request, session); + } + } + + @NonNull + protected Queue newQueryPlanRegular(@Nullable Request request, @Nullable Session session) { // Take a snapshot since the set is concurrent: Object[] currentNodes = liveNodes.dc(localDc).toArray(); @@ -294,6 +345,101 @@ public Queue newQueryPlan(@Nullable Request request, @Nullable Session ses return maybeAddDcFailover(request, plan); } + /** + * Builds a query plan that preserves replica order: local replicas, remote replicas, local + * non-replicas (rotated), remote non-replicas (rotated). + */ + @NonNull + protected Queue newQueryPlanPreserveReplicas( + @Nullable Request request, @Nullable Session session) { + List replicas = getReplicas(request, session); + String localDc = getLocalDatacenter(); + List queryPlan = new ArrayList<>(); + + if (localDc == null) { + // No local DC: all replicas first, then rotated non-replicas + List allNodes = new ArrayList<>(); + for (Object obj : getLiveNodes().dc(null).toArray()) { + allNodes.add((Node) obj); + } + queryPlan.addAll(replicas); + addRotatedNonReplicas(queryPlan, allNodes, replicas, request); + } else { + // With local DC: prioritize local, then remote + Map> nodesByDc = getAllNodesByDc(); + addReplicasByDc(queryPlan, replicas, localDc); + addNonReplicasByDc(queryPlan, nodesByDc, replicas, localDc, request); + } + + return new SimpleQueryPlan(queryPlan.toArray()); + } + + /** Collect all live nodes grouped by DC. */ + private Map> getAllNodesByDc() { + Map> nodesByDc = new HashMap<>(); + for (String dc : getLiveNodes().dcs()) { + List dcNodes = new ArrayList<>(); + for (Object obj : getLiveNodes().dc(dc).toArray()) { + dcNodes.add((Node) obj); + } + nodesByDc.put(dc, dcNodes); + } + return nodesByDc; + } + + /** Add replicas with local DC first, then remote DCs. */ + private void addReplicasByDc(List queryPlan, List replicas, String localDc) { + replicas.stream() + .filter(r -> Objects.equals(r.getDatacenter(), localDc)) + .forEach(queryPlan::add); + replicas.stream() + .filter(r -> !Objects.equals(r.getDatacenter(), localDc)) + .forEach(queryPlan::add); + } + + /** Add non-replicas with local DC first, then remote DCs (all rotated). */ + private void addNonReplicasByDc( + List queryPlan, + Map> nodesByDc, + List replicas, + String localDc, + Request request) { + // Local DC non-replicas first + addRotatedNonReplicas( + queryPlan, nodesByDc.getOrDefault(localDc, new ArrayList<>()), replicas, request); + // Remote DC non-replicas + for (Map.Entry> entry : nodesByDc.entrySet()) { + if (!Objects.equals(entry.getKey(), localDc)) { + addRotatedNonReplicas(queryPlan, entry.getValue(), replicas, request); + } + } + } + + /** Add non-replica nodes from given list with rotation. */ + private void addRotatedNonReplicas( + List queryPlan, List nodes, List replicas, Request request) { + List nonReplicas = + nodes.stream().filter(n -> !replicas.contains(n)).collect(Collectors.toList()); + if (!nonReplicas.isEmpty()) { + rotateNonReplicas(nonReplicas, request); + queryPlan.addAll(nonReplicas); + } + } + + /** Rotates nodes based on routing key (consistent) or randomly. */ + private void rotateNonReplicas(List nodes, @Nullable Request request) { + if (nodes.size() <= 1) return; + + int rotationAmount = + (request != null && request.getRoutingKey() != null) + ? Math.abs(request.getRoutingKey().hashCode()) % nodes.size() + : randomNextInt(nodes.size()); + + if (rotationAmount > 0) { + Collections.rotate(nodes, -rotationAmount); + } + } + @NonNull protected List getReplicas(@Nullable Request request, @Nullable Session session) { if (request == null || session == null) { @@ -441,6 +587,11 @@ protected Object[] computeNodes() { return new CompositeQueryPlan(queryPlans); } + /** Exposed as a protected method so that it can be accessed by tests */ + protected int randomNextInt(int bound) { + return ThreadLocalRandom.current().nextInt(bound); + } + /** Exposed as a protected method so that it can be accessed by tests */ protected void shuffleHead(Object[] currentNodes, int headLength) { ArrayUtils.shuffleHead(currentNodes, headLength); diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java index 2d1a283f657..6b0935a887d 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java @@ -20,7 +20,6 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; -import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.context.DriverContext; @@ -48,7 +47,6 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicLongArray; import net.jcip.annotations.ThreadSafe; import org.slf4j.Logger; @@ -96,11 +94,6 @@ @ThreadSafe public class DefaultLoadBalancingPolicy extends BasicLoadBalancingPolicy implements RequestTracker { - public enum RequestRoutingMethod { - REGULAR, - PRESERVE_REPLICA_ORDER - } - private static final Logger LOG = LoggerFactory.getLogger(DefaultLoadBalancingPolicy.class); private static final long NEWLY_UP_INTERVAL_NANOS = MINUTES.toNanos(1); @@ -110,31 +103,14 @@ public enum RequestRoutingMethod { protected final ConcurrentMap responseTimes; protected final Map upTimes = new ConcurrentHashMap<>(); private final boolean avoidSlowReplicas; - private final RequestRoutingMethod lwtRequestRoutingMethod; public DefaultLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String profileName) { super(context, profileName); this.avoidSlowReplicas = profile.getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true); - this.lwtRequestRoutingMethod = getDefaultLWTRequestRoutingMethod(); this.responseTimes = new MapMaker().weakKeys().makeMap(); } - @NonNull - private RequestRoutingMethod getDefaultLWTRequestRoutingMethod() { - String methodString = - profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD); - try { - return RequestRoutingMethod.valueOf(methodString.toUpperCase()); - } catch (IllegalArgumentException e) { - LOG.warn( - "[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER", - logPrefix, - methodString); - return RequestRoutingMethod.PRESERVE_REPLICA_ORDER; - } - } - @NonNull @Override public Optional getRequestTracker() { @@ -151,52 +127,13 @@ protected Optional discoverLocalDc(@NonNull Map nodes) { return new MandatoryLocalDcHelper(context, profile, logPrefix).discoverLocalDc(nodes); } - @NonNull - public RequestRoutingMethod getDefaultLWTRequestRoutingMethod(@Nullable Request request) { - if (request == null) { - return RequestRoutingMethod.REGULAR; - } - if (request.getRequestRoutingType() == RequestRoutingType.LWT) { - return lwtRequestRoutingMethod; - } else { - return RequestRoutingMethod.REGULAR; - } - } - - @NonNull - @Override - public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { - switch (getDefaultLWTRequestRoutingMethod(request)) { - case PRESERVE_REPLICA_ORDER: - return newQueryPlanPreserveReplicas(request, session); - case REGULAR: - default: - return newQueryPlanRegular(request, session); - } - } - - /** - * Builds a query plan that preserves the replica order as returned by the load balancing - * strategy, while pushing non-local replicas after local ones. - */ - @NonNull - public Queue newQueryPlanPreserveReplicas( - @Nullable Request request, @Nullable Session session) { - List replicas = getReplicas(request, session); - String localDc = getLocalDatacenter(); - if (localDc == null || replicas.isEmpty()) { - return new SimpleQueryPlan(replicas.toArray()); - } - - return new SimpleQueryPlan(moveNonLocalReplicasToTheEnd(replicas, localDc)); - } - /** * Builds a query plan that prioritizes local replicas, shuffles them for balance, and then * round-robins the remaining local nodes. */ @NonNull - public Queue newQueryPlanRegular(@Nullable Request request, @Nullable Session session) { + @Override + protected Queue newQueryPlanRegular(@Nullable Request request, @Nullable Session session) { List replicas = getReplicas(request, session); Object[] currentNodes = getLiveNodes().dc(getLocalDatacenter()).toArray(); int replicaCount = 0; // in currentNodes @@ -228,26 +165,6 @@ public Queue newQueryPlanRegular(@Nullable Request request, @Nullable Sess return maybeAddDcFailover(request, plan); } - /** - * Returns a replica array with local-datacenter replicas first and remote replicas preserved at - * the end. - */ - private static Object[] moveNonLocalReplicasToTheEnd(List replicas, String localDc) { - Object[] orderedReplicas = new Object[replicas.size()]; - int index = 0; - for (Node replica : replicas) { - if (Objects.equals(replica.getDatacenter(), localDc)) { - orderedReplicas[index++] = replica; - } - } - for (Node replica : replicas) { - if (!Objects.equals(replica.getDatacenter(), localDc)) { - orderedReplicas[index++] = replica; - } - } - return orderedReplicas; - } - private int[] moveReplicasToFront(Object[] currentNodes, List allReplicas) { int replicaCount = 0, localRackReplicaCount = 0; for (int i = 0; i < currentNodes.length; i++) { @@ -329,7 +246,7 @@ private void avoidSlowReplicas( // - the replica in first or second position is the most recent replica marked as UP and // - dice roll 1d4 != 1 else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1]) - && diceRoll1d4() != 1) { + && randomNextInt(4) != 1) { // Send it to the back of the replicas ArrayUtils.bubbleDown( @@ -370,11 +287,6 @@ protected long nanoTime() { return System.nanoTime(); } - /** Exposed as a protected method so that it can be accessed by tests */ - protected int diceRoll1d4() { - return ThreadLocalRandom.current().nextInt(4); - } - protected boolean isUnhealthy(@NonNull Node node, @NonNull Session session, long now) { return isBusy(node, session) && isResponseRateInsufficient(node, now); } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyLwtRoutingTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyLwtRoutingTest.java new file mode 100644 index 00000000000..72890cb2741 --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyLwtRoutingTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2020 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.internal.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.RequestRoutingType; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import com.datastax.oss.driver.api.core.context.DriverContext; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap; +import java.util.Queue; +import java.util.UUID; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.Silent.class) +public class BasicLoadBalancingPolicyLwtRoutingTest extends LwtRoutingTestBase { + + @Override + protected BasicLoadBalancingPolicy createPolicy(DriverContext context, String profileName) { + return new BasicLoadBalancingPolicy(context, profileName); + } + + @Test + public void should_handle_null_local_datacenter() { + when(defaultProfile.isDefined(DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER)) + .thenReturn(false); + + BasicLoadBalancingPolicy noDcPolicy = + createPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + noDcPolicy.init( + ImmutableMap.of( + UUID.randomUUID(), node1, + UUID.randomUUID(), node2, + UUID.randomUUID(), node3, + UUID.randomUUID(), node4, + UUID.randomUUID(), node5), + distanceReporter); + + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + Queue plan = noDcPolicy.newQueryPlan(request, session); + + assertThat(plan).hasSize(5); + assertThat(plan.poll()).isEqualTo(node1); + assertThat(plan.poll()).isEqualTo(node2); + assertThat(plan).containsExactlyInAnyOrder(node3, node4, node5); + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyQueryPlanTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyQueryPlanTest.java index 86223bb887f..1e543c0a180 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyQueryPlanTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyQueryPlanTest.java @@ -43,7 +43,7 @@ protected long nanoTime() { } @Override - protected int diceRoll1d4() { + protected int randomNextInt(int bound) { return diceRoll; } }); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java index 1e16aafa5f2..ea784e4277d 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java @@ -23,231 +23,15 @@ */ package com.datastax.oss.driver.internal.core.loadbalancing; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.when; - -import com.datastax.oss.driver.api.core.CqlIdentifier; -import com.datastax.oss.driver.api.core.RequestRoutingType; -import com.datastax.oss.driver.api.core.config.DefaultDriverOption; -import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; -import com.datastax.oss.driver.api.core.metadata.Metadata; -import com.datastax.oss.driver.api.core.metadata.Node; -import com.datastax.oss.driver.api.core.metadata.TokenMap; -import com.datastax.oss.driver.api.core.metadata.token.Token; -import com.datastax.oss.driver.api.core.session.Request; -import com.datastax.oss.driver.internal.core.session.DefaultSession; -import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; -import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap; -import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableSet; -import com.datastax.oss.protocol.internal.util.Bytes; -import java.nio.ByteBuffer; -import java.util.Optional; -import java.util.Queue; -import java.util.UUID; -import org.junit.Before; -import org.junit.Test; +import com.datastax.oss.driver.api.core.context.DriverContext; import org.junit.runner.RunWith; -import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @RunWith(MockitoJUnitRunner.Silent.class) -public class DefaultLoadBalancingPolicyLwtRoutingTest extends LoadBalancingPolicyTestBase { - - private static final CqlIdentifier KEYSPACE = CqlIdentifier.fromInternal("ks"); - private static final ByteBuffer ROUTING_KEY = Bytes.fromHexString("0xdeadbeef"); - - @Mock protected Request request; - @Mock protected DefaultSession session; - @Mock protected Metadata metadata; - @Mock protected TokenMap tokenMap; - @Mock protected Token routingToken; +public class DefaultLoadBalancingPolicyLwtRoutingTest extends LwtRoutingTestBase { - private DefaultLoadBalancingPolicy policy; - - @Before @Override - public void setup() { - super.setup(); - when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); - when(metadataManager.getMetadata()).thenReturn(metadata); - when(metadata.getTokenMap()).thenAnswer(invocation -> Optional.of(this.tokenMap)); - - // Set up nodes with proper DCs - when(node1.getDatacenter()).thenReturn("dc1"); - when(node2.getDatacenter()).thenReturn("dc1"); - when(node3.getDatacenter()).thenReturn("dc1"); - when(node4.getDatacenter()).thenReturn("dc2"); - when(node5.getDatacenter()).thenReturn("dc2"); - - // Configure for PRESERVE_REPLICA_ORDER routing for LWT - when(defaultProfile.getString( - DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) - .thenReturn("PRESERVE_REPLICA_ORDER"); - - policy = new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); - policy.init( - ImmutableMap.of( - UUID.randomUUID(), node1, - UUID.randomUUID(), node2, - UUID.randomUUID(), node3, - UUID.randomUUID(), node4, - UUID.randomUUID(), node5), - distanceReporter); - } - - @Test - public void should_preserve_replica_order_with_empty_replicas() { - // Given - given(request.getRoutingKeyspace()).willReturn(KEYSPACE); - given(request.getRoutingKey()).willReturn(ROUTING_KEY); - given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); - given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)).willReturn(ImmutableList.of()); - - // When - Queue plan = policy.newQueryPlan(request, session); - - // Then - assertThat(plan).isEmpty(); - } - - @Test - public void should_preserve_replica_order_with_single_local_replica() { - // Given - given(request.getRoutingKeyspace()).willReturn(KEYSPACE); - given(request.getRoutingKey()).willReturn(ROUTING_KEY); - given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); - given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) - .willReturn(ImmutableList.of(node2)); - - // When - Queue plan = policy.newQueryPlan(request, session); - - // Then - assertThat(plan).containsExactly(node2); - } - - @Test - public void should_preserve_replica_order_with_multiple_local_replicas() { - // Given - given(request.getRoutingKeyspace()).willReturn(KEYSPACE); - given(request.getRoutingKey()).willReturn(ROUTING_KEY); - given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); - given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) - .willReturn(ImmutableList.of(node3, node1, node2)); - - // When - Queue plan = policy.newQueryPlan(request, session); - - // Then - order preserved exactly as returned from token map - assertThat(plan).containsExactly(node3, node1, node2); - } - - @Test - public void should_push_remote_replicas_to_end() { - // Given - given(request.getRoutingKeyspace()).willReturn(KEYSPACE); - given(request.getRoutingKey()).willReturn(ROUTING_KEY); - given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); - // Token map returns replicas in mixed order: remote, local, remote, local - given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) - .willReturn(ImmutableList.of(node4, node1, node5, node2)); - - // When - Queue plan = policy.newQueryPlan(request, session); - - // Then - local replicas first (preserving their order), remote replicas last (preserving their - // order) - assertThat(plan).containsExactly(node1, node2, node4, node5); - } - - @Test - public void should_preserve_replica_order_with_all_remote_replicas() { - // Given - given(request.getRoutingKeyspace()).willReturn(KEYSPACE); - given(request.getRoutingKey()).willReturn(ROUTING_KEY); - given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); - given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) - .willReturn(ImmutableList.of(node5, node4)); - - // When - Queue plan = policy.newQueryPlan(request, session); - - // Then - all remote replicas, order preserved - assertThat(plan).containsExactly(node5, node4); - } - - @Test - public void should_handle_null_local_datacenter() { - // Given - when(defaultProfile.isDefined(DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER)) - .thenReturn(false); - - given(request.getRoutingKeyspace()).willReturn(KEYSPACE); - given(request.getRoutingKey()).willReturn(ROUTING_KEY); - given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); - given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) - .willReturn(ImmutableList.of(node1, node2)); - - // When - calling with request that might not have local DC set - // The method should handle null localDc gracefully and just return replicas as-is - Queue plan = policy.newQueryPlanPreserveReplicas(request, session); - - // Then - returns all replicas in order when localDc is not defined - assertThat(plan).containsExactly(node1, node2); - } - - @Test - public void should_preserve_order_when_no_routing_key() { - // Given - given(request.getRoutingKeyspace()).willReturn(null); - given(request.getRoutingKey()).willReturn(null); - given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); - - // When - Queue plan = policy.newQueryPlan(request, session); - - // Then - with no routing key, no replicas identified, falls back to empty or default behavior - // This tests the edge case where getReplicas returns empty list - assertThat(plan).isNotNull(); - } - - @Test - public void should_dispatch_to_preserve_replicas_when_lwt_and_config_set() { - // Given - given(request.getRoutingKeyspace()).willReturn(KEYSPACE); - given(request.getRoutingKey()).willReturn(ROUTING_KEY); - given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); - given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) - .willReturn(ImmutableList.of(node1, node2)); - - // When - Queue plan = policy.newQueryPlan(request, session); - - // Then - verify it used preserve replica order (no shuffling) - // Call multiple times to ensure order is always preserved (not shuffled) - Queue plan2 = policy.newQueryPlan(request, session); - Queue plan3 = policy.newQueryPlan(request, session); - - assertThat(plan).containsExactly(node1, node2); - assertThat(plan2).containsExactly(node1, node2); - assertThat(plan3).containsExactly(node1, node2); - } - - @Test - public void should_not_add_non_replicas_in_preserve_mode() { - // Given - given(request.getRoutingKeyspace()).willReturn(KEYSPACE); - given(request.getRoutingKey()).willReturn(ROUTING_KEY); - given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); - // Only node1 is a replica - given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) - .willReturn(ImmutableList.of(node1)); - - // When - Queue plan = policy.newQueryPlan(request, session); - - // Then - only the replica is in the plan, other live nodes are NOT added - assertThat(plan).containsExactly(node1); + protected BasicLoadBalancingPolicy createPolicy(DriverContext context, String profileName) { + return new DefaultLoadBalancingPolicy(context, profileName); } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java index f9445b84d76..0bf7469dfc4 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java @@ -116,7 +116,7 @@ public void should_prioritize_and_shuffle_2_replicas() { then(dsePolicy).should(times(3)).shuffleHead(any(), anyInt()); then(dsePolicy).should(never()).nanoTime(); - then(dsePolicy).should(never()).diceRoll1d4(); + then(dsePolicy).should(never()).randomNextInt(4); } @Test @@ -144,7 +144,7 @@ public void should_prioritize_and_shuffle_3_or_more_replicas_when_all_healthy_an then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); then(dsePolicy).should(times(2)).nanoTime(); - then(dsePolicy).should(never()).diceRoll1d4(); + then(dsePolicy).should(never()).randomNextInt(4); } @Test @@ -172,7 +172,7 @@ public void should_prioritize_and_shuffle_3_or_more_replicas_when_all_healthy_an then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); then(dsePolicy).should(times(2)).nanoTime(); - then(dsePolicy).should(times(2)).diceRoll1d4(); + then(dsePolicy).should(times(2)).randomNextInt(4); } @Test @@ -201,7 +201,7 @@ public void should_prioritize_and_shuffle_3_or_more_replicas_when_all_healthy_an then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); then(dsePolicy).should(times(2)).nanoTime(); - then(dsePolicy).should(times(2)).diceRoll1d4(); + then(dsePolicy).should(times(2)).randomNextInt(4); } @Test @@ -232,7 +232,7 @@ public void should_prioritize_and_shuffle_3_or_more_replicas_when_first_unhealth then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); then(dsePolicy).should(times(2)).nanoTime(); - then(dsePolicy).should(never()).diceRoll1d4(); + then(dsePolicy).should(never()).randomNextInt(4); } @Test @@ -263,7 +263,7 @@ public void should_prioritize_and_shuffle_3_or_more_replicas_when_first_unhealth then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); then(dsePolicy).should(times(2)).nanoTime(); - then(dsePolicy).should(never()).diceRoll1d4(); + then(dsePolicy).should(never()).randomNextInt(4); } @Test @@ -289,7 +289,7 @@ public void should_prioritize_and_shuffle_3_or_more_replicas_when_last_unhealthy then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); then(dsePolicy).should(times(2)).nanoTime(); - then(dsePolicy).should(never()).diceRoll1d4(); + then(dsePolicy).should(never()).randomNextInt(4); } @Test @@ -315,7 +315,7 @@ public void should_prioritize_and_shuffle_3_or_more_replicas_when_majority_unhea then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); then(dsePolicy).should(times(2)).nanoTime(); - then(dsePolicy).should(never()).diceRoll1d4(); + then(dsePolicy).should(never()).randomNextInt(4); } @Test @@ -340,7 +340,7 @@ public void should_reorder_first_two_replicas_when_first_has_more_in_flight_than then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); then(dsePolicy).should(times(2)).nanoTime(); - then(dsePolicy).should(never()).diceRoll1d4(); + then(dsePolicy).should(never()).randomNextInt(4); } @Test @@ -475,7 +475,7 @@ protected long nanoTime() { } @Override - protected int diceRoll1d4() { + protected int randomNextInt(int bound) { return diceRoll; } }); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java index 9aef1825329..4877659092f 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java @@ -36,7 +36,7 @@ import com.datastax.oss.driver.api.core.metadata.TokenMap; import com.datastax.oss.driver.api.core.metadata.token.Token; import com.datastax.oss.driver.api.core.session.Request; -import com.datastax.oss.driver.internal.core.loadbalancing.DefaultLoadBalancingPolicy.RequestRoutingMethod; +import com.datastax.oss.driver.internal.core.loadbalancing.BasicLoadBalancingPolicy.RequestRoutingMethod; import com.datastax.oss.driver.internal.core.session.DefaultSession; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap; @@ -99,7 +99,7 @@ public void should_return_regular_when_request_is_null() { initPolicy("REGULAR"); // When - RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(null); + RequestRoutingMethod method = policy.getRequestRoutingMethod(null); // Then assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); @@ -112,7 +112,7 @@ public void should_return_regular_when_routing_type_is_regular() { given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); // When - RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + RequestRoutingMethod method = policy.getRequestRoutingMethod(request); // Then assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); @@ -125,7 +125,7 @@ public void should_return_regular_for_lwt_when_config_is_regular() { given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); // When - RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + RequestRoutingMethod method = policy.getRequestRoutingMethod(request); // Then assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); @@ -138,7 +138,7 @@ public void should_return_preserve_replica_order_for_lwt_when_config_is_preserve given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); // When - RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + RequestRoutingMethod method = policy.getRequestRoutingMethod(request); // Then assertThat(method).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); @@ -180,10 +180,10 @@ public void should_dispatch_to_preserve_query_plan_when_lwt_and_config_preserve( Queue plan2 = policy.newQueryPlan(request, session); Queue plan3 = policy.newQueryPlan(request, session); - // Then - preserve routing maintains exact order - assertThat(plan1).containsExactly(node2, node1); - assertThat(plan2).containsExactly(node2, node1); - assertThat(plan3).containsExactly(node2, node1); + // Then - preserve routing maintains replica order, non-replicas follow + assertThat(plan1).containsExactly(node2, node1, node3); + assertThat(plan2).containsExactly(node2, node1, node3); + assertThat(plan3).containsExactly(node2, node1, node3); } @Test @@ -228,7 +228,7 @@ public void should_use_regular_routing_for_unknown_routing_type() { .willReturn(ImmutableList.of(node1)); // When - RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + RequestRoutingMethod method = policy.getRequestRoutingMethod(request); // Then - defaults to REGULAR for any unrecognized type assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); @@ -245,9 +245,9 @@ public void should_consistently_route_same_request_type() { .willReturn(ImmutableList.of(node1, node2, node3)); // When - call multiple times - RequestRoutingMethod method1 = policy.getDefaultLWTRequestRoutingMethod(request); - RequestRoutingMethod method2 = policy.getDefaultLWTRequestRoutingMethod(request); - RequestRoutingMethod method3 = policy.getDefaultLWTRequestRoutingMethod(request); + RequestRoutingMethod method1 = policy.getRequestRoutingMethod(request); + RequestRoutingMethod method2 = policy.getRequestRoutingMethod(request); + RequestRoutingMethod method3 = policy.getRequestRoutingMethod(request); // Then - should always return the same method assertThat(method1).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LwtRoutingTestBase.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LwtRoutingTestBase.java new file mode 100644 index 00000000000..0960eef4a4b --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LwtRoutingTestBase.java @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2020 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.internal.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import com.datastax.oss.driver.api.core.context.DriverContext; +import com.datastax.oss.driver.api.core.metadata.Metadata; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.TokenMap; +import com.datastax.oss.driver.api.core.metadata.token.Token; +import com.datastax.oss.driver.api.core.session.Request; +import com.datastax.oss.driver.internal.core.session.DefaultSession; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableSet; +import com.datastax.oss.protocol.internal.util.Bytes; +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.Queue; +import java.util.UUID; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; + +/** + * Abstract base for testing LWT preserve-replica-order routing on both {@link + * BasicLoadBalancingPolicy} and {@link DefaultLoadBalancingPolicy}. + */ +public abstract class LwtRoutingTestBase extends LoadBalancingPolicyTestBase { + + protected static final CqlIdentifier KEYSPACE = CqlIdentifier.fromInternal("ks"); + protected static final ByteBuffer ROUTING_KEY = Bytes.fromHexString("0xdeadbeef"); + + @Mock protected Request request; + @Mock protected DefaultSession session; + @Mock protected Metadata metadata; + @Mock protected TokenMap tokenMap; + @Mock protected Token routingToken; + + protected BasicLoadBalancingPolicy policy; + + protected abstract BasicLoadBalancingPolicy createPolicy( + DriverContext context, String profileName); + + @Before + @Override + public void setup() { + super.setup(); + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + when(metadataManager.getMetadata()).thenReturn(metadata); + when(metadata.getTokenMap()).thenAnswer(invocation -> Optional.of(this.tokenMap)); + + // Enable remote DC nodes + when(defaultProfile.getInt( + DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_MAX_NODES_PER_REMOTE_DC)) + .thenReturn(2); + + // Configure for PRESERVE_REPLICA_ORDER routing for LWT + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn("PRESERVE_REPLICA_ORDER"); + + // Set up nodes with proper DCs + when(node1.getDatacenter()).thenReturn("dc1"); + when(node2.getDatacenter()).thenReturn("dc1"); + when(node3.getDatacenter()).thenReturn("dc1"); + when(node4.getDatacenter()).thenReturn("dc2"); + when(node5.getDatacenter()).thenReturn("dc2"); + + policy = createPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + policy.init( + ImmutableMap.of( + UUID.randomUUID(), node1, + UUID.randomUUID(), node2, + UUID.randomUUID(), node3, + UUID.randomUUID(), node4, + UUID.randomUUID(), node5), + distanceReporter); + } + + @Test + public void should_fallback_to_all_nodes_when_empty_replicas() { + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)).willReturn(ImmutableList.of()); + + Queue plan = policy.newQueryPlan(request, session); + + assertThat(plan).hasSize(5); + assertThat(plan).containsExactlyInAnyOrder(node1, node2, node3, node4, node5); + } + + @Test + public void should_preserve_replica_order_with_single_local_replica() { + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node2)); + + Queue plan = policy.newQueryPlan(request, session); + + assertThat(plan).hasSize(5); + assertThat(plan.poll()).isEqualTo(node2); + assertThat(plan).containsExactlyInAnyOrder(node1, node3, node4, node5); + } + + @Test + public void should_preserve_replica_order_with_multiple_local_replicas() { + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node3, node1, node2)); + + Queue plan = policy.newQueryPlan(request, session); + + assertThat(plan).hasSize(5); + assertThat(plan.poll()).isEqualTo(node3); + assertThat(plan.poll()).isEqualTo(node1); + assertThat(plan.poll()).isEqualTo(node2); + assertThat(plan).containsExactlyInAnyOrder(node4, node5); + } + + @Test + public void should_push_remote_replicas_to_end() { + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node4, node1, node5, node2)); + + Queue plan = policy.newQueryPlan(request, session); + + assertThat(plan).hasSize(5); + assertThat(plan.poll()).isEqualTo(node1); // local replica + assertThat(plan.poll()).isEqualTo(node2); // local replica + assertThat(plan.poll()).isEqualTo(node4); // remote replica + assertThat(plan.poll()).isEqualTo(node5); // remote replica + assertThat(plan.poll()).isEqualTo(node3); // local non-replica + } + + @Test + public void should_preserve_replica_order_with_all_remote_replicas() { + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node5, node4)); + + Queue plan = policy.newQueryPlan(request, session); + + assertThat(plan).hasSize(5); + assertThat(plan.poll()).isEqualTo(node5); + assertThat(plan.poll()).isEqualTo(node4); + assertThat(plan).containsExactlyInAnyOrder(node1, node2, node3); + } + + @Test + public void should_preserve_order_when_no_routing_key() { + given(request.getRoutingKeyspace()).willReturn(null); + given(request.getRoutingKey()).willReturn(null); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); + + Queue plan = policy.newQueryPlan(request, session); + + assertThat(plan).isNotNull(); + } + + @Test + public void should_dispatch_to_preserve_replicas_when_lwt_and_config_set() { + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + Queue plan1 = policy.newQueryPlan(request, session); + Queue plan2 = policy.newQueryPlan(request, session); + Queue plan3 = policy.newQueryPlan(request, session); + + assertThat(plan1).hasSize(5); + assertThat(plan2).hasSize(5); + assertThat(plan3).hasSize(5); + + Node[] plan1Array = plan1.toArray(new Node[0]); + Node[] plan2Array = plan2.toArray(new Node[0]); + Node[] plan3Array = plan3.toArray(new Node[0]); + + assertThat(plan1Array[0]).isEqualTo(node1); + assertThat(plan1Array[1]).isEqualTo(node2); + assertThat(plan2Array[0]).isEqualTo(node1); + assertThat(plan2Array[1]).isEqualTo(node2); + assertThat(plan3Array[0]).isEqualTo(node1); + assertThat(plan3Array[1]).isEqualTo(node2); + } + + @Test + public void should_add_non_replicas_after_replicas_in_preserve_mode() { + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1)); + + Queue plan = policy.newQueryPlan(request, session); + + assertThat(plan).hasSize(5); + assertThat(plan.poll()).isEqualTo(node1); + assertThat(plan).containsExactlyInAnyOrder(node2, node3, node4, node5); + } + + @Test + public void should_fallback_to_all_live_nodes_when_lwt_has_no_routing_info() { + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(request.getKeyspace()).willReturn(null); + given(request.getRoutingKeyspace()).willReturn(null); + given(request.getRoutingKey()).willReturn(null); + given(request.getRoutingToken()).willReturn(null); + + Queue plan = policy.newQueryPlan(request, session); + + assertThat(plan).hasSize(5); + assertThat(plan).containsExactlyInAnyOrder(node1, node2, node3, node4, node5); + } + + @Test + public void + should_maintain_node_priority_order_local_replicas_then_remote_then_local_non_replicas() { + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node2, node5)); + + Queue plan = policy.newQueryPlan(request, session); + + assertThat(plan).hasSize(5); + assertThat(plan.poll()).isEqualTo(node2); // local replica + assertThat(plan.poll()).isEqualTo(node5); // remote replica + assertThat(plan).containsExactlyInAnyOrder(node1, node3, node4); + } + + @Test + public void should_rotate_non_replicas_with_controlled_randomness() { + // Put all nodes in dc1 so we have 3 non-replicas for controlled rotation + when(node4.getDatacenter()).thenReturn("dc1"); + + BasicLoadBalancingPolicy spyPolicy = + spy(createPolicy(context, DriverExecutionProfile.DEFAULT_NAME)); + spyPolicy.init( + ImmutableMap.of( + UUID.randomUUID(), node1, + UUID.randomUUID(), node2, + UUID.randomUUID(), node3, + UUID.randomUUID(), node4), + distanceReporter); + + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(null); // null key = random rotation + given(request.getRoutingToken()).willReturn(routingToken); // token for replica lookup + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, routingToken)).willReturn(ImmutableList.of(node1)); + + doReturn(0).when(spyPolicy).randomNextInt(3); + Queue plan1 = spyPolicy.newQueryPlan(request, session); + + doReturn(1).when(spyPolicy).randomNextInt(3); + Queue plan2 = spyPolicy.newQueryPlan(request, session); + + doReturn(2).when(spyPolicy).randomNextInt(3); + Queue plan3 = spyPolicy.newQueryPlan(request, session); + + Node[] plan1Array = plan1.toArray(new Node[0]); + Node[] plan2Array = plan2.toArray(new Node[0]); + Node[] plan3Array = plan3.toArray(new Node[0]); + + assertThat(plan1Array[0]).isEqualTo(node1); + assertThat(plan2Array[0]).isEqualTo(node1); + assertThat(plan3Array[0]).isEqualTo(node1); + + assertThat(plan1Array).isNotEqualTo(plan2Array); + assertThat(plan2Array).isNotEqualTo(plan3Array); + + assertThat(plan1).hasSize(4); + assertThat(plan1).containsExactlyInAnyOrder(plan2Array); + assertThat(plan1).containsExactlyInAnyOrder(plan3Array); + } + + @Test + public void should_rotate_non_replicas_consistently_when_routing_key_present() { + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1)); + + Queue plan1 = policy.newQueryPlan(request, session); + Queue plan2 = policy.newQueryPlan(request, session); + Queue plan3 = policy.newQueryPlan(request, session); + + assertThat(plan1).containsExactly(plan2.toArray(new Node[0])); + assertThat(plan1).containsExactly(plan3.toArray(new Node[0])); + assertThat(plan1).hasSize(5); + assertThat(plan1.poll()).isEqualTo(node1); + } +}