Skip to content

Commit 54d0ffc

Browse files
Add RackAwareRoundRobinPolicy for host selection
1 parent c9b24b7 commit 54d0ffc

File tree

6 files changed

+354
-83
lines changed

6 files changed

+354
-83
lines changed

cassandra/cluster.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,8 @@ def _profiles_without_explicit_lbps(self):
496496

497497
def distance(self, host):
498498
distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values())
499-
return HostDistance.LOCAL if HostDistance.LOCAL in distances else \
499+
return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \
500+
HostDistance.LOCAL if HostDistance.LOCAL in distances else \
500501
HostDistance.REMOTE if HostDistance.REMOTE in distances else \
501502
HostDistance.IGNORED
502503

@@ -613,7 +614,7 @@ class Cluster(object):
613614
614615
Defaults to loopback interface.
615616
616-
Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit
617+
Note: When using :class:`.DCAwareRoundRobinPolicy` with no explicit
617618
local_dc set (as is the default), the DC is chosen from an arbitrary
618619
host in contact_points. In this case, contact_points should contain
619620
only nodes from a single, local DC.
@@ -1373,21 +1374,25 @@ def __init__(self,
13731374
self._user_types = defaultdict(dict)
13741375

13751376
self._min_requests_per_connection = {
1377+
HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS,
13761378
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
13771379
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
13781380
}
13791381

13801382
self._max_requests_per_connection = {
1383+
HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS,
13811384
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
13821385
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
13831386
}
13841387

13851388
self._core_connections_per_host = {
1389+
HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
13861390
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
13871391
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
13881392
}
13891393

13901394
self._max_connections_per_host = {
1395+
HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
13911396
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
13921397
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
13931398
}

cassandra/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3440,7 +3440,7 @@ def group_keys_by_replica(session, keyspace, table, keys):
34403440
all_replicas = cluster.metadata.get_replicas(keyspace, routing_key)
34413441
# First check if there are local replicas
34423442
valid_replicas = [host for host in all_replicas if
3443-
host.is_up and distance(host) == HostDistance.LOCAL]
3443+
host.is_up and (distance(host) == HostDistance.LOCAL or distance(host) == HostDistance.LOCAL_RACK)]
34443444
if not valid_replicas:
34453445
valid_replicas = [host for host in all_replicas if host.is_up]
34463446

cassandra/policies.py

