29
29
import com .datastax .oss .driver .api .core .config .DefaultDriverOption ;
30
30
import com .datastax .oss .driver .api .core .config .DriverConfig ;
31
31
import com .datastax .oss .driver .api .core .config .DriverExecutionProfile ;
32
- import com .datastax .oss .driver .api .core .context .DriverContext ;
33
32
import com .datastax .oss .driver .api .core .metadata .EndPoint ;
34
33
import com .datastax .oss .driver .api .core .metadata .Node ;
35
34
import com .datastax .oss .driver .api .core .metadata .NodeShardingInfo ;
53
52
import io .netty .channel .ChannelInitializer ;
54
53
import io .netty .channel .ChannelOption ;
55
54
import io .netty .channel .ChannelPipeline ;
56
- import java .io .IOException ;
57
55
import java .net .InetSocketAddress ;
58
- import java .net . ServerSocket ;
56
+ import java .util . Iterator ;
59
57
import java .util .List ;
60
58
import java .util .Map ;
61
59
import java .util .Optional ;
62
60
import java .util .concurrent .CompletableFuture ;
63
61
import java .util .concurrent .CompletionStage ;
64
62
import java .util .concurrent .CopyOnWriteArrayList ;
65
63
import java .util .concurrent .atomic .AtomicBoolean ;
66
- import java .util .concurrent .atomic .AtomicInteger ;
67
64
import net .jcip .annotations .ThreadSafe ;
68
65
import org .slf4j .Logger ;
69
66
import org .slf4j .LoggerFactory ;
@@ -193,28 +190,43 @@ CompletionStage<DriverChannel> connect(
193
190
isNegotiating = true ;
194
191
}
195
192
193
+ PortIterator portIterator = null ;
194
+
195
+ if (shardId != null && shardingInfo != null ) {
196
+ int lowestPort =
197
+ context
198
+ .getConfig ()
199
+ .getDefaultProfile ()
200
+ .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_LOW );
201
+ int highestPort =
202
+ context
203
+ .getConfig ()
204
+ .getDefaultProfile ()
205
+ .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_HIGH );
206
+ portIterator =
207
+ new PortIterator (lowestPort , highestPort , shardingInfo .getShardsCount (), shardId );
208
+ }
209
+
196
210
connect (
197
211
endPoint ,
198
- shardingInfo ,
199
- shardId ,
200
212
options ,
201
213
nodeMetricUpdater ,
202
214
currentVersion ,
203
215
isNegotiating ,
204
216
attemptedVersions ,
217
+ portIterator ,
205
218
resultFuture );
206
219
return resultFuture ;
207
220
}
208
221
209
222
private void connect (
210
223
EndPoint endPoint ,
211
- NodeShardingInfo shardingInfo ,
212
- Integer shardId ,
213
224
DriverChannelOptions options ,
214
225
NodeMetricUpdater nodeMetricUpdater ,
215
226
ProtocolVersion currentVersion ,
216
227
boolean isNegotiating ,
217
228
List <ProtocolVersion > attemptedVersions ,
229
+ PortIterator portIterator ,
218
230
CompletableFuture <DriverChannel > resultFuture ) {
219
231
220
232
NettyOptions nettyOptions = context .getNettyOptions ();
@@ -230,26 +242,14 @@ private void connect(
230
242
nettyOptions .afterBootstrapInitialized (bootstrap );
231
243
232
244
ChannelFuture connectFuture ;
233
- if (shardId == null || shardingInfo == null ) {
234
- if (shardId != null ) {
235
- LOG .debug (
236
- "Requested connection to shard {} but shardingInfo is currently missing for Node at endpoint {}. Falling back to arbitrary local port." ,
237
- shardId ,
238
- endPoint );
239
- }
245
+ if (portIterator == null ) {
240
246
connectFuture = bootstrap .connect (endPoint .resolve ());
241
247
} else {
242
- int localPort =
243
- PortAllocator .getNextAvailablePort (shardingInfo .getShardsCount (), shardId , context );
244
- if (localPort == -1 ) {
245
- LOG .warn (
246
- "Could not find free port for shard {} at {}. Falling back to arbitrary local port." ,
247
- shardId ,
248
- endPoint );
249
- connectFuture = bootstrap .connect (endPoint .resolve ());
250
- } else {
251
- connectFuture = bootstrap .connect (endPoint .resolve (), new InetSocketAddress (localPort ));
248
+ if (!portIterator .hasNext ()) {
249
+ portIterator .reset ();
252
250
}
251
+ connectFuture =
252
+ bootstrap .connect (endPoint .resolve (), new InetSocketAddress (portIterator .next ()));
253
253
}
254
254
255
255
connectFuture .addListener (
@@ -299,19 +299,28 @@ private void connect(
299
299
downgraded .get ());
300
300
connect (
301
301
endPoint ,
302
- shardingInfo ,
303
- shardId ,
304
302
options ,
305
303
nodeMetricUpdater ,
306
304
downgraded .get (),
307
305
true ,
308
306
attemptedVersions ,
307
+ portIterator ,
309
308
resultFuture );
310
309
} else {
311
310
resultFuture .completeExceptionally (
312
311
UnsupportedProtocolVersionException .forNegotiation (
313
312
endPoint , attemptedVersions ));
314
313
}
314
+ } else if (isBindException (error )) {
315
+ connect (
316
+ endPoint ,
317
+ options ,
318
+ nodeMetricUpdater ,
319
+ currentVersion ,
320
+ true ,
321
+ attemptedVersions ,
322
+ portIterator ,
323
+ resultFuture );
315
324
} else {
316
325
// Note: might be completed already if the failure happened in initializer(), this is
317
326
// fine
@@ -321,6 +330,16 @@ private void connect(
321
330
});
322
331
}
323
332
333
+ private static boolean isBindException (Throwable error ) {
334
+ while (error != null ) {
335
+ if (error instanceof java .net .BindException ) {
336
+ return true ;
337
+ }
338
+ error = error .getCause ();
339
+ }
340
+ return false ;
341
+ }
342
+
324
343
@ VisibleForTesting
325
344
ChannelInitializer <Channel > initializer (
326
345
EndPoint endPoint ,
@@ -330,7 +349,7 @@ ChannelInitializer<Channel> initializer(
330
349
CompletableFuture <DriverChannel > resultFuture ) {
331
350
return new ChannelFactoryInitializer (
332
351
endPoint , protocolVersion , options , nodeMetricUpdater , resultFuture );
333
- };
352
+ }
334
353
335
354
class ChannelFactoryInitializer extends ChannelInitializer <Channel > {
336
355
@@ -440,87 +459,35 @@ protected void initChannel(Channel channel) {
440
459
}
441
460
}
442
461
443
- static class PortAllocator {
444
- private static final AtomicInteger lastPort = new AtomicInteger (-1 );
445
- private static final Logger LOG = LoggerFactory .getLogger (PortAllocator .class );
462
+ static class PortIterator implements Iterator <Integer > {
463
+ private final int highestPort ;
464
+ private final int startPort ;
465
+ private int currentPort ;
466
+ private final int shardCount ;
467
+
468
+ PortIterator (int lowestPort , int highestPort , int shardCount , int shardId ) {
469
+ this .highestPort = highestPort ;
470
+ this .currentPort = lowestPort + (shardCount - lowestPort % shardCount ) + shardId ;
471
+ this .startPort = currentPort ;
472
+ this .shardCount = shardCount ;
473
+ }
446
474
447
- public static int getNextAvailablePort (int shardCount , int shardId , DriverContext context ) {
448
- int lowPort =
449
- context
450
- .getConfig ()
451
- .getDefaultProfile ()
452
- .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_LOW );
453
- int highPort =
454
- context
455
- .getConfig ()
456
- .getDefaultProfile ()
457
- .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_HIGH );
458
- if (highPort - lowPort < shardCount ) {
459
- LOG .error (
460
- "There is not enough ports in range [{},{}] for {} shards. Update your configuration." ,
461
- lowPort ,
462
- highPort ,
463
- shardCount );
475
+ @ Override
476
+ public Integer next () {
477
+ try {
478
+ return currentPort ;
479
+ } finally {
480
+ this .currentPort += shardCount ;
464
481
}
465
- int lastPortValue , foundPort = -1 ;
466
- do {
467
- lastPortValue = lastPort .get ();
468
-
469
- // We will scan from lastPortValue
470
- // (or lowPort is there was no lastPort or lastPort is too low)
471
- int scanStart = lastPortValue == -1 ? lowPort : lastPortValue ;
472
- if (scanStart < lowPort ) {
473
- scanStart = lowPort ;
474
- }
475
-
476
- // Round it up to "% shardCount == shardId"
477
- scanStart += (shardCount - scanStart % shardCount ) + shardId ;
478
-
479
- // Scan from scanStart upwards to highPort.
480
- for (int port = scanStart ; port <= highPort ; port += shardCount ) {
481
- if (isTcpPortAvailable (port )) {
482
- foundPort = port ;
483
- break ;
484
- }
485
- }
486
-
487
- // If we started scanning from a high scanStart port
488
- // there might have been not enough ports left that are
489
- // smaller than highPort. Scan from the beginning
490
- // from the lowPort.
491
- if (foundPort == -1 ) {
492
- scanStart = lowPort + (shardCount - lowPort % shardCount ) + shardId ;
493
-
494
- for (int port = scanStart ; port <= highPort ; port += shardCount ) {
495
- if (isTcpPortAvailable (port )) {
496
- foundPort = port ;
497
- break ;
498
- }
499
- }
500
- }
501
-
502
- // No luck! All ports taken!
503
- if (foundPort == -1 ) {
504
- return -1 ;
505
- }
506
- } while (!lastPort .compareAndSet (lastPortValue , foundPort ));
482
+ }
507
483
508
- return foundPort ;
484
+ @ Override
485
+ public boolean hasNext () {
486
+ return currentPort <= highestPort ;
509
487
}
510
488
511
- public static boolean isTcpPortAvailable (int port ) {
512
- try {
513
- ServerSocket serverSocket = new ServerSocket ();
514
- try {
515
- serverSocket .setReuseAddress (false );
516
- serverSocket .bind (new InetSocketAddress (port ), 1 );
517
- return true ;
518
- } finally {
519
- serverSocket .close ();
520
- }
521
- } catch (IOException ex ) {
522
- return false ;
523
- }
489
+ public void reset () {
490
+ this .currentPort = startPort ;
524
491
}
525
492
}
526
493
}
0 commit comments