Skip to content

Commit ea8afec

Browse files
Move sharding info to ProtocolFeatures
Sharding is a protocol extention, now sharing-related info is a part of ProtocolFeatures class, also _ShardingInfo.parse_sharding_info is moved to ProtocolFeatures to have all features strings in one place.
1 parent f36ba79 commit ea8afec

File tree

7 files changed

+67
-81
lines changed

7 files changed

+67
-81
lines changed

cassandra/c_shard_info.pyx

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,6 @@ cdef class ShardingInfo():
3636
self.shard_aware_port = int(shard_aware_port) if shard_aware_port else 0
3737
self.shard_aware_port_ssl = int(shard_aware_port_ssl) if shard_aware_port_ssl else 0
3838

39-
@staticmethod
40-
def parse_sharding_info(message):
41-
shard_id = message.options.get('SCYLLA_SHARD', [''])[0] or None
42-
shards_count = message.options.get('SCYLLA_NR_SHARDS', [''])[0] or None
43-
partitioner = message.options.get('SCYLLA_PARTITIONER', [''])[0] or None
44-
sharding_algorithm = message.options.get('SCYLLA_SHARDING_ALGORITHM', [''])[0] or None
45-
sharding_ignore_msb = message.options.get('SCYLLA_SHARDING_IGNORE_MSB', [''])[0] or None
46-
shard_aware_port = message.options.get('SCYLLA_SHARD_AWARE_PORT', [''])[0] or None
47-
shard_aware_port_ssl = message.options.get('SCYLLA_SHARD_AWARE_PORT_SSL', [''])[0] or None
48-
49-
if not (shard_id or shards_count or partitioner == "org.apache.cassandra.dht.Murmur3Partitioner" or
50-
sharding_algorithm == "biased-token-round-robin" or sharding_ignore_msb):
51-
return 0, None
52-
53-
return int(shard_id), ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb,
54-
shard_aware_port, shard_aware_port_ssl)
55-
56-
5739
def shard_id_from_token(self, int64_t token_input):
5840
cdef uint64_t biased_token = token_input + (<uint64_t>1 << 63);
5941
biased_token <<= self.sharding_ignore_msb;

cassandra/connection.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -767,9 +767,6 @@ class Connection(object):
767767

768768
_owning_pool = None
769769

770-
shard_id = 0
771-
sharding_info = None
772-
773770
_is_checksumming_enabled = False
774771

775772
_on_orphaned_stream_released = None
@@ -835,7 +832,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
835832

836833
self.lock = RLock()
837834
self.connected_event = Event()
838-
self.shard_id = shard_id
835+
self.features = ProtocolFeatures(shard_id=shard_id)
839836
self.total_shards = total_shards
840837
self.original_endpoint = self.endpoint
841838

@@ -900,8 +897,8 @@ def _wrap_socket_from_context(self):
900897
self._socket = self.ssl_context.wrap_socket(self._socket, **ssl_options)
901898

902899
def _initiate_connection(self, sockaddr):
903-
if self.shard_id is not None:
904-
for port in ShardawarePortGenerator.generate(self.shard_id, self.total_shards):
900+
if self.features.shard_id is not None:
901+
for port in ShardawarePortGenerator.generate(self.features.shard_id, self.total_shards):
905902
try:
906903
self._socket.bind(('', port))
907904
break
@@ -1322,7 +1319,7 @@ def _send_options_message(self):
13221319

13231320
@defunct_on_error
13241321
def _handle_options_response(self, options_response):
1325-
self.shard_id, self.sharding_info = ShardingInfo.parse_sharding_info(options_response)
1322+
self.features = ProtocolFeatures.parse_from_supported(options_response.options)
13261323
if self.is_defunct:
13271324
return
13281325

@@ -1342,10 +1339,8 @@ def _handle_options_response(self, options_response):
13421339
remote_supported_compressions = options_response.options['COMPRESSION']
13431340
self._product_type = options_response.options.get('PRODUCT_TYPE', [None])[0]
13441341

1345-
protocol_features = ProtocolFeatures.parse_from_supported(options_response.options)
13461342
options = {}
1347-
protocol_features.add_startup_options(options)
1348-
self.features = protocol_features
1343+
self.features.add_startup_options(options)
13491344

