Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,10 @@ def make_query_plan(self, working_keyspace=None, query=None):
keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key))

if tablet is not None:
replicas_mapped = set(map(lambda r: r[0], tablet.replicas))
replica_dict = tablet._replica_dict
child_plan = child.make_query_plan(keyspace, query)

replicas = [host for host in child_plan if host.host_id in replicas_mapped]
replicas = [host for host in child_plan if host.host_id in replica_dict]
else:
replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key)

Expand Down
5 changes: 1 addition & 4 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,7 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
tablet = self._session.cluster.metadata._tablets.get_tablet_for_key(keyspace, table, t)

if tablet is not None:
for replica in tablet.replicas:
if replica[0] == self.host.host_id:
shard_id = replica[1]
break
shard_id = tablet._replica_dict.get(self.host.host_id)

if shard_id is None:
shard_id = self.host.sharding_info.shard_id_from_token(t.value)
Expand Down
99 changes: 59 additions & 40 deletions cassandra/tablets.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,84 @@
from bisect import bisect_left
from operator import attrgetter
from threading import Lock
from typing import Optional
from uuid import UUID

# C-accelerated attrgetter avoids per-call lambda allocation overhead
_get_first_token = attrgetter("first_token")
_get_last_token = attrgetter("last_token")


class Tablet(object):
"""
Represents a single ScyllaDB tablet.
It stores information about each replica, its host and shard,
and the token interval in the format (first_token, last_token].
"""
first_token = 0
last_token = 0
replicas = None
__slots__ = ('first_token', 'last_token', 'replicas', '_replica_dict')

def __init__(self, first_token=0, last_token=0, replicas=None):
self.first_token = first_token
self.last_token = last_token
self.replicas = replicas
if replicas is not None:
replicas_tuple = tuple(replicas)
self.replicas = replicas_tuple
self._replica_dict = {r[0]: r[1] for r in replicas_tuple}
else:
self.replicas = None
self._replica_dict = {}

def __str__(self):
return "<Tablet: first_token=%s last_token=%s replicas=%s>" \
% (self.first_token, self.last_token, self.replicas)
__repr__ = __str__

@staticmethod
def _is_valid_tablet(replicas):
return replicas is not None and len(replicas) != 0

@staticmethod
def from_row(first_token, last_token, replicas):
if Tablet._is_valid_tablet(replicas):
tablet = Tablet(first_token, last_token, replicas)
return tablet
return None
if not replicas:
return None
return Tablet(first_token, last_token, replicas)

def replica_contains_host_id(self, uuid: UUID) -> bool:
for replica in self.replicas:
if replica[0] == uuid:
return True
return False
return uuid in self._replica_dict


class Tablets(object):
_lock = None
_tablets = {}
_tablets = {} # (keyspace, table) -> list[Tablet]
_first_tokens = {} # (keyspace, table) -> list[int]
_last_tokens = {} # (keyspace, table) -> list[int]

def __init__(self, tablets):
self._tablets = tablets
# Build parallel token index lists from any pre-populated data
self._first_tokens = {
key: [t.first_token for t in tlist]
for key, tlist in tablets.items()
}
self._last_tokens = {
key: [t.last_token for t in tlist]
for key, tlist in tablets.items()
}
self._lock = Lock()

def table_has_tablets(self, keyspace, table) -> bool:
return bool(self._tablets.get((keyspace, table), []))

def get_tablet_for_key(self, keyspace, table, t):
tablet = self._tablets.get((keyspace, table), [])
if not tablet:
key = (keyspace, table)
last_tokens = self._last_tokens.get(key)
if not last_tokens:
return None

id = bisect_left(tablet, t.value, key=_get_last_token)
if id < len(tablet) and t.value > tablet[id].first_token:
return tablet[id]
token_value = t.value
id = bisect_left(last_tokens, token_value)
if id < len(last_tokens) and token_value > self._first_tokens[key][id]:
return self._tablets[key][id]
return None

