2929import com .datastax .oss .driver .api .core .config .DefaultDriverOption ;
3030import com .datastax .oss .driver .api .core .config .DriverConfig ;
3131import com .datastax .oss .driver .api .core .config .DriverExecutionProfile ;
32- import com .datastax .oss .driver .api .core .context .DriverContext ;
3332import com .datastax .oss .driver .api .core .metadata .EndPoint ;
3433import com .datastax .oss .driver .api .core .metadata .Node ;
3534import com .datastax .oss .driver .api .core .metadata .NodeShardingInfo ;
5352import io .netty .channel .ChannelInitializer ;
5453import io .netty .channel .ChannelOption ;
5554import io .netty .channel .ChannelPipeline ;
56- import java .io .IOException ;
5755import java .net .InetSocketAddress ;
58- import java .net . ServerSocket ;
56+ import java .util . Iterator ;
5957import java .util .List ;
6058import java .util .Map ;
6159import java .util .Optional ;
6260import java .util .concurrent .CompletableFuture ;
6361import java .util .concurrent .CompletionStage ;
6462import java .util .concurrent .CopyOnWriteArrayList ;
6563import java .util .concurrent .atomic .AtomicBoolean ;
66- import java .util .concurrent .atomic .AtomicInteger ;
6764import net .jcip .annotations .ThreadSafe ;
6865import org .slf4j .Logger ;
6966import org .slf4j .LoggerFactory ;
@@ -193,28 +190,43 @@ CompletionStage<DriverChannel> connect(
193190 isNegotiating = true ;
194191 }
195192
193+ PortIterator portIterator = null ;
194+
195+ if (shardId != null && shardingInfo != null ) {
196+ int lowestPort =
197+ context
198+ .getConfig ()
199+ .getDefaultProfile ()
200+ .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_LOW );
201+ int highestPort =
202+ context
203+ .getConfig ()
204+ .getDefaultProfile ()
205+ .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_HIGH );
206+ portIterator =
207+ new PortIterator (lowestPort , highestPort , shardingInfo .getShardsCount (), shardId );
208+ }
209+
196210 connect (
197211 endPoint ,
198- shardingInfo ,
199- shardId ,
200212 options ,
201213 nodeMetricUpdater ,
202214 currentVersion ,
203215 isNegotiating ,
204216 attemptedVersions ,
217+ portIterator ,
205218 resultFuture );
206219 return resultFuture ;
207220 }
208221
209222 private void connect (
210223 EndPoint endPoint ,
211- NodeShardingInfo shardingInfo ,
212- Integer shardId ,
213224 DriverChannelOptions options ,
214225 NodeMetricUpdater nodeMetricUpdater ,
215226 ProtocolVersion currentVersion ,
216227 boolean isNegotiating ,
217228 List <ProtocolVersion > attemptedVersions ,
229+ PortIterator portIterator ,
218230 CompletableFuture <DriverChannel > resultFuture ) {
219231
220232 NettyOptions nettyOptions = context .getNettyOptions ();
@@ -230,26 +242,14 @@ private void connect(
230242 nettyOptions .afterBootstrapInitialized (bootstrap );
231243
232244 ChannelFuture connectFuture ;
233- if (shardId == null || shardingInfo == null ) {
234- if (shardId != null ) {
235- LOG .debug (
236- "Requested connection to shard {} but shardingInfo is currently missing for Node at endpoint {}. Falling back to arbitrary local port." ,
237- shardId ,
238- endPoint );
239- }
245+ if (portIterator == null ) {
240246 connectFuture = bootstrap .connect (endPoint .resolve ());
241247 } else {
242- int localPort =
243- PortAllocator .getNextAvailablePort (shardingInfo .getShardsCount (), shardId , context );
244- if (localPort == -1 ) {
245- LOG .warn (
246- "Could not find free port for shard {} at {}. Falling back to arbitrary local port." ,
247- shardId ,
248- endPoint );
249- connectFuture = bootstrap .connect (endPoint .resolve ());
250- } else {
251- connectFuture = bootstrap .connect (endPoint .resolve (), new InetSocketAddress (localPort ));
248+ if (!portIterator .hasNext ()) {
249+ portIterator .reset ();
252250 }
251+ connectFuture =
252+ bootstrap .connect (endPoint .resolve (), new InetSocketAddress (portIterator .next ()));
253253 }
254254
255255 connectFuture .addListener (
@@ -299,19 +299,28 @@ private void connect(
299299 downgraded .get ());
300300 connect (
301301 endPoint ,
302- shardingInfo ,
303- shardId ,
304302 options ,
305303 nodeMetricUpdater ,
306304 downgraded .get (),
307305 true ,
308306 attemptedVersions ,
307+ portIterator ,
309308 resultFuture );
310309 } else {
311310 resultFuture .completeExceptionally (
312311 UnsupportedProtocolVersionException .forNegotiation (
313312 endPoint , attemptedVersions ));
314313 }
314+ } else if (isBindException (error )) {
315+ connect (
316+ endPoint ,
317+ options ,
318+ nodeMetricUpdater ,
319+ currentVersion ,
320+ true ,
321+ attemptedVersions ,
322+ portIterator ,
323+ resultFuture );
315324 } else {
316325 // Note: might be completed already if the failure happened in initializer(), this is
317326 // fine
@@ -321,6 +330,16 @@ private void connect(
321330 });
322331 }
323332
333+ private static boolean isBindException (Throwable error ) {
334+ while (error != null ) {
335+ if (error instanceof java .net .BindException ) {
336+ return true ;
337+ }
338+ error = error .getCause ();
339+ }
340+ return false ;
341+ }
342+
324343 @ VisibleForTesting
325344 ChannelInitializer <Channel > initializer (
326345 EndPoint endPoint ,
@@ -330,7 +349,7 @@ ChannelInitializer<Channel> initializer(
330349 CompletableFuture <DriverChannel > resultFuture ) {
331350 return new ChannelFactoryInitializer (
332351 endPoint , protocolVersion , options , nodeMetricUpdater , resultFuture );
333- };
352+ }
334353
335354 class ChannelFactoryInitializer extends ChannelInitializer <Channel > {
336355
@@ -440,87 +459,35 @@ protected void initChannel(Channel channel) {
440459 }
441460 }
442461
443- static class PortAllocator {
444- private static final AtomicInteger lastPort = new AtomicInteger (-1 );
445- private static final Logger LOG = LoggerFactory .getLogger (PortAllocator .class );
462+ static class PortIterator implements Iterator <Integer > {
463+ private final int highestPort ;
464+ private final int startPort ;
465+ private int currentPort ;
466+ private final int shardCount ;
467+
468+ PortIterator (int lowestPort , int highestPort , int shardCount , int shardId ) {
469+ this .highestPort = highestPort ;
470+ this .currentPort = lowestPort + (shardCount - lowestPort % shardCount ) + shardId ;
471+ this .startPort = currentPort ;
472+ this .shardCount = shardCount ;
473+ }
446474
447- public static int getNextAvailablePort (int shardCount , int shardId , DriverContext context ) {
448- int lowPort =
449- context
450- .getConfig ()
451- .getDefaultProfile ()
452- .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_LOW );
453- int highPort =
454- context
455- .getConfig ()
456- .getDefaultProfile ()
457- .getInt (DefaultDriverOption .ADVANCED_SHARD_AWARENESS_PORT_HIGH );
458- if (highPort - lowPort < shardCount ) {
459- LOG .error (
460- "There is not enough ports in range [{},{}] for {} shards. Update your configuration." ,
461- lowPort ,
462- highPort ,
463- shardCount );
475+ @ Override
476+ public Integer next () {
477+ try {
478+ return currentPort ;
479+ } finally {
480+ this .currentPort += shardCount ;
464481 }
465- int lastPortValue , foundPort = -1 ;
466- do {
467- lastPortValue = lastPort .get ();
468-
469- // We will scan from lastPortValue
470- // (or lowPort is there was no lastPort or lastPort is too low)
471- int scanStart = lastPortValue == -1 ? lowPort : lastPortValue ;
472- if (scanStart < lowPort ) {
473- scanStart = lowPort ;
474- }
475-
476- // Round it up to "% shardCount == shardId"
477- scanStart += (shardCount - scanStart % shardCount ) + shardId ;
478-
479- // Scan from scanStart upwards to highPort.
480- for (int port = scanStart ; port <= highPort ; port += shardCount ) {
481- if (isTcpPortAvailable (port )) {
482- foundPort = port ;
483- break ;
484- }
485- }
486-
487- // If we started scanning from a high scanStart port
488- // there might have been not enough ports left that are
489- // smaller than highPort. Scan from the beginning
490- // from the lowPort.
491- if (foundPort == -1 ) {
492- scanStart = lowPort + (shardCount - lowPort % shardCount ) + shardId ;
493-
494- for (int port = scanStart ; port <= highPort ; port += shardCount ) {
495- if (isTcpPortAvailable (port )) {
496- foundPort = port ;
497- break ;
498- }
499- }
500- }
501-
502- // No luck! All ports taken!
503- if (foundPort == -1 ) {
504- return -1 ;
505- }
506- } while (!lastPort .compareAndSet (lastPortValue , foundPort ));
482+ }
507483
508- return foundPort ;
484+ @ Override
485+ public boolean hasNext () {
486+ return currentPort <= highestPort ;
509487 }
510488
511- public static boolean isTcpPortAvailable (int port ) {
512- try {
513- ServerSocket serverSocket = new ServerSocket ();
514- try {
515- serverSocket .setReuseAddress (false );
516- serverSocket .bind (new InetSocketAddress (port ), 1 );
517- return true ;
518- } finally {
519- serverSocket .close ();
520- }
521- } catch (IOException ex ) {
522- return false ;
523- }
489+ public void reset () {
490+ this .currentPort = startPort ;
524491 }
525492 }
526493}
0 commit comments