13501345
if self.cql_version:
13511346
if self.cql_version not in supported_cql_versions:

cassandra/pool.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -427,15 +427,15 @@ def __init__(self, host, host_distance, session):
427427

428428
log.debug("Initializing connection for host %s", self.host)
429429
first_connection = session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
430-
log.debug("First connection created to %s for shard_id=%i", self.host, first_connection.shard_id)
431-
self._connections[first_connection.shard_id] = first_connection
430+
log.debug("First connection created to %s for shard_id=%i", self.host, first_connection.features.shard_id)
431+
self._connections[first_connection.features.shard_id] = first_connection
432432
self._keyspace = session.keyspace
433433

434434
if self._keyspace:
435435
first_connection.set_keyspace_blocking(self._keyspace)
436-
if first_connection.sharding_info and not self._session.cluster.shard_aware_options.disable:
437-
self.host.sharding_info = first_connection.sharding_info
438-
self._open_connections_for_all_shards(first_connection.shard_id)
436+
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
437+
self.host.sharding_info = first_connection.features.sharding_info
438+
self._open_connections_for_all_shards(first_connection.features.shard_id)
439439

440440
log.debug("Finished initializing connection for host %s", self.host)
441441

@@ -556,7 +556,7 @@ def return_connection(self, connection, stream_was_orphaned=False):
556556
with self._lock:
557557
if self.is_shutdown:
558558
return
559-
self._connections.pop(connection.shard_id, None)
559+
self._connections.pop(connection.features.shard_id, None)
560560
if self._is_replacing:
561561
return
562562
self._is_replacing = True
@@ -587,17 +587,17 @@ def _replace(self, connection):
587587

588588
log.debug("Replacing connection (%s) to %s", id(connection), self.host)
589589
try:
590-
if connection.shard_id in self._connections.keys():
591-
del self._connections[connection.shard_id]
590+
if connection.features.shard_id in self._connections.keys():
591+
del self._connections[connection.features.shard_id]
592592
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
593-
self._connecting.add(connection.shard_id)
594-
self._session.submit(self._open_connection_to_missing_shard, connection.shard_id)
593+
self._connecting.add(connection.features.shard_id)
594+
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
595595
else:
596596
connection = self._session.cluster.connection_factory(self.host.endpoint,
597597
on_orphaned_stream_released=self.on_orphaned_stream_released)
598598
if self._keyspace:
599599
connection.set_keyspace_blocking(self._keyspace)
600-
self._connections[connection.shard_id] = connection
600+
self._connections[connection.features.shard_id] = connection
601601
except Exception:
602602
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
603603
self._session.submit(self._replace, connection)
@@ -703,59 +703,59 @@ def _open_connection_to_missing_shard(self, shard_id):
703703
else:
704704
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
705705

706-
log.debug("Received a connection %s for shard_id=%i on host %s", id(conn), conn.shard_id, self.host)
706+
log.debug("Received a connection %s for shard_id=%i on host %s", id(conn), conn.features.shard_id, self.host)
707707
if self.is_shutdown:
708708
log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", self.host, id(conn))
709709
conn.close()
710710
return
711711

712-
if shard_aware_endpoint and shard_id != conn.shard_id:
712+
if shard_aware_endpoint and shard_id != conn.features.shard_id:
713713
# connection didn't land on expected shared
714714
# assuming behind a NAT, disabling advanced shard aware for a while
715715
self.disable_advanced_shard_aware(10 * 60)
716716

717-
old_conn = self._connections.get(conn.shard_id)
717+
old_conn = self._connections.get(conn.features.shard_id)
718718
if old_conn is None or old_conn.orphaned_threshold_reached:
719719
log.debug(
720720
"New connection (%s) created to shard_id=%i on host %s",
721721
id(conn),
722-
conn.shard_id,
722+
conn.features.shard_id,
723723
self.host
724724
)
725725
old_conn = None
726726
with self._lock:
727727
if self.is_shutdown:
728728
conn.close()
729729
return
730-
if conn.shard_id in self._connections.keys():
730+
if conn.features.shard_id in self._connections.keys():
731731
# Move the current connection to the trash and use the new one from now on
732-
old_conn = self._connections[conn.shard_id]
732+
old_conn = self._connections[conn.features.shard_id]
733733
log.debug(
734734
"Replacing overloaded connection (%s) with (%s) for shard %i for host %s",
735735
id(old_conn),
736736
id(conn),
737-
conn.shard_id,
737+
conn.features.shard_id,
738738
self.host
739739
)
740740
if self._keyspace:
741741
conn.set_keyspace_blocking(self._keyspace)
742742