def drop_tablets(self, keyspace: str, table: Optional[str] = None):
with self._lock:
if table is not None:
self._tablets.pop((keyspace, table), None)
key = (keyspace, table)
self._tablets.pop(key, None)
self._first_tokens.pop(key, None)
self._last_tokens.pop(key, None)
return

to_be_deleted = []
Expand All @@ -81,36 +88,48 @@ def drop_tablets(self, keyspace: str, table: Optional[str] = None):

for key in to_be_deleted:
del self._tablets[key]
self._first_tokens.pop(key, None)
self._last_tokens.pop(key, None)

def drop_tablets_by_host_id(self, host_id: Optional[UUID]):
if host_id is None:
return
with self._lock:
for key, tablets in self._tablets.items():
to_be_deleted = []
for tablet_id, tablet in enumerate(tablets):
if tablet.replica_contains_host_id(host_id):
to_be_deleted.append(tablet_id)

for tablet_id in reversed(to_be_deleted):
tablets.pop(tablet_id)
# Filter in one pass instead of popping one-by-one (O(n) vs O(k*n))
keep = [i for i, t in enumerate(tablets)
if not t.replica_contains_host_id(host_id)]
if len(keep) == len(tablets):
continue # nothing to drop
self._tablets[key] = [tablets[i] for i in keep]
first = self._first_tokens[key]
last = self._last_tokens[key]
self._first_tokens[key] = [first[i] for i in keep]
self._last_tokens[key] = [last[i] for i in keep]

def add_tablet(self, keyspace, table, tablet):
with self._lock:
tablets_for_table = self._tablets.setdefault((keyspace, table), [])
key = (keyspace, table)
tablets_for_table = self._tablets.setdefault(key, [])
first_tokens = self._first_tokens.setdefault(key, [])
last_tokens = self._last_tokens.setdefault(key, [])

# find first overlapping range
start = bisect_left(tablets_for_table, tablet.first_token, key=_get_first_token)
if start > 0 and tablets_for_table[start - 1].last_token > tablet.first_token:
start = bisect_left(first_tokens, tablet.first_token)
if start > 0 and last_tokens[start - 1] > tablet.first_token:
start = start - 1

# find last overlapping range
end = bisect_left(tablets_for_table, tablet.last_token, key=_get_last_token)
if end < len(tablets_for_table) and tablets_for_table[end].first_token >= tablet.last_token:
end = bisect_left(last_tokens, tablet.last_token)
if end < len(last_tokens) and first_tokens[end] >= tablet.last_token:
end = end - 1

if start <= end:
del tablets_for_table[start:end + 1]
del first_tokens[start:end + 1]
del last_tokens[start:end + 1]

tablets_for_table.insert(start, tablet)
first_tokens.insert(start, tablet.first_token)
last_tokens.insert(start, tablet.last_token)

95 changes: 95 additions & 0 deletions tests/unit/test_tablets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from uuid import UUID

from cassandra.tablets import Tablets, Tablet

Expand Down Expand Up @@ -124,3 +125,97 @@ def __init__(self, v):
# Token value 50 is not > first_token (100) of the tablet whose
# last_token (200) is >= 50, so no match.
self.assertIsNone(tablets.get_tablet_for_key("ks", "tb", Token(50)))


class TabletReplicaDictTest(unittest.TestCase):
"""Tests for Tablet._replica_dict cached lookup."""

def test_replica_dict_built_from_replicas(self):
u1 = UUID('12345678-1234-5678-1234-567812345678')
u2 = UUID('87654321-4321-8765-4321-876543218765')
t = Tablet(0, 100, [(u1, 3), (u2, 7)])
self.assertEqual(t._replica_dict, {u1: 3, u2: 7})

def test_replica_dict_empty_when_no_replicas(self):
t = Tablet(0, 100, None)
self.assertEqual(t._replica_dict, {})

