29
29
import com .datastax .oss .driver .api .core .config .DefaultDriverOption ;
30
30
import com .datastax .oss .driver .api .core .config .DriverConfig ;
31
31
import com .datastax .oss .driver .api .core .config .DriverExecutionProfile ;
32
+ import com .datastax .oss .driver .api .core .context .DriverContext ;
32
33
import com .datastax .oss .driver .api .core .metadata .EndPoint ;
33
34
import com .datastax .oss .driver .api .core .metadata .Node ;
35
+ import com .datastax .oss .driver .api .core .metadata .NodeShardingInfo ;
34
36
import com .datastax .oss .driver .api .core .metrics .DefaultNodeMetric ;
35
37
import com .datastax .oss .driver .api .core .metrics .DefaultSessionMetric ;
36
38
import com .datastax .oss .driver .internal .core .config .typesafe .TypesafeDriverConfig ;
@@ -157,12 +159,27 @@ public CompletionStage<DriverChannel> connect(Node node, DriverChannelOptions op
157
159
} else {
158
160
nodeMetricUpdater = NoopNodeMetricUpdater .INSTANCE ;
159
161
}
160
- return connect (node .getEndPoint (), options , nodeMetricUpdater );
162
+ return connect (node .getEndPoint (), null , null , options , nodeMetricUpdater );
163
+ }
164
+
165
+ public CompletionStage <DriverChannel > connect (
166
+ Node node , Integer shardId , DriverChannelOptions options ) {
167
+ NodeMetricUpdater nodeMetricUpdater ;
168
+ if (node instanceof DefaultNode ) {
169
+ nodeMetricUpdater = ((DefaultNode ) node ).getMetricUpdater ();
170
+ } else {
171
+ nodeMetricUpdater = NoopNodeMetricUpdater .INSTANCE ;
172
+ }
173
+ return connect (node .getEndPoint (), node .getShardingInfo (), shardId , options , nodeMetricUpdater );
161
174
}
162
175
163
176
@ VisibleForTesting
164
177
CompletionStage <DriverChannel > connect (
165
- EndPoint endPoint , DriverChannelOptions options , NodeMetricUpdater nodeMetricUpdater ) {
178
+ EndPoint endPoint ,
179
+ NodeShardingInfo shardingInfo ,
180
+ Integer shardId ,
181
+ DriverChannelOptions options ,
182
+ NodeMetricUpdater nodeMetricUpdater ) {
166
183
CompletableFuture <DriverChannel > resultFuture = new CompletableFuture <>();
167
184
168
185
ProtocolVersion currentVersion ;
@@ -178,6 +195,8 @@ CompletionStage<DriverChannel> connect(
178
195
179
196
connect (
180
197
endPoint ,
198
+ shardingInfo ,
199
+ shardId ,
181
200
options ,
182
201
nodeMetricUpdater ,
183
202
currentVersion ,
@@ -189,6 +208,8 @@ CompletionStage<DriverChannel> connect(
189
208
190
209
private void connect (
191
210
EndPoint endPoint ,
211
+ NodeShardingInfo shardingInfo ,
212
+ Integer shardId ,
192
213
DriverChannelOptions options ,
193
214
NodeMetricUpdater nodeMetricUpdater ,
194
215
ProtocolVersion currentVersion ,
@@ -208,7 +229,20 @@ private void connect(
208
229
209
230
nettyOptions .afterBootstrapInitialized (bootstrap );
210
231
211
- ChannelFuture connectFuture = bootstrap .connect (endPoint .resolve ());
232
+ ChannelFuture connectFuture ;
233
+ if (shardId == null || shardingInfo == null ) {
234
+ if (shardId != null ) {
235
+ LOG .debug (
236
+ "Requested connection to shard {} but shardingInfo is currently missing for Node at endpoint {}. Falling back to arbitrary local port." ,
237
+ shardId ,
238
+ endPoint );
239
+ }
240
+ connectFuture = bootstrap .connect (endPoint .resolve ());
241
+ } else {
242
+ int localPort =
243
+ PortAllocator .getNextAvailablePort (shardingInfo .getShardsCount (), shardId , context );
244
+ connectFuture = bootstrap .connect (endPoint .resolve (), new InetSocketAddress (localPort ));
245
+ }
212
246
213
247
connectFuture .addListener (
214
248
cf -> {
@@ -257,6 +291,8 @@ private void connect(
257
291
downgraded .get ());
258
292
connect (
259
293
endPoint ,
294
+ shardingInfo ,
295
+ shardId ,
260
296
options ,
261
297
nodeMetricUpdater ,
262
298
downgraded .get (),
@@ -399,7 +435,17 @@ protected void initChannel(Channel channel) {
399
435
static class PortAllocator {
400
436
private static final AtomicInteger lastPort = new AtomicInteger (-1 );
401
437
402
- public static int getNextAvailablePort (int shardCount , int shardId , int lowPort , int highPort ) {
438
+ public static int getNextAvailablePort (int shardCount , int shardId , DriverContext context ) {
439
+ int lowPort =
440
+ context
441
+ .getConfig ()
442
+ .getDefaultProfile ()
443
+ .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_LOW );
444
+ int highPort =
445
+ context
446
+ .getConfig ()
447
+ .getDefaultProfile ()
448
+ .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_HIGH );
403
449
int lastPortValue , foundPort = -1 ;
404
450
do {
405
451
lastPortValue = lastPort .get ();
0 commit comments