Skip to content

Commit 05e59a5

Browse files
Add RackAwareRoundRobinPolicy for host selection
1 parent c9b24b7 commit 05e59a5

File tree

6 files changed

+264
-48
lines changed

6 files changed

+264
-48
lines changed

cassandra/cluster.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,8 @@ def _profiles_without_explicit_lbps(self):
496496

497497
def distance(self, host):
498498
distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values())
499-
return HostDistance.LOCAL if HostDistance.LOCAL in distances else \
499+
return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \
500+
HostDistance.LOCAL if HostDistance.LOCAL in distances else \
500501
HostDistance.REMOTE if HostDistance.REMOTE in distances else \
501502
HostDistance.IGNORED
502503

@@ -613,7 +614,7 @@ class Cluster(object):
613614
614615
Defaults to loopback interface.
615616
616-
Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit
617+
Note: When using :class:`.DCAwareRoundRobinPolicy` with no explicit
617618
local_dc set (as is the default), the DC is chosen from an arbitrary
618619
host in contact_points. In this case, contact_points should contain
619620
only nodes from a single, local DC.
@@ -1373,21 +1374,25 @@ def __init__(self,
13731374
self._user_types = defaultdict(dict)
13741375

13751376
self._min_requests_per_connection = {
1377+
HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS,
13761378
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
13771379
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
13781380
}
13791381

13801382
self._max_requests_per_connection = {
1383+
HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS,
13811384
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
13821385
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
13831386
}
13841387

13851388
self._core_connections_per_host = {
1389+
HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
13861390
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
13871391
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
13881392
}
13891393

13901394
self._max_connections_per_host = {
1395+
HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
13911396
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
13921397
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
13931398
}

cassandra/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3440,7 +3440,7 @@ def group_keys_by_replica(session, keyspace, table, keys):
34403440
all_replicas = cluster.metadata.get_replicas(keyspace, routing_key)
34413441
# First check if there are local replicas
34423442
valid_replicas = [host for host in all_replicas if
3443-
host.is_up and distance(host) == HostDistance.LOCAL]
3443+
host.is_up and (distance(host) == HostDistance.LOCAL or distance(host) == HostDistance.LOCAL_RACK)]
34443444
if not valid_replicas:
34453445
valid_replicas = [host for host in all_replicas if host.is_up]
34463446

cassandra/policies.py