def test_replica_dict_contains_host(self):
u1 = UUID('12345678-1234-5678-1234-567812345678')
u2 = UUID('87654321-4321-8765-4321-876543218765')
u3 = UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee')
t = Tablet(0, 100, [(u1, 3), (u2, 7)])
self.assertIn(u1, t._replica_dict)
self.assertIn(u2, t._replica_dict)
self.assertNotIn(u3, t._replica_dict)

def test_replica_dict_shard_lookup(self):
u1 = UUID('12345678-1234-5678-1234-567812345678')
u2 = UUID('87654321-4321-8765-4321-876543218765')
t = Tablet(0, 100, [(u1, 3), (u2, 7)])
self.assertEqual(t._replica_dict.get(u1), 3)
self.assertEqual(t._replica_dict.get(u2), 7)
self.assertIsNone(t._replica_dict.get(UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee')))

def test_replica_contains_host_id_uses_dict(self):
u1 = UUID('12345678-1234-5678-1234-567812345678')
u2 = UUID('87654321-4321-8765-4321-876543218765')
t = Tablet(0, 100, [(u1, 3), (u2, 7)])
self.assertTrue(t.replica_contains_host_id(u1))
self.assertTrue(t.replica_contains_host_id(u2))
self.assertFalse(t.replica_contains_host_id(UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee')))

def test_replicas_stored_as_tuple(self):
t = Tablet(0, 100, [("host1", 0), ("host2", 1)])
self.assertIsInstance(t.replicas, tuple)

def test_replica_dict_from_iterator(self):
"""Ensure _replica_dict is correctly built even when replicas is a
one-shot iterator (generator), not a reusable list."""
u1 = UUID('12345678-1234-5678-1234-567812345678')
u2 = UUID('87654321-4321-8765-4321-876543218765')

def gen():
yield (u1, 3)
yield (u2, 7)

t = Tablet(0, 100, gen())
self.assertEqual(t.replicas, ((u1, 3), (u2, 7)))
self.assertEqual(t._replica_dict, {u1: 3, u2: 7})
self.assertTrue(t.replica_contains_host_id(u1))
self.assertTrue(t.replica_contains_host_id(u2))


class DropTabletsByHostIdTest(unittest.TestCase):
"""Tests for Tablets.drop_tablets_by_host_id batch-filter path."""

def test_drop_removes_matching_tablets(self):
u1 = UUID('12345678-1234-5678-1234-567812345678')
u2 = UUID('87654321-4321-8765-4321-876543218765')
t1 = Tablet(0, 100, [(u1, 0)])
t2 = Tablet(100, 200, [(u2, 0)])
t3 = Tablet(200, 300, [(u1, 1), (u2, 1)])
tablets = Tablets({("ks", "tb"): [t1, t2, t3]})

tablets.drop_tablets_by_host_id(u1)

remaining = tablets._tablets[("ks", "tb")]
self.assertEqual(len(remaining), 1)
self.assertIs(remaining[0], t2)
# Verify token index lists are in sync
self.assertEqual(tablets._first_tokens[("ks", "tb")], [100])
self.assertEqual(tablets._last_tokens[("ks", "tb")], [200])

def test_drop_none_host_id_is_noop(self):
t1 = Tablet(0, 100, [("host1", 0)])
tablets = Tablets({("ks", "tb"): [t1]})
tablets.drop_tablets_by_host_id(None)
self.assertEqual(len(tablets._tablets[("ks", "tb")]), 1)

def test_drop_nonexistent_host_id_is_noop(self):
u1 = UUID('12345678-1234-5678-1234-567812345678')
u_missing = UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee')
t1 = Tablet(0, 100, [(u1, 0)])
tablets = Tablets({("ks", "tb"): [t1]})
tablets.drop_tablets_by_host_id(u_missing)
self.assertEqual(len(tablets._tablets[("ks", "tb")]), 1)
Loading