Skip to content

Commit ffc6392

Browse files
committed
Add unit-tests for ShardAwarePortGenerator
1. Make it testable 2. Add unit tests for it
1 parent e022378 commit ffc6392

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

cassandra/connection.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -667,17 +667,23 @@ def reset_cql_frame_buffer(self):
667667
self.reset_io_buffer()
668668

669669

670-
class ShardawarePortGenerator:
671-
@classmethod
672-
def generate(cls, shard_id, total_shards):
673-
start = random.randrange(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)
674-
available_ports = itertools.chain(range(start, DEFAULT_LOCAL_PORT_HIGH), range(DEFAULT_LOCAL_PORT_LOW, start))
670+
class ShardAwarePortGenerator:
671+
def __init__(self, start_port: int, end_port: int):
672+
self.start_port = start_port
673+
self.end_port = end_port
674+
675+
def generate(self, shard_id: int, total_shards: int):
676+
start = random.randrange(self.start_port, self.end_port)
677+
available_ports = itertools.chain(range(start, self.end_port), range(self.start_port, start))
675678

676679
for port in available_ports:
677680
if port % total_shards == shard_id:
678681
yield port
679682

680683

684+
DefaultShardAwarePortGenerator = ShardAwarePortGenerator(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)
685+
686+
681687
class Connection(object):
682688

683689
CALLBACK_ERR_THREAD_THRESHOLD = 100
@@ -928,7 +934,7 @@ def _wrap_socket_from_context(self):
928934

929935
def _initiate_connection(self, sockaddr):
930936
if self.features.shard_id is not None:
931-
for port in ShardawarePortGenerator.generate(self.features.shard_id, self.total_shards):
937+
for port in DefaultShardAwarePortGenerator.generate(self.features.shard_id, self.total_shards):
932938
try:
933939
self._socket.bind(('', port))
934940
break

tests/unit/test_connection.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import itertools
1415
import unittest
1516
from io import BytesIO
1617
import time
@@ -21,7 +22,7 @@
2122
from cassandra.cluster import Cluster
2223
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError,
2324
locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager,
24-
ConnectionException, DefaultEndPoint)
25+
ConnectionException, DefaultEndPoint, ShardAwarePortGenerator)
2526
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
2627
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
2728
SupportedMessage, ProtocolHandler)
@@ -478,3 +479,43 @@ def test_endpoint_resolve(self):
478479
DefaultEndPoint('10.0.0.1', 3232).resolve(),
479480
('10.0.0.1', 3232)
480481
)
482+
483+
484+
class TestShardawarePortGenerator(unittest.TestCase):
485+
@patch('random.randrange')
486+
def test_generate_ports_basic(self, mock_randrange):
487+
mock_randrange.return_value = 10005
488+
gen = ShardAwarePortGenerator(10000, 10020)
489+
ports = list(itertools.islice(gen.generate(shard_id=1, total_shards=3), 5))
490+
491+
# Starting from aligned 10005 + shard_id (1), step by 3
492+
self.assertEqual(ports, [10006, 10009, 10012, 10015, 10018])
493+
494+
@patch('random.randrange')
495+
def test_wraps_around_to_start(self, mock_randrange):
496+
mock_randrange.return_value = 10008
497+
gen = ShardAwarePortGenerator(10000, 10020)
498+
ports = list(itertools.islice(gen.generate(shard_id=2, total_shards=4), 5))
499+
500+
# Expected wrap-around from start_port after end_port is exceeded
501+
self.assertEqual(ports, [10010, 10014, 10018, 10002, 10006])
502+
503+
@patch('random.randrange')
504+
def test_all_ports_have_correct_modulo(self, mock_randrange):
505+
mock_randrange.return_value = 10012
506+
total_shards = 5
507+
shard_id = 3
508+
gen = ShardAwarePortGenerator(10000, 10020)
509+
510+
for port in gen.generate(shard_id=shard_id, total_shards=total_shards):
511+
self.assertEqual(port % total_shards, shard_id)
512+
513+
@patch('random.randrange')
514+
def test_generate_is_repeatable_with_same_mock(self, mock_randrange):
515+
mock_randrange.return_value = 10010
516+
gen = ShardAwarePortGenerator(10000, 10020)
517+
518+
first_run = list(itertools.islice(gen.generate(0, 2), 5))
519+
second_run = list(itertools.islice(gen.generate(0, 2), 5))
520+
521+
self.assertEqual(first_run, second_run)

0 commit comments

Comments
 (0)