Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions benchmarks/micro/bench_orphan_lock_skip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright ScyllaDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Micro-benchmark: orphaned request lock skip in process_msg.

Measures the cost of always acquiring a lock vs checking the set first.

Run:
python benchmarks/bench_orphan_lock_skip.py
"""

import sys
import timeit
from threading import Lock


def bench():
n = 2_000_000
lock = Lock()
orphaned_set = set() # empty — the common case
stream_id = 42
in_flight = 100

# Old: always acquire lock
def old_check():
nonlocal in_flight
with lock:
if stream_id in orphaned_set:
in_flight -= 1
orphaned_set.remove(stream_id)

# New: check set first, skip lock if empty
def new_check():
nonlocal in_flight
if orphaned_set:
with lock:
if stream_id in orphaned_set:
in_flight -= 1
orphaned_set.remove(stream_id)

print(f"=== orphaned request lock skip ({n:,} iters) ===\n")

# Warmup
for _ in range(10000):
old_check()
new_check()

t_old = timeit.timeit(old_check, number=n)
t_new = timeit.timeit(new_check, number=n)
ns_old = t_old / n * 1e9
ns_new = t_new / n * 1e9
saving = ns_old - ns_new
speedup = ns_old / ns_new if ns_new > 0 else float('inf')
print(f" Empty orphaned set (common case):")
print(f" Old (always lock): {ns_old:.1f} ns")
print(f" New (check first): {ns_new:.1f} ns")
print(f" Saving: {saving:.1f} ns ({speedup:.2f}x)")

# Non-empty set (rare case) — should still work
orphaned_set_full = {99, 100, 101}
def old_check_full():
nonlocal in_flight
with lock:
if stream_id in orphaned_set_full:
in_flight -= 1
orphaned_set_full.remove(stream_id)

def new_check_full():
nonlocal in_flight
if orphaned_set_full:
with lock:
if stream_id in orphaned_set_full:
in_flight -= 1
orphaned_set_full.remove(stream_id)

for _ in range(10000):
old_check_full()
new_check_full()

t_old = timeit.timeit(old_check_full, number=n)
t_new = timeit.timeit(new_check_full, number=n)
ns_old = t_old / n * 1e9
ns_new = t_new / n * 1e9
diff = ns_new - ns_old
print(f"\n Non-empty orphaned set (rare case):")
print(f" Old (always lock): {ns_old:.1f} ns")
print(f" New (check first): {ns_new:.1f} ns")
print(f" Overhead: {diff:.1f} ns (extra truthiness check)")


if __name__ == "__main__":
print(f"Python {sys.version}\n")
bench()
71 changes: 71 additions & 0 deletions benchmarks/micro/bench_rlock_vs_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Micro-benchmark: RLock vs Lock acquire/release overhead.

Measures the performance difference between threading.RLock and
threading.Lock for non-recursive lock acquisition patterns.

Run:
python benchmarks/bench_rlock_vs_lock.py
"""
import timeit
from threading import Lock, RLock


def bench_lock_types():
"""Compare Lock vs RLock acquire/release cycles."""
lock = Lock()
rlock = RLock()

n = 2_000_000

def use_lock():
lock.acquire()
lock.release()

def use_rlock():
rlock.acquire()
rlock.release()

def use_lock_with():
with lock:
pass

def use_rlock_with():
with rlock:
pass

t_lock = timeit.timeit(use_lock, number=n)
t_rlock = timeit.timeit(use_rlock, number=n)

print(f"Lock acquire/release ({n} iters): {t_lock:.3f}s ({t_lock / n * 1e9:.1f} ns/cycle)")
print(f"RLock acquire/release ({n} iters): {t_rlock:.3f}s ({t_rlock / n * 1e9:.1f} ns/cycle)")
print(f"RLock overhead: {(t_rlock / t_lock - 1) * 100:.0f}% ({t_rlock / t_lock:.2f}x)")

t_lock_with = timeit.timeit(use_lock_with, number=n)
t_rlock_with = timeit.timeit(use_rlock_with, number=n)

print(f"\nLock 'with' stmt ({n} iters): {t_lock_with:.3f}s ({t_lock_with / n * 1e9:.1f} ns/cycle)")
print(f"RLock 'with' stmt ({n} iters): {t_rlock_with:.3f}s ({t_rlock_with / n * 1e9:.1f} ns/cycle)")
print(f"RLock overhead: {(t_rlock_with / t_lock_with - 1) * 100:.0f}% ({t_rlock_with / t_lock_with:.2f}x)")


def main():
bench_lock_types()


