diff --git a/benchmarks/bench_was_applied.py b/benchmarks/bench_was_applied.py new file mode 100644 index 0000000000..88edf5a3ae --- /dev/null +++ b/benchmarks/bench_was_applied.py @@ -0,0 +1,72 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: was_applied fast path for known LWT statements. + +Measures the speedup from skipping regex batch detection when the +query already knows it's an LWT statement (is_lwt() returns True). + +Run: + python benchmarks/bench_was_applied.py +""" +import re +import timeit +from unittest.mock import Mock + +from cassandra.query import named_tuple_factory, SimpleStatement, BatchStatement + + +def bench_was_applied(): + """Benchmark was_applied fast path vs slow path.""" + batch_regex = re.compile(r'\s*BEGIN', re.IGNORECASE) + + # Fast path: known LWT statement (BoundStatement-like, is_lwt=True) + lwt_query = Mock() + lwt_query.is_lwt.return_value = True + + def fast_path(): + query = lwt_query + if query.is_lwt() and not isinstance(query, BatchStatement): + # Fast path - known single LWT, skip batch detection + pass + + # Slow path: non-LWT SimpleStatement (must check regex) + non_lwt_query = Mock(spec=SimpleStatement) + non_lwt_query.is_lwt.return_value = False + non_lwt_query.query_string = "INSERT INTO t (k, v) VALUES (1, 2) IF NOT EXISTS" + + def slow_path(): + query = non_lwt_query + if query.is_lwt() and not isinstance(query, BatchStatement): + pass + else: + isinstance(query, BatchStatement) or \ + (isinstance(query, SimpleStatement) and batch_regex.match(query.query_string)) + + n = 500_000 + t_fast = timeit.timeit(fast_path, number=n) + t_slow = timeit.timeit(slow_path, number=n) + + print(f"Fast path (known LWT, {n} iters): {t_fast:.3f}s ({t_fast / n * 1e6:.2f} us/call)") + print(f"Slow path (regex check, {n} iters): {t_slow:.3f}s ({t_slow / n * 1e6:.2f} us/call)") + print(f"Speedup: {t_slow / t_fast:.1f}x") + + +def main(): + bench_was_applied() + + +if __name__ == '__main__': + main() diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9eace8810d..962f2eca36 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -5430,13 +5430,22 @@ def was_applied(self): if self.response_future.row_factory not in (named_tuple_factory, dict_factory, tuple_factory): raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factory,)) - is_batch_statement = isinstance(self.response_future.query, BatchStatement) \ - or (isinstance(self.response_future.query, SimpleStatement) and self.batch_regex.match(self.response_future.query.query_string)) - if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"): - raise RuntimeError("No LWT were present in the BatchStatement") + query = self.response_future.query + + # Fast path: BoundStatement/PreparedStatement with known LWT status + # from the server PREPARE response avoids batch detection entirely. + if query.is_lwt() and not isinstance(query, BatchStatement): + # Known single LWT statement - skip batch detection + if len(self.current_rows) != 1: + raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows))) + else: + is_batch_statement = isinstance(query, BatchStatement) \ + or (isinstance(query, SimpleStatement) and self.batch_regex.match(query.query_string)) + if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"): + raise RuntimeError("No LWT were present in the BatchStatement") - if not is_batch_statement and len(self.current_rows) != 1: - raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows))) + if not is_batch_statement and len(self.current_rows) != 1: + raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows))) row = self.current_rows[0] if isinstance(row, tuple): diff --git a/tests/unit/test_resultset.py b/tests/unit/test_resultset.py index 80e9c21ff9..b811fd7fe3 100644 --- a/tests/unit/test_resultset.py +++ b/tests/unit/test_resultset.py @@ -16,7 +16,7 @@ from unittest.mock import Mock, PropertyMock, patch from cassandra.cluster import ResultSet -from cassandra.query import named_tuple_factory, dict_factory, tuple_factory +from cassandra.query import named_tuple_factory, dict_factory, tuple_factory, SimpleStatement, BatchStatement from tests.util import assertListEqual import pytest @@ -175,11 +175,18 @@ def test_bool(self): assert ResultSet(Mock(has_more_pages=False), [1]) def test_was_applied(self): + # Create a non-LWT query so these assertions exercise the slow (regex) path. + # Without this, Mock().query.is_lwt() returns a truthy Mock, accidentally + # routing all checks through the fast path. + non_lwt_query = Mock(spec=SimpleStatement) + non_lwt_query.is_lwt.return_value = False + non_lwt_query.query_string = "INSERT INTO t (k) VALUES (1)" + # unknown row factory raises with pytest.raises(RuntimeError): - ResultSet(Mock(), []).was_applied + ResultSet(Mock(query=non_lwt_query), []).was_applied - response_future = Mock(row_factory=named_tuple_factory) + response_future = Mock(row_factory=named_tuple_factory, query=non_lwt_query) # no row with pytest.raises(RuntimeError): @@ -192,14 +199,68 @@ def test_was_applied(self): # various internal row factories for row_factory in (named_tuple_factory, tuple_factory): for applied in (True, False): - rs = ResultSet(Mock(row_factory=row_factory), [(applied,)]) + rs = ResultSet(Mock(row_factory=row_factory, query=non_lwt_query), [(applied,)]) assert rs.was_applied == applied row_factory = dict_factory for applied in (True, False): - rs = ResultSet(Mock(row_factory=row_factory), [{'[applied]': applied}]) + rs = ResultSet(Mock(row_factory=row_factory, query=non_lwt_query), [{'[applied]': applied}]) assert rs.was_applied == applied + + def test_was_applied_lwt_fast_path(self): + """Test that was_applied uses fast path for known LWT statements.""" + # BoundStatement-like query with is_lwt() = True (fast path) + lwt_query = Mock() + lwt_query.is_lwt.return_value = True + for row_factory in (named_tuple_factory, tuple_factory): + for applied in (True, False): + rf = Mock(row_factory=row_factory, query=lwt_query) + rs = ResultSet(rf, [(applied,)]) + assert rs.was_applied == applied + + for applied in (True, False): + rf = Mock(row_factory=dict_factory, query=lwt_query) + rs = ResultSet(rf, [{'[applied]': applied}]) + assert rs.was_applied == applied + + # Fast path with too many rows should raise + rf = Mock(row_factory=named_tuple_factory, query=lwt_query) + with pytest.raises(RuntimeError, match="exactly one row"): + ResultSet(rf, [tuple(), tuple()]).was_applied + + def test_was_applied_non_lwt_fallback(self): + """Test that was_applied falls back to slow path for non-LWT statements.""" + # SimpleStatement-like query with is_lwt() = False (slow path, non-batch) + non_lwt_query = Mock(spec=SimpleStatement) + non_lwt_query.is_lwt.return_value = False + non_lwt_query.query_string = "INSERT INTO t (k) VALUES (1)" + + for applied in (True, False): + rf = Mock(row_factory=tuple_factory, query=non_lwt_query) + rs = ResultSet(rf, [(applied,)]) + assert rs.was_applied == applied + + def test_was_applied_batch_statement(self): + """Test that was_applied handles BatchStatement correctly (slow path).""" + # BatchStatement with LWT should check column_names + batch_query = Mock(spec=BatchStatement) + batch_query.is_lwt.return_value = True + + # Batch with [applied] column -- pass _col_names so ResultSet.__init__ + # sets column_names correctly (instead of post-construction override). + rf = Mock(row_factory=tuple_factory, query=batch_query, + _col_names=['[applied]'], _col_types=None) + rs = ResultSet(rf, [(True,)]) + assert rs.was_applied == True + + # Batch without [applied] column raises + rf = Mock(row_factory=tuple_factory, query=batch_query, + _col_names=['other'], _col_types=None) + rs = ResultSet(rf, [(True,)]) + with pytest.raises(RuntimeError, match="No LWT were present"): + rs.was_applied + def test_one(self): # no pages first, second = Mock(), Mock()