Skip to content

Commit 67348e7

Browse files
committed
1
1 parent 33a9063 commit 67348e7

File tree

1 file changed

+71
-104
lines changed

1 file changed

+71
-104
lines changed

core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java

+71-104
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
3030
import com.datastax.oss.driver.api.core.config.DriverConfig;
3131
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
32-
import com.datastax.oss.driver.api.core.context.DriverContext;
3332
import com.datastax.oss.driver.api.core.metadata.EndPoint;
3433
import com.datastax.oss.driver.api.core.metadata.Node;
3534
import com.datastax.oss.driver.api.core.metadata.NodeShardingInfo;
@@ -53,17 +52,15 @@
5352
import io.netty.channel.ChannelInitializer;
5453
import io.netty.channel.ChannelOption;
5554
import io.netty.channel.ChannelPipeline;
56-
import java.io.IOException;
5755
import java.net.InetSocketAddress;
58-
import java.net.ServerSocket;
56+
import java.util.Iterator;
5957
import java.util.List;
6058
import java.util.Map;
6159
import java.util.Optional;
6260
import java.util.concurrent.CompletableFuture;
6361
import java.util.concurrent.CompletionStage;
6462
import java.util.concurrent.CopyOnWriteArrayList;
6563
import java.util.concurrent.atomic.AtomicBoolean;
66-
import java.util.concurrent.atomic.AtomicInteger;
6764
import net.jcip.annotations.ThreadSafe;
6865
import org.slf4j.Logger;
6966
import org.slf4j.LoggerFactory;
@@ -193,28 +190,43 @@ CompletionStage<DriverChannel> connect(
193190
isNegotiating = true;
194191
}
195192

193+
PortIterator portIterator = null;
194+
195+
if (shardId != null && shardingInfo != null) {
196+
int lowestPort =
197+
context
198+
.getConfig()
199+
.getDefaultProfile()
200+
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW);
201+
int highestPort =
202+
context
203+
.getConfig()
204+
.getDefaultProfile()
205+
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
206+
portIterator =
207+
new PortIterator(lowestPort, highestPort, shardingInfo.getShardsCount(), shardId);
208+
}
209+
196210
connect(
197211
endPoint,
198-
shardingInfo,
199-
shardId,
200212
options,
201213
nodeMetricUpdater,
202214
currentVersion,
203215
isNegotiating,
204216
attemptedVersions,
217+
portIterator,
205218
resultFuture);
206219
return resultFuture;
207220
}
208221