if __name__ == '__main__':
main()
39 changes: 23 additions & 16 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ def __init__(self,
self.executor = self._create_thread_pool_executor(max_workers=executor_threads)
self.scheduler = _Scheduler(self.executor)

self._lock = RLock()
self._lock = Lock()

if self.metrics_enabled:
from cassandra.metrics import Metrics
Expand Down Expand Up @@ -1746,6 +1746,7 @@ def connect(self, keyspace=None, wait_for_all_pools=False):
established or attempted. Default is `False`, which means it will return when the first
successful connection is established. Remaining pools are added asynchronously.
"""
connect_exc = None
with self._lock:
if self.is_shutdown:
raise DriverException("Cluster is already shut down")
Expand All @@ -1761,21 +1762,27 @@ def connect(self, keyspace=None, wait_for_all_pools=False):
self._populate_hosts()

log.debug("Control connection created")
except Exception:
except Exception as exc:
log.exception("Control connection failed to connect, "
"shutting down Cluster:")
self.shutdown()
raise

self.profile_manager.check_supported() # todo: rename this method

if self.idle_heartbeat_interval:
self._idle_heartbeat = ConnectionHeartbeat(
self.idle_heartbeat_interval,
self.get_connection_holders,
timeout=self.idle_heartbeat_timeout
)
self._is_setup = True
connect_exc = exc

if connect_exc is None:
self.profile_manager.check_supported() # todo: rename this method

if self.idle_heartbeat_interval:
self._idle_heartbeat = ConnectionHeartbeat(
self.idle_heartbeat_interval,
self.get_connection_holders,
timeout=self.idle_heartbeat_timeout
)
self._is_setup = True

if connect_exc is not None:
# shutdown() acquires self._lock, so must be called after
# releasing it above to avoid deadlock.
self.shutdown()
raise connect_exc

session = self._new_session(keyspace)
if wait_for_all_pools:
Expand Down Expand Up @@ -3540,11 +3547,11 @@ def __init__(self, cluster, timeout,
self._token_meta_enabled = token_meta_enabled
self._schema_meta_page_size = schema_meta_page_size

self._lock = RLock()
self._lock = Lock()
self._schema_agreement_lock = Lock()

self._reconnection_handler = None
self._reconnection_lock = RLock()
self._reconnection_lock = Lock()

self._event_schedule_times = {}

Expand Down
20 changes: 13 additions & 7 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import socket
import struct
import sys
from threading import Thread, Event, RLock, Condition
from threading import Thread, Event, Lock, Condition
import time
import ssl
import uuid
Expand Down Expand Up @@ -928,7 +928,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
self.request_ids = deque(range(initial_size))
self.highest_request_id = initial_size - 1

self.lock = RLock()
self.lock = Lock()
self.connected_event = Event()
self.features = ProtocolFeatures(shard_id=shard_id)
self.total_shards = total_shards
Expand Down Expand Up @@ -1395,11 +1395,17 @@ def process_msg(self, header, body):
result_metadata = None
else:
need_notify_of_release = False
with self.lock:
if stream_id in self.orphaned_request_ids:
self.in_flight -= 1
self.orphaned_request_ids.remove(stream_id)
need_notify_of_release = True
# Fast path: skip lock when no orphaned requests (common case).
# Reading orphaned_request_ids without the lock is safe: it's a
# set and we only check truthiness. A false negative just means
# we'll process the orphaned response normally; a false positive
# (rare) falls through to the locked check which is correct.
if self.orphaned_request_ids:
with self.lock:
if stream_id in self.orphaned_request_ids:
self.in_flight -= 1
self.orphaned_request_ids.remove(stream_id)
need_notify_of_release = True
if need_notify_of_release and self._on_orphaned_stream_released:
self._on_orphaned_stream_released()

Expand Down
2 changes: 1 addition & 1 deletion cassandra/cqlengine/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, name, hosts, consistency=None,
self.lazy_connect = lazy_connect
self.retry_connect = retry_connect
self.cluster_options = cluster_options if cluster_options else {}
self.lazy_connect_lock = threading.RLock()
self.lazy_connect_lock = threading.Lock()

@classmethod
def from_session(cls, name, session):
Expand Down
6 changes: 3 additions & 3 deletions cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import logging
import re
import sys
from threading import RLock
from threading import Lock
import struct
import random
import itertools
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self):
self.dbaas = False
self._hosts = {}
self._host_id_by_endpoint = {}
self._hosts_lock = RLock()
self._hosts_lock = Lock()
self._tablets = Tablets({})

def export_schema_as_string(self):
Expand Down Expand Up @@ -1778,7 +1778,7 @@ def __init__(self, token_class, token_to_host_owner, all_tokens, metadata):

self.tokens_to_hosts_by_ks = {}
self._metadata = metadata
self._rebuild_lock = RLock()
self._rebuild_lock = Lock()

def rebuild_keyspace(self, keyspace, build_if_absent=False):
with self._rebuild_lock:
Expand Down
4 changes: 2 additions & 2 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import random
import copy
import uuid
from threading import Lock, RLock, Condition
from threading import Lock, Condition
import weakref
try:
from weakref import WeakSet
Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No
raise ValueError("host_id may not be None")
self.host_id = host_id
self.set_location_info(datacenter, rack)
self.lock = RLock()
self.lock = Lock()

@property
def address(self):
Expand Down
Loading
Loading