Skip to content

Commit 4708b33

Browse files
committed
1
1 parent 33a9063 commit 4708b33

File tree

1 file changed

+61
-104
lines changed

1 file changed

+61
-104
lines changed

core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java

+61-104
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
3030
import com.datastax.oss.driver.api.core.config.DriverConfig;
3131
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
32-
import com.datastax.oss.driver.api.core.context.DriverContext;
3332
import com.datastax.oss.driver.api.core.metadata.EndPoint;
3433
import com.datastax.oss.driver.api.core.metadata.Node;
3534
import com.datastax.oss.driver.api.core.metadata.NodeShardingInfo;
@@ -53,17 +52,14 @@
5352
import io.netty.channel.ChannelInitializer;
5453
import io.netty.channel.ChannelOption;
5554
import io.netty.channel.ChannelPipeline;
56-
import java.io.IOException;
5755
import java.net.InetSocketAddress;
58-
import java.net.ServerSocket;
5956
import java.util.List;
6057
import java.util.Map;
6158
import java.util.Optional;
6259
import java.util.concurrent.CompletableFuture;
6360
import java.util.concurrent.CompletionStage;
6461
import java.util.concurrent.CopyOnWriteArrayList;
6562
import java.util.concurrent.atomic.AtomicBoolean;
66-
import java.util.concurrent.atomic.AtomicInteger;
6763
import net.jcip.annotations.ThreadSafe;
6864
import org.slf4j.Logger;
6965
import org.slf4j.LoggerFactory;
@@ -193,28 +189,43 @@ CompletionStage<DriverChannel> connect(
193189
isNegotiating = true;
194190
}
195191

192+
PortIterator portIterator = null;
193+
194+
if (shardId != null && shardingInfo != null) {
195+
int lowestPort =
196+
context
197+
.getConfig()
198+
.getDefaultProfile()
199+
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW);
200+
int highestPort =
201+
context
202+
.getConfig()
203+
.getDefaultProfile()
204+
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
205+
portIterator =
206+
new PortIterator(lowestPort, highestPort, shardingInfo.getShardsCount(), shardId);
207+
}
208+
196209
connect(
197210
endPoint,
198-
shardingInfo,
199-
shardId,
200211
options,
201212
nodeMetricUpdater,
202213
currentVersion,
203214
isNegotiating,
204215
attemptedVersions,
216+
portIterator,
205217
resultFuture);
206218
return resultFuture;
207219
}
208220

209221
private void connect(
210222
EndPoint endPoint,
211-
NodeShardingInfo shardingInfo,
212-
Integer shardId,
213223
DriverChannelOptions options,
214224
NodeMetricUpdater nodeMetricUpdater,
215225
ProtocolVersion currentVersion,
216226
boolean isNegotiating,
217227
List<ProtocolVersion> attemptedVersions,
228+
PortIterator portIterator,
218229
CompletableFuture<DriverChannel> resultFuture) {
219230

220231
NettyOptions nettyOptions = context.getNettyOptions();
@@ -230,26 +241,11 @@ private void connect(
230241
nettyOptions.afterBootstrapInitialized(bootstrap);
231242

232243
ChannelFuture connectFuture;
233-
if (shardId == null || shardingInfo == null) {
234-
if (shardId != null) {
235-
LOG.debug(
236-
"Requested connection to shard {} but shardingInfo is currently missing for Node at endpoint {}. Falling back to arbitrary local port.",
237-
shardId,
238-
endPoint);
239-
}
244+
if (portIterator == null) {
240245
connectFuture = bootstrap.connect(endPoint.resolve());
241246
} else {
242-
int localPort =
243-
PortAllocator.getNextAvailablePort(shardingInfo.getShardsCount(), shardId, context);
244-
if (localPort == -1) {
245-
LOG.warn(
246-
"Could not find free port for shard {} at {}. Falling back to arbitrary local port.",
247-
shardId,
248-
endPoint);
249-
connectFuture = bootstrap.connect(endPoint.resolve());
250-
} else {
251-
connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort));
252-
}
247+
connectFuture =
248+
bootstrap.connect(endPoint.resolve(), new InetSocketAddress(portIterator.get()));
253249
}
254250

255251
connectFuture.addListener(
@@ -299,19 +295,30 @@ private void connect(
299295
downgraded.get());
300296
connect(
301297
endPoint,
302-
shardingInfo,
303-
shardId,
304298
options,
305299
nodeMetricUpdater,
306300
downgraded.get(),
307301
true,
308302
attemptedVersions,
303+
portIterator,
309304
resultFuture);
310305
} else {
311306
resultFuture.completeExceptionally(
312307
UnsupportedProtocolVersionException.forNegotiation(
313308
endPoint, attemptedVersions));
314309
}
310+
} else if (portIterator != null
311+
&& isBindException(error)
312+
&& portIterator.bumpPortUp()) {
313+
connect(
314+
endPoint,
315+
options,
316+
nodeMetricUpdater,
317+
currentVersion,
318+
true,
319+
attemptedVersions,
320+
portIterator,
321+
resultFuture);
315322
} else {
316323
// Note: might be completed already if the failure happened in initializer(), this is
317324
// fine
@@ -321,6 +328,16 @@ private void connect(
321328
});
322329
}
323330

