From 6c6919d238df850a6eb0e39728a68a9f453fe05b Mon Sep 17 00:00:00 2001
From: Bouncheck <36934780+Bouncheck@users.noreply.github.com>
Date: Mon, 14 Apr 2025 11:59:20 +0200
Subject: [PATCH 1/4] Add config options for advanced shard awareness
Adds config options for toggling the feature on and defining
port ranges to use.
By default the feature will be enabled.
---
.../api/core/config/DefaultDriverOption.java | 11 +++++
.../driver/api/core/config/OptionsMap.java | 3 ++
.../api/core/config/TypedDriverOption.java | 12 +++++
core/src/main/resources/reference.conf | 46 +++++++++++++++++++
4 files changed, 72 insertions(+)
diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java
index 7ff3b2719ba..a16c04c6c1e 100644
--- a/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java
+++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java
@@ -141,6 +141,17 @@ public enum DefaultDriverOption implements DriverOption {
*
Value-type: boolean
*/
CONNECTION_WARN_INIT_ERROR("advanced.connection.warn-on-init-error"),
+ /**
+ * Whether to use advanced shard awareness.
+ *
+ *
Value-type: boolean
+ */
+ CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED(
+ "advanced.connection.advanced-shard-awareness.enabled"),
+ /** Inclusive lower bound of port range to use in advanced shard awareness */
+ ADVANCED_SHARD_AWARENESS_PORT_LOW("advanced.connection.advanced-shard-awareness.port-low"),
+ /** Inclusive upper bound of port range to use in advanced shard awareness */
+ ADVANCED_SHARD_AWARENESS_PORT_HIGH("advanced.connection.advanced-shard-awareness.port-high"),
/**
* The number of connections in the LOCAL pool.
*
diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java
index 64665709028..bc5238908be 100644
--- a/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java
+++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java
@@ -276,6 +276,9 @@ protected static void fillWithDriverDefaults(OptionsMap map) {
map.put(TypedDriverOption.CONNECTION_MAX_REQUESTS, 1024);
map.put(TypedDriverOption.CONNECTION_MAX_ORPHAN_REQUESTS, 256);
map.put(TypedDriverOption.CONNECTION_WARN_INIT_ERROR, true);
+ map.put(TypedDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true);
+ map.put(TypedDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000);
+ map.put(TypedDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 65535);
map.put(TypedDriverOption.RECONNECT_ON_INIT, false);
map.put(TypedDriverOption.RECONNECTION_POLICY_CLASS, "ExponentialReconnectionPolicy");
map.put(TypedDriverOption.RECONNECTION_BASE_DELAY, Duration.ofSeconds(1));
diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java
index b6051749c71..a37236a0e0f 100644
--- a/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java
+++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java
@@ -175,6 +175,18 @@ public String toString() {
/** Whether to log non-fatal errors when the driver tries to open a new connection. */
public static final TypedDriverOption CONNECTION_WARN_INIT_ERROR =
new TypedDriverOption<>(DefaultDriverOption.CONNECTION_WARN_INIT_ERROR, GenericType.BOOLEAN);
+ /** Whether to use advanced shard awareness */
+ public static final TypedDriverOption CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED =
+ new TypedDriverOption<>(
+ DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, GenericType.BOOLEAN);
+ /** Inclusive lower bound of port range to use in advanced shard awareness */
+ public static final TypedDriverOption ADVANCED_SHARD_AWARENESS_PORT_LOW =
+ new TypedDriverOption<>(
+ DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, GenericType.INTEGER);
+ /** Inclusive upper bound of port range to use in advanced shard awareness */
+ public static final TypedDriverOption ADVANCED_SHARD_AWARENESS_PORT_HIGH =
+ new TypedDriverOption<>(
+ DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, GenericType.INTEGER);
/** The number of connections in the LOCAL pool. */
public static final TypedDriverOption CONNECTION_POOL_LOCAL_SIZE =
new TypedDriverOption<>(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, GenericType.INTEGER);
diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf
index 1ac37d14132..e29eb5e57b5 100644
--- a/core/src/main/resources/reference.conf
+++ b/core/src/main/resources/reference.conf
@@ -535,6 +535,52 @@ datastax-java-driver {
# change.
# Overridable in a profile: no
warn-on-init-error = true
+
+
+ advanced-shard-awareness {
+ # Whether to use advanced shard awareness when trying to open new connections.
+ #
+ # Requires passing shard-aware port as contact point (usually 19042 or 19142(ssl)).
+ # Having this enabled makes sense only for ScyllaDB clusters.
+ #
+ # Reduces number of reconnections driver needs to fully initialize connection pool.
+ # In short it's a feature that allows targeting particular shard when connecting to a node by using specific
+ # local port number.
+ # For context see https://www.scylladb.com/2021/04/27/connect-faster-to-scylla-with-a-shard-aware-port/
+ #
+ # If set to false the driver will not attempt to use this feature. This means connection's local port
+ # will be random according to system rules and driver will keep opening connections until it gets right shards.
+ # In such case non-shard aware port is recommended (by default 9042 or 9142).
+ # If set to true the driver will attempt to use it and will log warnings each time something
+ # makes it not possible.
+ #
+ # If the node for some reason does not report it's sharding info the driver
+ # will log a warning and create connection the same way as if this feature was disabled.
+ # If the cluster ignores the request for specific shard warning will also be logged,
+ # although the local port will already be chosen according to advanced shard awareness rules.
+ #
+ # Required: yes
+ # Modifiable at runtime: yes, the new value will be used for connections created after the
+ # change.
+ # Overridable in a profile: no
+ enabled = true
+
+ # Inclusive lower bound of port range to use in advanced shard awareness
+ # The driver will attempt to reserve ports for connection only within the range.
+ # Required: yes
+ # Modifiable at runtime: yes, the new value will be used for calls after the
+ # change.
+ # Overridable in a profile: no
+ port-low = 10000
+
+ # Inclusive upper bound of port range to use in advanced shard awareness.
+ # The driver will attempt to reserve ports for connection only within the range.
+ # Required: yes
+ # Modifiable at runtime: yes, the new value will be used for calls after the
+ # change.
+ # Overridable in a profile: no
+ port-high = 65535
+ }
}
# Advanced options for the built-in load-balancing policies.
From eca9460dbe7e37bc049a5791b37c50f22e6268d3 Mon Sep 17 00:00:00 2001
From: Bouncheck <36934780+Bouncheck@users.noreply.github.com>
Date: Mon, 14 Apr 2025 11:59:51 +0200
Subject: [PATCH 2/4] Copy 3.x PortAllocator
---
.../internal/core/channel/ChannelFactory.java | 70 +++++++++++++++++++
1 file changed, 70 insertions(+)
diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java
index 02898a1fedd..40db6e6aad5 100644
--- a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java
+++ b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java
@@ -51,6 +51,9 @@
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -58,6 +61,7 @@
import java.util.concurrent.CompletionStage;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
import net.jcip.annotations.ThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -391,4 +395,70 @@ protected void initChannel(Channel channel) {
}
}
}
+
+ static class PortAllocator {
+ private static final AtomicInteger lastPort = new AtomicInteger(-1);
+
+ public static int getNextAvailablePort(int shardCount, int shardId, int lowPort, int highPort) {
+ int lastPortValue, foundPort = -1;
+ do {
+ lastPortValue = lastPort.get();
+
+ // We will scan from lastPortValue
+ // (or lowPort is there was no lastPort or lastPort is too low)
+ int scanStart = lastPortValue == -1 ? lowPort : lastPortValue;
+ if (scanStart < lowPort) {
+ scanStart = lowPort;
+ }
+
+ // Round it up to "% shardCount == shardId"
+ scanStart += (shardCount - scanStart % shardCount) + shardId;
+
+ // Scan from scanStart upwards to highPort.
+ for (int port = scanStart; port <= highPort; port += shardCount) {
+ if (isTcpPortAvailable(port)) {
+ foundPort = port;
+ break;
+ }
+ }
+
+ // If we started scanning from a high scanStart port
+ // there might have been not enough ports left that are
+ // smaller than highPort. Scan from the beginning
+ // from the lowPort.
+ if (foundPort == -1) {
+ scanStart = lowPort + (shardCount - lowPort % shardCount) + shardId;
+
+ for (int port = scanStart; port <= highPort; port += shardCount) {
+ if (isTcpPortAvailable(port)) {
+ foundPort = port;
+ break;
+ }
+ }
+ }
+
+ // No luck! All ports taken!
+ if (foundPort == -1) {
+ return -1;
+ }
+ } while (!lastPort.compareAndSet(lastPortValue, foundPort));
+
+ return foundPort;
+ }
+
+ public static boolean isTcpPortAvailable(int port) {
+ try {
+ ServerSocket serverSocket = new ServerSocket();
+ try {
+ serverSocket.setReuseAddress(false);
+ serverSocket.bind(new InetSocketAddress(port), 1);
+ return true;
+ } finally {
+ serverSocket.close();
+ }
+ } catch (IOException ex) {
+ return false;
+ }
+ }
+ }
}
From 5e7ee8747861fed1a5c011e8f5077736ca8c8d1f Mon Sep 17 00:00:00 2001
From: Bouncheck <36934780+Bouncheck@users.noreply.github.com>
Date: Mon, 14 Apr 2025 12:00:19 +0200
Subject: [PATCH 3/4] Extend ChannelFactory methods with shardId params
With advanced shard awareness target shard is a necessary parameter
for connecting. Extends non public methods with new parameters and
provides one additional public method.
Passing `null` as shard id or ShardingInfo will lead to previous non-shard
aware behavior.
---
.../internal/core/channel/ChannelFactory.java | 54 +++++++++++++++++--
.../ChannelFactoryAvailableIdsTest.java | 6 ++-
.../ChannelFactoryClusterNameTest.java | 24 +++++++--
...ChannelFactoryProtocolNegotiationTest.java | 36 ++++++++++---
.../ChannelFactorySupportedOptionsTest.java | 12 ++++-
5 files changed, 115 insertions(+), 17 deletions(-)
diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java
index 40db6e6aad5..cdd82314144 100644
--- a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java
+++ b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java
@@ -29,8 +29,10 @@
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverConfig;
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
+import com.datastax.oss.driver.api.core.context.DriverContext;
import com.datastax.oss.driver.api.core.metadata.EndPoint;
import com.datastax.oss.driver.api.core.metadata.Node;
+import com.datastax.oss.driver.api.core.metadata.NodeShardingInfo;
import com.datastax.oss.driver.api.core.metrics.DefaultNodeMetric;
import com.datastax.oss.driver.api.core.metrics.DefaultSessionMetric;
import com.datastax.oss.driver.internal.core.config.typesafe.TypesafeDriverConfig;
@@ -157,12 +159,27 @@ public CompletionStage connect(Node node, DriverChannelOptions op
} else {
nodeMetricUpdater = NoopNodeMetricUpdater.INSTANCE;
}
- return connect(node.getEndPoint(), options, nodeMetricUpdater);
+ return connect(node.getEndPoint(), null, null, options, nodeMetricUpdater);
+ }
+
+ public CompletionStage connect(
+ Node node, Integer shardId, DriverChannelOptions options) {
+ NodeMetricUpdater nodeMetricUpdater;
+ if (node instanceof DefaultNode) {
+ nodeMetricUpdater = ((DefaultNode) node).getMetricUpdater();
+ } else {
+ nodeMetricUpdater = NoopNodeMetricUpdater.INSTANCE;
+ }
+ return connect(node.getEndPoint(), node.getShardingInfo(), shardId, options, nodeMetricUpdater);
}
@VisibleForTesting
CompletionStage connect(
- EndPoint endPoint, DriverChannelOptions options, NodeMetricUpdater nodeMetricUpdater) {
+ EndPoint endPoint,
+ NodeShardingInfo shardingInfo,
+ Integer shardId,
+ DriverChannelOptions options,
+ NodeMetricUpdater nodeMetricUpdater) {
CompletableFuture resultFuture = new CompletableFuture<>();
ProtocolVersion currentVersion;
@@ -178,6 +195,8 @@ CompletionStage connect(
connect(
endPoint,
+ shardingInfo,
+ shardId,
options,
nodeMetricUpdater,
currentVersion,
@@ -189,6 +208,8 @@ CompletionStage connect(
private void connect(
EndPoint endPoint,
+ NodeShardingInfo shardingInfo,
+ Integer shardId,
DriverChannelOptions options,
NodeMetricUpdater nodeMetricUpdater,
ProtocolVersion currentVersion,
@@ -208,7 +229,20 @@ private void connect(
nettyOptions.afterBootstrapInitialized(bootstrap);
- ChannelFuture connectFuture = bootstrap.connect(endPoint.resolve());
+ ChannelFuture connectFuture;
+ if (shardId == null || shardingInfo == null) {
+ if (shardId != null) {
+ LOG.debug(
+ "Requested connection to shard {} but shardingInfo is currently missing for Node at endpoint {}. Falling back to arbitrary local port.",
+ shardId,
+ endPoint);
+ }
+ connectFuture = bootstrap.connect(endPoint.resolve());
+ } else {
+ int localPort =
+ PortAllocator.getNextAvailablePort(shardingInfo.getShardsCount(), shardId, context);
+ connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort));
+ }
connectFuture.addListener(
cf -> {
@@ -257,6 +291,8 @@ private void connect(
downgraded.get());
connect(
endPoint,
+ shardingInfo,
+ shardId,
options,
nodeMetricUpdater,
downgraded.get(),
@@ -399,7 +435,17 @@ protected void initChannel(Channel channel) {
static class PortAllocator {
private static final AtomicInteger lastPort = new AtomicInteger(-1);
- public static int getNextAvailablePort(int shardCount, int shardId, int lowPort, int highPort) {
+ public static int getNextAvailablePort(int shardCount, int shardId, DriverContext context) {
+ int lowPort =
+ context
+ .getConfig()
+ .getDefaultProfile()
+ .getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW);
+ int highPort =
+ context
+ .getConfig()
+ .getDefaultProfile()
+ .getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
int lastPortValue, foundPort = -1;
do {
lastPortValue = lastPort.get();
diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryAvailableIdsTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryAvailableIdsTest.java
index a1eab41b998..d219bc3c73b 100644
--- a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryAvailableIdsTest.java
+++ b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryAvailableIdsTest.java
@@ -61,7 +61,11 @@ public void should_report_available_ids() {
// When
CompletionStage channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.builder().build(), NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.builder().build(),
+ NoopNodeMetricUpdater.INSTANCE);
completeSimpleChannelInit();
// Then
diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryClusterNameTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryClusterNameTest.java
index d9793247c9c..4734a8ffdeb 100644
--- a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryClusterNameTest.java
+++ b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryClusterNameTest.java
@@ -41,7 +41,11 @@ public void should_set_cluster_name_from_first_connection() {
// When
CompletionStage channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
writeInboundFrame(
readOutboundFrame(), TestResponses.supportedResponse("mock_key", "mock_value"));
@@ -63,7 +67,11 @@ public void should_check_cluster_name_for_next_connections() throws Throwable {
// When
CompletionStage channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
// open a first connection that will define the cluster name
writeInboundFrame(
readOutboundFrame(), TestResponses.supportedResponse("mock_key", "mock_value"));
@@ -73,7 +81,11 @@ public void should_check_cluster_name_for_next_connections() throws Throwable {
// open a second connection that returns the same cluster name
channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
writeInboundFrame(readOutboundFrame(), new Ready());
writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("mockClusterName"));
@@ -84,7 +96,11 @@ public void should_check_cluster_name_for_next_connections() throws Throwable {
// open a third connection that returns a different cluster name
channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
writeInboundFrame(readOutboundFrame(), new Ready());
writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("wrongClusterName"));
diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryProtocolNegotiationTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryProtocolNegotiationTest.java
index b9738a140c0..fceb8777904 100644
--- a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryProtocolNegotiationTest.java
+++ b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryProtocolNegotiationTest.java
@@ -50,7 +50,11 @@ public void should_succeed_if_version_specified_and_supported_by_server() {
// When
CompletionStage channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
completeSimpleChannelInit();
@@ -72,7 +76,11 @@ public void should_fail_if_version_specified_and_not_supported_by_server(int err
// When
CompletionStage channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
Frame requestFrame = readOutboundFrame();
assertThat(requestFrame.message).isInstanceOf(Options.class);
@@ -107,7 +115,11 @@ public void should_fail_if_version_specified_and_considered_beta_by_server() {
// When
CompletionStage channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
Frame requestFrame = readOutboundFrame();
assertThat(requestFrame.message).isInstanceOf(Options.class);
@@ -144,7 +156,11 @@ public void should_succeed_if_version_not_specified_and_server_supports_latest_s
// When
CompletionStage channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
Frame requestFrame = readOutboundFrame();
assertThat(requestFrame.message).isInstanceOf(Options.class);
@@ -176,7 +192,11 @@ public void should_negotiate_if_version_not_specified_and_server_supports_legacy
// When
CompletionStage channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
Frame requestFrame = readOutboundFrame();
assertThat(requestFrame.message).isInstanceOf(Options.class);
@@ -219,7 +239,11 @@ public void should_fail_if_negotiation_finds_no_matching_version(int errorCode)
// When
CompletionStage channelFuture =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
Frame requestFrame = readOutboundFrame();
assertThat(requestFrame.message).isInstanceOf(Options.class);
diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactorySupportedOptionsTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactorySupportedOptionsTest.java
index 559e11e0bc2..87407921619 100644
--- a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactorySupportedOptionsTest.java
+++ b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactorySupportedOptionsTest.java
@@ -41,7 +41,11 @@ public void should_query_supported_options_on_first_channel() throws Throwable {
// When
CompletionStage channelFuture1 =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
writeInboundFrame(
readOutboundFrame(), TestResponses.supportedResponse("mock_key", "mock_value"));
writeInboundFrame(readOutboundFrame(), new Ready());
@@ -56,7 +60,11 @@ public void should_query_supported_options_on_first_channel() throws Throwable {
// When
CompletionStage channelFuture2 =
factory.connect(
- SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE);
+ SERVER_ADDRESS,
+ null,
+ null,
+ DriverChannelOptions.DEFAULT,
+ NoopNodeMetricUpdater.INSTANCE);
writeInboundFrame(readOutboundFrame(), new Ready());
writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("mockClusterName"));
From 1d92051d6a5e73b8c2d0e4236526dcaf8e969859 Mon Sep 17 00:00:00 2001
From: Bouncheck <36934780+Bouncheck@users.noreply.github.com>
Date: Mon, 14 Apr 2025 12:00:49 +0200
Subject: [PATCH 4/4] Enable advanced shard awareness
Makes `addMissingChannels` use advanced shard awareness.
It will now specify target shard when adding missing channels for
specific shards.
In case returned channels do not match requested shards warnings are logged.
Initial connection to the node works on previous rules, meaning it uses
arbitrary local port for connection to arbitrary shard.
Adds AdvancedShardAwarenessIT that has several methods displaying
the difference between establishing connections with option enabled
and disabled.
Adds ChannelPoolShardAwareInitTest
---
.../internal/core/channel/ChannelFactory.java | 30 +-
.../internal/core/pool/ChannelPool.java | 38 ++-
.../pool/ChannelPoolShardAwareInitTest.java | 105 ++++++
.../core/pool/ChannelPoolTestBase.java | 3 +
.../core/pool/AdvancedShardAwarenessIT.java | 299 ++++++++++++++++++
5 files changed, 467 insertions(+), 8 deletions(-)
create mode 100644 core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolShardAwareInitTest.java
create mode 100644 integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java
diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java
index cdd82314144..0d29b59c2ad 100644
--- a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java
+++ b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java
@@ -241,7 +241,15 @@ private void connect(
} else {
int localPort =
PortAllocator.getNextAvailablePort(shardingInfo.getShardsCount(), shardId, context);
- connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort));
+ if (localPort == -1) {
+ LOG.warn(
+ "Could not find free port for shard {} at {}. Falling back to arbitrary local port.",
+ shardId,
+ endPoint);
+ connectFuture = bootstrap.connect(endPoint.resolve());
+ } else {
+ connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort));
+ }
}
connectFuture.addListener(
@@ -434,6 +442,7 @@ protected void initChannel(Channel channel) {
static class PortAllocator {
private static final AtomicInteger lastPort = new AtomicInteger(-1);
+ private static final Logger LOG = LoggerFactory.getLogger(PortAllocator.class);
public static int getNextAvailablePort(int shardCount, int shardId, DriverContext context) {
int lowPort =
@@ -446,6 +455,13 @@ public static int getNextAvailablePort(int shardCount, int shardId, DriverContex
.getConfig()
.getDefaultProfile()
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
+ if (highPort - lowPort < shardCount) {
+ LOG.error(
+ "There is not enough ports in range [{},{}] for {} shards. Update your configuration.",
+ lowPort,
+ highPort,
+ shardCount);
+ }
int lastPortValue, foundPort = -1;
do {
lastPortValue = lastPort.get();
@@ -462,7 +478,7 @@ public static int getNextAvailablePort(int shardCount, int shardId, DriverContex
// Scan from scanStart upwards to highPort.
for (int port = scanStart; port <= highPort; port += shardCount) {
- if (isTcpPortAvailable(port)) {
+ if (isTcpPortAvailable(port, context)) {
foundPort = port;
break;
}
@@ -476,7 +492,7 @@ public static int getNextAvailablePort(int shardCount, int shardId, DriverContex
scanStart = lowPort + (shardCount - lowPort % shardCount) + shardId;
for (int port = scanStart; port <= highPort; port += shardCount) {
- if (isTcpPortAvailable(port)) {
+ if (isTcpPortAvailable(port, context)) {
foundPort = port;
break;
}
@@ -492,11 +508,15 @@ public static int getNextAvailablePort(int shardCount, int shardId, DriverContex
return foundPort;
}
- public static boolean isTcpPortAvailable(int port) {
+ public static boolean isTcpPortAvailable(int port, DriverContext context) {
try {
ServerSocket serverSocket = new ServerSocket();
try {
- serverSocket.setReuseAddress(false);
+ serverSocket.setReuseAddress(
+ context
+ .getConfig()
+ .getDefaultProfile()
+ .getBoolean(DefaultDriverOption.SOCKET_REUSE_ADDRESS, false));
serverSocket.bind(new InetSocketAddress(port), 1);
return true;
} finally {
diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java b/core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java
index cea941ad96f..1b0a59fb505 100644
--- a/core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java
+++ b/core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java
@@ -58,6 +58,7 @@
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
+import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
@@ -489,9 +490,23 @@ private CompletionStage addMissingChannels() {
channels.length * wantedCount - Arrays.stream(channels).mapToInt(ChannelSet::size).sum();
LOG.debug("[{}] Trying to create {} missing channels", logPrefix, missing);
DriverChannelOptions options = buildDriverOptions();
- for (int i = 0; i < missing; i++) {
- CompletionStage channelFuture = channelFactory.connect(node, options);
- pendingChannels.add(channelFuture);
+ for (int shard = 0; shard < channels.length; shard++) {
+ LOG.trace(
+ "[{}] Missing {} channels for shard {}",
+ logPrefix,
+ wantedCount - channels[shard].size(),
+ shard);
+ for (int p = channels[shard].size(); p < wantedCount; p++) {
+ CompletionStage channelFuture;
+ if (config
+ .getDefaultProfile()
+ .getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)) {
+ channelFuture = channelFactory.connect(node, shard, options);
+ } else {
+ channelFuture = channelFactory.connect(node, options);
+ }
+ pendingChannels.add(channelFuture);
+ }
}
return CompletableFutures.allDone(pendingChannels)
.thenApplyAsync(this::onAllConnected, adminExecutor);
@@ -551,6 +566,23 @@ private boolean onAllConnected(@SuppressWarnings("unused") Void v) {
channel);
channel.forceClose();
} else {
+ if (config
+ .getDefaultProfile()
+ .getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)
+ && channel.localAddress() instanceof InetSocketAddress
+ && channel.getShardingInfo() != null) {
+ int port = ((InetSocketAddress) channel.localAddress()).getPort();
+ int actualShard = channel.getShardId();
+ int targetShard = port % channel.getShardingInfo().getShardsCount();
+ if (actualShard != targetShard) {
+ LOG.warn(
+ "[{}] New channel {} connected to shard {}, but shard {} was requested. If this is not transient check your driver AND cluster configuration of shard aware port.",
+ logPrefix,
+ channel,
+ actualShard,
+ targetShard);
+ }
+ }
LOG.debug("[{}] New channel added {}", logPrefix, channel);
if (channels[channel.getShardId()].size() < wantedCount) {
addChannel(channel);
diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolShardAwareInitTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolShardAwareInitTest.java
new file mode 100644
index 00000000000..77296b16376
--- /dev/null
+++ b/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolShardAwareInitTest.java
@@ -0,0 +1,105 @@
+package com.datastax.oss.driver.internal.core.pool;
+
+import static com.datastax.oss.driver.Assertions.assertThatStage;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.timeout;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
+import com.datastax.oss.driver.api.core.loadbalancing.NodeDistance;
+import com.datastax.oss.driver.api.core.metadata.Node;
+import com.datastax.oss.driver.internal.core.channel.ChannelEvent;
+import com.datastax.oss.driver.internal.core.channel.DriverChannel;
+import com.datastax.oss.driver.internal.core.channel.DriverChannelOptions;
+import com.datastax.oss.driver.internal.core.protocol.ShardingInfo;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.InOrder;
+import org.mockito.Mockito;
+
+public class ChannelPoolShardAwareInitTest extends ChannelPoolTestBase {
+
+ @Before
+ @Override
+ public void setup() {
+ super.setup();
+ when(defaultProfile.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED))
+ .thenReturn(true);
+ when(defaultProfile.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW))
+ .thenReturn(10000);
+ when(defaultProfile.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH))
+ .thenReturn(60000);
+ }
+
+ @Test
+ public void should_initialize_when_all_channels_succeed() throws Exception {
+ when(defaultProfile.getInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE)).thenReturn(4);
+ int shardsPerNode = 2;
+ DriverChannel channel1 = newMockDriverChannel(1);
+ DriverChannel channel2 = newMockDriverChannel(2);
+ DriverChannel channel3 = newMockDriverChannel(3);
+ DriverChannel channel4 = newMockDriverChannel(4);
+
+ ShardingInfo shardingInfo = mock(ShardingInfo.class);
+ when(shardingInfo.getShardsCount()).thenReturn(shardsPerNode);
+ node.setShardingInfo(shardingInfo);
+
+ when(channel1.getShardingInfo()).thenReturn(shardingInfo);
+ when(channel2.getShardingInfo()).thenReturn(shardingInfo);
+ when(channel3.getShardingInfo()).thenReturn(shardingInfo);
+ when(channel4.getShardingInfo()).thenReturn(shardingInfo);
+
+ when(channel1.getShardId()).thenReturn(0);
+ when(channel2.getShardId()).thenReturn(0);
+ when(channel3.getShardId()).thenReturn(1);
+ when(channel4.getShardId()).thenReturn(1);
+
+ when(channelFactory.connect(eq(node), any(DriverChannelOptions.class)))
+ .thenReturn(CompletableFuture.completedFuture(channel1));
+ when(channelFactory.connect(eq(node), eq(0), any(DriverChannelOptions.class)))
+ .thenReturn(CompletableFuture.completedFuture(channel2));
+ when(channelFactory.connect(eq(node), eq(1), any(DriverChannelOptions.class)))
+ .thenReturn(CompletableFuture.completedFuture(channel3))
+ .thenReturn(CompletableFuture.completedFuture(channel4));
+
+ CompletionStage poolFuture =
+ ChannelPool.init(node, null, NodeDistance.LOCAL, context, "test");
+
+ ArgumentCaptor optionsCaptor =
+ ArgumentCaptor.forClass(DriverChannelOptions.class);
+ InOrder inOrder = Mockito.inOrder(channelFactory);
+ inOrder
+ .verify(channelFactory, timeout(500).atLeast(1))
+ .connect(eq(node), optionsCaptor.capture());
+ int num = optionsCaptor.getAllValues().size();
+ assertThat(num).isEqualTo(1);
+ inOrder
+ .verify(channelFactory, timeout(500).atLeast(3))
+ .connect(eq(node), anyInt(), optionsCaptor.capture());
+ int num2 = optionsCaptor.getAllValues().size();
+ assertThat(num2).isEqualTo(4);
+
+ assertThatStage(poolFuture)
+ .isSuccess(
+ pool -> {
+ assertThat(pool.channels[0]).containsOnly(channel1, channel2);
+ assertThat(pool.channels[1]).containsOnly(channel3, channel4);
+ });
+ verify(eventBus, VERIFY_TIMEOUT.times(4)).fire(ChannelEvent.channelOpened(node));
+
+ inOrder
+ .verify(channelFactory, timeout(500).times(0))
+ .connect(any(Node.class), any(DriverChannelOptions.class));
+ inOrder
+ .verify(channelFactory, timeout(500).times(0))
+ .connect(any(Node.class), anyInt(), any(DriverChannelOptions.class));
+ }
+}
diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java b/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java
index 2f8056e49e0..eb0b5633d60 100644
--- a/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java
+++ b/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java
@@ -24,6 +24,7 @@
import static org.mockito.Mockito.when;
import com.datastax.oss.driver.api.core.CqlIdentifier;
+import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverConfig;
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
import com.datastax.oss.driver.api.core.connection.ReconnectionPolicy;
@@ -77,6 +78,8 @@ public void setup() {
when(nettyOptions.adminEventExecutorGroup()).thenReturn(adminEventLoopGroup);
when(context.getConfig()).thenReturn(config);
when(config.getDefaultProfile()).thenReturn(defaultProfile);
+ when(defaultProfile.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED))
+ .thenReturn(false);
this.eventBus = spy(new EventBus("test"));
when(context.getEventBus()).thenReturn(eventBus);
when(context.getChannelFactory()).thenReturn(channelFactory);
diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java
new file mode 100644
index 00000000000..5734036bc72
--- /dev/null
+++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java
@@ -0,0 +1,299 @@
+package com.datastax.oss.driver.core.pool;
+
+import static junit.framework.TestCase.fail;
+
+import ch.qos.logback.classic.Level;
+import ch.qos.logback.classic.Logger;
+import ch.qos.logback.classic.spi.ILoggingEvent;
+import ch.qos.logback.core.read.ListAppender;
+import com.datastax.oss.driver.api.core.CqlSession;
+import com.datastax.oss.driver.api.core.CqlSessionBuilder;
+import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
+import com.datastax.oss.driver.api.core.config.DriverConfigLoader;
+import com.datastax.oss.driver.api.core.session.Session;
+import com.datastax.oss.driver.api.testinfra.CassandraSkip;
+import com.datastax.oss.driver.api.testinfra.ccm.CustomCcmRule;
+import com.datastax.oss.driver.api.testinfra.session.SessionUtils;
+import com.datastax.oss.driver.internal.core.pool.ChannelPool;
+import com.datastax.oss.driver.internal.core.util.concurrent.CompletableFutures;
+import com.datastax.oss.driver.internal.core.util.concurrent.Reconnection;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.util.concurrent.Uninterruptibles;
+import com.tngtech.java.junit.dataprovider.DataProvider;
+import com.tngtech.java.junit.dataprovider.DataProviderRunner;
+import com.tngtech.java.junit.dataprovider.UseDataProvider;
+import java.net.InetSocketAddress;
+import java.time.Duration;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletionStage;
+import java.util.concurrent.TimeUnit;
+import java.util.regex.Pattern;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.slf4j.LoggerFactory;
+
+@CassandraSkip(description = "Advanced shard awareness relies on ScyllaDB's shard aware port")
+@RunWith(DataProviderRunner.class)
+public class AdvancedShardAwarenessIT {
+
+ @ClassRule
+ public static final CustomCcmRule CCM_RULE =
+ CustomCcmRule.builder().withNodes(2).withJvmArgs("--smp=3").build();
+
+ public static ch.qos.logback.classic.Logger channelPoolLogger =
+ (ch.qos.logback.classic.Logger) LoggerFactory.getLogger(ChannelPool.class);
+ public static ch.qos.logback.classic.Logger reconnectionLogger =
+ (ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Reconnection.class);
+ ListAppender appender;
+ Level originalLevelChannelPool;
+ Level originalLevelReconnection;
+ private final Pattern shardMismatchPattern =
+ Pattern.compile(".*r configuration of shard aware port.*");
+ private final Pattern reconnectionPattern =
+ Pattern.compile(".*Scheduling next reconnection in.*");
+ Set forbiddenOccurences = ImmutableSet.of(shardMismatchPattern, reconnectionPattern);
+
+ @DataProvider
+ public static Object[][] reuseAddressOption() {
+ return new Object[][] {{true}, {false}};
+ }
+
+ @Before
+ public void startCapturingLogs() {
+ originalLevelChannelPool = channelPoolLogger.getLevel();
+ originalLevelReconnection = reconnectionLogger.getLevel();
+ channelPoolLogger.setLevel(Level.DEBUG);
+ reconnectionLogger.setLevel(Level.DEBUG);
+ appender = new ListAppender<>();
+ appender.setContext(
+ ((Logger) LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME)).getLoggerContext());
+ channelPoolLogger.addAppender(appender);
+ reconnectionLogger.addAppender(appender);
+ appender.list.clear();
+ appender.start();
+ }
+
+ @After
+ public void stopCapturingLogs() {
+ appender.stop();
+ appender.list.clear();
+ channelPoolLogger.setLevel(originalLevelChannelPool);
+ reconnectionLogger.setLevel(originalLevelReconnection);
+ channelPoolLogger.detachAppender(appender);
+ reconnectionLogger.detachAppender(appender);
+ }
+
+ @Test
+ @UseDataProvider("reuseAddressOption")
+ public void should_initialize_all_channels(boolean reuseAddress) {
+ Map expectedOccurences =
+ ImmutableMap.of(
+ Pattern.compile(
+ ".*127\\.0\\.0\\.2:19042.*Reconnection attempt complete, 6/6 channels.*"),
+ 1,
+ Pattern.compile(
+ ".*127\\.0\\.0\\.1:19042.*Reconnection attempt complete, 6/6 channels.*"),
+ 1,
+ Pattern.compile(".*Reconnection attempt complete.*"), 2,
+ Pattern.compile(".*127\\.0\\.0\\.1:19042.*New channel added \\[.*"), 5,
+ Pattern.compile(".*127\\.0\\.0\\.2:19042.*New channel added \\[.*"), 5,
+ Pattern.compile(".*127\\.0\\.0\\.1:19042\\] Trying to create 5 missing channels.*"), 1,
+ Pattern.compile(".*127\\.0\\.0\\.2:19042\\] Trying to create 5 missing channels.*"), 1);
+ DriverConfigLoader loader =
+ SessionUtils.configLoaderBuilder()
+ .withBoolean(DefaultDriverOption.SOCKET_REUSE_ADDRESS, reuseAddress)
+ .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true)
+ .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000)
+ .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 60000)
+ // Due to rounding up the connections per shard this will result in 6 connections per
+ // node
+ .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 4)
+ .build();
+ try (Session session =
+ CqlSession.builder()
+ .addContactPoint(
+ new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 19042))
+ .withConfigLoader(loader)
+ .build()) {
+ Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS);
+ expectedOccurences.forEach(
+ (pattern, times) -> assertMatchesExactly(pattern, times, appender.list));
+ forbiddenOccurences.forEach(pattern -> assertNoLogMatches(pattern, appender.list));
+ }
+ }
+
+ @Test
+ public void should_see_mismatched_shard() {
+ DriverConfigLoader loader =
+ SessionUtils.configLoaderBuilder()
+ .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true)
+ .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000)
+ .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 60000)
+ .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 64)
+ .build();
+ try (Session session =
+ CqlSession.builder()
+ .addContactPoint(
+ new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 9042))
+ .withConfigLoader(loader)
+ .build()) {
+ Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS);
+ assertMatchesAtLeast(shardMismatchPattern, 5, appender.list);
+ }
+ }
+
+ // There is no need to run this as a test, but it serves as a comparison
+ @SuppressWarnings("unused")
+ public void should_struggle_to_fill_pools() {
+ DriverConfigLoader loader =
+ SessionUtils.configLoaderBuilder()
+ .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, false)
+ .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 64)
+ .withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofMillis(200))
+ .withDuration(DefaultDriverOption.RECONNECTION_MAX_DELAY, Duration.ofMillis(4000))
+ .build();
+ CqlSessionBuilder builder =
+ CqlSession.builder()
+ .addContactPoint(
+ new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 9042))
+ .withConfigLoader(loader);
+ CompletionStage stage1 = builder.buildAsync();
+ CompletionStage stage2 = builder.buildAsync();
+ CompletionStage stage3 = builder.buildAsync();
+ CompletionStage stage4 = builder.buildAsync();
+ try (CqlSession session1 = CompletableFutures.getUninterruptibly(stage1);
+ CqlSession session2 = CompletableFutures.getUninterruptibly(stage2);
+ CqlSession session3 = CompletableFutures.getUninterruptibly(stage3);
+ CqlSession session4 = CompletableFutures.getUninterruptibly(stage4); ) {
+ Uninterruptibles.sleepUninterruptibly(20, TimeUnit.SECONDS);
+ assertNoLogMatches(shardMismatchPattern, appender.list);
+ assertMatchesAtLeast(reconnectionPattern, 8, appender.list);
+ }
+ }
+
+ @Test
+ public void should_not_struggle_to_fill_pools() {
+ DriverConfigLoader loader =
+ SessionUtils.configLoaderBuilder()
+ .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true)
+ .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 66)
+ .withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofMillis(10))
+ .withDuration(DefaultDriverOption.RECONNECTION_MAX_DELAY, Duration.ofMillis(20))
+ .build();
+ CqlSessionBuilder builder =
+ CqlSession.builder()
+ .addContactPoint(
+ new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 19042))
+ .withConfigLoader(loader);
+ CompletionStage stage1 = builder.buildAsync();
+ CompletionStage stage2 = builder.buildAsync();
+ CompletionStage stage3 = builder.buildAsync();
+ CompletionStage stage4 = builder.buildAsync();
+ int sessions = 4;
+ try (CqlSession session1 = CompletableFutures.getUninterruptibly(stage1);
+ CqlSession session2 = CompletableFutures.getUninterruptibly(stage2);
+ CqlSession session3 = CompletableFutures.getUninterruptibly(stage3);
+ CqlSession session4 = CompletableFutures.getUninterruptibly(stage4); ) {
+ Uninterruptibles.sleepUninterruptibly(8, TimeUnit.SECONDS);
+ int tolerance = 2; // Sometimes socket ends up already in use
+ Map expectedOccurences =
+ ImmutableMap.of(
+ Pattern.compile(
+ ".*127\\.0\\.0\\.2:19042.*Reconnection attempt complete, 66/66 channels.*"),
+ 1 * sessions,
+ Pattern.compile(
+ ".*127\\.0\\.0\\.1:19042.*Reconnection attempt complete, 66/66 channels.*"),
+ 1 * sessions,
+ Pattern.compile(".*Reconnection attempt complete.*"), 2 * sessions,
+ Pattern.compile(".*127\\.0\\.0\\.1:19042.*New channel added \\[.*"),
+ 65 * sessions - tolerance,
+ Pattern.compile(".*127\\.0\\.0\\.2:19042.*New channel added \\[.*"),
+ 65 * sessions - tolerance,
+ Pattern.compile(".*127\\.0\\.0\\.1:19042\\] Trying to create 65 missing channels.*"),
+ 1 * sessions,
+ Pattern.compile(".*127\\.0\\.0\\.2:19042\\] Trying to create 65 missing channels.*"),
+ 1 * sessions);
+ expectedOccurences.forEach(
+ (pattern, times) -> assertMatchesAtLeast(pattern, times, appender.list));
+ assertNoLogMatches(shardMismatchPattern, appender.list);
+ assertMatchesAtMost(reconnectionPattern, tolerance, appender.list);
+ }
+ }
+
+ private void assertNoLogMatches(Pattern pattern, List logs) {
+ for (ILoggingEvent log : logs) {
+ if (pattern.matcher(log.getFormattedMessage()).matches()) {
+ fail(
+ "Logs should not contain pattern ["
+ + pattern.toString()
+ + "] but found in ["
+ + log.getFormattedMessage()
+ + "]");
+ }
+ }
+ }
+
+ private void assertMatchesExactly(Pattern pattern, Integer times, List logs) {
+ int occurences = 0;
+ for (ILoggingEvent log : logs) {
+ if (pattern.matcher(log.getFormattedMessage()).matches()) {
+ occurences++;
+ }
+ }
+ if (occurences != times) {
+ fail(
+ "Expected to find pattern exactly "
+ + times
+ + " times but found it "
+ + occurences
+ + " times. Pattern: ["
+ + pattern.toString()
+ + "]");
+ }
+ }
+
+ private void assertMatchesAtLeast(Pattern pattern, Integer times, List logs) {
+ int occurences = 0;
+ for (ILoggingEvent log : logs) {
+ if (pattern.matcher(log.getFormattedMessage()).matches()) {
+ occurences++;
+ if (occurences >= times) {
+ return;
+ }
+ }
+ }
+ fail(
+ "Expected to find pattern at least "
+ + times
+ + " times but found only "
+ + occurences
+ + " times. Pattern: ["
+ + pattern.toString()
+ + "]");
+ }
+
+ private void assertMatchesAtMost(Pattern pattern, Integer times, List logs) {
+ int occurences = 0;
+ for (ILoggingEvent log : logs) {
+ if (pattern.matcher(log.getFormattedMessage()).matches()) {
+ occurences++;
+ if (occurences > times) {
+ fail(
+ "Expected to find pattern at most "
+ + times
+ + " times but found it "
+ + occurences
+ + " times. Pattern: ["
+ + pattern.toString()
+ + "]");
+ }
+ }
+ }
+ }
+}