743-
self._connections[conn.shard_id] = conn
743+
self._connections[conn.features.shard_id] = conn
744744
if old_conn is not None:
745745
remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids)
746746
if remaining == 0:
747747
log.debug(
748748
"Immediately closing the old connection (%s) for shard %i on host %s",
749749
id(old_conn),
750-
old_conn.shard_id,
750+
old_conn.features.shard_id,
751751
self.host
752752
)
753753
old_conn.close()
754754
else:
755755
log.debug(
756756
"Moving the connection (%s) for shard %i to trash on host %s, %i requests remaining",
757757
id(old_conn),
758-
old_conn.shard_id,
758+
old_conn.features.shard_id,
759759
self.host,
760760
remaining,
761761
)
@@ -800,7 +800,7 @@ def _open_connection_to_missing_shard(self, shard_id):
800800
log.debug(
801801
"Putting a connection %s to shard %i to the excess pool of host %s",
802802
id(conn),
803-
conn.shard_id,
803+
conn.features.shard_id,
804804
self.host
805805
)
806806
close_connection = False

cassandra/protocol_features.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
import logging
22

3+
from cassandra.shard_info import _ShardingInfo
4+
35
log = logging.getLogger(__name__)
46

57

68
RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR"
79

810
class ProtocolFeatures(object):
911
rate_limit_error = None
12+
shard_id = 0
13+
sharding_info = None
1014

11-
def __init__(self, rate_limit_error=None):
15+
def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None):
1216
self.rate_limit_error = rate_limit_error
17+
self.shard_id = shard_id
18+
self.sharding_info = sharding_info
1319

1420
@staticmethod
1521
def parse_from_supported(supported):
16-
return ProtocolFeatures(rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported))
22+
rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported)
23+
shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported)
24+
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info)
1725

1826
@staticmethod
1927
def maybe_parse_rate_limit_error(supported):
@@ -36,3 +44,22 @@ def add_startup_options(self, options):
3644
if self.rate_limit_error is not None:
3745
options[RATE_LIMIT_ERROR_EXTENSION] = ""
3846

47+
@staticmethod
48+
def parse_sharding_info(options):
49+
shard_id = options.get('SCYLLA_SHARD', [''])[0] or None
50+
shards_count = options.get('SCYLLA_NR_SHARDS', [''])[0] or None
51+
partitioner = options.get('SCYLLA_PARTITIONER', [''])[0] or None
52+
sharding_algorithm = options.get('SCYLLA_SHARDING_ALGORITHM', [''])[0] or None
53+
sharding_ignore_msb = options.get('SCYLLA_SHARDING_IGNORE_MSB', [''])[0] or None
54+
shard_aware_port = options.get('SCYLLA_SHARD_AWARE_PORT', [''])[0] or None
55+
shard_aware_port_ssl = options.get('SCYLLA_SHARD_AWARE_PORT_SSL', [''])[0] or None
56+
log.debug("Parsing sharding info from message options %s", options)
57+
58+
if not (shard_id or shards_count or partitioner == "org.apache.cassandra.dht.Murmur3Partitioner" or
59+
sharding_algorithm == "biased-token-round-robin" or sharding_ignore_msb):
60+
return 0, None
61+
62+
return int(shard_id), _ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb,
63+
shard_aware_port, shard_aware_port_ssl)
64+
65+

