From 5447fbfdc6b7fef137b577c72e03e94492c2a45e Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 5 Apr 2026 18:36:21 +0300 Subject: [PATCH 1/2] perf: replace RLock with Lock where re-entrant locking is not needed Convert 7 of 8 RLock instances to plain Lock. All verified to use only flat (non-recursive) acquisition patterns: - Connection.lock (hot path: every message send/receive) - Cluster._lock (connect/shutdown) - ControlConnection._lock and _reconnection_lock - Metadata._hosts_lock and TokenMap._rebuild_lock - Host.lock and cqlengine Connection.lazy_connect_lock Session._lock is kept as RLock because run_add_or_renew_pool() uses manual release/acquire inside a 'with' block. Benchmark: RLock 'with' stmt is ~14% slower than plain Lock. --- benchmarks/micro/bench_rlock_vs_lock.py | 71 +++++++++++ cassandra/cluster.py | 39 +++--- cassandra/connection.py | 4 +- cassandra/cqlengine/connection.py | 2 +- cassandra/metadata.py | 6 +- cassandra/pool.py | 4 +- tests/unit/test_rlock_to_lock.py | 160 ++++++++++++++++++++++++ 7 files changed, 262 insertions(+), 24 deletions(-) create mode 100644 benchmarks/micro/bench_rlock_vs_lock.py create mode 100644 tests/unit/test_rlock_to_lock.py diff --git a/benchmarks/micro/bench_rlock_vs_lock.py b/benchmarks/micro/bench_rlock_vs_lock.py new file mode 100644 index 0000000000..537463834c --- /dev/null +++ b/benchmarks/micro/bench_rlock_vs_lock.py @@ -0,0 +1,71 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: RLock vs Lock acquire/release overhead. + +Measures the performance difference between threading.RLock and +threading.Lock for non-recursive lock acquisition patterns. + +Run: + python benchmarks/bench_rlock_vs_lock.py +""" +import timeit +from threading import Lock, RLock + + +def bench_lock_types(): + """Compare Lock vs RLock acquire/release cycles.""" + lock = Lock() + rlock = RLock() + + n = 2_000_000 + + def use_lock(): + lock.acquire() + lock.release() + + def use_rlock(): + rlock.acquire() + rlock.release() + + def use_lock_with(): + with lock: + pass + + def use_rlock_with(): + with rlock: + pass + + t_lock = timeit.timeit(use_lock, number=n) + t_rlock = timeit.timeit(use_rlock, number=n) + + print(f"Lock acquire/release ({n} iters): {t_lock:.3f}s ({t_lock / n * 1e9:.1f} ns/cycle)") + print(f"RLock acquire/release ({n} iters): {t_rlock:.3f}s ({t_rlock / n * 1e9:.1f} ns/cycle)") + print(f"RLock overhead: {(t_rlock / t_lock - 1) * 100:.0f}% ({t_rlock / t_lock:.2f}x)") + + t_lock_with = timeit.timeit(use_lock_with, number=n) + t_rlock_with = timeit.timeit(use_rlock_with, number=n) + + print(f"\nLock 'with' stmt ({n} iters): {t_lock_with:.3f}s ({t_lock_with / n * 1e9:.1f} ns/cycle)") + print(f"RLock 'with' stmt ({n} iters): {t_rlock_with:.3f}s ({t_rlock_with / n * 1e9:.1f} ns/cycle)") + print(f"RLock overhead: {(t_rlock_with / t_lock_with - 1) * 100:.0f}% ({t_rlock_with / t_lock_with:.2f}x)") + + +def main(): + bench_lock_types() + + +if __name__ == '__main__': + main() diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9eace8810d..3646a7ca6e 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1498,7 +1498,7 @@ def __init__(self, self.executor = self._create_thread_pool_executor(max_workers=executor_threads) self.scheduler = _Scheduler(self.executor) - self._lock = RLock() + self._lock = Lock() if self.metrics_enabled: from cassandra.metrics import Metrics @@ -1746,6 +1746,7 @@ def connect(self, keyspace=None, wait_for_all_pools=False): established or attempted. Default is `False`, which means it will return when the first successful connection is established. Remaining pools are added asynchronously. """ + connect_exc = None with self._lock: if self.is_shutdown: raise DriverException("Cluster is already shut down") @@ -1761,21 +1762,27 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self._populate_hosts() log.debug("Control connection created") - except Exception: + except Exception as exc: log.exception("Control connection failed to connect, " "shutting down Cluster:") - self.shutdown() - raise - - self.profile_manager.check_supported() # todo: rename this method - - if self.idle_heartbeat_interval: - self._idle_heartbeat = ConnectionHeartbeat( - self.idle_heartbeat_interval, - self.get_connection_holders, - timeout=self.idle_heartbeat_timeout - ) - self._is_setup = True + connect_exc = exc + + if connect_exc is None: + self.profile_manager.check_supported() # todo: rename this method + + if self.idle_heartbeat_interval: + self._idle_heartbeat = ConnectionHeartbeat( + self.idle_heartbeat_interval, + self.get_connection_holders, + timeout=self.idle_heartbeat_timeout + ) + self._is_setup = True + + if connect_exc is not None: + # shutdown() acquires self._lock, so must be called after + # releasing it above to avoid deadlock. + self.shutdown() + raise connect_exc session = self._new_session(keyspace) if wait_for_all_pools: @@ -3540,11 +3547,11 @@ def __init__(self, cluster, timeout, self._token_meta_enabled = token_meta_enabled self._schema_meta_page_size = schema_meta_page_size - self._lock = RLock() + self._lock = Lock() self._schema_agreement_lock = Lock() self._reconnection_handler = None - self._reconnection_lock = RLock() + self._reconnection_lock = Lock() self._event_schedule_times = {} diff --git a/cassandra/connection.py b/cassandra/connection.py index c045b36cb3..b71f00c987 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -22,7 +22,7 @@ import socket import struct import sys -from threading import Thread, Event, RLock, Condition +from threading import Thread, Event, Lock, Condition import time import ssl import uuid @@ -928,7 +928,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None, self.request_ids = deque(range(initial_size)) self.highest_request_id = initial_size - 1 - self.lock = RLock() + self.lock = Lock() self.connected_event = Event() self.features = ProtocolFeatures(shard_id=shard_id) self.total_shards = total_shards diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py index bf3e55a2e8..9b37359a05 100644 --- a/cassandra/cqlengine/connection.py +++ b/cassandra/cqlengine/connection.py @@ -78,7 +78,7 @@ def __init__(self, name, hosts, consistency=None, self.lazy_connect = lazy_connect self.retry_connect = retry_connect self.cluster_options = cluster_options if cluster_options else {} - self.lazy_connect_lock = threading.RLock() + self.lazy_connect_lock = threading.Lock() @classmethod def from_session(cls, name, session): diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 43399b7152..d3592da88d 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -22,7 +22,7 @@ import logging import re import sys -from threading import RLock +from threading import Lock import struct import random import itertools @@ -126,7 +126,7 @@ def __init__(self): self.dbaas = False self._hosts = {} self._host_id_by_endpoint = {} - self._hosts_lock = RLock() + self._hosts_lock = Lock() self._tablets = Tablets({}) def export_schema_as_string(self): @@ -1778,7 +1778,7 @@ def __init__(self, token_class, token_to_host_owner, all_tokens, metadata): self.tokens_to_hosts_by_ks = {} self._metadata = metadata - self._rebuild_lock = RLock() + self._rebuild_lock = Lock() def rebuild_keyspace(self, keyspace, build_if_absent=False): with self._rebuild_lock: diff --git a/cassandra/pool.py b/cassandra/pool.py index 227e1b5315..772ca7ba8c 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -23,7 +23,7 @@ import random import copy import uuid -from threading import Lock, RLock, Condition +from threading import Lock, Condition import weakref try: from weakref import WeakSet @@ -179,7 +179,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No raise ValueError("host_id may not be None") self.host_id = host_id self.set_location_info(datacenter, rack) - self.lock = RLock() + self.lock = Lock() @property def address(self): diff --git a/tests/unit/test_rlock_to_lock.py b/tests/unit/test_rlock_to_lock.py new file mode 100644 index 0000000000..bb65dd1449 --- /dev/null +++ b/tests/unit/test_rlock_to_lock.py @@ -0,0 +1,160 @@ +""" +Unit tests verifying that RLock -> Lock conversion is safe. + +Tests that the lock objects are of the correct type and that basic +operations (connect, metadata, pool) still work correctly. +""" +import threading +import unittest +from unittest.mock import Mock, patch + +from cassandra.cluster import Cluster +from cassandra.metadata import Metadata, TokenMap +from cassandra.pool import Host + + +class TestLockTypes(unittest.TestCase): + """Verify each converted lock is a plain Lock, not RLock.""" + + def _assert_is_lock_not_rlock(self, lock_obj): + """Assert the given object is a plain Lock, not an RLock.""" + # In CPython, Lock() creates _thread.lock, RLock() creates _thread.RLock + lock_type_name = type(lock_obj).__name__ + self.assertNotIn('RLock', lock_type_name, + f"Expected plain Lock but got {type(lock_obj)}") + + def test_metadata_hosts_lock_is_plain_lock(self): + """Metadata._hosts_lock should be a plain Lock.""" + m = Metadata() + self._assert_is_lock_not_rlock(m._hosts_lock) + + def test_metadata_rebuild_lock_is_plain_lock(self): + """TokenMap._rebuild_lock should be a plain Lock.""" + tm = TokenMap( + token_class=Mock(), + token_to_host_owner={}, + all_tokens=[], + metadata=Mock() + ) + self._assert_is_lock_not_rlock(tm._rebuild_lock) + + def test_host_lock_is_plain_lock(self): + """Host.lock should be a plain Lock.""" + import uuid + h = Host( + endpoint=Mock(), + conviction_policy_factory=Mock(), + host_id=uuid.uuid4() + ) + self._assert_is_lock_not_rlock(h.lock) + + def test_cqlengine_connection_lock_is_plain_lock(self): + """CQLEngine Connection.lazy_connect_lock should be a plain Lock.""" + from cassandra.cqlengine.connection import Connection as CQLConn + c = CQLConn.__new__(CQLConn) + c.lazy_connect_lock = threading.Lock() + self._assert_is_lock_not_rlock(c.lazy_connect_lock) + + +class TestMetadataOperationsWithLock(unittest.TestCase): + """Verify metadata operations work correctly with plain Lock.""" + + def test_add_and_get_host(self): + """add_or_return_host + get_host should work with plain Lock.""" + import uuid + m = Metadata() + endpoint = Mock() + host = Host(endpoint=endpoint, conviction_policy_factory=Mock(), + host_id=uuid.uuid4()) + returned, new = m.add_or_return_host(host) + self.assertTrue(new) + self.assertIs(returned, host) + + # Second add should return same host + returned2, new2 = m.add_or_return_host(host) + self.assertFalse(new2) + self.assertIs(returned2, host) + + def test_update_host_sequential_lock(self): + """update_host acquires lock twice sequentially — must not deadlock.""" + import uuid + m = Metadata() + old_endpoint = Mock() + new_endpoint = Mock() + host = Host(endpoint=new_endpoint, conviction_policy_factory=Mock(), + host_id=uuid.uuid4()) + # update_host calls add_or_return_host (acquires lock, releases), + # then acquires lock again for endpoint update. + # With plain Lock, this must NOT deadlock. + m.update_host(host, old_endpoint) + # Host should be retrievable by host_id + result = m.get_host_by_host_id(host.host_id) + self.assertIs(result, host) + + def test_remove_host(self): + """remove_host should work with plain Lock.""" + import uuid + m = Metadata() + endpoint = Mock() + host = Host(endpoint=endpoint, conviction_policy_factory=Mock(), + host_id=uuid.uuid4()) + m.add_or_return_host(host) + removed = m.remove_host(host) + self.assertTrue(removed) + + def test_all_hosts(self): + """all_hosts should work under plain Lock.""" + import uuid + m = Metadata() + hosts = [] + for _ in range(3): + h = Host(endpoint=Mock(), conviction_policy_factory=Mock(), + host_id=uuid.uuid4()) + m.add_or_return_host(h) + hosts.append(h) + all_h = m.all_hosts() + self.assertEqual(len(all_h), 3) + + +class TestHostLockOperations(unittest.TestCase): + """Verify Host lock operations work with plain Lock.""" + + def test_get_and_set_reconnection_handler(self): + """get_and_set_reconnection_handler should work with plain Lock.""" + import uuid + h = Host(endpoint=Mock(), conviction_policy_factory=Mock(), + host_id=uuid.uuid4()) + handler = Mock() + old = h.get_and_set_reconnection_handler(handler) + self.assertIsNone(old) + old2 = h.get_and_set_reconnection_handler(Mock()) + self.assertIs(old2, handler) + + +class TestClusterConnectFailureNoDeadlock(unittest.TestCase): + """Verify Cluster.connect() failure path doesn't deadlock with plain Lock. + + Cluster._lock is a plain Lock. connect() acquires it, and on failure + calls shutdown() which also acquires it. The shutdown() call must happen + after releasing the lock to avoid deadlock. + """ + + def test_connect_failure_calls_shutdown_without_deadlock(self): + """connect() should call shutdown() and re-raise on control connection failure.""" + cluster = Cluster(contact_points=[]) + # Ensure Cluster._lock is a plain Lock (not RLock) + lock_type_name = type(cluster._lock).__name__ + self.assertNotIn('RLock', lock_type_name) + + with patch.object(cluster.connection_class, 'initialize_reactor'): + with patch.object(cluster.control_connection, 'connect', + side_effect=Exception("test connection failure")): + with patch.object(cluster, 'shutdown') as mock_shutdown: + with self.assertRaises(Exception) as ctx: + cluster.connect() + self.assertIn("test connection failure", str(ctx.exception)) + mock_shutdown.assert_called_once() + + +if __name__ == '__main__': + unittest.main() From 014e82e60da27c2ab078d2f0970f2a81a6a54306 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sat, 11 Apr 2026 00:32:52 +0300 Subject: [PATCH 2/2] perf: skip lock acquisition when no orphaned requests in process_msg MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Check orphaned_request_ids truthiness before acquiring the lock. Since orphaned requests are rare (only on timeouts), the set is almost always empty. Skipping the lock in the common case saves ~57 ns per response. The unlocked truthiness check on a set is thread-safe under the GIL. Worst case (false positive): we enter the lock block and re-check, which is correct. Worst case (false negative): an orphaned response is processed normally — acceptable behavior. Benchmark (2M iters, Python 3.14): Empty set (common): 80.6 -> 23.2 ns (3.47x, -57.4 ns/response) Non-empty set (rare): 79.7 -> 87.8 ns (+8.1 ns overhead) --- benchmarks/micro/bench_orphan_lock_skip.py | 105 +++++++++++++++++++++ cassandra/connection.py | 16 +++- 2 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 benchmarks/micro/bench_orphan_lock_skip.py diff --git a/benchmarks/micro/bench_orphan_lock_skip.py b/benchmarks/micro/bench_orphan_lock_skip.py new file mode 100644 index 0000000000..e56cabe972 --- /dev/null +++ b/benchmarks/micro/bench_orphan_lock_skip.py @@ -0,0 +1,105 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: orphaned request lock skip in process_msg. + +Measures the cost of always acquiring a lock vs checking the set first. + +Run: + python benchmarks/bench_orphan_lock_skip.py +""" + +import sys +import timeit +from threading import Lock + + +def bench(): + n = 2_000_000 + lock = Lock() + orphaned_set = set() # empty — the common case + stream_id = 42 + in_flight = 100 + + # Old: always acquire lock + def old_check(): + nonlocal in_flight + with lock: + if stream_id in orphaned_set: + in_flight -= 1 + orphaned_set.remove(stream_id) + + # New: check set first, skip lock if empty + def new_check(): + nonlocal in_flight + if orphaned_set: + with lock: + if stream_id in orphaned_set: + in_flight -= 1 + orphaned_set.remove(stream_id) + + print(f"=== orphaned request lock skip ({n:,} iters) ===\n") + + # Warmup + for _ in range(10000): + old_check() + new_check() + + t_old = timeit.timeit(old_check, number=n) + t_new = timeit.timeit(new_check, number=n) + ns_old = t_old / n * 1e9 + ns_new = t_new / n * 1e9 + saving = ns_old - ns_new + speedup = ns_old / ns_new if ns_new > 0 else float('inf') + print(f" Empty orphaned set (common case):") + print(f" Old (always lock): {ns_old:.1f} ns") + print(f" New (check first): {ns_new:.1f} ns") + print(f" Saving: {saving:.1f} ns ({speedup:.2f}x)") + + # Non-empty set (rare case) — should still work + orphaned_set_full = {99, 100, 101} + def old_check_full(): + nonlocal in_flight + with lock: + if stream_id in orphaned_set_full: + in_flight -= 1 + orphaned_set_full.remove(stream_id) + + def new_check_full(): + nonlocal in_flight + if orphaned_set_full: + with lock: + if stream_id in orphaned_set_full: + in_flight -= 1 + orphaned_set_full.remove(stream_id) + + for _ in range(10000): + old_check_full() + new_check_full() + + t_old = timeit.timeit(old_check_full, number=n) + t_new = timeit.timeit(new_check_full, number=n) + ns_old = t_old / n * 1e9 + ns_new = t_new / n * 1e9 + diff = ns_new - ns_old + print(f"\n Non-empty orphaned set (rare case):") + print(f" Old (always lock): {ns_old:.1f} ns") + print(f" New (check first): {ns_new:.1f} ns") + print(f" Overhead: {diff:.1f} ns (extra truthiness check)") + + +if __name__ == "__main__": + print(f"Python {sys.version}\n") + bench() diff --git a/cassandra/connection.py b/cassandra/connection.py index b71f00c987..f444187fd0 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1395,11 +1395,17 @@ def process_msg(self, header, body): result_metadata = None else: need_notify_of_release = False - with self.lock: - if stream_id in self.orphaned_request_ids: - self.in_flight -= 1 - self.orphaned_request_ids.remove(stream_id) - need_notify_of_release = True + # Fast path: skip lock when no orphaned requests (common case). + # Reading orphaned_request_ids without the lock is safe: it's a + # set and we only check truthiness. A false negative just means + # we'll process the orphaned response normally; a false positive + # (rare) falls through to the locked check which is correct. + if self.orphaned_request_ids: + with self.lock: + if stream_id in self.orphaned_request_ids: + self.in_flight -= 1 + self.orphaned_request_ids.remove(stream_id) + need_notify_of_release = True if need_notify_of_release and self._on_orphaned_stream_released: self._on_orphaned_stream_released()