Skip to content

Commit 33a9063

Browse files
committed
Enable advanced shard awareness
Makes `addMissingChannels` use advanced shard awareness. It will now specify target shard when adding missing channels for specific shards. In case returned channels do not match requested shards warnings are logged. Initial connection to the node works on previous rules, meaning it uses arbitrary local port for connection to arbitrary shard. Adds AdvancedShardAwarenessIT that has several methods displaying the difference between establishing connections with option enabled and disabled. Adds ChannelPoolShardAwareInitTest
1 parent 1342eac commit 33a9063

File tree

5 files changed

+447
-4
lines changed

5 files changed

+447
-4
lines changed

core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java

+17-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,15 @@ private void connect(
241241
} else {
242242
int localPort =
243243
PortAllocator.getNextAvailablePort(shardingInfo.getShardsCount(), shardId, context);
244-
connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort));
244+
if (localPort == -1) {
245+
LOG.warn(
246+
"Could not find free port for shard {} at {}. Falling back to arbitrary local port.",
247+
shardId,
248+
endPoint);
249+
connectFuture = bootstrap.connect(endPoint.resolve());
250+
} else {
251+
connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort));
252+
}
245253
}
246254

247255
connectFuture.addListener(
@@ -434,6 +442,7 @@ protected void initChannel(Channel channel) {
434442

435443
static class PortAllocator {
436444
private static final AtomicInteger lastPort = new AtomicInteger(-1);
445+
private static final Logger LOG = LoggerFactory.getLogger(PortAllocator.class);
437446

438447
public static int getNextAvailablePort(int shardCount, int shardId, DriverContext context) {
439448
int lowPort =
@@ -446,6 +455,13 @@ public static int getNextAvailablePort(int shardCount, int shardId, DriverContex
446455
.getConfig()
447456
.getDefaultProfile()
448457
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
458+
if (highPort - lowPort < shardCount) {
459+
LOG.error(
460+
"There is not enough ports in range [{},{}] for {} shards. Update your configuration.",
461+
lowPort,
462+
highPort,
463+
shardCount);
464+
}
449465
int lastPortValue, foundPort = -1;
450466
do {
451467
lastPortValue = lastPort.get();

core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java

+35-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import io.netty.util.concurrent.EventExecutor;
5959
import io.netty.util.concurrent.Future;
6060
import io.netty.util.concurrent.GenericFutureListener;
61+
import java.net.InetSocketAddress;
6162
import java.util.ArrayList;
6263
import java.util.Arrays;
6364
import java.util.HashSet;
@@ -489,9 +490,23 @@ private CompletionStage<Boolean> addMissingChannels() {
489490
channels.length * wantedCount - Arrays.stream(channels).mapToInt(ChannelSet::size).sum();
490491
LOG.debug("[{}] Trying to create {} missing channels", logPrefix, missing);
491492
DriverChannelOptions options = buildDriverOptions();
492-
for (int i = 0; i < missing; i++) {
493-
CompletionStage<DriverChannel> channelFuture = channelFactory.connect(node, options);
494-
pendingChannels.add(channelFuture);
493+
for (int shard = 0; shard < channels.length; shard++) {
494+
LOG.trace(
495+
"[{}] Missing {} channels for shard {}",
496+
logPrefix,
497+
wantedCount - channels[shard].size(),
498+
shard);
499+
for (int p = channels[shard].size(); p < wantedCount; p++) {
500+
CompletionStage<DriverChannel> channelFuture;
501+
if (config
502+
.getDefaultProfile()
503+
.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)) {
504+
channelFuture = channelFactory.connect(node, shard, options);
505+
} else {
506+
channelFuture = channelFactory.connect(node, options);
507+
}
508+
pendingChannels.add(channelFuture);
509+
}
495510
}
496511
return CompletableFutures.allDone(pendingChannels)
497512
.thenApplyAsync(this::onAllConnected, adminExecutor);
@@ -551,6 +566,23 @@ private boolean onAllConnected(@SuppressWarnings("unused") Void v) {
551566
channel);
552567
channel.forceClose();
553568
} else {
569+
if (config
570+
.getDefaultProfile()
571+
.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)
572+
&& channel.localAddress() instanceof InetSocketAddress
573+
&& channel.getShardingInfo() != null) {
574+
int port = ((InetSocketAddress) channel.localAddress()).getPort();
575+
int actualShard = channel.getShardId();
576+
int targetShard = port % channel.getShardingInfo().getShardsCount();
577+
if (actualShard != targetShard) {
578+
LOG.warn(
579+
"[{}] New channel {} connected to shard {}, but shard {} was requested. If this is not transient check your driver AND cluster configuration of shard aware port.",
580+
logPrefix,
581+
channel,
582+
actualShard,
583+
targetShard);
584+
}
585+
}
554586
LOG.debug("[{}] New channel added {}", logPrefix, channel);
555587
if (channels[channel.getShardId()].size() < wantedCount) {
556588
addChannel(channel);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package com.datastax.oss.driver.internal.core.pool;
2+
3+
import static com.datastax.oss.driver.Assertions.assertThatStage;
4+
import static org.assertj.core.api.Assertions.assertThat;
5+
import static org.mockito.ArgumentMatchers.any;
6+
import static org.mockito.ArgumentMatchers.anyInt;
7+
import static org.mockito.ArgumentMatchers.eq;
8+
import static org.mockito.Mockito.mock;
9+
import static org.mockito.Mockito.timeout;
10+
import static org.mockito.Mockito.verify;
11+
import static org.mockito.Mockito.when;
12+
13+
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
14+
import com.datastax.oss.driver.api.core.loadbalancing.NodeDistance;
15+
import com.datastax.oss.driver.api.core.metadata.Node;
16+
import com.datastax.oss.driver.internal.core.channel.ChannelEvent;
17+
import com.datastax.oss.driver.internal.core.channel.DriverChannel;
18+
import com.datastax.oss.driver.internal.core.channel.DriverChannelOptions;
19+
import com.datastax.oss.driver.internal.core.protocol.ShardingInfo;
20+
import java.util.concurrent.CompletableFuture;
21+
import java.util.concurrent.CompletionStage;
22+
import org.junit.Before;
23+
import org.junit.Test;
24+
import org.mockito.ArgumentCaptor;
25+
import org.mockito.InOrder;
26+
import org.mockito.Mockito;
27+
28+
public class ChannelPoolShardAwareInitTest extends ChannelPoolTestBase {
29+
30+
@Before
31+
@Override
32+
public void setup() {
33+
super.setup();
34+
when(defaultProfile.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED))
35+
.thenReturn(true);
36+
when(defaultProfile.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW))
37+
.thenReturn(10000);
38+
when(defaultProfile.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH))
39+
.thenReturn(60000);
40+
}
41+
42+
@Test
43+
public void should_initialize_when_all_channels_succeed() throws Exception {
44+
when(defaultProfile.getInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE)).thenReturn(4);
45+
int shardsPerNode = 2;
46+
DriverChannel channel1 = newMockDriverChannel(1);
47+
DriverChannel channel2 = newMockDriverChannel(2);
48+
DriverChannel channel3 = newMockDriverChannel(3);
49+
DriverChannel channel4 = newMockDriverChannel(4);
50+
51+
ShardingInfo shardingInfo = mock(ShardingInfo.class);
52+
when(shardingInfo.getShardsCount()).thenReturn(shardsPerNode);
53+
node.setShardingInfo(shardingInfo);
54+
55+
when(channel1.getShardingInfo()).thenReturn(shardingInfo);
56+
when(channel2.getShardingInfo()).thenReturn(shardingInfo);
57+
when(channel3.getShardingInfo()).thenReturn(shardingInfo);
58+
when(channel4.getShardingInfo()).thenReturn(shardingInfo);
59+
60+
when(channel1.getShardId()).thenReturn(0);
61+
when(channel2.getShardId()).thenReturn(0);
62+
when(channel3.getShardId()).thenReturn(1);
63+
when(channel4.getShardId()).thenReturn(1);
64+
65+
when(channelFactory.connect(eq(node), any(DriverChannelOptions.class)))
66+
.thenReturn(CompletableFuture.completedFuture(channel1));
67+
when(channelFactory.connect(eq(node), eq(0), any(DriverChannelOptions.class)))
68+
.thenReturn(CompletableFuture.completedFuture(channel2));
69+
when(channelFactory.connect(eq(node), eq(1), any(DriverChannelOptions.class)))
70+
.thenReturn(CompletableFuture.completedFuture(channel3))
71+
.thenReturn(CompletableFuture.completedFuture(channel4));
72+
73+
CompletionStage<ChannelPool> poolFuture =
74+
ChannelPool.init(node, null, NodeDistance.LOCAL, context, "test");
75+
76+
ArgumentCaptor<DriverChannelOptions> optionsCaptor =
77+
ArgumentCaptor.forClass(DriverChannelOptions.class);
78+
InOrder inOrder = Mockito.inOrder(channelFactory);
79+
inOrder
80+
.verify(channelFactory, timeout(500).atLeast(1))
81+
.connect(eq(node), optionsCaptor.capture());
82+
int num = optionsCaptor.getAllValues().size();
83+
assertThat(num).isEqualTo(1);
84+
inOrder
85+
.verify(channelFactory, timeout(500).atLeast(3))
86+
.connect(eq(node), anyInt(), optionsCaptor.capture());
87+
int num2 = optionsCaptor.getAllValues().size();
88+
assertThat(num2).isEqualTo(4);
89+
90+
assertThatStage(poolFuture)
91+
.isSuccess(
92+
pool -> {
93+
assertThat(pool.channels[0]).containsOnly(channel1, channel2);
94+
assertThat(pool.channels[1]).containsOnly(channel3, channel4);
95+
});
96+
verify(eventBus, VERIFY_TIMEOUT.times(4)).fire(ChannelEvent.channelOpened(node));
97+
98+
inOrder
99+
.verify(channelFactory, timeout(500).times(0))
100+
.connect(any(Node.class), any(DriverChannelOptions.class));
101+
inOrder
102+
.verify(channelFactory, timeout(500).times(0))
103+
.connect(any(Node.class), anyInt(), any(DriverChannelOptions.class));
104+
}
105+
}

core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.mockito.Mockito.when;
2525

2626
import com.datastax.oss.driver.api.core.CqlIdentifier;
27+
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
2728
import com.datastax.oss.driver.api.core.config.DriverConfig;
2829
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
2930
import com.datastax.oss.driver.api.core.connection.ReconnectionPolicy;
@@ -77,6 +78,8 @@ public void setup() {
7778
when(nettyOptions.adminEventExecutorGroup()).thenReturn(adminEventLoopGroup);
7879
when(context.getConfig()).thenReturn(config);
7980
when(config.getDefaultProfile()).thenReturn(defaultProfile);
81+
when(defaultProfile.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED))
82+
.thenReturn(false);
8083
this.eventBus = spy(new EventBus("test"));
8184
when(context.getEventBus()).thenReturn(eventBus);
8285
when(context.getChannelFactory()).thenReturn(channelFactory);

0 commit comments

Comments
 (0)