Skip to content

Commit 75a2d85

Browse files
authored
4.x: Add advanced shard awareness (#517)
* Add config options for advanced shard awareness Adds config options for toggling the feature on and defining port ranges to use. By default the feature will be enabled. * Copy 3.x PortAllocator * Extend ChannelFactory methods with shardId params With advanced shard awareness target shard is a necessary parameter for connecting. Extends non public methods with new parameters and provides one additional public method. Passing `null` as shard id or ShardingInfo will lead to previous non-shard aware behavior. * Enable advanced shard awareness Makes `addMissingChannels` use advanced shard awareness. It will now specify target shard when adding missing channels for specific shards. In case returned channels do not match requested shards warnings are logged. Initial connection to the node works on previous rules, meaning it uses arbitrary local port for connection to arbitrary shard. Adds AdvancedShardAwarenessIT that has several methods displaying the difference between establishing connections with option enabled and disabled. Adds ChannelPoolShardAwareInitTest
1 parent 67edb36 commit 75a2d85

File tree

13 files changed

+718
-19
lines changed

13 files changed

+718
-19
lines changed

core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java

+11
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ public enum DefaultDriverOption implements DriverOption {
141141
* <p>Value-type: boolean
142142
*/
143143
CONNECTION_WARN_INIT_ERROR("advanced.connection.warn-on-init-error"),
144+
/**
145+
* Whether to use advanced shard awareness.
146+
*
147+
* <p>Value-type: boolean
148+
*/
149+
CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED(
150+
"advanced.connection.advanced-shard-awareness.enabled"),
151+
/** Inclusive lower bound of port range to use in advanced shard awareness */
152+
ADVANCED_SHARD_AWARENESS_PORT_LOW("advanced.connection.advanced-shard-awareness.port-low"),
153+
/** Inclusive upper bound of port range to use in advanced shard awareness */
154+
ADVANCED_SHARD_AWARENESS_PORT_HIGH("advanced.connection.advanced-shard-awareness.port-high"),
144155
/**
145156
* The number of connections in the LOCAL pool.
146157
*

core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java

+3
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ protected static void fillWithDriverDefaults(OptionsMap map) {
276276
map.put(TypedDriverOption.CONNECTION_MAX_REQUESTS, 1024);
277277
map.put(TypedDriverOption.CONNECTION_MAX_ORPHAN_REQUESTS, 256);
278278
map.put(TypedDriverOption.CONNECTION_WARN_INIT_ERROR, true);
279+
map.put(TypedDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true);
280+
map.put(TypedDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000);
281+
map.put(TypedDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 65535);
279282
map.put(TypedDriverOption.RECONNECT_ON_INIT, false);
280283
map.put(TypedDriverOption.RECONNECTION_POLICY_CLASS, "ExponentialReconnectionPolicy");
281284
map.put(TypedDriverOption.RECONNECTION_BASE_DELAY, Duration.ofSeconds(1));

core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java

+12
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,18 @@ public String toString() {
175175
/** Whether to log non-fatal errors when the driver tries to open a new connection. */
176176
public static final TypedDriverOption<Boolean> CONNECTION_WARN_INIT_ERROR =
177177
new TypedDriverOption<>(DefaultDriverOption.CONNECTION_WARN_INIT_ERROR, GenericType.BOOLEAN);
178+
/** Whether to use advanced shard awareness */
179+
public static final TypedDriverOption<Boolean> CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED =
180+
new TypedDriverOption<>(
181+
DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, GenericType.BOOLEAN);
182+
/** Inclusive lower bound of port range to use in advanced shard awareness */
183+
public static final TypedDriverOption<Integer> ADVANCED_SHARD_AWARENESS_PORT_LOW =
184+
new TypedDriverOption<>(
185+
DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, GenericType.INTEGER);
186+
/** Inclusive upper bound of port range to use in advanced shard awareness */
187+
public static final TypedDriverOption<Integer> ADVANCED_SHARD_AWARENESS_PORT_HIGH =
188+
new TypedDriverOption<>(
189+
DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, GenericType.INTEGER);
178190
/** The number of connections in the LOCAL pool. */
179191
public static final TypedDriverOption<Integer> CONNECTION_POOL_LOCAL_SIZE =
180192
new TypedDriverOption<>(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, GenericType.INTEGER);

core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java

+139-3
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
3030
import com.datastax.oss.driver.api.core.config.DriverConfig;
3131
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
32+
import com.datastax.oss.driver.api.core.context.DriverContext;
3233
import com.datastax.oss.driver.api.core.metadata.EndPoint;
3334
import com.datastax.oss.driver.api.core.metadata.Node;
35+
import com.datastax.oss.driver.api.core.metadata.NodeShardingInfo;
3436
import com.datastax.oss.driver.api.core.metrics.DefaultNodeMetric;
3537
import com.datastax.oss.driver.api.core.metrics.DefaultSessionMetric;
3638
import com.datastax.oss.driver.internal.core.config.typesafe.TypesafeDriverConfig;
@@ -51,13 +53,17 @@
5153
import io.netty.channel.ChannelInitializer;
5254
import io.netty.channel.ChannelOption;
5355
import io.netty.channel.ChannelPipeline;
56+
import java.io.IOException;
57+
import java.net.InetSocketAddress;
58+
import java.net.ServerSocket;
5459
import java.util.List;
5560
import java.util.Map;
5661
import java.util.Optional;
5762
import java.util.concurrent.CompletableFuture;
5863
import java.util.concurrent.CompletionStage;
5964
import java.util.concurrent.CopyOnWriteArrayList;
6065
import java.util.concurrent.atomic.AtomicBoolean;
66+
import java.util.concurrent.atomic.AtomicInteger;
6167
import net.jcip.annotations.ThreadSafe;
6268
import org.slf4j.Logger;
6369
import org.slf4j.LoggerFactory;
@@ -153,12 +159,27 @@ public CompletionStage<DriverChannel> connect(Node node, DriverChannelOptions op
153159
} else {
154160
nodeMetricUpdater = NoopNodeMetricUpdater.INSTANCE;
155161
}
156-
return connect(node.getEndPoint(), options, nodeMetricUpdater);
162+
return connect(node.getEndPoint(), null, null, options, nodeMetricUpdater);
163+
}
164+
165+
public CompletionStage<DriverChannel> connect(
166+
Node node, Integer shardId, DriverChannelOptions options) {
167+
NodeMetricUpdater nodeMetricUpdater;
168+
if (node instanceof DefaultNode) {
169+
nodeMetricUpdater = ((DefaultNode) node).getMetricUpdater();
170+
} else {
171+
nodeMetricUpdater = NoopNodeMetricUpdater.INSTANCE;
172+
}
173+
return connect(node.getEndPoint(), node.getShardingInfo(), shardId, options, nodeMetricUpdater);
157174
}
158175

159176
@VisibleForTesting
160177
CompletionStage<DriverChannel> connect(
161-
EndPoint endPoint, DriverChannelOptions options, NodeMetricUpdater nodeMetricUpdater) {
178+
EndPoint endPoint,
179+
NodeShardingInfo shardingInfo,
180+
Integer shardId,
181+
DriverChannelOptions options,
182+
NodeMetricUpdater nodeMetricUpdater) {
162183
CompletableFuture<DriverChannel> resultFuture = new CompletableFuture<>();
163184

164185
ProtocolVersion currentVersion;
@@ -174,6 +195,8 @@ CompletionStage<DriverChannel> connect(
174195

175196
connect(
176197
endPoint,
198+
shardingInfo,
199+
shardId,
177200
options,
178201
nodeMetricUpdater,
179202
currentVersion,
@@ -185,6 +208,8 @@ CompletionStage<DriverChannel> connect(
185208

186209
private void connect(
187210
EndPoint endPoint,
211+
NodeShardingInfo shardingInfo,
212+
Integer shardId,
188213
DriverChannelOptions options,
189214
NodeMetricUpdater nodeMetricUpdater,
190215
ProtocolVersion currentVersion,
@@ -204,7 +229,28 @@ private void connect(
204229

205230
nettyOptions.afterBootstrapInitialized(bootstrap);
206231

207-
ChannelFuture connectFuture = bootstrap.connect(endPoint.resolve());
232+
ChannelFuture connectFuture;
233+
if (shardId == null || shardingInfo == null) {
234+
if (shardId != null) {
235+
LOG.debug(
236+
"Requested connection to shard {} but shardingInfo is currently missing for Node at endpoint {}. Falling back to arbitrary local port.",
237+
shardId,
238+
endPoint);
239+
}
240+
connectFuture = bootstrap.connect(endPoint.resolve());
241+
} else {
242+
int localPort =
243+
PortAllocator.getNextAvailablePort(shardingInfo.getShardsCount(), shardId, context);
244+
if (localPort == -1) {
245+
LOG.warn(
246+
"Could not find free port for shard {} at {}. Falling back to arbitrary local port.",
247+
shardId,
248+
endPoint);
249+
connectFuture = bootstrap.connect(endPoint.resolve());
250+
} else {
251+
connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort));
252+
}
253+
}
208254

209255
connectFuture.addListener(
210256
cf -> {
@@ -253,6 +299,8 @@ private void connect(
253299
downgraded.get());
254300
connect(
255301
endPoint,
302+
shardingInfo,
303+
shardId,
256304
options,
257305
nodeMetricUpdater,
258306
downgraded.get(),
@@ -391,4 +439,92 @@ protected void initChannel(Channel channel) {
391439
}
392440
}
393441
}
442+
443+
static class PortAllocator {
444+
private static final AtomicInteger lastPort = new AtomicInteger(-1);
445+
private static final Logger LOG = LoggerFactory.getLogger(PortAllocator.class);
446+
447+
public static int getNextAvailablePort(int shardCount, int shardId, DriverContext context) {
448+
int lowPort =
449+
context
450+
.getConfig()
451+
.getDefaultProfile()
452+
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW);
453+
int highPort =
454+
context
455+
.getConfig()
456+
.getDefaultProfile()
457+
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
458+
if (highPort - lowPort < shardCount) {
459+
LOG.error(
460+
"There is not enough ports in range [{},{}] for {} shards. Update your configuration.",
461+
lowPort,
462+
highPort,
463+
shardCount);
464+
}
465+
int lastPortValue, foundPort = -1;
466+
do {
467+
lastPortValue = lastPort.get();
468+
469+
// We will scan from lastPortValue
470+
// (or lowPort is there was no lastPort or lastPort is too low)
471+
int scanStart = lastPortValue == -1 ? lowPort : lastPortValue;
472+
if (scanStart < lowPort) {
473+
scanStart = lowPort;
474+
}
475+
476+
// Round it up to "% shardCount == shardId"
477+
scanStart += (shardCount - scanStart % shardCount) + shardId;
478+
479+
// Scan from scanStart upwards to highPort.
480+
for (int port = scanStart; port <= highPort; port += shardCount) {
481+
if (isTcpPortAvailable(port, context)) {
482+
foundPort = port;
483+
break;
484+
}
485+
}
486+
487+
// If we started scanning from a high scanStart port
488+
// there might have been not enough ports left that are
489+
// smaller than highPort. Scan from the beginning
490+
// from the lowPort.
491+
if (foundPort == -1) {
492+
scanStart = lowPort + (shardCount - lowPort % shardCount) + shardId;
493+
494+
for (int port = scanStart; port <= highPort; port += shardCount) {
495+
if (isTcpPortAvailable(port, context)) {
496+
foundPort = port;
497+
break;
498+
}
499+
}
500+
}
501+
502+
// No luck! All ports taken!
503+
if (foundPort == -1) {
504+
return -1;
505+
}
506+
} while (!lastPort.compareAndSet(lastPortValue, foundPort));
507+
508+
return foundPort;
509+
}
510+
511+
public static boolean isTcpPortAvailable(int port, DriverContext context) {
512+
try {
513+
ServerSocket serverSocket = new ServerSocket();
514+
try {
515+
serverSocket.setReuseAddress(
516+
context
517+
.getConfig()
518+
.getDefaultProfile()
519+
.getBoolean(DefaultDriverOption.SOCKET_REUSE_ADDRESS, false));
520+
serverSocket.bind(new InetSocketAddress(port), 1);
521+
return true;
522+
} finally {
523+
serverSocket.close();
524+
}
525+
} catch (IOException ex) {
526+
return false;
527+
}
528+
}
529+
}
394530
}

core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java

+35-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import io.netty.util.concurrent.EventExecutor;
5959
import io.netty.util.concurrent.Future;
6060
import io.netty.util.concurrent.GenericFutureListener;
61+
import java.net.InetSocketAddress;
6162
import java.util.ArrayList;
6263
import java.util.Arrays;
6364
import java.util.HashSet;
@@ -489,9 +490,23 @@ private CompletionStage<Boolean> addMissingChannels() {
489490
channels.length * wantedCount - Arrays.stream(channels).mapToInt(ChannelSet::size).sum();
490491
LOG.debug("[{}] Trying to create {} missing channels", logPrefix, missing);
491492
DriverChannelOptions options = buildDriverOptions();
492-
for (int i = 0; i < missing; i++) {
493-
CompletionStage<DriverChannel> channelFuture = channelFactory.connect(node, options);
494-
pendingChannels.add(channelFuture);
493+
for (int shard = 0; shard < channels.length; shard++) {
494+
LOG.trace(
495+
"[{}] Missing {} channels for shard {}",
496+
logPrefix,
497+
wantedCount - channels[shard].size(),
498+
shard);
499+
for (int p = channels[shard].size(); p < wantedCount; p++) {
500+
CompletionStage<DriverChannel> channelFuture;
501+
if (config
502+
.getDefaultProfile()
503+
.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)) {
504+
channelFuture = channelFactory.connect(node, shard, options);
505+
} else {
506+
channelFuture = channelFactory.connect(node, options);
507+
}
508+
pendingChannels.add(channelFuture);
509+
}
495510
}
496511
return CompletableFutures.allDone(pendingChannels)
497512
.thenApplyAsync(this::onAllConnected, adminExecutor);
@@ -551,6 +566,23 @@ private boolean onAllConnected(@SuppressWarnings("unused") Void v) {
551566
channel);
552567
channel.forceClose();
553568
} else {
569+
if (config
570+
.getDefaultProfile()
571+
.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)
572+
&& channel.localAddress() instanceof InetSocketAddress
573+
&& channel.getShardingInfo() != null) {
574+
int port = ((InetSocketAddress) channel.localAddress()).getPort();
575+
int actualShard = channel.getShardId();
576+
int targetShard = port % channel.getShardingInfo().getShardsCount();
577+
if (actualShard != targetShard) {
578+
LOG.warn(
579+
"[{}] New channel {} connected to shard {}, but shard {} was requested. If this is not transient check your driver AND cluster configuration of shard aware port.",
580+
logPrefix,
581+
channel,
582+
actualShard,
583+
targetShard);
584+
}
585+
}
554586
LOG.debug("[{}] New channel added {}", logPrefix, channel);
555587
if (channels[channel.getShardId()].size() < wantedCount) {
556588
addChannel(channel);

core/src/main/resources/reference.conf

+46
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,52 @@ datastax-java-driver {
535535
# change.
536536
# Overridable in a profile: no
537537
warn-on-init-error = true
538+
539+
540+
advanced-shard-awareness {
541+
# Whether to use advanced shard awareness when trying to open new connections.
542+
#
543+
# Requires passing shard-aware port as contact point (usually 19042 or 19142(ssl)).
544+
# Having this enabled makes sense only for ScyllaDB clusters.
545+
#
546+
# Reduces number of reconnections driver needs to fully initialize connection pool.
547+
# In short it's a feature that allows targeting particular shard when connecting to a node by using specific
548+
# local port number.
549+
# For context see https://www.scylladb.com/2021/04/27/connect-faster-to-scylla-with-a-shard-aware-port/
550+
#
551+
# If set to false the driver will not attempt to use this feature. This means connection's local port
552+
# will be random according to system rules and driver will keep opening connections until it gets right shards.
553+
# In such case non-shard aware port is recommended (by default 9042 or 9142).
554+
# If set to true the driver will attempt to use it and will log warnings each time something
555+
# makes it not possible.
556+
#
557+
# If the node for some reason does not report it's sharding info the driver
558+
# will log a warning and create connection the same way as if this feature was disabled.
559+
# If the cluster ignores the request for specific shard warning will also be logged,
560+
# although the local port will already be chosen according to advanced shard awareness rules.
561+
#
562+
# Required: yes
563+
# Modifiable at runtime: yes, the new value will be used for connections created after the
564+
# change.
565+
# Overridable in a profile: no
566+
enabled = true
567+
568+
# Inclusive lower bound of port range to use in advanced shard awareness
569+
# The driver will attempt to reserve ports for connection only within the range.
570+
# Required: yes
571+
# Modifiable at runtime: yes, the new value will be used for calls after the
572+
# change.
573+
# Overridable in a profile: no
574+
port-low = 10000
575+
576+
# Inclusive upper bound of port range to use in advanced shard awareness.
577+
# The driver will attempt to reserve ports for connection only within the range.
578+
# Required: yes
579+
# Modifiable at runtime: yes, the new value will be used for calls after the
580+
# change.
581+
# Overridable in a profile: no
582+
port-high = 65535
583+
}
538584
}
539585

540586
# Advanced options for the built-in load-balancing policies.

core/src/test/java/com/datastax/oss/driver/internal/core/channel/ChannelFactoryAvailableIdsTest.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ public void should_report_available_ids() {
6161
// When
6262
CompletionStage<DriverChannel> channelFuture =
6363
factory.connect(
64-
SERVER_ADDRESS, DriverChannelOptions.builder().build(), NoopNodeMetricUpdater.INSTANCE);
64+
SERVER_ADDRESS,
65+
null,
66+
null,
67+
DriverChannelOptions.builder().build(),
68+
NoopNodeMetricUpdater.INSTANCE);
6569
completeSimpleChannelInit();
6670

6771
// Then

0 commit comments

Comments
 (0)