From 5bff3d34807d51a4255cfe1be6a404e58160471e Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Fri, 9 Jan 2026 16:28:58 +0100 Subject: [PATCH 1/2] fix(#384): refactoring `DefaultLoadBalancingPolicy` to include LWT logic --- .../core/graph/BytecodeGraphStatement.java | 8 + .../graph/DefaultBatchGraphStatement.java | 8 + .../graph/DefaultFluentGraphStatement.java | 8 + .../graph/DefaultScriptGraphStatement.java | 7 + .../driver/api/core/RequestRoutingMethod.java | 7 + .../driver/api/core/RequestRoutingType.java | 6 + .../api/core/cql/BatchStatementBuilder.java | 5 +- .../oss/driver/api/core/session/Request.java | 16 ++ .../internal/core/cql/CqlRequestHandler.java | 30 --- .../core/cql/DefaultBatchStatement.java | 61 ++--- .../core/cql/DefaultBoundStatement.java | 7 + .../core/cql/DefaultPrepareRequest.java | 7 + .../core/cql/DefaultSimpleStatement.java | 7 + .../DefaultLoadBalancingPolicy.java | 214 +++++++++--------- .../oss/driver/core/metadata/NodeStateIT.java | 3 +- .../example/guava/internal/KeyRequest.java | 6 + 16 files changed, 241 insertions(+), 159 deletions(-) create mode 100644 core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java create mode 100644 core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java index b6fe05a987c..5e1f2f7a1ea 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java @@ -19,9 +19,11 @@ import com.datastax.dse.driver.api.core.graph.FluentGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.Statement; import com.datastax.oss.driver.api.core.metadata.Node; +import edu.umd.cs.findbugs.annotations.Nullable; import java.nio.ByteBuffer; import java.time.Duration; import java.util.Collections; @@ -127,4 +129,10 @@ protected BytecodeGraphStatement newInstance( readConsistencyLevel, writeConsistencyLevel); } + + @Nullable + @Override + public RequestRoutingType getRequestType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java index e16287c415d..3dc07f21752 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java @@ -19,10 +19,12 @@ import com.datastax.dse.driver.api.core.graph.BatchGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; import java.nio.ByteBuffer; import java.time.Duration; import java.util.Iterator; @@ -151,4 +153,10 @@ protected BatchGraphStatement newInstance( public Iterator iterator() { return this.traversals.iterator(); } + + @Nullable + @Override + public RequestRoutingType getRequestType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java index 0f6f1faabbf..acb1ebba638 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java @@ -19,9 +19,11 @@ import com.datastax.dse.driver.api.core.graph.FluentGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; import java.nio.ByteBuffer; import java.time.Duration; import java.util.Map; @@ -103,4 +105,10 @@ protected FluentGraphStatement newInstance( public GraphTraversal getTraversal() { return traversal; } + + @Nullable + @Override + public RequestRoutingType getRequestType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java index 71f79134237..c495f364d89 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java @@ -19,6 +19,7 @@ import com.datastax.dse.driver.api.core.graph.ScriptGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.protocol.internal.util.collection.NullAllowingImmutableMap; @@ -204,4 +205,10 @@ protected ScriptGraphStatement newInstance( public String toString() { return String.format("ScriptGraphStatement['%s', params: %s]", this.script, this.queryParams); } + + @Nullable + @Override + public RequestRoutingType getRequestType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java new file mode 100644 index 00000000000..205f40b1408 --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java @@ -0,0 +1,7 @@ +package com.datastax.oss.driver.api.core; + +public enum RequestRoutingMethod { + REGULAR, + PRESERVE_REPLICA_ORDER, + TOKEN_BASED_REPLICA_SHUFFLING +} diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java new file mode 100644 index 00000000000..d8f6d6b9d68 --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java @@ -0,0 +1,6 @@ +package com.datastax.oss.driver.api.core; + +public enum RequestRoutingType { + REGULAR, + LWT +} diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java index 26e0aef8ca1..abf3ef0892e 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java @@ -18,6 +18,7 @@ package com.datastax.oss.driver.api.core.cql; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.internal.core.cql.DefaultBatchStatement; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import com.datastax.oss.driver.shaded.guava.common.collect.Iterables; @@ -152,6 +153,8 @@ public BatchStatementBuilder clearStatements() { @NonNull public BatchStatement build() { List> statements = statementsBuilder.build(); + RequestRoutingType routingType = + isLWT != null ? (isLWT ? RequestRoutingType.LWT : RequestRoutingType.REGULAR) : null; return new DefaultBatchStatement( batchType, statements, @@ -172,7 +175,7 @@ public BatchStatement build() { timeout, node, nowInSeconds, - isLWT); + routingType); } public int getStatementsCount() { diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java index 92c25e146c7..8a94f528eb4 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java @@ -25,6 +25,8 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingMethod; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfig; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; @@ -199,4 +201,18 @@ default Partitioner getPartitioner() { /** @return The node configured on this statement, or null if none is configured. */ @Nullable Node getNode(); + + /** + * Returns the routing type configured on this statement. + * + * @return The routing method configured on this statement, or {@link RequestRoutingType#REGULAR} + * if none is configured. + */ + @Nullable + RequestRoutingType getRequestType(); + + @Nullable + default RequestRoutingMethod getRoutingMethod() { + return RequestRoutingMethod.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java index 80eece271a8..4008dd528f0 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java @@ -97,11 +97,9 @@ import java.util.List; import java.util.Map; import java.util.Queue; -import java.util.Set; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; -import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -208,14 +206,6 @@ public void onThrottleReady(boolean wasDelayed) { Queue queryPlan; if (this.initialStatement.getNode() != null) { queryPlan = new SimpleQueryPlan(this.initialStatement.getNode()); - } else if (this.initialStatement.isLWT()) { - queryPlan = - getReplicas( - session.getKeyspace().orElse(null), - this.initialStatement, - context - .getLoadBalancingPolicyWrapper() - .newQueryPlan(initialStatement, executionProfile.getName(), session)); } else { queryPlan = context @@ -226,26 +216,6 @@ public void onThrottleReady(boolean wasDelayed) { sendRequest(initialStatement, null, queryPlan, 0, 0, true); } - private Queue getReplicas( - CqlIdentifier loggedKeyspace, Statement statement, Queue fallback) { - Token routingToken = getRoutingToken(statement); - CqlIdentifier keyspace = statement.getKeyspace(); - if (keyspace == null) { - keyspace = statement.getRoutingKeyspace(); - if (keyspace == null) { - keyspace = loggedKeyspace; - } - } - - TokenMap tokenMap = context.getMetadataManager().getMetadata().getTokenMap().orElse(null); - if (routingToken == null || keyspace == null || tokenMap == null) { - return fallback; - } - - Set replicas = tokenMap.getReplicas(keyspace, routingToken); - return new ConcurrentLinkedQueue<>(replicas); - } - public CompletionStage handle() { return result; } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java index c8cb5b7a084..0447981ef21 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BatchType; @@ -69,7 +70,7 @@ public class DefaultBatchStatement implements BatchStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; - private final Boolean isLWT; + private final RequestRoutingType routingType; public DefaultBatchStatement( BatchType batchType, @@ -91,7 +92,7 @@ public DefaultBatchStatement( Duration timeout, Node node, int nowInSeconds, - Boolean isLWT) { + RequestRoutingType routingType) { for (BatchableStatement statement : statements) { if (statement != null && (statement.getConsistencyLevel() != null @@ -123,7 +124,7 @@ public DefaultBatchStatement( this.timeout = timeout; this.node = node; this.nowInSeconds = nowInSeconds; - this.isLWT = isLWT; + this.routingType = routingType; } @NonNull @@ -155,7 +156,7 @@ public BatchStatement setBatchType(@NonNull BatchType newBatchType) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -181,7 +182,7 @@ public BatchStatement setKeyspace(@Nullable CqlIdentifier newKeyspace) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -211,7 +212,7 @@ public BatchStatement add(@NonNull BatchableStatement statement) { timeout, node, nowInSeconds, - isLWT); + routingType); } } @@ -245,7 +246,7 @@ public BatchStatement addAll(@NonNull Iterable> timeout, node, nowInSeconds, - isLWT); + routingType); } } @@ -277,7 +278,7 @@ public BatchStatement clear() { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -314,7 +315,7 @@ public BatchStatement setPagingState(ByteBuffer newPagingState) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -345,7 +346,7 @@ public BatchStatement setPageSize(int newPageSize) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Nullable @@ -377,7 +378,7 @@ public BatchStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsiste timeout, node, nowInSeconds, - isLWT); + routingType); } @Nullable @@ -410,7 +411,7 @@ public BatchStatement setSerialConsistencyLevel( timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -441,7 +442,7 @@ public BatchStatement setExecutionProfileName(@Nullable String newConfigProfileN timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -472,7 +473,7 @@ public DefaultBatchStatement setExecutionProfile(@Nullable DriverExecutionProfil timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -538,7 +539,7 @@ public BatchStatement setRoutingKeyspace(CqlIdentifier newRoutingKeyspace) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -564,7 +565,7 @@ public BatchStatement setNode(@Nullable Node newNode) { timeout, newNode, nowInSeconds, - isLWT); + routingType); } @Nullable @@ -573,6 +574,12 @@ public Node getNode() { return node; } + @Nullable + @Override + public RequestRoutingType getRequestType() { + return isLWT() ? RequestRoutingType.LWT : RequestRoutingType.REGULAR; + } + @Override public ByteBuffer getRoutingKey() { if (routingKey != null) { @@ -611,7 +618,7 @@ public BatchStatement setRoutingKey(ByteBuffer newRoutingKey) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -652,7 +659,7 @@ public BatchStatement setRoutingToken(Token newRoutingToken) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -684,7 +691,7 @@ public DefaultBatchStatement setCustomPayload(@NonNull Map n timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -721,7 +728,7 @@ public DefaultBatchStatement setIdempotent(Boolean newIdempotence) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -752,7 +759,7 @@ public BatchStatement setTracing(boolean newTracing) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -783,7 +790,7 @@ public BatchStatement setQueryTimestamp(long newTimestamp) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -809,7 +816,7 @@ public BatchStatement setTimeout(@Nullable Duration newTimeout) { newTimeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -840,12 +847,14 @@ public BatchStatement setNowInSeconds(int newNowInSeconds) { timeout, node, newNowInSeconds, - isLWT); + routingType); } @NonNull @Override public BatchStatement setIsLWT(Boolean newIsLWT) { + RequestRoutingType routingType = + newIsLWT != null ? (newIsLWT ? RequestRoutingType.LWT : RequestRoutingType.REGULAR) : null; return new DefaultBatchStatement( batchType, statements, @@ -866,12 +875,12 @@ public BatchStatement setIsLWT(Boolean newIsLWT) { timeout, node, nowInSeconds, - newIsLWT); + routingType); } @Override public boolean isLWT() { - if (isLWT != null) return isLWT; + if (routingType != null) return routingType == RequestRoutingType.LWT; return statements.stream().anyMatch(Statement::isLWT); } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java index 05673692ce9..0856fa4f89c 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java @@ -26,6 +26,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.ColumnDefinitions; @@ -784,4 +785,10 @@ public BoundStatement setNowInSeconds(int newNowInSeconds) { public boolean isLWT() { return this.getPreparedStatement().isLWT(); } + + @Nullable + @Override + public RequestRoutingType getRequestType() { + return isLWT() ? RequestRoutingType.LWT : RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java index 7f87dbe5b51..68e569ab66d 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java @@ -20,6 +20,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.PrepareRequest; import com.datastax.oss.driver.api.core.cql.SimpleStatement; @@ -197,6 +198,12 @@ public Node getNode() { return null; } + @Nullable + @Override + public RequestRoutingType getRequestType() { + return RequestRoutingType.REGULAR; + } + @Override public boolean areBoundStatementsTracing() { return statement.isTracing(); diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java index 0af32b988fe..6387d41c5d8 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.core.metadata.Node; @@ -754,6 +755,12 @@ public boolean isLWT() { return false; } + @Nullable + @Override + public RequestRoutingType getRequestType() { + return isLWT() ? RequestRoutingType.LWT : RequestRoutingType.REGULAR; + } + public static Map wrapKeys(Map namedValues) { NullAllowingImmutableMap.Builder builder = NullAllowingImmutableMap.builder(); diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java index 469e54c56fb..bc145e2ccd4 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java @@ -20,6 +20,8 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; +import com.datastax.oss.driver.api.core.RequestRoutingMethod; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.context.DriverContext; @@ -50,6 +52,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicLongArray; import net.jcip.annotations.ThreadSafe; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -131,113 +134,26 @@ protected Optional discoverLocalDc(@NonNull Map nodes) { @NonNull @Override public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { - if (!avoidSlowReplicas) { - return super.newQueryPlan(request, session); + Set replicas = getReplicas(request, session); + boolean isLWT = Objects.nonNull(request) && request.getRequestType() == RequestRoutingType.LWT; + Object[] currentNodes = + isLWT ? replicas.toArray() : getLiveNodes().dc(getLocalDatacenter()).toArray(); + if (Objects.nonNull(request) + && request.getRoutingMethod() == RequestRoutingMethod.PRESERVE_REPLICA_ORDER) { + return new SimpleQueryPlan(currentNodes); } - // Take a snapshot since the set is concurrent: - Object[] currentNodes = getLiveNodes().dc(getLocalDatacenter()).toArray(); - - Set allReplicas = getReplicas(request, session); int replicaCount = 0; // in currentNodes - int localRackReplicaCount = 0; // in currentNodes - String localRack = getLocalRack(); - - if (!allReplicas.isEmpty()) { - - // Move replicas to the beginning of the plan - // Replicas in local rack should precede other replicas - for (int i = 0; i < currentNodes.length; i++) { - Node node = (Node) currentNodes[i]; - if (allReplicas.contains(node)) { - if (Objects.equals(node.getRack(), localRack) - && Objects.equals(node.getDatacenter(), getLocalDatacenter())) { - ArrayUtils.bubbleUp(currentNodes, i, localRackReplicaCount); - localRackReplicaCount++; - } else { - ArrayUtils.bubbleUp(currentNodes, i, replicaCount); - } - replicaCount++; - } - } + if (!replicas.isEmpty()) { + Pair counts = moveReplicasToFront(currentNodes, replicas); + replicaCount = counts.getLeft(); + int localRackReplicaCount = counts.getRight(); // in currentNodes if (replicaCount > 1) { - if (localRack != null && localRackReplicaCount > 0) { - // Shuffle only replicas that are in the local rack - shuffleHead(currentNodes, localRackReplicaCount); - // Shuffles only replicas that are not in local rack - shuffleInRange(currentNodes, localRackReplicaCount, replicaCount - 1); - } else { - shuffleHead(currentNodes, replicaCount); - } + shuffleLocalRackReplicasAndReplicas(currentNodes, replicaCount, localRackReplicaCount); - if (replicaCount > 2) { - - assert session != null; - - // Test replicas health - Node newestUpReplica = null; - BitSet unhealthyReplicas = null; // bit mask storing indices of unhealthy replicas - long mostRecentUpTimeNanos = -1; - long now = nanoTime(); - for (int i = 0; i < replicaCount; i++) { - Node node = (Node) currentNodes[i]; - assert node != null; - Long upTimeNanos = upTimes.get(node); - if (upTimeNanos != null - && now - upTimeNanos - NEWLY_UP_INTERVAL_NANOS < 0 - && upTimeNanos - mostRecentUpTimeNanos > 0) { - newestUpReplica = node; - mostRecentUpTimeNanos = upTimeNanos; - } - if (newestUpReplica == null && isUnhealthy(node, session, now)) { - if (unhealthyReplicas == null) { - unhealthyReplicas = new BitSet(replicaCount); - } - unhealthyReplicas.set(i); - } - } - - // When: - // - there isn't any newly UP replica and - // - there is one or more unhealthy replicas and - // - there is a majority of healthy replicas - int unhealthyReplicasCount = - unhealthyReplicas == null ? 0 : unhealthyReplicas.cardinality(); - if (newestUpReplica == null - && unhealthyReplicasCount > 0 - && unhealthyReplicasCount < (replicaCount / 2.0)) { - - // Reorder the unhealthy replicas to the back of the list - // Start from the back of the replicas, then move backwards; - // stop once all unhealthy replicas are moved to the back. - int counter = 0; - for (int i = replicaCount - 1; i >= 0 && counter < unhealthyReplicasCount; i--) { - if (unhealthyReplicas.get(i)) { - ArrayUtils.bubbleDown(currentNodes, i, replicaCount - 1 - counter); - counter++; - } - } - } - - // When: - // - there is a newly UP replica and - // - the replica in first or second position is the most recent replica marked as UP and - // - dice roll 1d4 != 1 - else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1]) - && diceRoll1d4() != 1) { - - // Send it to the back of the replicas - ArrayUtils.bubbleDown( - currentNodes, newestUpReplica == currentNodes[0] ? 0 : 1, replicaCount - 1); - } - - // Reorder the first two replicas in the shuffled list based on the number of - // in-flight requests - if (getInFlight((Node) currentNodes[0], session) - > getInFlight((Node) currentNodes[1], session)) { - ArrayUtils.swap(currentNodes, 0, 1); - } + if (replicaCount > 2 && avoidSlowReplicas) { + avoidSlowReplicas(Objects.requireNonNull(session), currentNodes, replicaCount); } } } @@ -255,6 +171,102 @@ > getInFlight((Node) currentNodes[1], session)) { return maybeAddDcFailover(request, plan); } + private Pair moveReplicasToFront(Object[] currentNodes, Set allReplicas) { + int replicaCount = 0, localRackReplicaCount = 0; + for (int i = 0; i < currentNodes.length; i++) { + Node node = (Node) currentNodes[i]; + if (allReplicas.contains(node)) { + if (Objects.equals(node.getRack(), getLocalRack()) + && Objects.equals(node.getDatacenter(), getLocalDatacenter())) { + ArrayUtils.bubbleUp(currentNodes, i, localRackReplicaCount); + localRackReplicaCount++; + } else { + ArrayUtils.bubbleUp(currentNodes, i, replicaCount); + } + replicaCount++; + } + } + return Pair.of(replicaCount, localRackReplicaCount); + } + + private void shuffleLocalRackReplicasAndReplicas( + Object[] currentNodes, int replicaCount, int localRackReplicaCount) { + if (getLocalRack() != null && localRackReplicaCount > 0) { + // Shuffle only replicas that are in the local rack + shuffleHead(currentNodes, localRackReplicaCount); + // Shuffles only replicas that are not in local rack + shuffleInRange(currentNodes, localRackReplicaCount, replicaCount - 1); + } else { + shuffleHead(currentNodes, replicaCount); + } + } + + private void avoidSlowReplicas( + @NonNull Session session, Object[] currentNodes, int replicaCount) { + // Test replicas health + Node newestUpReplica = null; + BitSet unhealthyReplicas = null; // bit mask storing indices of unhealthy replicas + long mostRecentUpTimeNanos = -1; + long now = nanoTime(); + for (int i = 0; i < replicaCount; i++) { + Node node = (Node) currentNodes[i]; + assert node != null; + Long upTimeNanos = upTimes.get(node); + if (upTimeNanos != null + && now - upTimeNanos - NEWLY_UP_INTERVAL_NANOS < 0 + && upTimeNanos - mostRecentUpTimeNanos > 0) { + newestUpReplica = node; + mostRecentUpTimeNanos = upTimeNanos; + } + if (newestUpReplica == null && isUnhealthy(node, session, now)) { + if (unhealthyReplicas == null) { + unhealthyReplicas = new BitSet(replicaCount); + } + unhealthyReplicas.set(i); + } + } + + // When: + // - there isn't any newly UP replica and + // - there is one or more unhealthy replicas and + // - there is a majority of healthy replicas + int unhealthyReplicasCount = unhealthyReplicas == null ? 0 : unhealthyReplicas.cardinality(); + if (newestUpReplica == null + && unhealthyReplicasCount > 0 + && unhealthyReplicasCount < (replicaCount / 2.0)) { + + // Reorder the unhealthy replicas to the back of the list + // Start from the back of the replicas, then move backwards; + // stop once all unhealthy replicas are moved to the back. + int counter = 0; + for (int i = replicaCount - 1; i >= 0 && counter < unhealthyReplicasCount; i--) { + if (unhealthyReplicas.get(i)) { + ArrayUtils.bubbleDown(currentNodes, i, replicaCount - 1 - counter); + counter++; + } + } + } + + // When: + // - there is a newly UP replica and + // - the replica in first or second position is the most recent replica marked as UP and + // - dice roll 1d4 != 1 + else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1]) + && diceRoll1d4() != 1) { + + // Send it to the back of the replicas + ArrayUtils.bubbleDown( + currentNodes, newestUpReplica == currentNodes[0] ? 0 : 1, replicaCount - 1); + } + + // Reorder the first two replicas in the shuffled list based on the number of + // in-flight requests + if (getInFlight((Node) currentNodes[0], session) + > getInFlight((Node) currentNodes[1], session)) { + ArrayUtils.swap(currentNodes, 0, 1); + } + } + @Override public void onNodeSuccess( @NonNull Request request, diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java index e468e0a10d7..dc7590da2ec 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java @@ -57,6 +57,7 @@ import com.datastax.oss.simulacron.server.BoundNode; import com.datastax.oss.simulacron.server.RejectScope; import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; import java.io.IOException; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -703,7 +704,7 @@ public void stopIgnoring(Node node) { @NonNull @Override - public Queue newQueryPlan(@NonNull Request request, @NonNull Session session) { + public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { Object[] snapshot = liveNodes.toArray(); Queue queryPlan = new ConcurrentLinkedQueue<>(); int start = offset.getAndIncrement(); // Note: offset overflow won't be an issue in tests diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java b/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java index ef582cce1b9..83201a9198c 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java @@ -18,6 +18,7 @@ package com.datastax.oss.driver.example.guava.internal; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.metadata.token.Token; @@ -94,4 +95,9 @@ public Duration getTimeout() { public Node getNode() { return null; } + + @Override + public @Nullable RequestRoutingType getRequestType() { + return null; + } } From 7fc1eac3a899a347628bbb3f927936f6ab9051ca Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Fri, 16 Jan 2026 18:40:42 +0100 Subject: [PATCH 2/2] fix(#384): Ignore local racks for LWT statements in `DefaultLoadBalancingPolicy.newQUeryPlan()`. --- .../DefaultLoadBalancingPolicy.java | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java index bc145e2ccd4..d2dd049ac11 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java @@ -135,7 +135,9 @@ protected Optional discoverLocalDc(@NonNull Map nodes) { @Override public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { Set replicas = getReplicas(request, session); - boolean isLWT = Objects.nonNull(request) && request.getRequestType() == RequestRoutingType.LWT; + RequestRoutingType requestType = + Objects.nonNull(request) ? request.getRequestType() : RequestRoutingType.REGULAR; + boolean isLWT = requestType == RequestRoutingType.LWT; Object[] currentNodes = isLWT ? replicas.toArray() : getLiveNodes().dc(getLocalDatacenter()).toArray(); if (Objects.nonNull(request) @@ -145,12 +147,13 @@ public Queue newQueryPlan(@Nullable Request request, @Nullable Session ses int replicaCount = 0; // in currentNodes if (!replicas.isEmpty()) { - Pair counts = moveReplicasToFront(currentNodes, replicas); + Pair counts = moveReplicasToFront(requestType, currentNodes, replicas); replicaCount = counts.getLeft(); int localRackReplicaCount = counts.getRight(); // in currentNodes if (replicaCount > 1) { - shuffleLocalRackReplicasAndReplicas(currentNodes, replicaCount, localRackReplicaCount); + shuffleLocalRackReplicasAndReplicas( + requestType, currentNodes, replicaCount, localRackReplicaCount); if (replicaCount > 2 && avoidSlowReplicas) { avoidSlowReplicas(Objects.requireNonNull(session), currentNodes, replicaCount); @@ -171,13 +174,15 @@ public Queue newQueryPlan(@Nullable Request request, @Nullable Session ses return maybeAddDcFailover(request, plan); } - private Pair moveReplicasToFront(Object[] currentNodes, Set allReplicas) { + private Pair moveReplicasToFront( + RequestRoutingType routingType, Object[] currentNodes, Set allReplicas) { int replicaCount = 0, localRackReplicaCount = 0; for (int i = 0; i < currentNodes.length; i++) { Node node = (Node) currentNodes[i]; if (allReplicas.contains(node)) { if (Objects.equals(node.getRack(), getLocalRack()) - && Objects.equals(node.getDatacenter(), getLocalDatacenter())) { + && Objects.equals(node.getDatacenter(), getLocalDatacenter()) + && routingType != RequestRoutingType.LWT) { ArrayUtils.bubbleUp(currentNodes, i, localRackReplicaCount); localRackReplicaCount++; } else { @@ -190,8 +195,13 @@ private Pair moveReplicasToFront(Object[] currentNodes, Set 0) { + RequestRoutingType routingType, + Object[] currentNodes, + int replicaCount, + int localRackReplicaCount) { + if (routingType != RequestRoutingType.LWT + && getLocalRack() != null + && localRackReplicaCount > 0) { // Shuffle only replicas that are in the local rack shuffleHead(currentNodes, localRackReplicaCount); // Shuffles only replicas that are not in local rack