From 1e1e7093aeb3549e49c8831ac75a2041c09eba6a Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sat, 4 Apr 2026 17:21:14 +0300 Subject: [PATCH 1/3] perf: buffer accumulation in _write_query_params() Replace the per-parameter write_value(f, param) loop in _QueryMessage._write_query_params() with a buffer accumulation approach: list.append + b"".join + single f.write(). This reduces the number of f.write() calls from 2*N+1 to 1, which is significant for vector workloads with large parameters. Also removes the redundant ExecuteMessage._write_query_params() pass-through override to avoid extra MRO lookup per call. Includes 14 unit tests covering normal, NULL, UNSET, empty, large vector, and mixed parameter scenarios for both ExecuteMessage and QueryMessage. Includes a benchmark script (benchmarks/bench_execute_write_params.py). --- benchmarks/bench_execute_write_params.py | 538 +++++++++++++++++++++++ cassandra/protocol.py | 19 +- tests/unit/test_protocol.py | 188 +++++++- 3 files changed, 739 insertions(+), 6 deletions(-) create mode 100644 benchmarks/bench_execute_write_params.py diff --git a/benchmarks/bench_execute_write_params.py b/benchmarks/bench_execute_write_params.py new file mode 100644 index 0000000000..cf02a438a0 --- /dev/null +++ b/benchmarks/bench_execute_write_params.py @@ -0,0 +1,538 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark: ExecuteMessage._write_query_params() and send_body() for vector +INSERT workloads. + +Compares five approaches for the parameter serialization hot loop: + + 1. baseline – current code (calling write_value() per param) + 2. pr788_inline – PR #788 inlining (local aliases, inline write_value) + 3. buf_accum – buffer accumulation (collect parts in list, single join) + 4. combined – inlining + buffer accumulation + 5. module_current – whatever the loaded module provides (.so or .py) + +Variants 1-4 are standalone pure-Python functions that call into +Cython-compiled helpers (write_value, write_string, etc.) when the .so is +loaded. Variant 5 calls the actual module method directly. + +NOTE: To compare Cython vs pure-Python for variant 5, move the .so aside: + mv cassandra/protocol.cpython-*-linux-gnu.so{,.bak} + +Usage: + python benchmarks/bench_execute_write_params.py +""" + +import io +import struct +import time +import timeit +import sys +import os + +# Ensure the repo root is importable +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import cassandra.protocol +from cassandra.protocol import ( + ExecuteMessage, + _QueryMessage, + ProtocolHandler, + write_consistency_level, + write_byte, + write_uint, + write_short, + write_int, + write_long, + write_string, + write_value, + _UNSET_VALUE, + _VALUES_FLAG, + _WITH_SERIAL_CONSISTENCY_FLAG, + _PAGE_SIZE_FLAG, + _WITH_PAGING_STATE_FLAG, + _PROTOCOL_TIMESTAMP_FLAG, + _WITH_KEYSPACE_FLAG, +) +from cassandra import ProtocolVersion +from cassandra.marshal import int32_pack, uint16_pack, uint8_pack, uint64_pack + +# --------------------------------------------------------------------------- +# Pre-computed constants (as in PR #788) +# --------------------------------------------------------------------------- +_INT32_NEG1 = int32_pack(-1) # NULL marker +_INT32_NEG2 = int32_pack(-2) # UNSET marker + + +# =================================================================== +# Variant 1: baseline – mirrors current _write_query_params exactly +# =================================================================== + + +def baseline_write_query_params(msg, f, protocol_version): + write_consistency_level(f, msg.consistency_level) + flags = 0x00 + if msg.query_params is not None: + flags |= _VALUES_FLAG + if msg.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if msg.fetch_size: + flags |= _PAGE_SIZE_FLAG + if msg.paging_state: + flags |= _WITH_PAGING_STATE_FLAG + if msg.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + if getattr(msg, "keyspace", None) is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + if ProtocolVersion.uses_int_query_flags(protocol_version): + write_uint(f, flags) + else: + write_byte(f, flags) + if msg.query_params is not None: + write_short(f, len(msg.query_params)) + for param in msg.query_params: + write_value(f, param) + if msg.fetch_size: + write_int(f, msg.fetch_size) + if msg.paging_state: + write_string(f, msg.paging_state) + if msg.serial_consistency_level: + write_consistency_level(f, msg.serial_consistency_level) + if msg.timestamp is not None: + write_long(f, msg.timestamp) + + +def baseline_send_body(msg, f, protocol_version): + write_string(f, msg.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, msg.result_metadata_id) + baseline_write_query_params(msg, f, protocol_version) + + +# =================================================================== +# Variant 2: pr788_inline – inline write_value with local aliases +# =================================================================== + + +def pr788_write_query_params(msg, f, protocol_version): + write_consistency_level(f, msg.consistency_level) + flags = 0x00 + if msg.query_params is not None: + flags |= _VALUES_FLAG + if msg.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if msg.fetch_size: + flags |= _PAGE_SIZE_FLAG + if msg.paging_state: + flags |= _WITH_PAGING_STATE_FLAG + if msg.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + if getattr(msg, "keyspace", None) is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + if ProtocolVersion.uses_int_query_flags(protocol_version): + write_uint(f, flags) + else: + write_byte(f, flags) + if msg.query_params is not None: + write_short(f, len(msg.query_params)) + _fw = f.write + _i32 = int32_pack + for param in msg.query_params: + if param is None: + _fw(_INT32_NEG1) + elif param is _UNSET_VALUE: + _fw(_INT32_NEG2) + else: + _fw(_i32(len(param))) + _fw(param) + if msg.fetch_size: + write_int(f, msg.fetch_size) + if msg.paging_state: + write_string(f, msg.paging_state) + if msg.serial_consistency_level: + write_consistency_level(f, msg.serial_consistency_level) + if msg.timestamp is not None: + write_long(f, msg.timestamp) + + +def pr788_send_body(msg, f, protocol_version): + write_string(f, msg.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, msg.result_metadata_id) + pr788_write_query_params(msg, f, protocol_version) + + +# =================================================================== +# Variant 3: buf_accum – collect all writes in a list, single join +# =================================================================== + + +def bufaccum_write_query_params(msg, f, protocol_version): + parts = [] + _p = parts.append + _i32 = int32_pack + _u16 = uint16_pack + _u8 = uint8_pack + _u64 = uint64_pack + + _p(_u16(msg.consistency_level)) + + flags = 0x00 + if msg.query_params is not None: + flags |= _VALUES_FLAG + if msg.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if msg.fetch_size: + flags |= _PAGE_SIZE_FLAG + if msg.paging_state: + flags |= _WITH_PAGING_STATE_FLAG + if msg.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + if getattr(msg, "keyspace", None) is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + + if ProtocolVersion.uses_int_query_flags(protocol_version): + from cassandra.marshal import uint32_pack + + _p(uint32_pack(flags)) + else: + _p(_u8(flags)) + + if msg.query_params is not None: + _p(_u16(len(msg.query_params))) + for param in msg.query_params: + if param is None: + _p(_INT32_NEG1) + elif param is _UNSET_VALUE: + _p(_INT32_NEG2) + else: + _p(_i32(len(param))) + _p(param) + + if msg.fetch_size: + _p(_i32(msg.fetch_size)) + if msg.paging_state: + ps = msg.paging_state + if isinstance(ps, str): + ps = ps.encode("utf8") + _p(_u16(len(ps))) + _p(ps) + if msg.serial_consistency_level: + _p(_u16(msg.serial_consistency_level)) + if msg.timestamp is not None: + _p(_u64(msg.timestamp)) + + f.write(b"".join(parts)) + + +def bufaccum_send_body(msg, f, protocol_version): + write_string(f, msg.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, msg.result_metadata_id) + bufaccum_write_query_params(msg, f, protocol_version) + + +# =================================================================== +# Variant 4: combined – inline write_value + buffer accumulation +# (single len+data concat per param, then single join) +# =================================================================== + + +def combined_write_query_params(msg, f, protocol_version): + parts = [] + _p = parts.append + _i32 = int32_pack + _u16 = uint16_pack + _u8 = uint8_pack + _u64 = uint64_pack + + _p(_u16(msg.consistency_level)) + + flags = 0x00 + if msg.query_params is not None: + flags |= _VALUES_FLAG + if msg.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if msg.fetch_size: + flags |= _PAGE_SIZE_FLAG + if msg.paging_state: + flags |= _WITH_PAGING_STATE_FLAG + if msg.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + if getattr(msg, "keyspace", None) is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + + if ProtocolVersion.uses_int_query_flags(protocol_version): + from cassandra.marshal import uint32_pack + + _p(uint32_pack(flags)) + else: + _p(_u8(flags)) + + if msg.query_params is not None: + _p(_u16(len(msg.query_params))) + for param in msg.query_params: + if param is None: + _p(_INT32_NEG1) + elif param is _UNSET_VALUE: + _p(_INT32_NEG2) + else: + _p(_i32(len(param)) + param) # single concat per param + + if msg.fetch_size: + _p(_i32(msg.fetch_size)) + if msg.paging_state: + ps = msg.paging_state + if isinstance(ps, str): + ps = ps.encode("utf8") + _p(_u16(len(ps))) + _p(ps) + if msg.serial_consistency_level: + _p(_u16(msg.serial_consistency_level)) + if msg.timestamp is not None: + _p(_u64(msg.timestamp)) + + f.write(b"".join(parts)) + + +def combined_send_body(msg, f, protocol_version): + write_string(f, msg.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, msg.result_metadata_id) + combined_write_query_params(msg, f, protocol_version) + + +# =================================================================== +# Test scenarios +# =================================================================== + + +def make_vector_params(dim): + """Simulate a prepared INSERT with (int32_key, float_vector) params. + + Returns a list of pre-serialized bytes, as BoundStatement.bind() would + produce *after* calling col_type.serialize() on each value. + """ + int_key = int32_pack(42) # 4 bytes + vector_bytes = struct.pack(f">{dim}f", *([0.1] * dim)) # dim * 4 bytes + return [int_key, vector_bytes] + + +def make_scalar_params(n, size=20): + """Simulate n text columns of `size` bytes each.""" + return [b"\x41" * size for _ in range(n)] + + +PROTO_VERSION = 4 +ITERATIONS = 500_000 +REPEATS = 5 + +SCENARIOS = [ + ("128D vector INSERT (2 params)", make_vector_params(128)), + ("768D vector INSERT (2 params)", make_vector_params(768)), + ("1536D vector INSERT (2 params)", make_vector_params(1536)), + ("scalar 10 text cols (10 params)", make_scalar_params(10, 20)), +] + +# _write_query_params variants (the core hot path) +WQP_VARIANTS = [ + ("1_baseline", baseline_write_query_params), + ("2_pr788_inline", pr788_write_query_params), + ("3_buf_accum", bufaccum_write_query_params), + ("4_combined", combined_write_query_params), +] + +# send_body variants (includes query_id framing) +SB_VARIANTS = [ + ("1_baseline", baseline_send_body), + ("2_pr788_inline", pr788_send_body), + ("3_buf_accum", bufaccum_send_body), + ("4_combined", combined_send_body), +] + + +# =================================================================== +# Benchmark helpers +# =================================================================== + + +def verify_output(ref_fn, test_fn, msg, pv): + """Verify two functions produce byte-identical output.""" + f1 = io.BytesIO() + ref_fn(msg, f1, pv) + ref_bytes = f1.getvalue() + + f2 = io.BytesIO() + test_fn(msg, f2, pv) + test_bytes = f2.getvalue() + + if ref_bytes != test_bytes: + for i, (a, b) in enumerate(zip(ref_bytes, test_bytes)): + if a != b: + return False, f"diff at byte {i}: ref=0x{a:02x}, test=0x{b:02x}" + if len(ref_bytes) != len(test_bytes): + return False, f"len diff: ref={len(ref_bytes)}, test={len(test_bytes)}" + return True, "" + + +def bench_fn(fn, msg, pv, iterations, repeats): + """Benchmark a single function, return best ns/call.""" + f = io.BytesIO() + + def run(): + f.seek(0) + f.truncate() + fn(msg, f, pv) + + t = timeit.repeat(run, number=iterations, repeat=repeats, timer=time.process_time) + return min(t) / iterations * 1e9 + + +def make_execute_msg(params): + """Create a realistic ExecuteMessage for a prepared INSERT.""" + return ExecuteMessage( + query_id=b"\x01\x02\x03\x04\x05\x06\x07\x08", # 8-byte prepared query ID + query_params=params, + consistency_level=1, # ONE + timestamp=1234567890123456, # typical microsecond timestamp + # No serial CL, no fetch_size, no paging — typical INSERT + ) + + +# =================================================================== +# Main +# =================================================================== + + +def main(): + is_cython = cassandra.protocol.__file__.endswith(".so") + print(f"Python: {sys.version.split()[0]}") + print(f"Module: {cassandra.protocol.__file__}") + print(f"Cython: {'YES (.so loaded)' if is_cython else 'NO (pure Python .py)'}") + print(f"Config: proto v{PROTO_VERSION}, {ITERATIONS:,} iters, best of {REPEATS}") + print() + print("NOTE: Variants 1-4 are standalone pure-Python functions.") + print( + " They call Cython-compiled helpers (write_value, etc.) when .so is loaded." + ) + print(" 'module' calls the actual loaded module method directly.") + print() + + # Grab the base-class _write_query_params to bypass ExecuteMessage's + # super() overhead — gives a fair comparison with standalone functions. + _module_wqp = _QueryMessage._write_query_params + + for scenario_label, params in SCENARIOS: + msg = make_execute_msg(params) + total_param_bytes = sum(len(p) for p in params) + print(f"=== {scenario_label} (payload: {total_param_bytes:,} bytes) ===") + print() + + # ---- _write_query_params benchmarks ---- + print(" _write_query_params() [core hot path]:") + print(f" {'variant':20s} {'ns/call':>10s} {'vs baseline':>11s}") + print(f" {'-------':20s} {'-------':>10s} {'-----------':>11s}") + + baseline_wqp_ns = None + for var_label, var_fn in WQP_VARIANTS: + ok, err = verify_output( + baseline_write_query_params, var_fn, msg, PROTO_VERSION + ) + if not ok: + print(f" {var_label:20s} MISMATCH: {err}") + continue + ns = bench_fn(var_fn, msg, PROTO_VERSION, ITERATIONS, REPEATS) + if baseline_wqp_ns is None: + baseline_wqp_ns = ns + speedup = baseline_wqp_ns / ns + print(f" {var_label:20s} {ns:8.1f} {speedup:5.2f}x") + + # Module variant (bypass super() for fair comparison) + def module_wqp(m, f, pv): + _module_wqp(m, f, pv) + + ok, err = verify_output( + baseline_write_query_params, module_wqp, msg, PROTO_VERSION + ) + if ok: + ns = bench_fn(module_wqp, msg, PROTO_VERSION, ITERATIONS, REPEATS) + speedup = baseline_wqp_ns / ns if baseline_wqp_ns else 0 + label = "5_module" + (" (cython)" if is_cython else " (py)") + print(f" {label:20s} {ns:8.1f} {speedup:5.2f}x") + else: + print(f" 5_module MISMATCH: {err}") + + print() + + # ---- send_body benchmarks ---- + print(" send_body() [includes query_id framing]:") + print(f" {'variant':20s} {'ns/call':>10s} {'vs baseline':>11s}") + print(f" {'-------':20s} {'-------':>10s} {'-----------':>11s}") + + baseline_sb_ns = None + for var_label, var_fn in SB_VARIANTS: + ok, err = verify_output(baseline_send_body, var_fn, msg, PROTO_VERSION) + if not ok: + print(f" {var_label:20s} MISMATCH: {err}") + continue + ns = bench_fn(var_fn, msg, PROTO_VERSION, ITERATIONS, REPEATS) + if baseline_sb_ns is None: + baseline_sb_ns = ns + speedup = baseline_sb_ns / ns + print(f" {var_label:20s} {ns:8.1f} {speedup:5.2f}x") + + # Module send_body (direct method call, no lambda) + def module_sb(m, f, pv): + m.send_body(f, pv) + + ok, err = verify_output(baseline_send_body, module_sb, msg, PROTO_VERSION) + if ok: + ns = bench_fn(module_sb, msg, PROTO_VERSION, ITERATIONS, REPEATS) + speedup = baseline_sb_ns / ns if baseline_sb_ns else 0 + label = "5_module" + (" (cython)" if is_cython else " (py)") + print(f" {label:20s} {ns:8.1f} {speedup:5.2f}x") + else: + print(f" 5_module MISMATCH: {err}") + + print() + + # ---- encode_message benchmark (full wire frame) ---- + print(" encode_message() [full wire frame]:") + + def run_encode(): + return ProtocolHandler.encode_message( + msg, + stream_id=1, + protocol_version=PROTO_VERSION, + compressor=None, + allow_beta_protocol_version=False, + ) + + ref_frame = run_encode() + t = timeit.repeat( + run_encode, number=ITERATIONS, repeat=REPEATS, timer=time.process_time + ) + enc_ns = min(t) / ITERATIONS * 1e9 + print(f" {'current':20s} {enc_ns:8.1f} (frame: {len(ref_frame)} bytes)") + print() + print() + + +if __name__ == "__main__": + main() diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 4628c7ee0e..ab27c89ead 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -587,9 +587,20 @@ def _write_query_params(self, f, protocol_version): write_byte(f, flags) if self.query_params is not None: - write_short(f, len(self.query_params)) + # Accumulate param bytes in a list and write once instead of + # 2*N+1 separate f.write() calls via write_value(). + _int32_pack = int32_pack + parts = [uint16_pack(len(self.query_params))] + _parts_append = parts.append for param in self.query_params: - write_value(f, param) + if param is None: + _parts_append(_int32_pack(-1)) + elif param is _UNSET_VALUE: + _parts_append(_int32_pack(-2)) + else: + _parts_append(_int32_pack(len(param))) + _parts_append(param) + f.write(b"".join(parts)) if self.fetch_size: write_int(f, self.fetch_size) if self.paging_state: @@ -635,8 +646,8 @@ def __init__(self, query_id, query_params, consistency_level, super(ExecuteMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size, paging_state, timestamp, skip_meta, continuous_paging_options) - def _write_query_params(self, f, protocol_version): - super(ExecuteMessage, self)._write_query_params(f, protocol_version) + # _write_query_params inherited from _QueryMessage; removed redundant + # pass-through override to avoid extra MRO lookup per call. def send_body(self, f, protocol_version): write_string(f, self.query_id) diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..b165153b3b 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -15,16 +15,19 @@ import unittest from unittest.mock import Mock +import io +import struct from cassandra import ProtocolVersion, UnsupportedOperation from cassandra.protocol import ( PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + BatchMessage, + _UNSET_VALUE, write_value, ProtocolHandler ) from cassandra.query import BatchType -from cassandra.marshal import uint32_unpack +from cassandra.marshal import uint32_unpack, int32_pack, uint16_pack from cassandra.cluster import ContinuousPagingOptions import pytest @@ -189,3 +192,184 @@ def test_batch_message_with_keyspace(self): (b'\x00\x03',), (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) ) + +class WriteQueryParamsBufferAccumulationTest(unittest.TestCase): + """ + Tests for the buffer accumulation optimization in + _QueryMessage._write_query_params(). + + The optimization replaces per-parameter write_value(f, param) calls with + list.append + b"".join + single f.write(). These tests verify the + serialized bytes are identical to the original write_value() behaviour. + """ + + # -- helpers ---------------------------------------------------------- + + @staticmethod + def _reference_write_value_bytes(params): + """Build expected bytes using the original write_value() function.""" + buf = io.BytesIO() + buf.write(uint16_pack(len(params))) + for p in params: + write_value(buf, p) + return buf.getvalue() + + @staticmethod + def _execute_msg_bytes(msg, protocol_version): + """Serialize an ExecuteMessage and return the raw bytes.""" + buf = io.BytesIO() + msg.send_body(buf, protocol_version) + return buf.getvalue() + + # -- basic write_value parity ----------------------------------------- + + def test_normal_params(self): + """Normal (non-NULL, non-UNSET) byte-string parameters.""" + params = [b'hello', b'world', b'\x00\x01\x02'] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_null_params(self): + """NULL parameters must serialize as int32(-1).""" + params = [None, None] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_unset_params(self): + """UNSET parameters must serialize as int32(-2).""" + params = [_UNSET_VALUE, _UNSET_VALUE] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=4) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_mixed_params(self): + """Mix of normal, NULL and UNSET params in one message.""" + params = [b'data', None, _UNSET_VALUE, b'more', None] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_empty_bytes_param(self): + """An empty bytes value (length 0) must differ from NULL (length -1).""" + params = [b''] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + # Verify it's NOT serialized as NULL + null_bytes = int32_pack(-1) + param_section_start = raw.find(expected) + param_section = raw[param_section_start:param_section_start + len(expected)] + self.assertNotIn(null_bytes, param_section[2:]) # skip the uint16 count + + def test_empty_query_params_list(self): + """An empty params list should write count=0 and nothing else.""" + params = [] + expected = self._reference_write_value_bytes(params) + self.assertEqual(expected, uint16_pack(0)) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_none_query_params(self): + """When query_params is None, no param block should be written.""" + msg1 = ExecuteMessage(query_id=b'qid', query_params=None, + consistency_level=1) + msg2 = ExecuteMessage(query_id=b'qid', query_params=[b'x'], + consistency_level=1) + raw1 = self._execute_msg_bytes(msg1, protocol_version=4) + raw2 = self._execute_msg_bytes(msg2, protocol_version=4) + # raw1 should be shorter (no param section) + self.assertLess(len(raw1), len(raw2)) + + def test_large_vector_param(self): + """Large parameter simulating a high-dimensional vector embedding.""" + # 768-dimensional float32 vector = 3072 bytes + vector_bytes = struct.pack('768f', *([0.123456] * 768)) + params = [vector_bytes] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_query_message_with_params(self): + """QueryMessage (not just ExecuteMessage) uses the same code path.""" + params = [b'val1', None, b'val2'] + expected = self._reference_write_value_bytes(params) + msg = QueryMessage(query='SELECT * FROM t WHERE k=? AND v=? AND w=?', + consistency_level=1, + query_params=params) + raw = io.BytesIO() + msg.send_body(raw, protocol_version=4) + self.assertIn(expected, raw.getvalue()) + + def test_proto_v3_vs_v4_params(self): + """The param encoding should be identical across protocol versions.""" + params = [b'abc', None, b'xyz'] + msg_v3 = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + msg_v4 = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw_v3 = self._execute_msg_bytes(msg_v3, protocol_version=3) + raw_v4 = self._execute_msg_bytes(msg_v4, protocol_version=4) + expected = self._reference_write_value_bytes(params) + self.assertIn(expected, raw_v3) + self.assertIn(expected, raw_v4) + + def test_encode_message_roundtrip(self): + """Full encode_message path exercises header + body framing.""" + params = [b'roundtrip'] + msg = QueryMessage(query='SELECT 1', + consistency_level=1, + query_params=params) + # encode_message returns the full on-wire frame + frame = ProtocolHandler.encode_message(msg, stream_id=1, + protocol_version=4, + compressor=None, + allow_beta_protocol_version=False) + # The frame should contain the param bytes somewhere inside + expected_param_bytes = self._reference_write_value_bytes(params) + # frame may be memoryview/bytearray; convert to bytes for assertIn + frame_bytes = bytes(frame) + self.assertIn(expected_param_bytes, frame_bytes) + + def test_many_params(self): + """50 parameters to exercise the accumulation loop at scale.""" + params = [b'param_%03d' % i for i in range(50)] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_single_null_param(self): + """Regression: a single NULL param should serialize correctly.""" + params = [None] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_single_unset_param(self): + """Regression: a single UNSET param should serialize correctly.""" + params = [_UNSET_VALUE] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=4) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + From 76cae837b2b08afa6f2abb230243e4b0725d7347 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sat, 4 Apr 2026 18:50:15 +0300 Subject: [PATCH 2/3] perf: buffer accumulation in BatchMessage.send_body() Replace per-write_value()/write_byte()/write_short() calls in BatchMessage.send_body() with buffer accumulation (list.append + b"".join + single f.write()), reducing f.write() calls from Q*(4 + 2*P) + footer to 1 for Q queries with P params each. Benchmark results (Python 3.14, Cython .so, 50K iters, best of 3, quiet machine): Scenario Before After Speedup 10 queries x 2 params (128D vec) 8364 ns 4475 ns 1.87x 10 queries x 2 params (768D vec) 8081 ns 5516 ns 1.47x 50 queries x 2 params (128D vec) 32368 ns 16271 ns 1.99x 10 queries x 10 text params 19138 ns 9051 ns 2.11x 50 queries x 10 text params 86845 ns 40020 ns 2.17x 10 unprepared x 2 params 8666 ns 4252 ns 2.04x Also updates test_batch_message_with_keyspace to use BytesIO for byte-level verification (compatible with single-write output). Adds 7 batch-specific unit tests covering prepared, unprepared, mixed, empty, many-query, NULL/UNSET, and vector parameter scenarios. Includes benchmark script benchmarks/bench_batch_send_body.py. --- benchmarks/bench_batch_send_body.py | 146 ++++++++++++++++++++++++++++ cassandra/protocol.py | 51 +++++++--- tests/unit/test_protocol.py | 121 ++++++++++++++++++++--- 3 files changed, 291 insertions(+), 27 deletions(-) create mode 100644 benchmarks/bench_batch_send_body.py diff --git a/benchmarks/bench_batch_send_body.py b/benchmarks/bench_batch_send_body.py new file mode 100644 index 0000000000..aa27d7c888 --- /dev/null +++ b/benchmarks/bench_batch_send_body.py @@ -0,0 +1,146 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark: BatchMessage.send_body() for vector and scalar workloads. + +Measures the actual loaded module's BatchMessage.send_body() method. +Run this before and after optimization to compare. + +Usage: + # Build baseline .so, then: + python benchmarks/bench_batch_send_body.py + # Apply optimization, rebuild .so, then: + python benchmarks/bench_batch_send_body.py +""" + +import io +import struct +import time +import timeit +import sys +import os + +# Ensure the repo root is importable +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import cassandra.protocol +from cassandra.protocol import BatchMessage +from cassandra.query import BatchType +from cassandra.marshal import int32_pack + + +# --------------------------------------------------------------------------- +# Scenario builders +# --------------------------------------------------------------------------- + + +def make_batch_vector_queries(num_queries, dim): + """Batch of prepared INSERT with (int32_key, float_vector) params.""" + vector_bytes = struct.pack(f">{dim}f", *([0.1] * dim)) + key_bytes = int32_pack(42) + return [ + (True, b"\x01\x02\x03\x04\x05\x06\x07\x08", [key_bytes, vector_bytes]) + for _ in range(num_queries) + ] + + +def make_batch_scalar_queries(num_queries, num_params, param_size=20): + """Batch of prepared INSERT with N text columns of param_size bytes.""" + params = [b"\x41" * param_size for _ in range(num_params)] + return [ + (True, b"\x01\x02\x03\x04\x05\x06\x07\x08", list(params)) + for _ in range(num_queries) + ] + + +def make_batch_unprepared_queries(num_queries, num_params, param_size=20): + """Batch of unprepared INSERT statements.""" + stmt = "INSERT INTO ks.tbl (k, v) VALUES (?, ?)" + params = [b"\x41" * param_size for _ in range(num_params)] + return [(False, stmt, list(params)) for _ in range(num_queries)] + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +PROTO_VERSION = 4 +ITERATIONS = 50_000 +REPEATS = 3 + +SCENARIOS = [ + ("10 queries x 2 params (128D vec)", make_batch_vector_queries(10, 128)), + ("10 queries x 2 params (768D vec)", make_batch_vector_queries(10, 768)), + ("50 queries x 2 params (128D vec)", make_batch_vector_queries(50, 128)), + ("10 queries x 10 text params", make_batch_scalar_queries(10, 10, 20)), + ("50 queries x 10 text params", make_batch_scalar_queries(50, 10, 20)), + ("10 unprepared x 2 params", make_batch_unprepared_queries(10, 2, 20)), +] + + +# --------------------------------------------------------------------------- +# Benchmark +# --------------------------------------------------------------------------- + + +def bench_batch(queries, iterations, repeats): + """Benchmark BatchMessage.send_body(), return best ns/call.""" + msg = BatchMessage( + batch_type=BatchType.LOGGED, + queries=queries, + consistency_level=1, + timestamp=1234567890123456, + ) + f = io.BytesIO() + + def run(): + f.seek(0) + f.truncate() + msg.send_body(f, PROTO_VERSION) + + t = timeit.repeat(run, number=iterations, repeat=repeats, timer=time.process_time) + return min(t) / iterations * 1e9 + + +def main(): + is_cython = cassandra.protocol.__file__.endswith(".so") + print(f"Python: {sys.version.split()[0]}") + print(f"Module: {cassandra.protocol.__file__}") + print(f"Cython: {'YES (.so loaded)' if is_cython else 'NO (pure Python .py)'}") + print(f"Config: proto v{PROTO_VERSION}, {ITERATIONS:,} iters, best of {REPEATS}") + print() + print(f"{'Scenario':45s} {'ns/call':>10s} {'bytes':>8s}") + print(f"{'-' * 45} {'-' * 10} {'-' * 8}") + + for label, queries in SCENARIOS: + # Measure output size + msg = BatchMessage( + batch_type=BatchType.LOGGED, + queries=queries, + consistency_level=1, + timestamp=1234567890123456, + ) + f = io.BytesIO() + msg.send_body(f, PROTO_VERSION) + nbytes = len(f.getvalue()) + + ns = bench_batch(queries, ITERATIONS, REPEATS) + print(f"{label:45s} {ns:8.1f} {nbytes:>6d}") + + print() + + +if __name__ == "__main__": + main() diff --git a/cassandra/protocol.py b/cassandra/protocol.py index ab27c89ead..2121d9aa48 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -923,21 +923,36 @@ def __init__(self, batch_type, queries, consistency_level, self.keyspace = keyspace def send_body(self, f, protocol_version): - write_byte(f, self.batch_type.value) - write_short(f, len(self.queries)) + # Buffer accumulation: collect all bytes and write once. + _i32 = int32_pack + _u16 = uint16_pack + _u8 = uint8_pack + parts = [_u8(self.batch_type.value), _u16(len(self.queries))] + _p = parts.append for prepared, string_or_query_id, params in self.queries: if not prepared: - write_byte(f, 0) - write_longstring(f, string_or_query_id) + _p(_u8(0)) + if isinstance(string_or_query_id, str): + string_or_query_id = string_or_query_id.encode('utf8') + _p(_i32(len(string_or_query_id))) + _p(string_or_query_id) else: - write_byte(f, 1) - write_short(f, len(string_or_query_id)) - f.write(string_or_query_id) - write_short(f, len(params)) + _p(_u8(1)) + _p(_u16(len(string_or_query_id))) + _p(string_or_query_id) + _p(_u16(len(params))) for param in params: - write_value(f, param) + if param is None: + _p(_i32(-1)) + elif param is _UNSET_VALUE: + _p(_i32(-2)) + else: + if isinstance(param, str): + param = param.encode('utf8') + _p(_i32(len(param))) + _p(param) - write_consistency_level(f, self.consistency_level) + _p(_u16(self.consistency_level)) flags = 0 if self.serial_consistency_level: flags |= _WITH_SERIAL_CONSISTENCY_FLAG @@ -951,18 +966,24 @@ def send_body(self, f, protocol_version): "Keyspaces may only be set on queries with protocol version " "5 or higher. Consider setting Cluster.protocol_version to 5.") if ProtocolVersion.uses_int_query_flags(protocol_version): - write_int(f, flags) + _p(_i32(flags)) else: - write_byte(f, flags) + _p(_u8(flags)) if self.serial_consistency_level: - write_consistency_level(f, self.serial_consistency_level) + _p(_u16(self.serial_consistency_level)) if self.timestamp is not None: - write_long(f, self.timestamp) + _p(uint64_pack(self.timestamp)) if ProtocolVersion.uses_keyspace_flag(protocol_version): if self.keyspace is not None: - write_string(f, self.keyspace) + ks = self.keyspace + if isinstance(ks, str): + ks = ks.encode('utf8') + _p(_u16(len(ks))) + _p(ks) + + f.write(b"".join(parts)) known_event_types = frozenset(( diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index b165153b3b..c1c5cda4fc 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -170,7 +170,7 @@ def test_keyspace_written_with_length(self): def test_batch_message_with_keyspace(self): self.maxDiff = None - io = Mock(name='io') + buf = io.BytesIO() batch = BatchMessage( batch_type=BatchType.LOGGED, queries=((False, 'stmt a', ('param a',)), @@ -180,18 +180,28 @@ def test_batch_message_with_keyspace(self): consistency_level=3, keyspace='ks' ) - batch.send_body(io, protocol_version=5) - self._check_calls(io, - ((b'\x00',), (b'\x00\x03',), (b'\x00',), - (b'\x00\x00\x00\x06',), (b'stmt a',), - (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param a',), - (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt b',), - (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param b',), - (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt c',), - (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param c',), - (b'\x00\x03',), - (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) + batch.send_body(buf, protocol_version=5) + expected = ( + b'\x00' # batch type LOGGED + b'\x00\x03' # 3 queries + b'\x00' # not prepared + b'\x00\x00\x00\x06' b'stmt a' # longstring 'stmt a' + b'\x00\x01' # 1 param + b'\x00\x00\x00\x07' b'param a' # write_value 'param a' + b'\x00' # not prepared + b'\x00\x00\x00\x06' b'stmt b' + b'\x00\x01' + b'\x00\x00\x00\x07' b'param b' + b'\x00' + b'\x00\x00\x00\x06' b'stmt c' + b'\x00\x01' + b'\x00\x00\x00\x07' b'param c' + b'\x00\x03' # consistency level + b'\x00\x00\x00\x80' # flags (keyspace) + b'\x00\x02' b'ks' # keyspace ) + self.assertEqual(buf.getvalue(), expected) + class WriteQueryParamsBufferAccumulationTest(unittest.TestCase): """ @@ -373,3 +383,90 @@ def test_single_unset_param(self): raw = self._execute_msg_bytes(msg, protocol_version=4) self.assertIn(expected, raw) + # -- BatchMessage buffer accumulation tests --------------------------- + + @staticmethod + def _batch_msg_bytes(queries, protocol_version=4, **kwargs): + """Serialize a BatchMessage and return the raw bytes.""" + msg = BatchMessage(batch_type=BatchType.LOGGED, queries=queries, + consistency_level=1, **kwargs) + buf = io.BytesIO() + msg.send_body(buf, protocol_version) + return buf.getvalue() + + def test_batch_prepared_queries_with_params(self): + """Batch of prepared queries with byte params serializes correctly.""" + queries = [ + (True, b'\x01\x02\x03\x04', [b'val1', b'val2']), + (True, b'\x01\x02\x03\x04', [b'val3', None]), + ] + raw = self._batch_msg_bytes(queries) + self.assertIn(b'val1', raw) + self.assertIn(b'val2', raw) + self.assertIn(b'val3', raw) + self.assertIn(int32_pack(-1), raw) # NULL + + def test_batch_unprepared_queries(self): + """Batch of unprepared (string) queries serializes correctly.""" + queries = [ + (False, 'INSERT INTO t (k) VALUES (?)', [b'\x01']), + (False, 'INSERT INTO t (k) VALUES (?)', [b'\x02']), + ] + raw = self._batch_msg_bytes(queries) + self.assertIn(b'INSERT INTO t (k) VALUES (?)', raw) + + def test_batch_mixed_prepared_unprepared(self): + """Batch mixing prepared and unprepared queries.""" + queries = [ + (False, 'SELECT 1', []), + (True, b'\xab\xcd', [b'data']), + ] + raw = self._batch_msg_bytes(queries) + self.assertIn(b'SELECT 1', raw) + self.assertIn(b'data', raw) + + def test_batch_empty_queries(self): + """Batch with zero queries.""" + raw = self._batch_msg_bytes([]) + self.assertIn(uint16_pack(0), raw) + + def test_batch_many_queries(self): + """Batch with 50 queries to exercise accumulation at scale.""" + queries = [ + (True, b'\x01\x02', [b'param_%03d' % i]) + for i in range(50) + ] + raw = self._batch_msg_bytes(queries) + self.assertIn(uint16_pack(50), raw) + for i in range(50): + self.assertIn(b'param_%03d' % i, raw) + + def test_batch_null_and_unset_params(self): + """Batch params with NULL and UNSET values.""" + queries = [ + (True, b'\x01', [None, _UNSET_VALUE, b'ok']), + ] + raw = self._batch_msg_bytes(queries, protocol_version=4) + self.assertIn(int32_pack(-1), raw) # NULL + self.assertIn(int32_pack(-2), raw) # UNSET + self.assertIn(b'ok', raw) + + def test_batch_vector_params(self): + """Batch with large vector params (simulating bulk vector INSERT).""" + vector = struct.pack('128f', *([0.5] * 128)) + queries = [ + (True, b'\x01\x02', [int32_pack(i), vector]) + for i in range(10) + ] + raw = self._batch_msg_bytes(queries) + # 10 copies of the vector should appear + count = 0 + start = 0 + while True: + idx = raw.find(vector, start) + if idx == -1: + break + count += 1 + start = idx + 1 + self.assertEqual(count, 10) + From 62f91ebb97786f9d6f58bcb3e7aa2d0ae731c7a5 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 5 Apr 2026 18:29:51 +0300 Subject: [PATCH 3/3] perf: pre-compute int32_pack constants for null/unset markers Replace per-call int32_pack(-1) and int32_pack(-2) with module-level _INT32_NEG1 and _INT32_NEG2 constants. Avoids redundant struct packing on every null or unset parameter in the inner write_value loop. Benchmark: ~11% speedup on the parameter serialization loop for a typical 12-param mix of values, nulls, and unsets. --- benchmarks/bench_protocol_write_value.py | 89 ++++++++++++++++++++++++ cassandra/protocol.py | 12 ++-- 2 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 benchmarks/bench_protocol_write_value.py diff --git a/benchmarks/bench_protocol_write_value.py b/benchmarks/bench_protocol_write_value.py new file mode 100644 index 0000000000..f16ebfa752 --- /dev/null +++ b/benchmarks/bench_protocol_write_value.py @@ -0,0 +1,89 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: pre-computed int32_pack constants for null/unset markers. + +Measures the speedup from using pre-computed _INT32_NEG1/_INT32_NEG2 +constants vs calling int32_pack(-1)/int32_pack(-2) per parameter. + +Run: + python benchmarks/bench_protocol_write_value.py +""" +import timeit +import struct + +int32_pack = struct.Struct('>i').pack + +# Pre-computed constants (the optimization) +_INT32_NEG1 = int32_pack(-1) +_INT32_NEG2 = int32_pack(-2) + +_UNSET_VALUE = object() + + +def bench_write_value_constants(): + """Benchmark pre-computed constants vs per-call int32_pack.""" + # Simulate a typical parameter list: mix of values, nulls, and unsets + params = [b'\x00\x00\x00\x01'] * 8 + [None] * 3 + [_UNSET_VALUE] * 1 + + def run_precomputed(): + parts = [] + _parts_append = parts.append + _i32 = int32_pack + for param in params: + if param is None: + _parts_append(_INT32_NEG1) + elif param is _UNSET_VALUE: + _parts_append(_INT32_NEG2) + else: + _parts_append(_i32(len(param))) + _parts_append(param) + return b"".join(parts) + + def run_pack_each_time(): + parts = [] + _parts_append = parts.append + _i32 = int32_pack + for param in params: + if param is None: + _parts_append(_i32(-1)) + elif param is _UNSET_VALUE: + _parts_append(_i32(-2)) + else: + _parts_append(_i32(len(param))) + _parts_append(param) + return b"".join(parts) + + # Verify identical output + assert run_precomputed() == run_pack_each_time() + + n = 500_000 + t_precomputed = timeit.timeit(run_precomputed, number=n) + t_pack = timeit.timeit(run_pack_each_time, number=n) + + print(f"Pre-computed constants ({n} iters, {len(params)} params): " + f"{t_precomputed:.3f}s ({t_precomputed / n * 1e6:.2f} us/call)") + print(f"Pack each time ({n} iters, {len(params)} params): " + f"{t_pack:.3f}s ({t_pack / n * 1e6:.2f} us/call)") + speedup = t_pack / t_precomputed + print(f"Speedup: {speedup:.2f}x") + + +def main(): + bench_write_value_constants() + + +if __name__ == '__main__': + main() diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 2121d9aa48..4ffc096c27 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -69,6 +69,10 @@ class InternalError(Exception): _UNSET_VALUE = object() +# Pre-computed packed constants for null/unset markers +_INT32_NEG1 = int32_pack(-1) # null value marker +_INT32_NEG2 = int32_pack(-2) # unset value marker + def register_class(cls): _message_types_by_opcode[cls.opcode] = cls @@ -594,9 +598,9 @@ def _write_query_params(self, f, protocol_version): _parts_append = parts.append for param in self.query_params: if param is None: - _parts_append(_int32_pack(-1)) + _parts_append(_INT32_NEG1) elif param is _UNSET_VALUE: - _parts_append(_int32_pack(-2)) + _parts_append(_INT32_NEG2) else: _parts_append(_int32_pack(len(param))) _parts_append(param) @@ -943,9 +947,9 @@ def send_body(self, f, protocol_version): _p(_u16(len(params))) for param in params: if param is None: - _p(_i32(-1)) + _p(_INT32_NEG1) elif param is _UNSET_VALUE: - _p(_i32(-2)) + _p(_INT32_NEG2) else: if isinstance(param, str): param = param.encode('utf8')