11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import itertools
14
15
import unittest
15
16
from io import BytesIO
16
17
import time
21
22
from cassandra .cluster import Cluster
22
23
from cassandra .connection import (Connection , HEADER_DIRECTION_TO_CLIENT , ProtocolError ,
23
24
locally_supported_compressions , ConnectionHeartbeat , _Frame , Timer , TimerManager ,
24
- ConnectionException , DefaultEndPoint )
25
+ ConnectionException , DefaultEndPoint , ShardAwarePortGenerator )
25
26
from cassandra .marshal import uint8_pack , uint32_pack , int32_pack
26
27
from cassandra .protocol import (write_stringmultimap , write_int , write_string ,
27
28
SupportedMessage , ProtocolHandler )
@@ -478,3 +479,43 @@ def test_endpoint_resolve(self):
478
479
DefaultEndPoint ('10.0.0.1' , 3232 ).resolve (),
479
480
('10.0.0.1' , 3232 )
480
481
)
482
+
483
+
484
+ class TestShardawarePortGenerator (unittest .TestCase ):
485
+ @patch ('random.randrange' )
486
+ def test_generate_ports_basic (self , mock_randrange ):
487
+ mock_randrange .return_value = 10005
488
+ gen = ShardAwarePortGenerator (10000 , 10020 )
489
+ ports = list (itertools .islice (gen .generate (shard_id = 1 , total_shards = 3 ), 5 ))
490
+
491
+ # Starting from aligned 10005 + shard_id (1), step by 3
492
+ self .assertEqual (ports , [10006 , 10009 , 10012 , 10015 , 10018 ])
493
+
494
+ @patch ('random.randrange' )
495
+ def test_wraps_around_to_start (self , mock_randrange ):
496
+ mock_randrange .return_value = 10008
497
+ gen = ShardAwarePortGenerator (10000 , 10020 )
498
+ ports = list (itertools .islice (gen .generate (shard_id = 2 , total_shards = 4 ), 5 ))
499
+
500
+ # Expected wrap-around from start_port after end_port is exceeded
501
+ self .assertEqual (ports , [10010 , 10014 , 10018 , 10002 , 10006 ])
502
+
503
+ @patch ('random.randrange' )
504
+ def test_all_ports_have_correct_modulo (self , mock_randrange ):
505
+ mock_randrange .return_value = 10012
506
+ total_shards = 5
507
+ shard_id = 3
508
+ gen = ShardAwarePortGenerator (10000 , 10020 )
509
+
510
+ for port in gen .generate (shard_id = shard_id , total_shards = total_shards ):
511
+ self .assertEqual (port % total_shards , shard_id )
512
+
513
+ @patch ('random.randrange' )
514
+ def test_generate_is_repeatable_with_same_mock (self , mock_randrange ):
515
+ mock_randrange .return_value = 10010
516
+ gen = ShardAwarePortGenerator (10000 , 10020 )
517
+
518
+ first_run = list (itertools .islice (gen .generate (0 , 2 ), 5 ))
519
+ second_run = list (itertools .islice (gen .generate (0 , 2 ), 5 ))
520
+
521
+ self .assertEqual (first_run , second_run )
0 commit comments