Skip to content

Commit 8f96d39

Browse files
committed
(improvement)TokenAware round robin policy and others - improved query planning.
Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
1 parent 6282e6f commit 8f96d39

1 file changed

Lines changed: 117 additions & 28 deletions

File tree

cassandra/policies.py

Lines changed: 117 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import random
1515

1616
from collections import namedtuple
17-
from itertools import islice, cycle, groupby, repeat
17+
from itertools import islice, cycle, groupby, repeat, chain
1818
import logging
1919
from random import randint, shuffle
2020
from threading import Lock
@@ -157,6 +157,18 @@ def make_query_plan(self, working_keyspace=None, query=None):
157157
"""
158158
raise NotImplementedError()
159159

160+
def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()):
161+
"""
162+
Same as :meth:`make_query_plan`, but with an additional `excluded` parameter.
163+
`excluded` should be a container (set, list, etc.) of hosts to skip.
164+
165+
The default implementation simply delegates to `make_query_plan` and filters the result.
166+
Subclasses may override this for performance.
167+
"""
168+
for host in self.make_query_plan(working_keyspace, query):
169+
if host not in excluded:
170+
yield host
171+
160172
def check_supported(self):
161173
"""
162174
This will be called after the cluster Metadata has been initialized.
@@ -198,6 +210,20 @@ def make_query_plan(self, working_keyspace=None, query=None):
198210
else:
199211
return []
200212

213+
def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()):
214+
pos = self._position
215+
self._position += 1
216+
217+
hosts = self._live_hosts
218+
length = len(hosts)
219+
if length:
220+
pos %= length
221+
for host in islice(cycle(hosts), pos, pos + length):
222+
if host not in excluded:
223+
yield host
224+
else:
225+
return
226+
201227
def on_up(self, host):
202228
with self._hosts_lock:
203229
self._live_hosts = self._live_hosts.union((host, ))
@@ -289,6 +315,24 @@ def make_query_plan(self, working_keyspace=None, query=None):
289315
for host in self._remote_hosts:
290316
yield host
291317

318+
def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()):
319+
# not thread-safe, but we don't care much about lost increments
320+
# for the purposes of load balancing
321+
pos = self._position
322+
self._position += 1
323+
324+
local_live = self._dc_live_hosts.get(self.local_dc, ())
325+
pos = (pos % len(local_live)) if local_live else 0
326+
for host in islice(cycle(local_live), pos, pos + len(local_live)):
327+
if excluded and host in excluded:
328+
continue
329+
yield host
330+
331+
for host in self._remote_hosts:
332+
if excluded and host in excluded:
333+
continue
334+
yield host
335+
292336
def on_up(self, host):
293337
# not worrying about threads because this will happen during
294338
# control connection startup/refresh
@@ -424,6 +468,33 @@ def make_query_plan(self, working_keyspace=None, query=None):
424468

425469
for host in self._remote_hosts:
426470
yield host
471+
472+
def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()):
473+
pos = self._position
474+
self._position += 1
475+
476+
local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
477+
length = len(local_rack_live)
478+
if length:
479+
p = pos % length
480+
for host in islice(cycle(local_rack_live), p, p + length):
481+
if excluded and host in excluded:
482+
continue
483+
yield host
484+
485+
local_non_rack = self._non_local_rack_hosts
486+
length = len(local_non_rack)
487+
if length:
488+
p = pos % length
489+
for host in islice(cycle(local_non_rack), p, p + length):
490+
if excluded and host in excluded:
491+
continue
492+
yield host
493+
494+
for host in self._remote_hosts:
495+
if excluded and host in excluded:
496+
continue
497+
yield host
427498

428499
def on_up(self, host):
429500
dc = self._dc(host)
@@ -495,16 +566,12 @@ class TokenAwarePolicy(LoadBalancingPolicy):
495566
policy's query plan will be used as is.
496567
"""
497568

498-
_child_policy = None
499-
_cluster_metadata = None
500-
shuffle_replicas = True
501-
"""
502-
Yield local replicas in a random order.
503-
"""
569+
__slots__ = ('_child_policy', '_cluster_metadata', 'shuffle_replicas')
504570

505571
def __init__(self, child_policy, shuffle_replicas=True):
506572
self._child_policy = child_policy
507573
self.shuffle_replicas = shuffle_replicas
574+
self._cluster_metadata = None
508575

509576
def populate(self, cluster, hosts):
510577
self._cluster_metadata = cluster.metadata
@@ -527,35 +594,57 @@ def make_query_plan(self, working_keyspace=None, query=None):
527594

528595
child = self._child_policy
529596
if query is None or query.routing_key is None or keyspace is None:
530-
for host in child.make_query_plan(keyspace, query):
531-
yield host
597+
yield from child.make_query_plan(keyspace, query)
532598
return
533599

600+
cluster_metadata = self._cluster_metadata
601+
token_map = cluster_metadata.token_map
534602
replicas = []
535-
tablet = self._cluster_metadata._tablets.get_tablet_for_key(
536-
keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key))
537603

538-
if tablet is not None:
539-
replicas_mapped = set(map(lambda r: r[0], tablet.replicas))
540-
child_plan = child.make_query_plan(keyspace, query)
541-
542-
replicas = [host for host in child_plan if host.host_id in replicas_mapped]
543-
else:
544-
replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key)
604+
if token_map:
605+
token = token_map.token_class.from_key(query.routing_key)
606+
tablet = cluster_metadata._tablets.get_tablet_for_key(
607+
keyspace, query.table, token)
608+
609+
if tablet is not None:
610+
replicas_mapped = set(map(lambda r: r[0], tablet.replicas))
611+
for host_id in replicas_mapped:
612+
host = cluster_metadata.get_host_by_host_id(host_id)
613+
if host:
614+
replicas.append(host)
615+
else:
616+
replicas = token_map.get_replicas(keyspace, token)
545617

546618
if self.shuffle_replicas and not query.is_lwt():
547619
shuffle(replicas)
548620

549-
def yield_in_order(hosts):
550-
for distance in [HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]:
551-
for replica in hosts:
552-
if replica.is_up and child.distance(replica) == distance:
553-
yield replica
554-
555-
# yield replicas: local_rack, local, remote
556-
yield from yield_in_order(replicas)
557-
# yield rest of the cluster: local_rack, local, remote
558-
yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas])
621+
local_rack = []
622+
local = []
623+
remote = []
624+
625+
child_distance = child.distance
626+
627+
for replica in replicas:
628+
if replica.is_up:
629+
d = child_distance(replica)
630+
if d == HostDistance.LOCAL_RACK:
631+
local_rack.append(replica)
632+
elif d == HostDistance.LOCAL:
633+
local.append(replica)
634+
elif d == HostDistance.REMOTE:
635+
remote.append(replica)
636+
637+
yielded_sequence = tuple(chain(local_rack, local, remote))
638+
639+
if yielded_sequence:
640+
yield from yielded_sequence
641+
642+
yielded = set(yielded_sequence)
643+
644+
# yield rest of the cluster
645+
yield from child.make_query_plan_with_exclusion(keyspace, query, yielded)
646+
else:
647+
yield from child.make_query_plan(keyspace, query)
559648

560649
def on_up(self, *args, **kwargs):
561650
return self._child_policy.on_up(*args, **kwargs)

0 commit comments

Comments
 (0)