Skip to content

Commit ec43291

Browse files
committed
feat(policy): add shard reconnection policies
Add abstract classes: `ShardReconnectionPolicy` and `ShardReconnectionScheduler` And implementations: `NoDelayShardReconnectionPolicy` - policy that represents old behavior of having no delay and no concurrency restriction. `NoConcurrentShardReconnectionPolicy` - policy that limits concurrent reconnections to 1 per scope and introduces delay between reconnections within the scope.
1 parent 3f7bcbb commit ec43291

File tree

4 files changed

+634
-23
lines changed

4 files changed

+634
-23
lines changed

cassandra/policies.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,31 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
15+
1416
import random
17+
import threading
18+
import time
19+
import weakref
20+
from abc import ABC, abstractmethod
1521

1622
from collections import namedtuple
23+
from enum import Enum
1724
from functools import lru_cache
1825
from itertools import islice, cycle, groupby, repeat
1926
import logging
2027
from random import randint, shuffle
2128
from threading import Lock
2229
import socket
2330
import warnings
31+
from typing import TYPE_CHECKING, Callable, Any, List, Tuple, Iterator, Optional, Dict
2432

2533
log = logging.getLogger(__name__)
2634

2735
from cassandra import WriteType as WT
2836

37+
if TYPE_CHECKING:
38+
from cluster import Session
2939

3040
# This is done this way because WriteType was originally
3141
# defined here and in order not to break the API.
@@ -864,6 +874,302 @@ def _add_jitter(self, value):
864874
return min(max(self.base_delay, delay), self.max_delay)
865875

866876