cassandra/shard_info.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,6 @@ def __init__(self, shard_id, shards_count, partitioner, sharding_algorithm, shar
2828
self.shard_aware_port = int(shard_aware_port) if shard_aware_port else None
2929
self.shard_aware_port_ssl = int(shard_aware_port_ssl) if shard_aware_port_ssl else None
3030

31-
@staticmethod
32-
def parse_sharding_info(message):
33-
shard_id = message.options.get('SCYLLA_SHARD', [''])[0] or None
34-
shards_count = message.options.get('SCYLLA_NR_SHARDS', [''])[0] or None
35-
partitioner = message.options.get('SCYLLA_PARTITIONER', [''])[0] or None
36-
sharding_algorithm = message.options.get('SCYLLA_SHARDING_ALGORITHM', [''])[0] or None
37-
sharding_ignore_msb = message.options.get('SCYLLA_SHARDING_IGNORE_MSB', [''])[0] or None
38-
shard_aware_port = message.options.get('SCYLLA_SHARD_AWARE_PORT', [''])[0] or None
39-
shard_aware_port_ssl = message.options.get('SCYLLA_SHARD_AWARE_PORT_SSL', [''])[0] or None
40-
log.debug("Parsing sharding info from message options %s", message.options)
41-
42-
if not (shard_id or shards_count or partitioner == "org.apache.cassandra.dht.Murmur3Partitioner" or
43-
sharding_algorithm == "biased-token-round-robin" or sharding_ignore_msb):
44-
return 0, None
45-
46-
return int(shard_id), _ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb,
47-
shard_aware_port, shard_aware_port_ssl)
48-
4931
def shard_id_from_token(self, token):
5032
"""
5133
Convert a Murmur3 token to shard_id based on the number of shards on the host

tests/unit/test_host_connection_pool.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from concurrent.futures import ThreadPoolExecutor
1515
import logging
1616
import time
17+
from cassandra.protocol_features import ProtocolFeatures
1718

1819
from cassandra.shard_info import _ShardingInfo
1920

@@ -300,11 +301,11 @@ def mock_connection_factory(self, *args, **kwargs):
300301
connection.is_shutdown = False
301302
connection.is_defunct = False
302303
connection.is_closed = False
303-
connection.shard_id = self.connection_counter
304+
connection.features = ProtocolFeatures(shard_id=self.connection_counter,
305+
sharding_info=_ShardingInfo(shard_id=1, shards_count=14,
306+
partitioner="", sharding_algorithm="", sharding_ignore_msb=0,
307+
shard_aware_port="", shard_aware_port_ssl=""))
304308
self.connection_counter += 1
305-
connection.sharding_info = _ShardingInfo(shard_id=1, shards_count=14,
306-
partitioner="", sharding_algorithm="", sharding_ignore_msb=0,
307-
shard_aware_port="", shard_aware_port_ssl="")
308309

309310
return connection
310311

tests/unit/test_shard_aware.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from cassandra.pool import HostConnection, HostDistance
2626
from cassandra.connection import ShardingInfo, DefaultEndPoint
2727
from cassandra.metadata import Murmur3Token
28+
from cassandra.protocol_features import ProtocolFeatures
2829

2930
LOGGER = logging.getLogger(__name__)
3031

@@ -43,7 +44,7 @@ class OptionsHolder(object):
4344
'SCYLLA_SHARDING_ALGORITHM': ['biased-token-round-robin'],
4445
'SCYLLA_SHARDING_IGNORE_MSB': ['12']
4546
}
46-
shard_id, shard_info = ShardingInfo.parse_sharding_info(OptionsHolder())
47+
shard_id, shard_info = ProtocolFeatures.parse_sharding_info(OptionsHolder().options)
4748

4849
self.assertEqual(shard_id, 1)
4950
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"a").value), 4)
@@ -88,12 +89,10 @@ def mock_connection_factory(self, *args, **kwargs):
8889
connection.is_defunct = False
8990
connection.is_closed = False
9091
connection.orphaned_threshold_reached = False
91-
connection.endpoint = args[0]
92-
connection.shard_id = kwargs.get('shard_id', self.connection_counter)
92+
connection.endpoint = args[0]
93+
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
94+
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
9395
self.connection_counter += 1
94-
connection.sharding_info = ShardingInfo(shard_id=1, shards_count=4,
95-
partitioner="", sharding_algorithm="", sharding_ignore_msb=0,
96-
shard_aware_port=19042, shard_aware_port_ssl=19045)
9796

9897
return connection
9998

@@ -107,7 +106,7 @@ def mock_connection_factory(self, *args, **kwargs):
107106
f.result()
108107
assert len(pool._connections) == 4
109108
for shard_id, connection in pool._connections.items():
110-
assert connection.shard_id == shard_id
109+
assert connection.features.shard_id == shard_id
111110
if shard_id == 0:
112111
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
113112
else:

0 commit comments

Comments
 (0)