Lines changed: 129 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,18 @@ class HostDistance(object):
4646
connections opened to it.
4747
"""
4848

49-
LOCAL = 0
49+
LOCAL_RACK = 0
50+
"""
51+
Nodes with ``LOCAL_RACK`` distance will be preferred for operations
52+
under some load balancing policies (such as :class:`.RackAwareRoundRobinPolicy`)
53+
and will have a greater number of connections opened against
54+
them by default.
55+
56+
This distance is typically used for nodes within the same
57+
datacenter and the same rack as the client.
58+
"""
59+
60+
LOCAL = 1
5061
"""
5162
Nodes with ``LOCAL`` distance will be preferred for operations
5263
under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
@@ -57,12 +68,12 @@ class HostDistance(object):
5768
datacenter as the client.
5869
"""
5970

60-
REMOTE = 1
71+
REMOTE = 2
6172
"""
6273
Nodes with ``REMOTE`` distance will be treated as a last resort
63-
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
64-
and will have a smaller number of connections opened against
65-
them by default.
74+
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`
75+
and :class:`.RackAwareRoundRobinPolicy`)and will have a smaller number of
76+
connections opened against them by default.
6677
6778
This distance is typically used for nodes outside of the
6879
datacenter that the client is running in.
@@ -316,6 +327,118 @@ def on_add(self, host):
316327
def on_remove(self, host):
317328
self.on_down(host)
318329

330+
class RackAwareRoundRobinPolicy(LoadBalancingPolicy):
331+
"""
332+
Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts
333+
in the local rack, before hosts in the local datacenter but a
334+
different rack, before hosts in all other datercentres
335+
"""
336+
337+
local_dc = None
338+
local_rack = None
339+
used_hosts_per_remote_dc = 0
340+
341+
def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
342+
"""
343+
The `local_dc` and `local_rack` parameters should be the name of the
344+
datacenter and rack (such as is reported by ``nodetool ring``) that
345+
should be considered local.
346+
347+
`used_hosts_per_remote_dc` controls how many nodes in
348+
each remote datacenter will have connections opened
349+
against them. In other words, `used_hosts_per_remote_dc` hosts
350+
will be considered :attr:`~.HostDistance.REMOTE` and the
351+
rest will be considered :attr:`~.HostDistance.IGNORED`.
352+
By default, all remote hosts are ignored.
353+
"""
354+
self.local_rack = local_rack
355+
self.local_dc = local_dc
356+
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
357+
self._live_hosts = {}
358+
self._dc_live_hosts = {}
359+
self._position = 0
360+
self._endpoints = []
361+
LoadBalancingPolicy.__init__(self)
362+
363+
def _rack(self, host):
364+
return host.rack or self.local_rack
365+
366+
def _dc(self, host):
367+
return host.datacenter or self.local_dc
368+
369+
def populate(self, cluster, hosts):
370+
for (dc, rack), dc_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
371+
self._live_hosts[(dc, rack)] = list(dc_hosts)
372+
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
373+
self._dc_live_hosts[dc] = list(dc_hosts)
374+
375+
# as in other policies choose random position for better distributing queries across hosts
376+
self._position = randint(0, len(hosts) - 1) if hosts else 0
377+
378+
def distance(self, host):
379+
rack = self._rack(host)
380+
dc = self._dc(host)
381+
if rack == self.local_rack and dc == self.local_dc:
382+
return HostDistance.LOCAL_RACK
383+
384+
if dc == self.local_dc:
385+
return HostDistance.LOCAL
386+
387+
if not self.used_hosts_per_remote_dc:
388+
return HostDistance.IGNORED
389+
else:
390+
dc_hosts = self._dc_live_hosts.get(dc, ())
391+
if not dc_hosts:
392+
return HostDistance.IGNORED
393+
394+
if host in dc_hosts[:self.used_hosts_per_remote_dc]:
395+
396+
return HostDistance.REMOTE
397+
else:
398+
return HostDistance.IGNORED
399+
400+
def make_query_plan(self, working_keyspace=None, query=None):
401+
pos = self._position
402+
self._position += 1
403+
404+
local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
405+
pos = (pos % len(local_rack_live)) if local_rack_live else 0
406+
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
407+
# This ensures we get exactly one full cycle starting from pos
408+
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
409+
yield host
410+
411+
local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack]
412+
pos = (pos % len(local_live)) if local_live else 0
413+
for host in islice(cycle(local_live), pos, pos + len(local_live)):
414+
yield host
415+
416+
# the dict can change, so get candidate DCs iterating over keys of a copy
417+
other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc]
418+
for dc in other_dcs:
419+
remote_live = self._dc_live_hosts.get(dc, ())
420+
for host in remote_live[:self.used_hosts_per_remote_dc]:
421+
yield host
422+
423+
def on_up(self, host):
424+
dc = self._dc(host)
425+
rack = self._rack(host)
426+
with self._hosts_lock:
427+
self._live_hosts[(dc, rack)].append(host)
428+
self._dc_live_hosts[dc].append(host)
429+
430+
def on_down(self, host):
431+
dc = self._dc(host)
432+
rack = self._rack(host)
433+
with self._hosts_lock:
434+
self._live_hosts[(dc, rack)].remove(host)
435+
self._dc_live_hosts[dc].remove(host)
436+
437+
def on_add(self, host):
438+
self.on_up(host)
439+
440+
def on_remove(self, host):
441+
self.on_down(host)
319442

320443
class TokenAwarePolicy(LoadBalancingPolicy):
321444
"""
@@ -396,7 +519,7 @@ def make_query_plan(self, working_keyspace=None, query=None):
396519
shuffle(replicas)
397520
for replica in replicas:
398521
if replica.is_up and \
399-
child.distance(replica) == HostDistance.LOCAL:
522+
(child.distance(replica) == HostDistance.LOCAL or child.distance(replica) == HostDistance.LOCAL_RACK):
400523
yield replica
401524

402525
for host in child.make_query_plan(keyspace, query):

docs/api/cassandra/policies.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ Load Balancing
1818
.. autoclass:: DCAwareRoundRobinPolicy
1919
:members:
2020

21+
.. autoclass:: RackAwareRoundRobinPolicy
22+
:members:
23+
2124
.. autoclass:: WhiteListRoundRobinPolicy
2225
:members:
2326

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import unittest
2+
3+
from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT, NoHostAvailable
4+
from cassandra.policies import ConstantReconnectionPolicy, HostFilterPolicy, RackAwareRoundRobinPolicy, RoundRobinPolicy, SimpleConvictionPolicy, \
5+
WhiteListRoundRobinPolicy, ExponentialBackoffRetryPolicy
6+
from cassandra.pool import Host
7+
from cassandra.connection import DefaultEndPoint
8+
9+
from tests.integration import PROTOCOL_VERSION, local, use_multidc, use_singledc, TestCluster
10+
11+
from concurrent.futures import wait as wait_futures
12+
13+
14+
def setup_module():
15+
use_multidc({'DC1': {'RC1': 2, 'RC2': 2}, 'DC2': {'RC1': 2}})
16+
17+
class RackAwareRoundRobinPolicyTests(unittest.TestCase):
18+
@classmethod
19+
def setup_class(cls):
20+
cls.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4", "127.0.0.5", "127.0.0.6"], protocol_version=PROTOCOL_VERSION,
21+
load_balancing_policy=RackAwareRoundRobinPolicy("DC1", "RAC1", used_hosts_per_remote_dc=1),
22+
reconnection_policy=ConstantReconnectionPolicy(1))
23+
cls.session = cls.cluster.connect()
24+
25+
@classmethod
26+
def teardown_class(cls):
27+
cls.cluster.shutdown()
28+
29+
def test_rack_aware(self):
30+
self.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4", "127.0.0.5", "127.0.0.6"], protocol_version=PROTOCOL_VERSION,
31+
load_balancing_policy=RackAwareRoundRobinPolicy("DC1", "RAC1", used_hosts_per_remote_dc=1),
32+
reconnection_policy=ConstantReconnectionPolicy(1))
33+
self.session = self.cluster.connect()
34+
queried_hosts = set()
35+
for _ in range(10):
36+
response = self.session.execute('SELECT * from system.local')
37+
queried_hosts.update(response.response_future.attempted_hosts)
38+
queried_hosts = set(host.address for host in queried_hosts)
39+
# Performe some checks
40+
self.cluster.shutdown()
41+
42+
def test_rack_aware_wrong_dc(self):
43+
self.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4", "127.0.0.5", "127.0.0.6"], protocol_version=PROTOCOL_VERSION,
44+
load_balancing_policy=RackAwareRoundRobinPolicy("WRONG_DC1", "RAC1", used_hosts_per_remote_dc=0),
45+
reconnection_policy=ConstantReconnectionPolicy(1))
46+
self.assertRaises(NoHostAvailable, self.cluster.connect)
47+
self.cluster.shutdown()
48+

0 commit comments

Comments
 (0)