331+
private static boolean isBindException(Throwable error) {
332+
while (error != null) {
333+
if (error instanceof java.net.BindException) {
334+
return true;
335+
}
336+
error = error.getCause();
337+
}
338+
return false;
339+
}
340+
324341
@VisibleForTesting
325342
ChannelInitializer<Channel> initializer(
326343
EndPoint endPoint,
@@ -330,7 +347,7 @@ ChannelInitializer<Channel> initializer(
330347
CompletableFuture<DriverChannel> resultFuture) {
331348
return new ChannelFactoryInitializer(
332349
endPoint, protocolVersion, options, nodeMetricUpdater, resultFuture);
333-
};
350+
}
334351

335352
class ChannelFactoryInitializer extends ChannelInitializer<Channel> {
336353

@@ -440,87 +457,27 @@ protected void initChannel(Channel channel) {
440457
}
441458
}
442459

443-
static class PortAllocator {
444-
private static final AtomicInteger lastPort = new AtomicInteger(-1);
445-
private static final Logger LOG = LoggerFactory.getLogger(PortAllocator.class);
460+
static class PortIterator {
461+
private final int highestPort;
462+
private int currentPort;
463+
private final int shardCount;
446464

447-
public static int getNextAvailablePort(int shardCount, int shardId, DriverContext context) {
448-
int lowPort =
449-
context
450-
.getConfig()
451-
.getDefaultProfile()
452-
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW);
453-
int highPort =
454-
context
455-
.getConfig()
456-
.getDefaultProfile()
457-
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
458-
if (highPort - lowPort < shardCount) {
459-
LOG.error(
460-
"There is not enough ports in range [{},{}] for {} shards. Update your configuration.",
461-
lowPort,
462-
highPort,
463-
shardCount);
464-
}
465-
int lastPortValue, foundPort = -1;
466-
do {
467-
lastPortValue = lastPort.get();
468-
469-
// We will scan from lastPortValue
470-
// (or lowPort is there was no lastPort or lastPort is too low)
471-
int scanStart = lastPortValue == -1 ? lowPort : lastPortValue;
472-
if (scanStart < lowPort) {
473-
scanStart = lowPort;
474-
}
475-
476-
// Round it up to "% shardCount == shardId"
477-
scanStart += (shardCount - scanStart % shardCount) + shardId;
478-
479-
// Scan from scanStart upwards to highPort.
480-
for (int port = scanStart; port <= highPort; port += shardCount) {
481-
if (isTcpPortAvailable(port)) {
482-
foundPort = port;
483-
break;
484-
}
485-
}
486-
487-
// If we started scanning from a high scanStart port
488-
// there might have been not enough ports left that are
489-
// smaller than highPort. Scan from the beginning
490-
// from the lowPort.
491-
if (foundPort == -1) {
492-
scanStart = lowPort + (shardCount - lowPort % shardCount) + shardId;
493-
494-
for (int port = scanStart; port <= highPort; port += shardCount) {
495-
if (isTcpPortAvailable(port)) {
496-
foundPort = port;
497-
break;
498-
}
499-
}
500-
}
501-
502-
// No luck! All ports taken!
503-
if (foundPort == -1) {
504-
return -1;
505-
}
506-
} while (!lastPort.compareAndSet(lastPortValue, foundPort));
465+
PortIterator(int lowestPort, int highestPort, int shardCount, int shardId) {
466+
this.highestPort = highestPort;
467+
this.currentPort = lowestPort + (shardCount - lowestPort % shardCount) + shardId;
468+
this.shardCount = shardCount;
469+
}
507470

508-
return foundPort;
471+
public Integer get() {
472+
return currentPort;
509473
}
510474

511-
public static boolean isTcpPortAvailable(int port) {
512-
try {
513-
ServerSocket serverSocket = new ServerSocket();
514-
try {
515-
serverSocket.setReuseAddress(false);
516-
serverSocket.bind(new InetSocketAddress(port), 1);
517-
return true;
518-
} finally {
519-
serverSocket.close();
520-
}
521-
} catch (IOException ex) {
475+
public boolean bumpPortUp() {
476+
if (this.currentPort + shardCount > highestPort) {
522477
return false;
523478
}
479+
this.currentPort += shardCount;
480+
return true;
524481
}
525482
}
526483
}

0 commit comments

Comments
 (0)