Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions benchmarks/bench_was_applied.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Micro-benchmark: was_applied fast path for known LWT statements.

Measures the speedup from skipping regex batch detection when the
query already knows it's an LWT statement (is_lwt() returns True).

Run:
python benchmarks/bench_was_applied.py
"""
import re
import timeit
from unittest.mock import Mock

from cassandra.query import named_tuple_factory, SimpleStatement, BatchStatement


def bench_was_applied():
"""Benchmark was_applied fast path vs slow path."""
batch_regex = re.compile(r'\s*BEGIN', re.IGNORECASE)

# Fast path: known LWT statement (BoundStatement-like, is_lwt=True)
lwt_query = Mock()
lwt_query.is_lwt.return_value = True

def fast_path():
query = lwt_query
if query.is_lwt() and not isinstance(query, BatchStatement):
# Fast path - known single LWT, skip batch detection
pass

# Slow path: non-LWT SimpleStatement (must check regex)
non_lwt_query = Mock(spec=SimpleStatement)
non_lwt_query.is_lwt.return_value = False
non_lwt_query.query_string = "INSERT INTO t (k, v) VALUES (1, 2) IF NOT EXISTS"

def slow_path():
query = non_lwt_query
if query.is_lwt() and not isinstance(query, BatchStatement):
pass
else:
isinstance(query, BatchStatement) or \
(isinstance(query, SimpleStatement) and batch_regex.match(query.query_string))

n = 500_000
t_fast = timeit.timeit(fast_path, number=n)
t_slow = timeit.timeit(slow_path, number=n)

print(f"Fast path (known LWT, {n} iters): {t_fast:.3f}s ({t_fast / n * 1e6:.2f} us/call)")
print(f"Slow path (regex check, {n} iters): {t_slow:.3f}s ({t_slow / n * 1e6:.2f} us/call)")
print(f"Speedup: {t_slow / t_fast:.1f}x")


def main():
bench_was_applied()


if __name__ == '__main__':
main()
21 changes: 15 additions & 6 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5430,13 +5430,22 @@ def was_applied(self):
if self.response_future.row_factory not in (named_tuple_factory, dict_factory, tuple_factory):
raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factory,))

is_batch_statement = isinstance(self.response_future.query, BatchStatement) \
or (isinstance(self.response_future.query, SimpleStatement) and self.batch_regex.match(self.response_future.query.query_string))
if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"):
raise RuntimeError("No LWT were present in the BatchStatement")
query = self.response_future.query

# Fast path: BoundStatement/PreparedStatement with known LWT status
# from the server PREPARE response avoids batch detection entirely.
if query.is_lwt() and not isinstance(query, BatchStatement):
# Known single LWT statement - skip batch detection
if len(self.current_rows) != 1:
raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows)))
else:
is_batch_statement = isinstance(query, BatchStatement) \
or (isinstance(query, SimpleStatement) and self.batch_regex.match(query.query_string))
if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"):
raise RuntimeError("No LWT were present in the BatchStatement")

if not is_batch_statement and len(self.current_rows) != 1:
raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows)))
if not is_batch_statement and len(self.current_rows) != 1:
raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows)))

row = self.current_rows[0]
if isinstance(row, tuple):
Expand Down
71 changes: 66 additions & 5 deletions tests/unit/test_resultset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from unittest.mock import Mock, PropertyMock, patch

from cassandra.cluster import ResultSet
from cassandra.query import named_tuple_factory, dict_factory, tuple_factory
from cassandra.query import named_tuple_factory, dict_factory, tuple_factory, SimpleStatement, BatchStatement

from tests.util import assertListEqual
import pytest
Expand Down Expand Up @@ -175,11 +175,18 @@ def test_bool(self):
assert ResultSet(Mock(has_more_pages=False), [1])

def test_was_applied(self):
# Create a non-LWT query so these assertions exercise the slow (regex) path.
# Without this, Mock().query.is_lwt() returns a truthy Mock, accidentally
# routing all checks through the fast path.
non_lwt_query = Mock(spec=SimpleStatement)
non_lwt_query.is_lwt.return_value = False
non_lwt_query.query_string = "INSERT INTO t (k) VALUES (1)"

# unknown row factory raises
with pytest.raises(RuntimeError):
ResultSet(Mock(), []).was_applied
ResultSet(Mock(query=non_lwt_query), []).was_applied

response_future = Mock(row_factory=named_tuple_factory)
response_future = Mock(row_factory=named_tuple_factory, query=non_lwt_query)

# no row
with pytest.raises(RuntimeError):
Expand All @@ -192,14 +199,68 @@ def test_was_applied(self):
# various internal row factories
for row_factory in (named_tuple_factory, tuple_factory):
for applied in (True, False):
rs = ResultSet(Mock(row_factory=row_factory), [(applied,)])
rs = ResultSet(Mock(row_factory=row_factory, query=non_lwt_query), [(applied,)])
assert rs.was_applied == applied

row_factory = dict_factory
for applied in (True, False):
rs = ResultSet(Mock(row_factory=row_factory), [{'[applied]': applied}])
rs = ResultSet(Mock(row_factory=row_factory, query=non_lwt_query), [{'[applied]': applied}])
assert rs.was_applied == applied


def test_was_applied_lwt_fast_path(self):
"""Test that was_applied uses fast path for known LWT statements."""
# BoundStatement-like query with is_lwt() = True (fast path)
lwt_query = Mock()
lwt_query.is_lwt.return_value = True
for row_factory in (named_tuple_factory, tuple_factory):
for applied in (True, False):
rf = Mock(row_factory=row_factory, query=lwt_query)
rs = ResultSet(rf, [(applied,)])
assert rs.was_applied == applied

for applied in (True, False):
rf = Mock(row_factory=dict_factory, query=lwt_query)
rs = ResultSet(rf, [{'[applied]': applied}])
assert rs.was_applied == applied

# Fast path with too many rows should raise
rf = Mock(row_factory=named_tuple_factory, query=lwt_query)
with pytest.raises(RuntimeError, match="exactly one row"):
ResultSet(rf, [tuple(), tuple()]).was_applied

def test_was_applied_non_lwt_fallback(self):
"""Test that was_applied falls back to slow path for non-LWT statements."""
# SimpleStatement-like query with is_lwt() = False (slow path, non-batch)
non_lwt_query = Mock(spec=SimpleStatement)
non_lwt_query.is_lwt.return_value = False
non_lwt_query.query_string = "INSERT INTO t (k) VALUES (1)"

for applied in (True, False):
rf = Mock(row_factory=tuple_factory, query=non_lwt_query)
rs = ResultSet(rf, [(applied,)])
assert rs.was_applied == applied

def test_was_applied_batch_statement(self):
"""Test that was_applied handles BatchStatement correctly (slow path)."""
# BatchStatement with LWT should check column_names
batch_query = Mock(spec=BatchStatement)
batch_query.is_lwt.return_value = True

# Batch with [applied] column -- pass _col_names so ResultSet.__init__
# sets column_names correctly (instead of post-construction override).
rf = Mock(row_factory=tuple_factory, query=batch_query,
_col_names=['[applied]'], _col_types=None)
rs = ResultSet(rf, [(True,)])
assert rs.was_applied == True

# Batch without [applied] column raises
rf = Mock(row_factory=tuple_factory, query=batch_query,
_col_names=['other'], _col_types=None)
rs = ResultSet(rf, [(True,)])
with pytest.raises(RuntimeError, match="No LWT were present"):
rs.was_applied

def test_one(self):
# no pages
first, second = Mock(), Mock()
Expand Down
Loading