Skip to content

Commit 69e0100

Browse files
authored
Merge pull request scylladb#251 from sylwiaszunejko/partition_rate_error
Add support for ScyllaDB's per partition rate limiting's new error
2 parents 9b12cc9 + 67d8b94 commit 69e0100

13 files changed

+263
-103
lines changed

cassandra/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from enum import Enum
1516
import logging
1617

1718

@@ -728,3 +729,21 @@ class UnresolvableContactPoints(DriverException):
728729
contact points, only when lookup fails for all hosts
729730
"""
730731
pass
732+
733+
734+
class OperationType(Enum):
735+
Read = 0
736+
Write = 1
737+
738+
class RateLimitReached(ConfigurationException):
739+
'''
740+
Rate limit was exceeded for a partition affected by the request.
741+
'''
742+
op_type = None
743+
rejected_by_coordinator = False
744+
745+
def __init__(self, op_type=None, rejected_by_coordinator=False):
746+
self.op_type = op_type
747+
self.rejected_by_coordinator = rejected_by_coordinator
748+
message = f"[request_error_rate_limit_reached OpType={op_type.name} RejectedByCoordinator={rejected_by_coordinator}]"
749+
Exception.__init__(self, message)

cassandra/c_shard_info.pyx

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,6 @@ cdef class ShardingInfo():
3636
self.shard_aware_port = int(shard_aware_port) if shard_aware_port else 0
3737
self.shard_aware_port_ssl = int(shard_aware_port_ssl) if shard_aware_port_ssl else 0
3838

39-
@staticmethod
40-
def parse_sharding_info(message):
41-
shard_id = message.options.get('SCYLLA_SHARD', [''])[0] or None
42-
shards_count = message.options.get('SCYLLA_NR_SHARDS', [''])[0] or None
43-
partitioner = message.options.get('SCYLLA_PARTITIONER', [''])[0] or None
44-
sharding_algorithm = message.options.get('SCYLLA_SHARDING_ALGORITHM', [''])[0] or None
45-
sharding_ignore_msb = message.options.get('SCYLLA_SHARDING_IGNORE_MSB', [''])[0] or None
46-
shard_aware_port = message.options.get('SCYLLA_SHARD_AWARE_PORT', [''])[0] or None
47-
shard_aware_port_ssl = message.options.get('SCYLLA_SHARD_AWARE_PORT_SSL', [''])[0] or None
48-
49-
if not (shard_id or shards_count or partitioner == "org.apache.cassandra.dht.Murmur3Partitioner" or
50-
sharding_algorithm == "biased-token-round-robin" or sharding_ignore_msb):
51-
return 0, None
52-
53-
return int(shard_id), ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb,
54-
shard_aware_port, shard_aware_port_ssl)
55-
56-
5739
def shard_id_from_token(self, int64_t token_input):
5840
cdef uint64_t biased_token = token_input + (<uint64_t>1 << 63);
5941
biased_token <<= self.sharding_ignore_msb;

cassandra/connection.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import random
3232
import itertools
3333

34+
from cassandra.protocol_features import ProtocolFeatures
35+
3436
if 'gevent.monkey' in sys.modules:
3537
from gevent.queue import Queue, Empty
3638
else:
@@ -765,13 +767,12 @@ class Connection(object):
765767

766768
_owning_pool = None
767769

768-
shard_id = 0
769-
sharding_info = None
770-
771770
_is_checksumming_enabled = False
772771

773772
_on_orphaned_stream_released = None
774773

774+
features = None
775+
775776
@property
776777
def _iobuf(self):
777778
# backward compatibility, to avoid any change in the reactors
@@ -831,7 +832,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
831832

832833
self.lock = RLock()
833834
self.connected_event = Event()
834-
self.shard_id = shard_id
835+
self.features = ProtocolFeatures(shard_id=shard_id)
835836
self.total_shards = total_shards
836837
self.original_endpoint = self.endpoint
837838

@@ -896,8 +897,8 @@ def _wrap_socket_from_context(self):
896897
self._socket = self.ssl_context.wrap_socket(self._socket, **ssl_options)
897898

898899
def _initiate_connection(self, sockaddr):
899-
if self.shard_id is not None:
900-
for port in ShardawarePortGenerator.generate(self.shard_id, self.total_shards):
900+
if self.features.shard_id is not None:
901+
for port in ShardawarePortGenerator.generate(self.features.shard_id, self.total_shards):
901902
try:
902903
self._socket.bind(('', port))
903904
break
@@ -1263,7 +1264,7 @@ def process_msg(self, header, body):
12631264
return
12641265

12651266
try:
1266-
response = decoder(header.version, self.user_type_map, stream_id,
1267+
response = decoder(header.version, self.features, self.user_type_map, stream_id,
12671268
header.flags, header.opcode, body, self.decompressor, result_metadata)
12681269
except Exception as exc:
12691270
log.exception("Error decoding response from Cassandra. "
@@ -1318,7 +1319,7 @@ def _send_options_message(self):
13181319

13191320
@defunct_on_error
13201321
def _handle_options_response(self, options_response):
1321-
self.shard_id, self.sharding_info = ShardingInfo.parse_sharding_info(options_response)
1322+
self.features = ProtocolFeatures.parse_from_supported(options_response.options)
13221323
if self.is_defunct:
13231324
return
13241325

@@ -1338,6 +1339,9 @@ def _handle_options_response(self, options_response):
13381339
remote_supported_compressions = options_response.options['COMPRESSION']
13391340
self._product_type = options_response.options.get('PRODUCT_TYPE', [None])[0]
13401341

1342+
options = {}
1343+
self.features.add_startup_options(options)
1344+
13411345
if self.cql_version:
13421346
if self.cql_version not in supported_cql_versions:
13431347
raise ProtocolError(
@@ -1388,13 +1392,14 @@ def _handle_options_response(self, options_response):
13881392
self._compressor, self.decompressor = \
13891393
locally_supported_compressions[compression_type]
13901394

1391-
self._send_startup_message(compression_type, no_compact=self.no_compact)
1395+
self._send_startup_message(compression_type, no_compact=self.no_compact, extra_options=options)
13921396

13931397
@defunct_on_error
1394-
def _send_startup_message(self, compression=None, no_compact=False):
1398+
def _send_startup_message(self, compression=None, no_compact=False, extra_options=None):
13951399
log.debug("Sending StartupMessage on %s", self)
13961400
opts = {'DRIVER_NAME': DRIVER_NAME,
1397-
'DRIVER_VERSION': DRIVER_VERSION}
1401+
'DRIVER_VERSION': DRIVER_VERSION,
1402+
**extra_options}
13981403
if compression:
13991404
opts['COMPRESSION'] = compression
14001405
if no_compact:

cassandra/pool.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -427,15 +427,15 @@ def __init__(self, host, host_distance, session):
427427

428428
log.debug("Initializing connection for host %s", self.host)
429429
first_connection = session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
430-
log.debug("First connection created to %s for shard_id=%i", self.host, first_connection.shard_id)
431-
self._connections[first_connection.shard_id] = first_connection
430+
log.debug("First connection created to %s for shard_id=%i", self.host, first_connection.features.shard_id)
431+
self._connections[first_connection.features.shard_id] = first_connection
432432
self._keyspace = session.keyspace
433433

434434
if self._keyspace:
435435
first_connection.set_keyspace_blocking(self._keyspace)
436-
if first_connection.sharding_info and not self._session.cluster.shard_aware_options.disable:
437-
self.host.sharding_info = first_connection.sharding_info
438-
self._open_connections_for_all_shards(first_connection.shard_id)
436+
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
437+
self.host.sharding_info = first_connection.features.sharding_info
438+
self._open_connections_for_all_shards(first_connection.features.shard_id)
439439

440440
log.debug("Finished initializing connection for host %s", self.host)
441441

@@ -556,7 +556,7 @@ def return_connection(self, connection, stream_was_orphaned=False):
556556
with self._lock:
557557
if self.is_shutdown:
558558
return
559-
self._connections.pop(connection.shard_id, None)
559+
self._connections.pop(connection.features.shard_id, None)
560560
if self._is_replacing:
561561
return
562562
self._is_replacing = True
@@ -587,17 +587,17 @@ def _replace(self, connection):
587587

588588
log.debug("Replacing connection (%s) to %s", id(connection), self.host)
589589
try:
590-
if connection.shard_id in self._connections.keys():
591-
del self._connections[connection.shard_id]
590+
if connection.features.shard_id in self._connections.keys():
591+
del self._connections[connection.features.shard_id]
592592
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
593-
self._connecting.add(connection.shard_id)
594-
self._session.submit(self._open_connection_to_missing_shard, connection.shard_id)
593+
self._connecting.add(connection.features.shard_id)
594+
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
595595
else:
596596
connection = self._session.cluster.connection_factory(self.host.endpoint,
597597
on_orphaned_stream_released=self.on_orphaned_stream_released)
598598
if self._keyspace:
599599
connection.set_keyspace_blocking(self._keyspace)
600-
self._connections[connection.shard_id] = connection
600+
self._connections[connection.features.shard_id] = connection
601601
except Exception:
602602
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
603603
self._session.submit(self._replace, connection)
@@ -703,59 +703,59 @@ def _open_connection_to_missing_shard(self, shard_id):
703703
else:
704704
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
705705

706-
log.debug("Received a connection %s for shard_id=%i on host %s", id(conn), conn.shard_id, self.host)
706+
log.debug("Received a connection %s for shard_id=%i on host %s", id(conn), conn.features.shard_id, self.host)
707707
if self.is_shutdown:
708708
log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", self.host, id(conn))
709709
conn.close()
710710
return
711711

712-
if shard_aware_endpoint and shard_id != conn.shard_id:
712+
if shard_aware_endpoint and shard_id != conn.features.shard_id:
713713
# connection didn't land on expected shared
714714
# assuming behind a NAT, disabling advanced shard aware for a while
715715
self.disable_advanced_shard_aware(10 * 60)
716716

717-
old_conn = self._connections.get(conn.shard_id)
717+
old_conn = self._connections.get(conn.features.shard_id)
718718
if old_conn is None or old_conn.orphaned_threshold_reached:
719719
log.debug(
720720
"New connection (%s) created to shard_id=%i on host %s",
721721
id(conn),
722-
conn.shard_id,
722+
conn.features.shard_id,
723723
self.host
724724
)
725725
old_conn = None
726726
with self._lock:
727727
if self.is_shutdown:
728728
conn.close()
729729
return
730-
if conn.shard_id in self._connections.keys():
730+
if conn.features.shard_id in self._connections.keys():
731731
# Move the current connection to the trash and use the new one from now on
732-
old_conn = self._connections[conn.shard_id]
732+
old_conn = self._connections[conn.features.shard_id]
733733
log.debug(
734734
"Replacing overloaded connection (%s) with (%s) for shard %i for host %s",
735735
id(old_conn),
736736
id(conn),
737-
conn.shard_id,
737+
conn.features.shard_id,
738738
self.host
739739
)
740740
if self._keyspace:
741741
conn.set_keyspace_blocking(self._keyspace)
742742

743-
self._connections[conn.shard_id] = conn
743+
self._connections[conn.features.shard_id] = conn
744744
if old_conn is not None:
745745
remaining = old_conn.in_flight - len(old_conn.orphaned_request_ids)
746746
if remaining == 0:
747747
log.debug(
748748
"Immediately closing the old connection (%s) for shard %i on host %s",
749749
id(old_conn),
750-
old_conn.shard_id,
750+
old_conn.features.shard_id,
751751
self.host
752752
)
753753
old_conn.close()
754754
else:
755755
log.debug(
756756
"Moving the connection (%s) for shard %i to trash on host %s, %i requests remaining",
757757
id(old_conn),
758-
old_conn.shard_id,
758+
old_conn.features.shard_id,
759759
self.host,
760760
remaining,
761761
)
@@ -800,7 +800,7 @@ def _open_connection_to_missing_shard(self, shard_id):
800800
log.debug(
801801
"Putting a connection %s to shard %i to the excess pool of host %s",
802802
id(conn),
803-
conn.shard_id,
803+
conn.features.shard_id,
804804
self.host
805805
)
806806
close_connection = False

cassandra/protocol.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from six.moves import range
2323
import io
2424

25-
from cassandra import ProtocolVersion
25+
from cassandra import OperationType, ProtocolVersion
2626
from cassandra import type_codes, DriverException
27-
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
27+
from cassandra import (Unavailable, WriteTimeout, RateLimitReached, ReadTimeout,
2828
WriteFailure, ReadFailure, FunctionFailure,
2929
AlreadyExists, InvalidRequest, Unauthorized,
3030
UnsupportedOperation, UserFunctionDescriptor,
@@ -126,10 +126,13 @@ def __init__(self, code, message, info):
126126
self.info = info
127127

128128
@classmethod
129-
def recv_body(cls, f, protocol_version, *args):
129+
def recv_body(cls, f, protocol_version, protocol_features, *args):
130130
code = read_int(f)
131131
msg = read_string(f)
132-
subcls = error_classes.get(code, cls)
132+
if code == protocol_features.rate_limit_error:
133+
subcls = RateLimitReachedException
134+
else:
135+
subcls = error_classes.get(code, cls)
133136
extra_info = subcls.recv_error_info(f, protocol_version)
134137
return subcls(code=code, message=msg, info=extra_info)
135138

@@ -390,6 +393,19 @@ def recv_error_info(f, protocol_version):
390393
def to_exception(self):
391394
return AlreadyExists(**self.info)
392395

396+
class RateLimitReachedException(ConfigurationException):
397+
summary= 'Rate limit was exceeded for a partition affected by the request'
398+
error_code = 0x4321
399+
400+
@staticmethod
401+
def recv_error_info(f, protocol_version):
402+
return {
403+
'op_type': OperationType(read_byte(f)),
404+
'rejected_by_coordinator': read_byte(f) != 0
405+
}
406+
407+
def to_exception(self):
408+
return RateLimitReached(**self.info)
393409

394410
class ClientWriteError(RequestExecutionException):
395411
summary = 'Client write failure.'
@@ -738,7 +754,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata):
738754
raise DriverException("Unknown RESULT kind: %d" % self.kind)
739755

740756
@classmethod
741-
def recv_body(cls, f, protocol_version, user_type_map, result_metadata):
757+
def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata):
742758
kind = read_int(f)
743759
msg = cls(kind)
744760
msg.recv(f, protocol_version, user_type_map, result_metadata)
@@ -1147,7 +1163,7 @@ def _write_header(f, version, flags, stream_id, opcode, length):
11471163
write_int(f, length)
11481164

11491165
@classmethod
1150-
def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body,
1166+
def decode_message(cls, protocol_version, protocol_features, user_type_map, stream_id, flags, opcode, body,
11511167
decompressor, result_metadata):
11521168
"""
11531169
Decodes a native protocol message body
@@ -1193,7 +1209,7 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod
11931209
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
11941210

11951211
msg_class = cls.message_types_by_opcode[opcode]
1196-
msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata)
1212+
msg = msg_class.recv_body(body, protocol_version, protocol_features, user_type_map, result_metadata)
11971213
msg.stream_id = stream_id
11981214
msg.trace_id = trace_id
11991215
msg.custom_payload = custom_payload

0 commit comments

Comments
 (0)