Skip to content

Commit 16f0fda

Browse files
committed
1
1 parent 33a9063 commit 16f0fda

File tree

1 file changed

+67
-103
lines changed

1 file changed

+67
-103
lines changed

core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java

+67-103
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
3030
import com.datastax.oss.driver.api.core.config.DriverConfig;
3131
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
32-
import com.datastax.oss.driver.api.core.context.DriverContext;
3332
import com.datastax.oss.driver.api.core.metadata.EndPoint;
3433
import com.datastax.oss.driver.api.core.metadata.Node;
3534
import com.datastax.oss.driver.api.core.metadata.NodeShardingInfo;
@@ -53,17 +52,15 @@
5352
import io.netty.channel.ChannelInitializer;
5453
import io.netty.channel.ChannelOption;
5554
import io.netty.channel.ChannelPipeline;
56-
import java.io.IOException;
5755
import java.net.InetSocketAddress;
58-
import java.net.ServerSocket;
56+
import java.util.Iterator;
5957
import java.util.List;
6058
import java.util.Map;
6159
import java.util.Optional;
6260
import java.util.concurrent.CompletableFuture;
6361
import java.util.concurrent.CompletionStage;
6462
import java.util.concurrent.CopyOnWriteArrayList;
6563
import java.util.concurrent.atomic.AtomicBoolean;
66-
import java.util.concurrent.atomic.AtomicInteger;
6764
import net.jcip.annotations.ThreadSafe;
6865
import org.slf4j.Logger;
6966
import org.slf4j.LoggerFactory;
@@ -193,28 +190,44 @@ CompletionStage<DriverChannel> connect(
193190
isNegotiating = true;
194191
}
195192

193+
PortIterator portIterator = null;
194+
195+
if (shardId != null && shardingInfo != null) {
196+
int lowestPort =
197+
context
198+
.getConfig()
199+
.getDefaultProfile()
200+
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW);
201+
int highestPort =
202+
context
203+
.getConfig()
204+
.getDefaultProfile()
205+
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
206+
portIterator =
207+
new PortIterator(lowestPort, highestPort, shardingInfo.getShardsCount(), shardId)
208+
.iterator();
209+
}
210+
196211
connect(
197212
endPoint,
198-
shardingInfo,
199-
shardId,
200213
options,
201214
nodeMetricUpdater,
202215
currentVersion,
203216
isNegotiating,
204217
attemptedVersions,
218+
portIterator,
205219
resultFuture);
206220
return resultFuture;
207221
}
208222

