diff --git a/cassandra/policies.py b/cassandra/policies.py index ceb5ebdc45..65845e55a7 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -507,10 +507,10 @@ def make_query_plan(self, working_keyspace=None, query=None): keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) if tablet is not None: - replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) + replica_dict = tablet._replica_dict child_plan = child.make_query_plan(keyspace, query) - replicas = [host for host in child_plan if host.host_id in replicas_mapped] + replicas = [host for host in child_plan if host.host_id in replica_dict] else: replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) diff --git a/cassandra/pool.py b/cassandra/pool.py index 227e1b5315..5b370f36d3 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -462,10 +462,7 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table tablet = self._session.cluster.metadata._tablets.get_tablet_for_key(keyspace, table, t) if tablet is not None: - for replica in tablet.replicas: - if replica[0] == self.host.host_id: - shard_id = replica[1] - break + shard_id = tablet._replica_dict.get(self.host.host_id) if shard_id is None: shard_id = self.host.sharding_info.shard_id_from_token(t.value) diff --git a/cassandra/tablets.py b/cassandra/tablets.py index 96e61a50c2..3f81334688 100644 --- a/cassandra/tablets.py +++ b/cassandra/tablets.py @@ -1,13 +1,8 @@ from bisect import bisect_left -from operator import attrgetter from threading import Lock from typing import Optional from uuid import UUID -# C-accelerated attrgetter avoids per-call lambda allocation overhead -_get_first_token = attrgetter("first_token") -_get_last_token = attrgetter("last_token") - class Tablet(object): """ @@ -15,63 +10,75 @@ class Tablet(object): It stores information about each replica, its host and shard, and the token interval in the format (first_token, last_token]. """ - first_token = 0 - last_token = 0 - replicas = None + __slots__ = ('first_token', 'last_token', 'replicas', '_replica_dict') def __init__(self, first_token=0, last_token=0, replicas=None): self.first_token = first_token self.last_token = last_token - self.replicas = replicas + if replicas is not None: + replicas_tuple = tuple(replicas) + self.replicas = replicas_tuple + self._replica_dict = {r[0]: r[1] for r in replicas_tuple} + else: + self.replicas = None + self._replica_dict = {} def __str__(self): return "" \ % (self.first_token, self.last_token, self.replicas) __repr__ = __str__ - @staticmethod - def _is_valid_tablet(replicas): - return replicas is not None and len(replicas) != 0 - @staticmethod def from_row(first_token, last_token, replicas): - if Tablet._is_valid_tablet(replicas): - tablet = Tablet(first_token, last_token, replicas) - return tablet - return None + if not replicas: + return None + return Tablet(first_token, last_token, replicas) def replica_contains_host_id(self, uuid: UUID) -> bool: - for replica in self.replicas: - if replica[0] == uuid: - return True - return False + return uuid in self._replica_dict class Tablets(object): _lock = None - _tablets = {} + _tablets = {} # (keyspace, table) -> list[Tablet] + _first_tokens = {} # (keyspace, table) -> list[int] + _last_tokens = {} # (keyspace, table) -> list[int] def __init__(self, tablets): self._tablets = tablets + # Build parallel token index lists from any pre-populated data + self._first_tokens = { + key: [t.first_token for t in tlist] + for key, tlist in tablets.items() + } + self._last_tokens = { + key: [t.last_token for t in tlist] + for key, tlist in tablets.items() + } self._lock = Lock() def table_has_tablets(self, keyspace, table) -> bool: return bool(self._tablets.get((keyspace, table), [])) def get_tablet_for_key(self, keyspace, table, t): - tablet = self._tablets.get((keyspace, table), []) - if not tablet: + key = (keyspace, table) + last_tokens = self._last_tokens.get(key) + if not last_tokens: return None - id = bisect_left(tablet, t.value, key=_get_last_token) - if id < len(tablet) and t.value > tablet[id].first_token: - return tablet[id] + token_value = t.value + id = bisect_left(last_tokens, token_value) + if id < len(last_tokens) and token_value > self._first_tokens[key][id]: + return self._tablets[key][id] return None def drop_tablets(self, keyspace: str, table: Optional[str] = None): with self._lock: if table is not None: - self._tablets.pop((keyspace, table), None) + key = (keyspace, table) + self._tablets.pop(key, None) + self._first_tokens.pop(key, None) + self._last_tokens.pop(key, None) return to_be_deleted = [] @@ -81,36 +88,48 @@ def drop_tablets(self, keyspace: str, table: Optional[str] = None): for key in to_be_deleted: del self._tablets[key] + self._first_tokens.pop(key, None) + self._last_tokens.pop(key, None) def drop_tablets_by_host_id(self, host_id: Optional[UUID]): if host_id is None: return with self._lock: for key, tablets in self._tablets.items(): - to_be_deleted = [] - for tablet_id, tablet in enumerate(tablets): - if tablet.replica_contains_host_id(host_id): - to_be_deleted.append(tablet_id) - - for tablet_id in reversed(to_be_deleted): - tablets.pop(tablet_id) + # Filter in one pass instead of popping one-by-one (O(n) vs O(k*n)) + keep = [i for i, t in enumerate(tablets) + if not t.replica_contains_host_id(host_id)] + if len(keep) == len(tablets): + continue # nothing to drop + self._tablets[key] = [tablets[i] for i in keep] + first = self._first_tokens[key] + last = self._last_tokens[key] + self._first_tokens[key] = [first[i] for i in keep] + self._last_tokens[key] = [last[i] for i in keep] def add_tablet(self, keyspace, table, tablet): with self._lock: - tablets_for_table = self._tablets.setdefault((keyspace, table), []) + key = (keyspace, table) + tablets_for_table = self._tablets.setdefault(key, []) + first_tokens = self._first_tokens.setdefault(key, []) + last_tokens = self._last_tokens.setdefault(key, []) # find first overlapping range - start = bisect_left(tablets_for_table, tablet.first_token, key=_get_first_token) - if start > 0 and tablets_for_table[start - 1].last_token > tablet.first_token: + start = bisect_left(first_tokens, tablet.first_token) + if start > 0 and last_tokens[start - 1] > tablet.first_token: start = start - 1 # find last overlapping range - end = bisect_left(tablets_for_table, tablet.last_token, key=_get_last_token) - if end < len(tablets_for_table) and tablets_for_table[end].first_token >= tablet.last_token: + end = bisect_left(last_tokens, tablet.last_token) + if end < len(last_tokens) and first_tokens[end] >= tablet.last_token: end = end - 1 if start <= end: del tablets_for_table[start:end + 1] + del first_tokens[start:end + 1] + del last_tokens[start:end + 1] tablets_for_table.insert(start, tablet) + first_tokens.insert(start, tablet.first_token) + last_tokens.insert(start, tablet.last_token) diff --git a/tests/unit/test_tablets.py b/tests/unit/test_tablets.py index 7a40e7de4d..d0e09527bc 100644 --- a/tests/unit/test_tablets.py +++ b/tests/unit/test_tablets.py @@ -1,4 +1,5 @@ import unittest +from uuid import UUID from cassandra.tablets import Tablets, Tablet @@ -124,3 +125,97 @@ def __init__(self, v): # Token value 50 is not > first_token (100) of the tablet whose # last_token (200) is >= 50, so no match. self.assertIsNone(tablets.get_tablet_for_key("ks", "tb", Token(50))) + + +class TabletReplicaDictTest(unittest.TestCase): + """Tests for Tablet._replica_dict cached lookup.""" + + def test_replica_dict_built_from_replicas(self): + u1 = UUID('12345678-1234-5678-1234-567812345678') + u2 = UUID('87654321-4321-8765-4321-876543218765') + t = Tablet(0, 100, [(u1, 3), (u2, 7)]) + self.assertEqual(t._replica_dict, {u1: 3, u2: 7}) + + def test_replica_dict_empty_when_no_replicas(self): + t = Tablet(0, 100, None) + self.assertEqual(t._replica_dict, {}) + + def test_replica_dict_contains_host(self): + u1 = UUID('12345678-1234-5678-1234-567812345678') + u2 = UUID('87654321-4321-8765-4321-876543218765') + u3 = UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee') + t = Tablet(0, 100, [(u1, 3), (u2, 7)]) + self.assertIn(u1, t._replica_dict) + self.assertIn(u2, t._replica_dict) + self.assertNotIn(u3, t._replica_dict) + + def test_replica_dict_shard_lookup(self): + u1 = UUID('12345678-1234-5678-1234-567812345678') + u2 = UUID('87654321-4321-8765-4321-876543218765') + t = Tablet(0, 100, [(u1, 3), (u2, 7)]) + self.assertEqual(t._replica_dict.get(u1), 3) + self.assertEqual(t._replica_dict.get(u2), 7) + self.assertIsNone(t._replica_dict.get(UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee'))) + + def test_replica_contains_host_id_uses_dict(self): + u1 = UUID('12345678-1234-5678-1234-567812345678') + u2 = UUID('87654321-4321-8765-4321-876543218765') + t = Tablet(0, 100, [(u1, 3), (u2, 7)]) + self.assertTrue(t.replica_contains_host_id(u1)) + self.assertTrue(t.replica_contains_host_id(u2)) + self.assertFalse(t.replica_contains_host_id(UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee'))) + + def test_replicas_stored_as_tuple(self): + t = Tablet(0, 100, [("host1", 0), ("host2", 1)]) + self.assertIsInstance(t.replicas, tuple) + + def test_replica_dict_from_iterator(self): + """Ensure _replica_dict is correctly built even when replicas is a + one-shot iterator (generator), not a reusable list.""" + u1 = UUID('12345678-1234-5678-1234-567812345678') + u2 = UUID('87654321-4321-8765-4321-876543218765') + + def gen(): + yield (u1, 3) + yield (u2, 7) + + t = Tablet(0, 100, gen()) + self.assertEqual(t.replicas, ((u1, 3), (u2, 7))) + self.assertEqual(t._replica_dict, {u1: 3, u2: 7}) + self.assertTrue(t.replica_contains_host_id(u1)) + self.assertTrue(t.replica_contains_host_id(u2)) + + +class DropTabletsByHostIdTest(unittest.TestCase): + """Tests for Tablets.drop_tablets_by_host_id batch-filter path.""" + + def test_drop_removes_matching_tablets(self): + u1 = UUID('12345678-1234-5678-1234-567812345678') + u2 = UUID('87654321-4321-8765-4321-876543218765') + t1 = Tablet(0, 100, [(u1, 0)]) + t2 = Tablet(100, 200, [(u2, 0)]) + t3 = Tablet(200, 300, [(u1, 1), (u2, 1)]) + tablets = Tablets({("ks", "tb"): [t1, t2, t3]}) + + tablets.drop_tablets_by_host_id(u1) + + remaining = tablets._tablets[("ks", "tb")] + self.assertEqual(len(remaining), 1) + self.assertIs(remaining[0], t2) + # Verify token index lists are in sync + self.assertEqual(tablets._first_tokens[("ks", "tb")], [100]) + self.assertEqual(tablets._last_tokens[("ks", "tb")], [200]) + + def test_drop_none_host_id_is_noop(self): + t1 = Tablet(0, 100, [("host1", 0)]) + tablets = Tablets({("ks", "tb"): [t1]}) + tablets.drop_tablets_by_host_id(None) + self.assertEqual(len(tablets._tablets[("ks", "tb")]), 1) + + def test_drop_nonexistent_host_id_is_noop(self): + u1 = UUID('12345678-1234-5678-1234-567812345678') + u_missing = UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee') + t1 = Tablet(0, 100, [(u1, 0)]) + tablets = Tablets({("ks", "tb"): [t1]}) + tablets.drop_tablets_by_host_id(u_missing) + self.assertEqual(len(tablets._tablets[("ks", "tb")]), 1)