Lines changed: 132 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,18 @@ class HostDistance(object):
4646
connections opened to it.
4747
"""
4848

49-
LOCAL = 0
49+
LOCAL_RACK = 0
50+
"""
51+
Nodes with ``LOCAL_RACK`` distance will be preferred for operations
52+
under some load balancing policies (such as :class:`.RackAwareRoundRobinPolicy`)
53+
and will have a greater number of connections opened against
54+
them by default.
55+
56+
This distance is typically used for nodes within the same
57+
datacenter and the same rack as the client.
58+
"""
59+
60+
LOCAL = 1
5061
"""
5162
Nodes with ``LOCAL`` distance will be preferred for operations
5263
under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
@@ -57,12 +68,12 @@ class HostDistance(object):
5768
datacenter as the client.
5869
"""
5970

60-
REMOTE = 1
71+
REMOTE = 2
6172
"""
6273
Nodes with ``REMOTE`` distance will be treated as a last resort
63-
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
64-
and will have a smaller number of connections opened against
65-
them by default.
74+
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`
75+
and :class:`.RackAwareRoundRobinPolicy`)and will have a smaller number of
76+
connections opened against them by default.
6677
6778
This distance is typically used for nodes outside of the
6879
datacenter that the client is running in.
@@ -316,6 +327,121 @@ def on_add(self, host):
316327
def on_remove(self, host):
317328
self.on_down(host)
318329

330+
class RackAwareRoundRobinPolicy(LoadBalancingPolicy):
331+
"""
332+
Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts
333+
in the local rack, before hosts in the local datacenter but a
334+
different rack, before hosts in all other datercentres
335+
"""
336+
337+
local_dc = None
338+
local_rack = None
339+
used_hosts_per_remote_dc = 0
340+
341+
def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
342+
"""
343+
The `local_dc` and `local_rack` parameters should be the name of the
344+
datacenter and rack (such as is reported by ``nodetool ring``) that
345+
should be considered local.
346+
347+
`used_hosts_per_remote_dc` controls how many nodes in
348+
each remote datacenter will have connections opened
349+
against them. In other words, `used_hosts_per_remote_dc` hosts
350+
will be considered :attr:`~.HostDistance.REMOTE` and the
351+
rest will be considered :attr:`~.HostDistance.IGNORED`.
352+
By default, all remote hosts are ignored.
353+
"""
354+
self.local_rack = local_rack
355+
self.local_dc = local_dc
356+
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
357+
self._live_hosts = {}
358+
self._dc_live_hosts = {}
359+
self._position = 0
360+
self._endpoints = []
361+
LoadBalancingPolicy.__init__(self)
362+
363+
def _rack(self, host):
364+
return host.rack or self.local_rack
365+
366+
def _dc(self, host):
367+
return host.datacenter or self.local_dc
368+
369+
def populate(self, cluster, hosts):
370+
for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
371+
self._live_hosts[(dc, rack)] = list(rack_hosts)
372+
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
373+
self._dc_live_hosts[dc] = list(dc_hosts)
374+
375+
# as in other policies choose random position for better distributing queries across hosts
376+
self._position = randint(0, len(hosts) - 1) if hosts else 0
377+
378+
def distance(self, host):
379+
rack = self._rack(host)
380+
dc = self._dc(host)
381+
if rack == self.local_rack and dc == self.local_dc:
382+
return HostDistance.LOCAL_RACK
383+
384+
if dc == self.local_dc:
385+
return HostDistance.LOCAL
386+
387+
if not self.used_hosts_per_remote_dc:
388+
return HostDistance.IGNORED
389+
else:
390+
dc_hosts = self._dc_live_hosts.get(dc, ())
391+
if not dc_hosts:
392+
return HostDistance.IGNORED
393+
394+
if host in dc_hosts[:self.used_hosts_per_remote_dc]:
395+
396+
return HostDistance.REMOTE
397+
else:
398+
return HostDistance.IGNORED
399+
400+
def make_query_plan(self, working_keyspace=None, query=None):
401+
pos = self._position
402+
self._position += 1
403+
404+
local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), [])
405+
pos = (pos % len(local_rack_live)) if local_rack_live else 0
406+
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
407+
# This ensures we get exactly one full cycle starting from pos
408+
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
409+
yield host
410+
411+
local_live = [host for host in self._dc_live_hosts.get(self.local_dc, []) if host.rack != self.local_rack]
412+
pos = (pos % len(local_live)) if local_live else 0
413+
for host in islice(cycle(local_live), pos, pos + len(local_live)):
414+
yield host
415+
416+
# the dict can change, so get candidate DCs iterating over keys of a copy
417+
for dc, remote_live in self._dc_live_hosts.copy().items():
418+
if dc != self.local_dc:
419+
for host in remote_live[:self.used_hosts_per_remote_dc]:
420+
yield host
421+
422+
def on_up(self, host):
423+
dc = self._dc(host)
424+
rack = self._rack(host)
425+
with self._hosts_lock:
426+
self._live_hosts.setdefault((dc, rack), []).append(host)
427+
self._dc_live_hosts.setdefault(dc, []).append(host)
428+
429+
def on_down(self, host):
430+
dc = self._dc(host)
431+
rack = self._rack(host)
432+
with self._hosts_lock:
433+
self._live_hosts[(dc, rack)].remove(host)
434+
self._dc_live_hosts[dc].remove(host)
435+
if not self._live_hosts[(dc, rack)]:
436+
del self._live_hosts[(dc, rack)]
437+
if not self._dc_live_hosts[dc]:
438+
del self._dc_live_hosts[dc]
439+
440+
def on_add(self, host):
441+
self.on_up(host)
442+
443+
def on_remove(self, host):
444+
self.on_down(host)
319445

320446
class TokenAwarePolicy(LoadBalancingPolicy):
321447
"""
@@ -396,7 +522,7 @@ def make_query_plan(self, working_keyspace=None, query=None):
396522
shuffle(replicas)
397523
for replica in replicas:
398524
if replica.is_up and \
399-
child.distance(replica) == HostDistance.LOCAL:
525+
(child.distance(replica) == HostDistance.LOCAL or child.distance(replica) == HostDistance.LOCAL_RACK):
400526
yield replica
401527

402528
for host in child.make_query_plan(keyspace, query):

docs/api/cassandra/policies.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ Load Balancing
1818
.. autoclass:: DCAwareRoundRobinPolicy
1919
:members:
2020

21+
.. autoclass:: RackAwareRoundRobinPolicy
22+
:members:
23+
2124
.. autoclass:: WhiteListRoundRobinPolicy
2225
:members:
2326

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import logging
2+
import unittest
3+
4+
from cassandra.cluster import Cluster
5+
from cassandra.policies import ConstantReconnectionPolicy, RackAwareRoundRobinPolicy
6+
7+
from tests.integration import PROTOCOL_VERSION, get_cluster, use_multidc
8+
9+
LOGGER = logging.getLogger(__name__)
10+
11+
def setup_module():
12+
use_multidc({'DC1': {'RC1': 2, 'RC2': 2}, 'DC2': {'RC1': 3}})
13+
14+
class RackAwareRoundRobinPolicyTests(unittest.TestCase):
15+
@classmethod
16+
def setup_class(cls):
17+
cls.cluster = Cluster(contact_points=[node.address() for node in get_cluster().nodelist()], protocol_version=PROTOCOL_VERSION,
18+
load_balancing_policy=RackAwareRoundRobinPolicy("DC1", "RC1", used_hosts_per_remote_dc=0),
19+
reconnection_policy=ConstantReconnectionPolicy(1))
20+
cls.session = cls.cluster.connect()
21+
cls.create_ks_and_cf(cls)
22+
cls.create_data(cls.session)
23+
cls.node1, cls.node2, cls.node3, cls.node4, cls.node5, cls.node6, cls.node7 = get_cluster().nodes.values()
24+
25+
@classmethod
26+
def teardown_class(cls):
27+
cls.cluster.shutdown()
28+
29+
def create_ks_and_cf(self):
30+
self.session.execute(
31+
"""
32+
DROP KEYSPACE IF EXISTS test1
33+
"""
34+
)
35+
self.session.execute(
36+
"""
37+
CREATE KEYSPACE test1
38+
WITH replication = {
39+
'class': 'NetworkTopologyStrategy',
40+
'replication_factor': 3
41+
}
42+
""")
43+
44+
self.session.execute(
45+
"""
46+
CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
47+
""")
48+
49+
@staticmethod
50+
def create_data(session):
51+
prepared = session.prepare(
52+
"""
53+
INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?)
54+
""")
55+
56+
for i in range(50):
57+
bound = prepared.bind((i, i%5, i%2))
58+
session.execute(bound)
59+
60+
def test_rack_aware(self):
61+
prepared = self.session.prepare(
62+
"""
63+
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
64+
""")
65+
66+
for i in range (10):
67+
bound = prepared.bind([(i)])
68+
results = self.session.execute(bound)
69+
self.assertEqual(results, [(i, i%5, i%2)])
70+
coordinator = str(results.response_future.coordinator_host.endpoint)
71+
self.assertTrue(coordinator in set(["127.0.0.1:9042", "127.0.0.2:9042"]))
72+
73+
self.node2.stop(wait_other_notice=True, gently=True)
74+
75+
for i in range (10):
76+
bound = prepared.bind([(i)])
77+
results = self.session.execute(bound)
78+
self.assertEqual(results, [(i, i%5, i%2)])
79+
coordinator =str(results.response_future.coordinator_host.endpoint)
80+
self.assertEqual(coordinator, "127.0.0.1:9042")
81+
82+
self.node1.stop(wait_other_notice=True, gently=True)
83+
84+
for i in range (10):
85+
bound = prepared.bind([(i)])
86+
results = self.session.execute(bound)
87+
self.assertEqual(results, [(i, i%5, i%2)])
88+
coordinator = str(results.response_future.coordinator_host.endpoint)
89+
self.assertTrue(coordinator in set(["127.0.0.3:9042", "127.0.0.4:9042"]))

0 commit comments

Comments
 (0)