209223
private void connect(
210224
EndPoint endPoint,
211-
NodeShardingInfo shardingInfo,
212-
Integer shardId,
213225
DriverChannelOptions options,
214226
NodeMetricUpdater nodeMetricUpdater,
215227
ProtocolVersion currentVersion,
216228
boolean isNegotiating,
217229
List<ProtocolVersion> attemptedVersions,
230+
PortIterator portIterator,
218231
CompletableFuture<DriverChannel> resultFuture) {
219232

220233
NettyOptions nettyOptions = context.getNettyOptions();
@@ -230,26 +243,14 @@ private void connect(
230243
nettyOptions.afterBootstrapInitialized(bootstrap);
231244

232245
ChannelFuture connectFuture;
233-
if (shardId == null || shardingInfo == null) {
234-
if (shardId != null) {
235-
LOG.debug(
236-
"Requested connection to shard {} but shardingInfo is currently missing for Node at endpoint {}. Falling back to arbitrary local port.",
237-
shardId,
238-
endPoint);
239-
}
246+
if (portIterator == null) {
240247
connectFuture = bootstrap.connect(endPoint.resolve());
241248
} else {
242-
int localPort =
243-
PortAllocator.getNextAvailablePort(shardingInfo.getShardsCount(), shardId, context);
244-
if (localPort == -1) {
245-
LOG.warn(
246-
"Could not find free port for shard {} at {}. Falling back to arbitrary local port.",
247-
shardId,
248-
endPoint);
249-
connectFuture = bootstrap.connect(endPoint.resolve());
250-
} else {
251-
connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort));
249+
if (!portIterator.hasNext()) {
250+
portIterator.reset();
252251
}
252+
connectFuture =
253+
bootstrap.connect(endPoint.resolve(), new InetSocketAddress(portIterator.next()));
253254
}
254255

255256
connectFuture.addListener(
@@ -299,19 +300,29 @@ private void connect(
299300
downgraded.get());
300301
connect(
301302
endPoint,
302-
shardingInfo,
303-
shardId,
304303
options,
305304
nodeMetricUpdater,
306305
downgraded.get(),
307306
true,
308307
attemptedVersions,
308+
portIterator,
309309
resultFuture);
310310
} else {
311311
resultFuture.completeExceptionally(
312312
UnsupportedProtocolVersionException.forNegotiation(
313313
endPoint, attemptedVersions));
314314
}
315+
} else if (error instanceof java.net.BindException
316+
|| error.getCause() instanceof java.net.BindException) {
317+
connect(
318+
endPoint,
319+
options,
320+
nodeMetricUpdater,
321+
currentVersion,
322+
true,
323+
attemptedVersions,
324+
portIterator,
325+
resultFuture);
315326
} else {
316327
// Note: might be completed already if the failure happened in initializer(), this is
317328
// fine
@@ -330,7 +341,7 @@ ChannelInitializer<Channel> initializer(
330341
CompletableFuture<DriverChannel> resultFuture) {
331342
return new ChannelFactoryInitializer(
332343
endPoint, protocolVersion, options, nodeMetricUpdater, resultFuture);
333-
};
344+
}
334345

335346
class ChannelFactoryInitializer extends ChannelInitializer<Channel> {
336347

@@ -440,87 +451,40 @@ protected void initChannel(Channel channel) {
440451
}
441452
}
442453

443-
static class PortAllocator {
444-
private static final AtomicInteger lastPort = new AtomicInteger(-1);
445-
private static final Logger LOG = LoggerFactory.getLogger(PortAllocator.class);
454+
static class PortIterator implements Iterator<Integer> {
455+
private final int highestPort;
456+
private final int startPort;
457+
private int currentPort;
458+
private final int shardCount;
459+
460+
PortIterator(int lowestPort, int highestPort, int shardCount, int shardId) {
461+
this.highestPort = highestPort;
462+
this.currentPort = lowestPort + (shardCount - lowestPort % shardCount) + shardId;
463+
this.startPort = currentPort;
464+
this.shardCount = shardCount;
465+
}
446466

447-
public static int getNextAvailablePort(int shardCount, int shardId, DriverContext context) {
448-
int lowPort =
449-
context
450-
.getConfig()
451-
.getDefaultProfile()
452-
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW);
453-
int highPort =
454-
context
455-
.getConfig()
456-
.getDefaultProfile()
457-
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
458-
if (highPort - lowPort < shardCount) {
459-
LOG.error(
460-
"There is not enough ports in range [{},{}] for {} shards. Update your configuration.",
461-
lowPort,
462-
highPort,
463-
shardCount);
467+
@Override
468+
public Integer next() {
469+
try {
470+
return currentPort;
471+
} finally {
472+
this.currentPort += shardCount;
464473
}
465-
int lastPortValue, foundPort = -1;
466-
do {
467-
lastPortValue = lastPort.get();
468-
469-
// We will scan from lastPortValue
470-
// (or lowPort is there was no lastPort or lastPort is too low)
471-
int scanStart = lastPortValue == -1 ? lowPort : lastPortValue;
472-
if (scanStart < lowPort) {
473-
scanStart = lowPort;
474-
}
475-
476-
// Round it up to "% shardCount == shardId"
477-
scanStart += (shardCount - scanStart % shardCount) + shardId;
478-
479-
// Scan from scanStart upwards to highPort.
480-
for (int port = scanStart; port <= highPort; port += shardCount) {
481-
if (isTcpPortAvailable(port)) {
482-
foundPort = port;
483-
break;
484-
}
485-
}
486-
487-
// If we started scanning from a high scanStart port
488-
// there might have been not enough ports left that are
489-
// smaller than highPort. Scan from the beginning
490-
// from the lowPort.
491-
if (foundPort == -1) {
492-
scanStart = lowPort + (shardCount - lowPort % shardCount) + shardId;
493-
494-
for (int port = scanStart; port <= highPort; port += shardCount) {
495-
if (isTcpPortAvailable(port)) {
496-
foundPort = port;
497-
break;
498-
}
499-
}
500-
}
474+
}
501475

502-
// No luck! All ports taken!
503-
if (foundPort == -1) {
504-
return -1;
505-
}
506-
} while (!lastPort.compareAndSet(lastPortValue, foundPort));
476+
@Override
477+
public boolean hasNext() {
478+
return currentPort <= highestPort;
479+
}
507480

508-
return foundPort;
481+
public void reset() {
482+
this.currentPort = startPort;
509483
}
510484

511-
public static boolean isTcpPortAvailable(int port) {
512-
try {
513-
ServerSocket serverSocket = new ServerSocket();
514-
try {
515-
serverSocket.setReuseAddress(false);
516-
serverSocket.bind(new InetSocketAddress(port), 1);
517-
return true;
518-
} finally {
519-
serverSocket.close();
520-
}
521-
} catch (IOException ex) {
522-
return false;
523-
}
485+
public PortIterator iterator() {
486+
this.reset();
487+
return this;
524488
}
525489
}
526490
}

0 commit comments

Comments
 (0)