209222
private void connect(
210223
EndPoint endPoint,
211-
NodeShardingInfo shardingInfo,
212-
Integer shardId,
213224
DriverChannelOptions options,
214225
NodeMetricUpdater nodeMetricUpdater,
215226
ProtocolVersion currentVersion,
216227
boolean isNegotiating,
217228
List<ProtocolVersion> attemptedVersions,
229+
PortIterator portIterator,
218230
CompletableFuture<DriverChannel> resultFuture) {
219231

220232
NettyOptions nettyOptions = context.getNettyOptions();
@@ -230,26 +242,14 @@ private void connect(
230242
nettyOptions.afterBootstrapInitialized(bootstrap);
231243

232244
ChannelFuture connectFuture;
233-
if (shardId == null || shardingInfo == null) {
234-
if (shardId != null) {
235-
LOG.debug(
236-
"Requested connection to shard {} but shardingInfo is currently missing for Node at endpoint {}. Falling back to arbitrary local port.",
237-
shardId,
238-
endPoint);
239-
}
245+
if (portIterator == null) {
240246
connectFuture = bootstrap.connect(endPoint.resolve());
241247
} else {
242-
int localPort =
243-
PortAllocator.getNextAvailablePort(shardingInfo.getShardsCount(), shardId, context);
244-
if (localPort == -1) {
245-
LOG.warn(
246-
"Could not find free port for shard {} at {}. Falling back to arbitrary local port.",
247-
shardId,
248-
endPoint);
249-
connectFuture = bootstrap.connect(endPoint.resolve());
250-
} else {
251-
connectFuture = bootstrap.connect(endPoint.resolve(), new InetSocketAddress(localPort));
248+
if (!portIterator.hasNext()) {
249+
portIterator.reset();
252250
}
251+
connectFuture =
252+
bootstrap.connect(endPoint.resolve(), new InetSocketAddress(portIterator.next()));
253253
}
254254

255255
connectFuture.addListener(
@@ -299,19 +299,28 @@ private void connect(
299299
downgraded.get());
300300
connect(
301301
endPoint,
302-
shardingInfo,
303-
shardId,
304302
options,
305303
nodeMetricUpdater,
306304
downgraded.get(),
307305
true,
308306
attemptedVersions,
307+
portIterator,
309308
resultFuture);
310309
} else {
311310
resultFuture.completeExceptionally(
312311
UnsupportedProtocolVersionException.forNegotiation(
313312
endPoint, attemptedVersions));
314313
}
314+
} else if (isBindException(error)) {
315+
connect(
316+
endPoint,
317+
options,
318+
nodeMetricUpdater,
319+
currentVersion,
320+
true,
321+
attemptedVersions,
322+
portIterator,
323+
resultFuture);
315324
} else {
316325
// Note: might be completed already if the failure happened in initializer(), this is
317326
// fine
@@ -321,6 +330,16 @@ private void connect(
321330
});
322331
}
323332

333+
private static boolean isBindException(Throwable error) {
334+
while (error != null) {
335+
if (error instanceof java.net.BindException) {
336+
return true;
337+
}
338+
error = error.getCause();
339+
}
340+
return false;
341+
}
342+
324343
@VisibleForTesting
325344
ChannelInitializer<Channel> initializer(
326345
EndPoint endPoint,
@@ -330,7 +349,7 @@ ChannelInitializer<Channel> initializer(
330349
CompletableFuture<DriverChannel> resultFuture) {
331350
return new ChannelFactoryInitializer(
332351
endPoint, protocolVersion, options, nodeMetricUpdater, resultFuture);
333-
};
352+
}
334353

335354
class ChannelFactoryInitializer extends ChannelInitializer<Channel> {
336355

@@ -440,87 +459,35 @@ protected void initChannel(Channel channel) {
440459
}
441460
}
442461

443-
static class PortAllocator {
444-
private static final AtomicInteger lastPort = new AtomicInteger(-1);
445-
private static final Logger LOG = LoggerFactory.getLogger(PortAllocator.class);
462+
static class PortIterator implements Iterator<Integer> {
463+
private final int highestPort;
464+
private final int startPort;
465+
private int currentPort;
466+
private final int shardCount;
467+
468+
PortIterator(int lowestPort, int highestPort, int shardCount, int shardId) {
469+
this.highestPort = highestPort;
470+
this.currentPort = lowestPort + (shardCount - lowestPort % shardCount) + shardId;
471+
this.startPort = currentPort;
472+
this.shardCount = shardCount;
473+
}
446474

447-
public static int getNextAvailablePort(int shardCount, int shardId, DriverContext context) {
448-
int lowPort =
449-
context
450-
.getConfig()
451-
.getDefaultProfile()
452-
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW);
453-
int highPort =
454-
context
455-
.getConfig()
456-
.getDefaultProfile()
457-
.getInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH);
458-
if (highPort - lowPort < shardCount) {
459-
LOG.error(
460-
"There is not enough ports in range [{},{}] for {} shards. Update your configuration.",
461-
lowPort,
462-
highPort,
463-
shardCount);
475+
@Override
476+
public Integer next() {
477+
try {
478+
return currentPort;
479+
} finally {
480+
this.currentPort += shardCount;
464481
}
465-
int lastPortValue, foundPort = -1;
466-
do {
467-
lastPortValue = lastPort.get();
468-
469-
// We will scan from lastPortValue
470-
// (or lowPort is there was no lastPort or lastPort is too low)
471-
int scanStart = lastPortValue == -1 ? lowPort : lastPortValue;
472-
if (scanStart < lowPort) {
473-
scanStart = lowPort;
474-
}
475-
476-
// Round it up to "% shardCount == shardId"
477-
scanStart += (shardCount - scanStart % shardCount) + shardId;
478-
479-
// Scan from scanStart upwards to highPort.
480-
for (int port = scanStart; port <= highPort; port += shardCount) {
481-
if (isTcpPortAvailable(port)) {
482-
foundPort = port;
483-
break;
484-
}
485-
}
486-
487-
// If we started scanning from a high scanStart port
488-
// there might have been not enough ports left that are
489-
// smaller than highPort. Scan from the beginning
490-
// from the lowPort.
491-
if (foundPort == -1) {
492-
scanStart = lowPort + (shardCount - lowPort % shardCount) + shardId;
493-
494-
for (int port = scanStart; port <= highPort; port += shardCount) {
495-
if (isTcpPortAvailable(port)) {
496-
foundPort = port;
497-
break;
498-
}
499-
}
500-
}
501-
502-
// No luck! All ports taken!
503-
if (foundPort == -1) {
504-
return -1;
505-
}
506-
} while (!lastPort.compareAndSet(lastPortValue, foundPort));
482+
}
507483

508-
return foundPort;
484+
@Override
485+
public boolean hasNext() {
486+
return currentPort <= highestPort;
509487
}
510488

511-
public static boolean isTcpPortAvailable(int port) {
512-
try {
513-
ServerSocket serverSocket = new ServerSocket();
514-
try {
515-
serverSocket.setReuseAddress(false);
516-
serverSocket.bind(new InetSocketAddress(port), 1);
517-
return true;
518-
} finally {
519-
serverSocket.close();
520-
}
521-
} catch (IOException ex) {
522-
return false;
523-
}
489+
public void reset() {
490+
this.currentPort = startPort;
524491
}
525492
}
526493
}

0 commit comments

Comments
 (0)