877+
class ShardConnectionScheduler(ABC):
878+
"""
879+
A base class for a scheduler for a shard connection backoff policy.
880+
``ShardConnectionScheduler`` is a per Session instance that implements logic described by ``ShardConnectionBackoffPolicy`` that creates it
881+
"""
882+
883+
@abstractmethod
884+
def schedule(
885+
self,
886+
host_id: str,
887+
shard_id: int,
888+
method: Callable[..., None],
889+
*args: List[Any],
890+
**kwargs: dict[Any, Any]) -> None:
891+
"""
892+
Schedule a shard connection according to policy or executes it right away.
893+
It is non-blocking call, even if policy executes it right away, it is being executed in a separate thread.
894+
895+
``host_id`` - an id of the host of the shard
896+
``shard_id`` - an id of the shard
897+
``method`` - a callable that creates connection and stores it in the connection pool.
898+
Currently, it is `HostConnection._open_connection_to_missing_shard`
899+
``*args`` and ``**kwargs`` are passed to ``method`` when policy executes it
900+
"""
901+
raise NotImplementedError()
902+
903+
904+
class ShardConnectionBackoffPolicy(ABC):
905+
"""
906+
Base class for shard connection backoff policies.
907+
These policies allow user to control pace of establishing new connections to shards
908+
909+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
910+
"""
911+
912+
@abstractmethod
913+
def new_scheduler(self, session: Session) -> ShardConnectionScheduler:
914+
raise NotImplementedError()
915+
916+
917+
class NoDelayShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
918+
"""
919+
A shard connection backoff policy with no delay between attempts and no concurrency restrictions.
920+
Ensures at most one pending connection per (host, shard) pair.
921+
If connection attempts for the same (host, shard) it is silently dropped.
922+
923+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
924+
"""
925+
926+
def new_scheduler(self, session: Session) -> ShardConnectionScheduler:
927+
return _NoDelayShardConnectionBackoffScheduler(session)
928+
929+
930+
class _NoDelayShardConnectionBackoffScheduler(ShardConnectionScheduler):
931+
"""
932+
A shard connection backoff scheduler for ``cassandra.policies.NoDelayShardConnectionBackoffPolicy``.
933+
It does not introduce any delay or concurrency restrictions.
934+
It only ensures that there is only one pending or scheduled connection per host+shard.
935+
"""
936+
session: Session
937+
already_scheduled: dict[str, bool]
938+
lock: threading.Lock
939+
940+
def __init__(self, session: Session):
941+
self.session = weakref.proxy(session)
942+
self.already_scheduled = {}
943+
self.lock = threading.Lock()
944+
945+
def _execute(
946+
self,
947+
scheduled_key: str,
948+
method: Callable[..., None],
949+
*args: List[Any],
950+
**kwargs: dict[Any, Any]) -> None:
951+
try:
952+
method(*args, **kwargs)
953+
finally:
954+
with self.lock:
955+
self.already_scheduled[scheduled_key] = False
956+
957+
def schedule(
958+
self,
959+
host_id: str,
960+
shard_id: int,
961+
method: Callable[..., None],
962+
*args: List[Any],
963+
**kwargs: dict[Any, Any]) -> None:
964+
scheduled_key = f'{host_id}-{shard_id}'
965+
966+
with self.lock:
967+
if self.already_scheduled.get(scheduled_key):
968+
return
969+
self.already_scheduled[scheduled_key] = True
970+
971+
if not self.session.is_shutdown:
972+
self.session.submit(self._execute, scheduled_key, method, *args, **kwargs)
973+
974+
975+
class ShardConnectionBackoffScope(Enum):
976+
"""
977+
A scope for `ShardConnectionBackoffPolicy`, in particular ``LimitedConcurrencyShardConnectionBackoffPolicy``
978+
979+
Scope defines concurrency limitation scope, for instance :
980+
``LimitedConcurrencyShardConnectionBackoffPolicy`` - allows only one pending connection per scope, if you set it to Cluster,
981+
only one connection per cluster will be allowed
982+
"""
983+
Cluster = 0
984+
Host = 1
985+
986+
987+
class ShardConnectionBackoffSchedule(ABC):
988+
@abstractmethod
989+
def new_schedule(self) -> Iterator[float]:
990+
"""
991+
This should return a finite or infinite iterable of delays (each as a
992+
floating point number of seconds).
993+
Note that if the iterable is finite, schedule will be recreated right after iterable is exhausted.
994+
"""
995+
raise NotImplementedError()
996+
997+
998+
class LimitedConcurrencyShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
999+
"""
1000+
A shard connection backoff policy that allows only ``max_concurrent`` concurrent connection per scope.
1001+
Scope could be ``Host``or ``Cluster``
1002+
For backoff calculation ir needs ``ShardConnectionBackoffSchedule`` or ``ReconnectionPolicy``, since both share same API.
1003+
When there is no more scheduled connections schedule of the backoff is reset.
1004+
1005+
it also does not allow multiple pending or scheduled connections for same host+shard,
1006+
it silently drops attempts to schedule it.
1007+
1008+
On ``new_scheduler`` instantiate a scheduler that behaves according to the policy
1009+
"""
1010+
scope: ShardConnectionBackoffScope
1011+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy
1012+
max_concurrent: int
1013+
1014+
def __init__(
1015+
self,
1016+
scope: ShardConnectionBackoffScope,
1017+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy,
1018+
max_concurrent: int = 1,
1019+
):
1020+
if not isinstance(scope, ShardConnectionBackoffScope):
1021+
raise ValueError("scope must be a ShardConnectionBackoffScope")
1022+
if not isinstance(backoff_policy, (ShardConnectionBackoffSchedule, ReconnectionPolicy)):
1023+
raise ValueError("backoff_policy must be a ShardConnectionBackoffSchedule or ReconnectionPolicy")
1024+
if max_concurrent < 1:
1025+
raise ValueError("max_concurrent must be a positive integer")
1026+
self.scope = scope
1027+
self.backoff_policy = backoff_policy
1028+
self.max_concurrent = max_concurrent
1029+
1030+
def new_scheduler(self, session: Session) -> ShardConnectionScheduler:
1031+
return _LimitedConcurrencyShardReconnectionScheduler(session, self.scope, self.backoff_policy, self.max_concurrent)
1032+
1033+
1034+
class Callback:
1035+
method: Callable[..., None]
1036+
args: Tuple[Any, ...]
1037+
kwargs: Dict[str, Any]
1038+
1039+
def __init__(self, method: Callable[..., None], *args, **kwargs) -> None:
1040+
self.method = method
1041+
self.args = args
1042+
self.kwargs = kwargs
1043+
1044+
1045+
class _ScopeBucket:
1046+
"""
1047+
Holds information for a shard reconnection scope, schedules and executes reconnections.
1048+
"""
1049+
items: List[Callback]
1050+
session: Session
1051+
backoff_policy: ShardConnectionBackoffSchedule
1052+
lock: threading.Lock
1053+
schedule: Optional[Iterator[float]]
1054+
max_concurrent: int
1055+
currently_pending: int
1056+
1057+
def __init__(
1058+
self,
1059+
session: Session,
1060+
backoff_policy: ShardConnectionBackoffSchedule,
1061+
max_concurrent: int,
1062+
):
1063+
self.items = []
1064+
self.session = session
1065+
self.backoff_policy = backoff_policy
1066+
self.lock = threading.Lock()
1067+
self.schedule = self.backoff_policy.new_schedule()
1068+
self.max_concurrent = max_concurrent
1069+
self.currently_pending = 0
1070+
1071+
def _get_delay(self) -> float:
1072+
try:
1073+
return next(self.schedule)
1074+
except StopIteration:
1075+
# A bit of trickery to avoid having lock around self.schedule
1076+
schedule = self.backoff_policy.new_schedule()
1077+
delay = next(schedule)
1078+
self.schedule = schedule
1079+
return delay
1080+
1081+
def _schedule(self):
1082+
if self.session.is_shutdown:
1083+
return
1084+
delay = self._get_delay()
1085+
if delay:
1086+
self.session.cluster.scheduler.schedule(delay, self._run)
1087+
else:
1088+
self.session.submit(self._run)
1089+
1090+
def _run(self):
1091+
if self.session.is_shutdown:
1092+
return
1093+
1094+
with self.lock:
1095+
try:
1096+
cb = self.items.pop()
1097+
except IndexError:
1098+
# Just in case
1099+
if self.currently_pending > 0:
1100+
self.currently_pending -= 1
1101+
# When items are exhausted reset schedule to ensure that new items going to get another schedule
1102+
# It is important for exponential policy
1103+
self.schedule = self.backoff_policy.new_schedule()
1104+
return
1105+
1106+
try:
1107+
cb.method(*cb.args, **cb.kwargs)
1108+
finally:
1109+
self._schedule()
1110+
1111+
def schedule_new_connection(self, cb: Callback):
1112+
with self.lock:
1113+
self.items.append(cb)
1114+
if self.currently_pending < self.max_concurrent:
1115+
self.currently_pending += 1
1116+
self._schedule()
1117+
1118+
1119+
class _LimitedConcurrencyShardReconnectionScheduler(ShardConnectionScheduler):
1120+
already_scheduled: dict[str, bool]
1121+
scopes: dict[str, _ScopeBucket]
1122+
scope: ShardConnectionBackoffScope
1123+
backoff_policy: ShardConnectionBackoffSchedule
1124+
session: Session
1125+
lock: threading.Lock
1126+
max_concurrent: int
1127+
1128+
def __init__(
1129+
self,
1130+
session: Session,
1131+
scope: ShardConnectionBackoffScope,
1132+
backoff_policy: ShardConnectionBackoffSchedule,
1133+
max_concurrent: int,
1134+
):
1135+
self.already_scheduled = {}
1136+
self.scopes = {}
1137+
self.scope = scope
1138+
self.backoff_policy = backoff_policy
1139+
self.max_concurrent = max_concurrent
1140+
self.session = session
1141+
self.lock = threading.Lock()
1142+
1143+
def _execute(self, scheduled_key: str, method: Callable[..., None], *args, **kwargs):
1144+
try:
1145+
method(*args, **kwargs)
1146+
finally:
1147+
with self.lock:
1148+
self.already_scheduled[scheduled_key] = False
1149+
1150+
def schedule(self, host_id: str, shard_id: int, method: Callable[..., None], *args, **kwargs):
1151+
if self.scope == ShardConnectionBackoffScope.Cluster:
1152+
scope_hash = "global-cluster-scope"
1153+
elif self.scope == ShardConnectionBackoffScope.Host:
1154+
scope_hash = host_id
1155+
else:
1156+
raise ValueError("scope must be Cluster or Host")
1157+
1158+
scheduled_key = f'{host_id}-{shard_id}'
1159+
1160+
with self.lock:
1161+
if self.already_scheduled.get(scheduled_key):
1162+
return False
1163+
self.already_scheduled[scheduled_key] = True
1164+
1165+
scope_info = self.scopes.get(scope_hash)
1166+
if not scope_info:
1167+
scope_info = _ScopeBucket(self.session, self.backoff_policy, self.max_concurrent)
1168+
self.scopes[scope_hash] = scope_info
1169+
scope_info.schedule_new_connection(Callback(self._execute, scheduled_key, method, *args, **kwargs))
1170+
return True
1171+
1172+
8671173
class RetryPolicy(object):
8681174
"""
8691175
A policy that describes whether to retry, rethrow, or ignore coordinator

tests/unit/test_host_connection_pool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from cassandra.connection import Connection
2727
from cassandra.pool import HostConnection, HostConnectionPool
2828
from cassandra.pool import Host, NoConnectionsAvailable
29-
from cassandra.policies import HostDistance, SimpleConvictionPolicy
29+
from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardConnectionBackoffScheduler
3030

3131
LOGGER = logging.getLogger(__name__)
3232

@@ -41,6 +41,8 @@ def make_session(self):
4141
session.cluster.get_core_connections_per_host.return_value = 1
4242
session.cluster.get_max_requests_per_connection.return_value = 1
4343
session.cluster.get_max_connections_per_host.return_value = 1
44+
session.shard_connection_backoff_scheduler = _NoDelayShardConnectionBackoffScheduler(session)
45+
session.is_shutdown = False
4446
return session
4547

4648
def test_borrow_and_return(self):

0 commit comments

Comments
 (0)