diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java index 7ff3b2719ba..a16c04c6c1e 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java @@ -141,6 +141,17 @@ public enum DefaultDriverOption implements DriverOption { *

Value-type: boolean */ CONNECTION_WARN_INIT_ERROR("advanced.connection.warn-on-init-error"), + /** + * Whether to use advanced shard awareness. + * + *

Value-type: boolean + */ + CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED( + "advanced.connection.advanced-shard-awareness.enabled"), + /** Inclusive lower bound of port range to use in advanced shard awareness */ + ADVANCED_SHARD_AWARENESS_PORT_LOW("advanced.connection.advanced-shard-awareness.port-low"), + /** Inclusive upper bound of port range to use in advanced shard awareness */ + ADVANCED_SHARD_AWARENESS_PORT_HIGH("advanced.connection.advanced-shard-awareness.port-high"), /** * The number of connections in the LOCAL pool. * diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java index 64665709028..bc5238908be 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java @@ -276,6 +276,9 @@ protected static void fillWithDriverDefaults(OptionsMap map) { map.put(TypedDriverOption.CONNECTION_MAX_REQUESTS, 1024); map.put(TypedDriverOption.CONNECTION_MAX_ORPHAN_REQUESTS, 256); map.put(TypedDriverOption.CONNECTION_WARN_INIT_ERROR, true); + map.put(TypedDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true); + map.put(TypedDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000); + map.put(TypedDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 65535); map.put(TypedDriverOption.RECONNECT_ON_INIT, false); map.put(TypedDriverOption.RECONNECTION_POLICY_CLASS, "ExponentialReconnectionPolicy"); map.put(TypedDriverOption.RECONNECTION_BASE_DELAY, Duration.ofSeconds(1)); diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java index b6051749c71..a37236a0e0f 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java @@ -175,6 +175,18 @@ public String toString() { /** Whether to log non-fatal errors when the driver tries to open a new connection. */ public static final TypedDriverOption CONNECTION_WARN_INIT_ERROR = new TypedDriverOption<>(DefaultDriverOption.CONNECTION_WARN_INIT_ERROR, GenericType.BOOLEAN); + /** Whether to use advanced shard awareness */ + public static final TypedDriverOption CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED = + new TypedDriverOption<>( + DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, GenericType.BOOLEAN); + /** Inclusive lower bound of port range to use in advanced shard awareness */ + public static final TypedDriverOption ADVANCED_SHARD_AWARENESS_PORT_LOW = + new TypedDriverOption<>( + DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, GenericType.INTEGER); + /** Inclusive upper bound of port range to use in advanced shard awareness */ + public static final TypedDriverOption ADVANCED_SHARD_AWARENESS_PORT_HIGH = + new TypedDriverOption<>( + DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, GenericType.INTEGER); /** The number of connections in the LOCAL pool. */ public static final TypedDriverOption CONNECTION_POOL_LOCAL_SIZE = new TypedDriverOption<>(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, GenericType.INTEGER); diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java index 02898a1fedd..0d29b59c2ad 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java @@ -29,8 +29,10 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfig; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import com.datastax.oss.driver.api.core.context.DriverContext; import com.datastax.oss.driver.api.core.metadata.EndPoint; import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.NodeShardingInfo; import com.datastax.oss.driver.api.core.metrics.DefaultNodeMetric; import com.datastax.oss.driver.api.core.metrics.DefaultSessionMetric; import com.datastax.oss.driver.internal.core.config.typesafe.TypesafeDriverConfig; @@ -51,6 +53,9 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; import java.util.List; import java.util.Map; import java.util.Optional; @@ -58,6 +63,7 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import net.jcip.annotations.ThreadSafe; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -153,12 +159,27 @@ public CompletionStage connect(Node node, DriverChannelOptions op } else { nodeMetricUpdater = NoopNodeMetricUpdater.INSTANCE; } - return connect(node.getEndPoint(), options, nodeMetricUpdater); + return connect(node.getEndPoint(), null, null, options, nodeMetricUpdater); + } + + public CompletionStage connect( + Node node, Integer shardId, DriverChannelOptions options) { + NodeMetricUpdater nodeMetricUpdater; + if (node instanceof DefaultNode) { + nodeMetricUpdater = ((DefaultNode) node).getMetricUpdater(); + } else { + nodeMetricUpdater = NoopNodeMetricUpdater.INSTANCE; + } + return connect(node.getEndPoint(), node.getShardingInfo(), shardId, options, nodeMetricUpdater); } @VisibleForTesting CompletionStage connect( - EndPoint endPoint, DriverChannelOptions options, NodeMetricUpdater nodeMetricUpdater) { + EndPoint endPoint, + NodeShardingInfo shardingInfo, + Integer shardId, + DriverChannelOptions options, + NodeMetricUpdater nodeMetricUpdater) { CompletableFuture resultFuture = new CompletableFuture<>(); ProtocolVersion currentVersion; @@ -174,6 +195,8 @@ CompletionStage connect( connect( endPoint, + shardingInfo, + shardId, options, nodeMetricUpdater, currentVersion, @@ -185,6 +208,8 @@ CompletionStage connect( private void connect( EndPoint endPoint, + NodeShardingInfo shardingInfo, + Integer shardId, DriverChannelOptions options, NodeMetricUpdater nodeMetricUpdater, ProtocolVersion currentVersion, @@ -204,7 +229,28 @@ private void connect( nettyOptions.afterBootstrapInitialized(bootstrap); - ChannelFuture connectFuture = bootstrap.connect(endPoint.resolve()); + ChannelFuture connectFuture; + if (shardId == null || shardingInfo == null) { + if (shardId != null) { + LOG.debug( + "Requested connection to shard {} but shardingInfo is currently missing for Node at endpoint {}. Falling back to arbitrary local port.", + shardId, + endPoint); + } + connectFuture = bootstrap.connect(endPoint.resolve()); + } else { + int localPort = + PortAllocator.getNextAvailablePort(shardingInfo.getShardsCount(), shardId, context); + if (localPort == -1) { + LOG.warn( + "Could not find free port for shard {} at {}. Falling back to arbitrary local port.", + shardId, + endPoint); + connectFuture = bootstrap.connect(endPoint.resolve()); + } else { + connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort)); + } + } connectFuture.addListener( cf -> { @@ -253,6 +299,8 @@ private void connect( downgraded.get()); connect( endPoint, + shardingInfo, + shardId, options, nodeMetricUpdater, downgraded.get(), @@ -391,4 +439,92 @@ protected void initChannel(Channel channel) { } } } + + static class PortAllocator { + private static final AtomicInteger lastPort = new AtomicInteger(-1); + private static final Logger LOG = LoggerFactory.getLogger(PortAllocator.class); + + public static int getNextAvailablePort(int shardCount, int shardId, DriverContext context) { + int lowPort = + context + .getConfig() + .getDefaultProfile() + .getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW); + int highPort = + context + .getConfig() + .getDefaultProfile() + .getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH); + if (highPort - lowPort < shardCount) { + LOG.error( + "There is not enough ports in range [{},{}] for {} shards. Update your configuration.", + lowPort, + highPort, + shardCount); + } + int lastPortValue, foundPort = -1; + do { + lastPortValue = lastPort.get(); + + // We will scan from lastPortValue + // (or lowPort is there was no lastPort or lastPort is too low) + int scanStart = lastPortValue == -1 ? lowPort : lastPortValue; + if (scanStart < lowPort) { + scanStart = lowPort; + } + + // Round it up to "% shardCount == shardId" + scanStart += (shardCount - scanStart % shardCount) + shardId; + + // Scan from scanStart upwards to highPort. + for (int port = scanStart; port <= highPort; port += shardCount) { + if (isTcpPortAvailable(port, context)) { + foundPort = port; + break; + } + } + + // If we started scanning from a high scanStart port + // there might have been not enough ports left that are + // smaller than highPort. Scan from the beginning + // from the lowPort. + if (foundPort == -1) { + scanStart = lowPort + (shardCount - lowPort % shardCount) + shardId; + + for (int port = scanStart; port <= highPort; port += shardCount) { + if (isTcpPortAvailable(port, context)) { + foundPort = port; + break; + } + } + } + + // No luck! All ports taken! + if (foundPort == -1) { + return -1; + } + } while (!lastPort.compareAndSet(lastPortValue, foundPort)); + + return foundPort; + } + + public static boolean isTcpPortAvailable(int port, DriverContext context) { + try { + ServerSocket serverSocket = new ServerSocket(); + try { + serverSocket.setReuseAddress( + context + .getConfig() + .getDefaultProfile() + .getBoolean(DefaultDriverOption.SOCKET_REUSE_ADDRESS, false)); + serverSocket.bind(new InetSocketAddress(port), 1); + return true; + } finally { + serverSocket.close(); + } + } catch (IOException ex) { + return false; + } + } + } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java b/core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java index cea941ad96f..1b0a59fb505 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java @@ -58,6 +58,7 @@ import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; +import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -489,9 +490,23 @@ private CompletionStage addMissingChannels() { channels.length * wantedCount - Arrays.stream(channels).mapToInt(ChannelSet::size).sum(); LOG.debug("[{}] Trying to create {} missing channels", logPrefix, missing); DriverChannelOptions options = buildDriverOptions(); - for (int i = 0; i < missing; i++) { - CompletionStage channelFuture = channelFactory.connect(node, options); - pendingChannels.add(channelFuture); + for (int shard = 0; shard < channels.length; shard++) { + LOG.trace( + "[{}] Missing {} channels for shard {}", + logPrefix, + wantedCount - channels[shard].size(), + shard); + for (int p = channels[shard].size(); p < wantedCount; p++) { + CompletionStage channelFuture; + if (config + .getDefaultProfile() + .getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)) { + channelFuture = channelFactory.connect(node, shard, options); + } else { + channelFuture = channelFactory.connect(node, options); + } + pendingChannels.add(channelFuture); + } } return CompletableFutures.allDone(pendingChannels) .thenApplyAsync(this::onAllConnected, adminExecutor); @@ -551,6 +566,23 @@ private boolean onAllConnected(@SuppressWarnings("unused") Void v) { channel); channel.forceClose(); } else { + if (config + .getDefaultProfile() + .getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED) + && channel.localAddress() instanceof InetSocketAddress + && channel.getShardingInfo() != null) { + int port = ((InetSocketAddress) channel.localAddress()).getPort(); + int actualShard = channel.getShardId(); + int targetShard = port % channel.getShardingInfo().getShardsCount(); + if (actualShard != targetShard) { + LOG.warn( + "[{}] New channel {} connected to shard {}, but shard {} was requested. If this is not transient check your driver AND cluster configuration of shard aware port.", + logPrefix, + channel, + actualShard, + targetShard); + } + } LOG.debug("[{}] New channel added {}", logPrefix, channel); if (channels[channel.getShardId()].size() < wantedCount) { addChannel(channel); diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index 1ac37d14132..e29eb5e57b5 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -535,6 +535,52 @@ datastax-java-driver { # change. # Overridable in a profile: no warn-on-init-error = true + + + advanced-shard-awareness { + # Whether to use advanced shard awareness when trying to open new connections. + # + # Requires passing shard-aware port as contact point (usually 19042 or 19142(ssl)). + # Having this enabled makes sense only for ScyllaDB clusters. + # + # Reduces number of reconnections driver needs to fully initialize connection pool. + # In short it's a feature that allows targeting particular shard when connecting to a node by using specific + # local port number. + # For context see https://www.scylladb.com/2021/04/27/connect-faster-to-scylla-with-a-shard-aware-port/ + # + # If set to false the driver will not attempt to use this feature. This means connection's local port + # will be random according to system rules and driver will keep opening connections until it gets right shards. + # In such case non-shard aware port is recommended (by default 9042 or 9142). + # If set to true the driver will attempt to use it and will log warnings each time something + # makes it not possible. + # + # If the node for some reason does not report it's sharding info the driver + # will log a warning and create connection the same way as if this feature was disabled. + # If the cluster ignores the request for specific shard warning will also be logged, + # although the local port will already be chosen according to advanced shard awareness rules. + # + # Required: yes + # Modifiable at runtime: yes, the new value will be used for connections created after the + # change. + # Overridable in a profile: no + enabled = true + + # Inclusive lower bound of port range to use in advanced shard awareness + # The driver will attempt to reserve ports for connection only within the range. + # Required: yes + # Modifiable at runtime: yes, the new value will be used for calls after the + # change. + # Overridable in a profile: no + port-low = 10000 + + # Inclusive upper bound of port range to use in advanced shard awareness. + # The driver will attempt to reserve ports for connection only within the range. + # Required: yes + # Modifiable at runtime: yes, the new value will be used for calls after the + # change. + # Overridable in a profile: no + port-high = 65535 + } } # Advanced options for the built-in load-balancing policies. diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryAvailableIdsTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryAvailableIdsTest.java index a1eab41b998..d219bc3c73b 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryAvailableIdsTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryAvailableIdsTest.java @@ -61,7 +61,11 @@ public void should_report_available_ids() { // When CompletionStage channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.builder().build(), NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.builder().build(), + NoopNodeMetricUpdater.INSTANCE); completeSimpleChannelInit(); // Then diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryClusterNameTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryClusterNameTest.java index d9793247c9c..4734a8ffdeb 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryClusterNameTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryClusterNameTest.java @@ -41,7 +41,11 @@ public void should_set_cluster_name_from_first_connection() { // When CompletionStage channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); writeInboundFrame( readOutboundFrame(), TestResponses.supportedResponse("mock_key", "mock_value")); @@ -63,7 +67,11 @@ public void should_check_cluster_name_for_next_connections() throws Throwable { // When CompletionStage channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); // open a first connection that will define the cluster name writeInboundFrame( readOutboundFrame(), TestResponses.supportedResponse("mock_key", "mock_value")); @@ -73,7 +81,11 @@ public void should_check_cluster_name_for_next_connections() throws Throwable { // open a second connection that returns the same cluster name channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); writeInboundFrame(readOutboundFrame(), new Ready()); writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("mockClusterName")); @@ -84,7 +96,11 @@ public void should_check_cluster_name_for_next_connections() throws Throwable { // open a third connection that returns a different cluster name channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); writeInboundFrame(readOutboundFrame(), new Ready()); writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("wrongClusterName")); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryProtocolNegotiationTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryProtocolNegotiationTest.java index b9738a140c0..fceb8777904 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryProtocolNegotiationTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryProtocolNegotiationTest.java @@ -50,7 +50,11 @@ public void should_succeed_if_version_specified_and_supported_by_server() { // When CompletionStage channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); completeSimpleChannelInit(); @@ -72,7 +76,11 @@ public void should_fail_if_version_specified_and_not_supported_by_server(int err // When CompletionStage channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); Frame requestFrame = readOutboundFrame(); assertThat(requestFrame.message).isInstanceOf(Options.class); @@ -107,7 +115,11 @@ public void should_fail_if_version_specified_and_considered_beta_by_server() { // When CompletionStage channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); Frame requestFrame = readOutboundFrame(); assertThat(requestFrame.message).isInstanceOf(Options.class); @@ -144,7 +156,11 @@ public void should_succeed_if_version_not_specified_and_server_supports_latest_s // When CompletionStage channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); Frame requestFrame = readOutboundFrame(); assertThat(requestFrame.message).isInstanceOf(Options.class); @@ -176,7 +192,11 @@ public void should_negotiate_if_version_not_specified_and_server_supports_legacy // When CompletionStage channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); Frame requestFrame = readOutboundFrame(); assertThat(requestFrame.message).isInstanceOf(Options.class); @@ -219,7 +239,11 @@ public void should_fail_if_negotiation_finds_no_matching_version(int errorCode) // When CompletionStage channelFuture = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); Frame requestFrame = readOutboundFrame(); assertThat(requestFrame.message).isInstanceOf(Options.class); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactorySupportedOptionsTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactorySupportedOptionsTest.java index 559e11e0bc2..87407921619 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactorySupportedOptionsTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactorySupportedOptionsTest.java @@ -41,7 +41,11 @@ public void should_query_supported_options_on_first_channel() throws Throwable { // When CompletionStage channelFuture1 = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); writeInboundFrame( readOutboundFrame(), TestResponses.supportedResponse("mock_key", "mock_value")); writeInboundFrame(readOutboundFrame(), new Ready()); @@ -56,7 +60,11 @@ public void should_query_supported_options_on_first_channel() throws Throwable { // When CompletionStage channelFuture2 = factory.connect( - SERVER_ADDRESS, DriverChannelOptions.DEFAULT, NoopNodeMetricUpdater.INSTANCE); + SERVER_ADDRESS, + null, + null, + DriverChannelOptions.DEFAULT, + NoopNodeMetricUpdater.INSTANCE); writeInboundFrame(readOutboundFrame(), new Ready()); writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("mockClusterName")); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolShardAwareInitTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolShardAwareInitTest.java new file mode 100644 index 00000000000..77296b16376 --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolShardAwareInitTest.java @@ -0,0 +1,105 @@ +package com.datastax.oss.driver.internal.core.pool; + +import static com.datastax.oss.driver.Assertions.assertThatStage; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.loadbalancing.NodeDistance; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.internal.core.channel.ChannelEvent; +import com.datastax.oss.driver.internal.core.channel.DriverChannel; +import com.datastax.oss.driver.internal.core.channel.DriverChannelOptions; +import com.datastax.oss.driver.internal.core.protocol.ShardingInfo; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.mockito.Mockito; + +public class ChannelPoolShardAwareInitTest extends ChannelPoolTestBase { + + @Before + @Override + public void setup() { + super.setup(); + when(defaultProfile.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)) + .thenReturn(true); + when(defaultProfile.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW)) + .thenReturn(10000); + when(defaultProfile.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH)) + .thenReturn(60000); + } + + @Test + public void should_initialize_when_all_channels_succeed() throws Exception { + when(defaultProfile.getInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE)).thenReturn(4); + int shardsPerNode = 2; + DriverChannel channel1 = newMockDriverChannel(1); + DriverChannel channel2 = newMockDriverChannel(2); + DriverChannel channel3 = newMockDriverChannel(3); + DriverChannel channel4 = newMockDriverChannel(4); + + ShardingInfo shardingInfo = mock(ShardingInfo.class); + when(shardingInfo.getShardsCount()).thenReturn(shardsPerNode); + node.setShardingInfo(shardingInfo); + + when(channel1.getShardingInfo()).thenReturn(shardingInfo); + when(channel2.getShardingInfo()).thenReturn(shardingInfo); + when(channel3.getShardingInfo()).thenReturn(shardingInfo); + when(channel4.getShardingInfo()).thenReturn(shardingInfo); + + when(channel1.getShardId()).thenReturn(0); + when(channel2.getShardId()).thenReturn(0); + when(channel3.getShardId()).thenReturn(1); + when(channel4.getShardId()).thenReturn(1); + + when(channelFactory.connect(eq(node), any(DriverChannelOptions.class))) + .thenReturn(CompletableFuture.completedFuture(channel1)); + when(channelFactory.connect(eq(node), eq(0), any(DriverChannelOptions.class))) + .thenReturn(CompletableFuture.completedFuture(channel2)); + when(channelFactory.connect(eq(node), eq(1), any(DriverChannelOptions.class))) + .thenReturn(CompletableFuture.completedFuture(channel3)) + .thenReturn(CompletableFuture.completedFuture(channel4)); + + CompletionStage poolFuture = + ChannelPool.init(node, null, NodeDistance.LOCAL, context, "test"); + + ArgumentCaptor optionsCaptor = + ArgumentCaptor.forClass(DriverChannelOptions.class); + InOrder inOrder = Mockito.inOrder(channelFactory); + inOrder + .verify(channelFactory, timeout(500).atLeast(1)) + .connect(eq(node), optionsCaptor.capture()); + int num = optionsCaptor.getAllValues().size(); + assertThat(num).isEqualTo(1); + inOrder + .verify(channelFactory, timeout(500).atLeast(3)) + .connect(eq(node), anyInt(), optionsCaptor.capture()); + int num2 = optionsCaptor.getAllValues().size(); + assertThat(num2).isEqualTo(4); + + assertThatStage(poolFuture) + .isSuccess( + pool -> { + assertThat(pool.channels[0]).containsOnly(channel1, channel2); + assertThat(pool.channels[1]).containsOnly(channel3, channel4); + }); + verify(eventBus, VERIFY_TIMEOUT.times(4)).fire(ChannelEvent.channelOpened(node)); + + inOrder + .verify(channelFactory, timeout(500).times(0)) + .connect(any(Node.class), any(DriverChannelOptions.class)); + inOrder + .verify(channelFactory, timeout(500).times(0)) + .connect(any(Node.class), anyInt(), any(DriverChannelOptions.class)); + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java b/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java index 2f8056e49e0..eb0b5633d60 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.when; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfig; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.connection.ReconnectionPolicy; @@ -77,6 +78,8 @@ public void setup() { when(nettyOptions.adminEventExecutorGroup()).thenReturn(adminEventLoopGroup); when(context.getConfig()).thenReturn(config); when(config.getDefaultProfile()).thenReturn(defaultProfile); + when(defaultProfile.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)) + .thenReturn(false); this.eventBus = spy(new EventBus("test")); when(context.getEventBus()).thenReturn(eventBus); when(context.getChannelFactory()).thenReturn(channelFactory); diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java new file mode 100644 index 00000000000..5734036bc72 --- /dev/null +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java @@ -0,0 +1,299 @@ +package com.datastax.oss.driver.core.pool; + +import static junit.framework.TestCase.fail; + +import ch.qos.logback.classic.Level; +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverConfigLoader; +import com.datastax.oss.driver.api.core.session.Session; +import com.datastax.oss.driver.api.testinfra.CassandraSkip; +import com.datastax.oss.driver.api.testinfra.ccm.CustomCcmRule; +import com.datastax.oss.driver.api.testinfra.session.SessionUtils; +import com.datastax.oss.driver.internal.core.pool.ChannelPool; +import com.datastax.oss.driver.internal.core.util.concurrent.CompletableFutures; +import com.datastax.oss.driver.internal.core.util.concurrent.Reconnection; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.Uninterruptibles; +import com.tngtech.java.junit.dataprovider.DataProvider; +import com.tngtech.java.junit.dataprovider.DataProviderRunner; +import com.tngtech.java.junit.dataprovider.UseDataProvider; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.slf4j.LoggerFactory; + +@CassandraSkip(description = "Advanced shard awareness relies on ScyllaDB's shard aware port") +@RunWith(DataProviderRunner.class) +public class AdvancedShardAwarenessIT { + + @ClassRule + public static final CustomCcmRule CCM_RULE = + CustomCcmRule.builder().withNodes(2).withJvmArgs("--smp=3").build(); + + public static ch.qos.logback.classic.Logger channelPoolLogger = + (ch.qos.logback.classic.Logger) LoggerFactory.getLogger(ChannelPool.class); + public static ch.qos.logback.classic.Logger reconnectionLogger = + (ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Reconnection.class); + ListAppender appender; + Level originalLevelChannelPool; + Level originalLevelReconnection; + private final Pattern shardMismatchPattern = + Pattern.compile(".*r configuration of shard aware port.*"); + private final Pattern reconnectionPattern = + Pattern.compile(".*Scheduling next reconnection in.*"); + Set forbiddenOccurences = ImmutableSet.of(shardMismatchPattern, reconnectionPattern); + + @DataProvider + public static Object[][] reuseAddressOption() { + return new Object[][] {{true}, {false}}; + } + + @Before + public void startCapturingLogs() { + originalLevelChannelPool = channelPoolLogger.getLevel(); + originalLevelReconnection = reconnectionLogger.getLevel(); + channelPoolLogger.setLevel(Level.DEBUG); + reconnectionLogger.setLevel(Level.DEBUG); + appender = new ListAppender<>(); + appender.setContext( + ((Logger) LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME)).getLoggerContext()); + channelPoolLogger.addAppender(appender); + reconnectionLogger.addAppender(appender); + appender.list.clear(); + appender.start(); + } + + @After + public void stopCapturingLogs() { + appender.stop(); + appender.list.clear(); + channelPoolLogger.setLevel(originalLevelChannelPool); + reconnectionLogger.setLevel(originalLevelReconnection); + channelPoolLogger.detachAppender(appender); + reconnectionLogger.detachAppender(appender); + } + + @Test + @UseDataProvider("reuseAddressOption") + public void should_initialize_all_channels(boolean reuseAddress) { + Map expectedOccurences = + ImmutableMap.of( + Pattern.compile( + ".*127\\.0\\.0\\.2:19042.*Reconnection attempt complete, 6/6 channels.*"), + 1, + Pattern.compile( + ".*127\\.0\\.0\\.1:19042.*Reconnection attempt complete, 6/6 channels.*"), + 1, + Pattern.compile(".*Reconnection attempt complete.*"), 2, + Pattern.compile(".*127\\.0\\.0\\.1:19042.*New channel added \\[.*"), 5, + Pattern.compile(".*127\\.0\\.0\\.2:19042.*New channel added \\[.*"), 5, + Pattern.compile(".*127\\.0\\.0\\.1:19042\\] Trying to create 5 missing channels.*"), 1, + Pattern.compile(".*127\\.0\\.0\\.2:19042\\] Trying to create 5 missing channels.*"), 1); + DriverConfigLoader loader = + SessionUtils.configLoaderBuilder() + .withBoolean(DefaultDriverOption.SOCKET_REUSE_ADDRESS, reuseAddress) + .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true) + .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000) + .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 60000) + // Due to rounding up the connections per shard this will result in 6 connections per + // node + .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 4) + .build(); + try (Session session = + CqlSession.builder() + .addContactPoint( + new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 19042)) + .withConfigLoader(loader) + .build()) { + Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS); + expectedOccurences.forEach( + (pattern, times) -> assertMatchesExactly(pattern, times, appender.list)); + forbiddenOccurences.forEach(pattern -> assertNoLogMatches(pattern, appender.list)); + } + } + + @Test + public void should_see_mismatched_shard() { + DriverConfigLoader loader = + SessionUtils.configLoaderBuilder() + .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true) + .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000) + .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 60000) + .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 64) + .build(); + try (Session session = + CqlSession.builder() + .addContactPoint( + new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 9042)) + .withConfigLoader(loader) + .build()) { + Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS); + assertMatchesAtLeast(shardMismatchPattern, 5, appender.list); + } + } + + // There is no need to run this as a test, but it serves as a comparison + @SuppressWarnings("unused") + public void should_struggle_to_fill_pools() { + DriverConfigLoader loader = + SessionUtils.configLoaderBuilder() + .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, false) + .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 64) + .withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofMillis(200)) + .withDuration(DefaultDriverOption.RECONNECTION_MAX_DELAY, Duration.ofMillis(4000)) + .build(); + CqlSessionBuilder builder = + CqlSession.builder() + .addContactPoint( + new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 9042)) + .withConfigLoader(loader); + CompletionStage stage1 = builder.buildAsync(); + CompletionStage stage2 = builder.buildAsync(); + CompletionStage stage3 = builder.buildAsync(); + CompletionStage stage4 = builder.buildAsync(); + try (CqlSession session1 = CompletableFutures.getUninterruptibly(stage1); + CqlSession session2 = CompletableFutures.getUninterruptibly(stage2); + CqlSession session3 = CompletableFutures.getUninterruptibly(stage3); + CqlSession session4 = CompletableFutures.getUninterruptibly(stage4); ) { + Uninterruptibles.sleepUninterruptibly(20, TimeUnit.SECONDS); + assertNoLogMatches(shardMismatchPattern, appender.list); + assertMatchesAtLeast(reconnectionPattern, 8, appender.list); + } + } + + @Test + public void should_not_struggle_to_fill_pools() { + DriverConfigLoader loader = + SessionUtils.configLoaderBuilder() + .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true) + .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 66) + .withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofMillis(10)) + .withDuration(DefaultDriverOption.RECONNECTION_MAX_DELAY, Duration.ofMillis(20)) + .build(); + CqlSessionBuilder builder = + CqlSession.builder() + .addContactPoint( + new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 19042)) + .withConfigLoader(loader); + CompletionStage stage1 = builder.buildAsync(); + CompletionStage stage2 = builder.buildAsync(); + CompletionStage stage3 = builder.buildAsync(); + CompletionStage stage4 = builder.buildAsync(); + int sessions = 4; + try (CqlSession session1 = CompletableFutures.getUninterruptibly(stage1); + CqlSession session2 = CompletableFutures.getUninterruptibly(stage2); + CqlSession session3 = CompletableFutures.getUninterruptibly(stage3); + CqlSession session4 = CompletableFutures.getUninterruptibly(stage4); ) { + Uninterruptibles.sleepUninterruptibly(8, TimeUnit.SECONDS); + int tolerance = 2; // Sometimes socket ends up already in use + Map expectedOccurences = + ImmutableMap.of( + Pattern.compile( + ".*127\\.0\\.0\\.2:19042.*Reconnection attempt complete, 66/66 channels.*"), + 1 * sessions, + Pattern.compile( + ".*127\\.0\\.0\\.1:19042.*Reconnection attempt complete, 66/66 channels.*"), + 1 * sessions, + Pattern.compile(".*Reconnection attempt complete.*"), 2 * sessions, + Pattern.compile(".*127\\.0\\.0\\.1:19042.*New channel added \\[.*"), + 65 * sessions - tolerance, + Pattern.compile(".*127\\.0\\.0\\.2:19042.*New channel added \\[.*"), + 65 * sessions - tolerance, + Pattern.compile(".*127\\.0\\.0\\.1:19042\\] Trying to create 65 missing channels.*"), + 1 * sessions, + Pattern.compile(".*127\\.0\\.0\\.2:19042\\] Trying to create 65 missing channels.*"), + 1 * sessions); + expectedOccurences.forEach( + (pattern, times) -> assertMatchesAtLeast(pattern, times, appender.list)); + assertNoLogMatches(shardMismatchPattern, appender.list); + assertMatchesAtMost(reconnectionPattern, tolerance, appender.list); + } + } + + private void assertNoLogMatches(Pattern pattern, List logs) { + for (ILoggingEvent log : logs) { + if (pattern.matcher(log.getFormattedMessage()).matches()) { + fail( + "Logs should not contain pattern [" + + pattern.toString() + + "] but found in [" + + log.getFormattedMessage() + + "]"); + } + } + } + + private void assertMatchesExactly(Pattern pattern, Integer times, List logs) { + int occurences = 0; + for (ILoggingEvent log : logs) { + if (pattern.matcher(log.getFormattedMessage()).matches()) { + occurences++; + } + } + if (occurences != times) { + fail( + "Expected to find pattern exactly " + + times + + " times but found it " + + occurences + + " times. Pattern: [" + + pattern.toString() + + "]"); + } + } + + private void assertMatchesAtLeast(Pattern pattern, Integer times, List logs) { + int occurences = 0; + for (ILoggingEvent log : logs) { + if (pattern.matcher(log.getFormattedMessage()).matches()) { + occurences++; + if (occurences >= times) { + return; + } + } + } + fail( + "Expected to find pattern at least " + + times + + " times but found only " + + occurences + + " times. Pattern: [" + + pattern.toString() + + "]"); + } + + private void assertMatchesAtMost(Pattern pattern, Integer times, List logs) { + int occurences = 0; + for (ILoggingEvent log : logs) { + if (pattern.matcher(log.getFormattedMessage()).matches()) { + occurences++; + if (occurences > times) { + fail( + "Expected to find pattern at most " + + times + + " times but found it " + + occurences + + " times. Pattern: [" + + pattern.toString() + + "]"); + } + } + } + } +}