diff --git a/benchmarks/micro/bench_bind_no_encryption.py b/benchmarks/micro/bench_bind_no_encryption.py new file mode 100644 index 0000000000..dfefc508e0 --- /dev/null +++ b/benchmarks/micro/bench_bind_no_encryption.py @@ -0,0 +1,106 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: BoundStatement.bind() fast path without column encryption. + +Measures the improvement from skipping ColDesc namedtuple creation and +ce_policy checks when column_encryption_policy is None (the common case). + +Run: + python benchmarks/bench_bind_no_encryption.py +""" + +import datetime +import sys +import timeit +from unittest.mock import MagicMock + +from cassandra.query import BoundStatement, PreparedStatement +from cassandra.cqltypes import ( + DateType, Int32Type, DoubleType, FloatType, UTF8Type, + BooleanType, LongType, +) + + +def make_prepared_statement(col_names, col_types): + """Build a mock PreparedStatement with the given columns.""" + col_meta = [] + for name, ctype in zip(col_names, col_types): + cm = MagicMock() + cm.name = name + cm.keyspace_name = 'ks' + cm.table_name = 'metrics' + cm.type = ctype + col_meta.append(cm) + + ps = MagicMock(spec=PreparedStatement) + ps.column_metadata = col_meta + ps.routing_key_indexes = None + ps.protocol_version = 4 + ps.column_encryption_policy = None + ps.serial_consistency_level = None + ps.retry_policy = None + ps.consistency_level = None + ps.fetch_size = None + ps.custom_payload = None + ps.is_idempotent = False + return ps + + +def bench(): + schemas = [ + ( + "3-col (int, double, text)", + ['id', 'value', 'tag'], + [Int32Type, DoubleType, UTF8Type], + [42, 3.14159, 'sensor-001'], + ), + ( + "5-col time-series", + ['ts', 'sensor_id', 'value', 'quality', 'tag'], + [DateType, Int32Type, DoubleType, FloatType, UTF8Type], + [datetime.datetime(2025, 4, 5, 12, 0, 0, 123456), 42, 3.14, 0.95, 'alpha'], + ), + ( + "8-col wide row", + ['ts', 'id', 'v1', 'v2', 'v3', 'v4', 'flag', 'name'], + [DateType, LongType, DoubleType, DoubleType, FloatType, FloatType, BooleanType, UTF8Type], + [datetime.datetime(2025, 1, 1), 12345678, 1.1, 2.2, 3.3, 4.4, True, 'test-row'], + ), + ] + + n = 200_000 + print(f"=== BoundStatement.bind() no-encryption fast path ({n:,} iters) ===\n") + + for label, col_names, col_types, row in schemas: + ps = make_prepared_statement(col_names, col_types) + + def do_bind(): + bs = BoundStatement(ps) + bs.bind(row) + + # Warmup + for _ in range(1000): + do_bind() + + t = timeit.timeit(do_bind, number=n) + ns_per = t / n * 1e9 + print(f" {label}:") + print(f" {ns_per:.1f} ns/call ({n:,} iters)") + + +if __name__ == "__main__": + print(f"Python {sys.version}\n") + bench() diff --git a/benchmarks/micro/bench_cql_parameterized_type.py b/benchmarks/micro/bench_cql_parameterized_type.py new file mode 100644 index 0000000000..1bdb93490a --- /dev/null +++ b/benchmarks/micro/bench_cql_parameterized_type.py @@ -0,0 +1,86 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: cql_parameterized_type memoization. + +Measures the cost of building the CQL type string representation +with and without memoization for various type complexities. + +Run: + python benchmarks/bench_cql_parameterized_type.py +""" + +import sys +import timeit + +from cassandra.cqltypes import ( + MapType, SetType, ListType, TupleType, + Int32Type, UTF8Type, FloatType, DoubleType, BooleanType, + _CassandraType, +) + + +def bench(): + # Create parameterized types + map_type = MapType.apply_parameters([UTF8Type, Int32Type]) + set_type = SetType.apply_parameters([FloatType]) + list_type = ListType.apply_parameters([DoubleType]) + tuple_type = TupleType.apply_parameters([Int32Type, UTF8Type, BooleanType]) + nested_type = MapType.apply_parameters([ + UTF8Type, + ListType.apply_parameters([ + TupleType.apply_parameters([Int32Type, FloatType, DoubleType]) + ]) + ]) + + test_types = [ + ("Int32Type (simple)", Int32Type), + ("MapType", map_type), + ("SetType", set_type), + ("ListType", list_type), + ("TupleType", tuple_type), + ("MapType>>", nested_type), + ] + + n = 500_000 + print(f"=== cql_parameterized_type ({n:,} iters) ===\n") + + for label, typ in test_types: + # Clear cache to measure uncached + typ._cql_type_str = None + # One call to populate cache + result = typ.cql_parameterized_type() + + # Measure cached (warm) + t_cached = timeit.timeit(typ.cql_parameterized_type, number=n) + + # Measure uncached (cold) + def uncached(): + typ._cql_type_str = None + return typ.cql_parameterized_type() + t_uncached = timeit.timeit(uncached, number=n) + + saving_ns = (t_uncached - t_cached) / n * 1e9 + speedup = t_uncached / t_cached if t_cached > 0 else float('inf') + print(f" {label}:") + print(f" result: {result}") + print(f" uncached: {t_uncached / n * 1e9:.1f} ns, " + f"cached: {t_cached / n * 1e9:.1f} ns, " + f"saving: {saving_ns:.1f} ns ({speedup:.1f}x)") + + +if __name__ == "__main__": + print(f"Python {sys.version}\n") + bench() diff --git a/benchmarks/micro/bench_timeseries.py b/benchmarks/micro/bench_timeseries.py new file mode 100644 index 0000000000..bcb8e90c8f --- /dev/null +++ b/benchmarks/micro/bench_timeseries.py @@ -0,0 +1,273 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 +""" +Microbenchmarks for time-series write and read hot paths. + +Covers: + - DateType.serialize / deserialize + - varint_pack / varint_unpack + - MonotonicTimestampGenerator + - BoundStatement.bind() for a typical time-series schema + +All results in nanoseconds per call. Run with: + python benchmarks/bench_timeseries.py +""" + +import datetime +import struct +import sys +import threading +import time +import timeit +import uuid + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +WARMUP = 50_000 +ITERATIONS = 500_000 + + +def bench(label, stmt, setup="pass", number=ITERATIONS, warmup=WARMUP): + """Run *stmt* under *setup*, return ns/call and print a line.""" + globs = {} + exec(setup, globs) + # warmup + t_code = compile(stmt, "", "exec") + for _ in range(warmup): + exec(t_code, globs) + # measure + timer = timeit.Timer(stmt, setup, globals=globs) + raw = timer.timeit(number=number) + ns = raw / number * 1e9 + print(f" {label:.<60s} {ns:>9.1f} ns/call") + return ns + + +# --------------------------------------------------------------------------- +# DateType.serialize / deserialize +# --------------------------------------------------------------------------- + + +def bench_datetype(): + print("\n=== DateType.serialize ===") + setup = """\ +from cassandra.cqltypes import DateType +import datetime +dt_now = datetime.datetime(2025, 4, 5, 12, 0, 0, 123456) +dt_epoch = datetime.datetime(1970, 1, 1, 0, 0, 1, 0) +dt_far = datetime.datetime(2300, 1, 1, 0, 0, 0, 1000) +d_only = datetime.date(2025, 4, 5) +ts_int = 1712318400000 +""" + bench("serialize datetime (2025)", "DateType.serialize(dt_now, 4)", setup) + bench("serialize datetime (epoch)", "DateType.serialize(dt_epoch, 4)", setup) + bench("serialize datetime (2300)", "DateType.serialize(dt_far, 4)", setup) + bench("serialize date object", "DateType.serialize(d_only, 4)", setup) + bench("serialize raw int timestamp", "DateType.serialize(ts_int, 4)", setup) + + print("\n=== DateType.deserialize ===") + setup_deser = ( + setup + + """\ +packed_now = DateType.serialize(dt_now, 4) +packed_far = DateType.serialize(dt_far, 4) +""" + ) + bench("deserialize (2025)", "DateType.deserialize(packed_now, 4)", setup_deser) + bench("deserialize (2300)", "DateType.deserialize(packed_far, 4)", setup_deser) + + # Cython serializer (if available) + try: + from cassandra.serializers import SerDateType # noqa: F401 + + print("\n=== SerDateType (Cython) ===") + setup_cy = """\ +from cassandra.serializers import SerDateType +from cassandra.cqltypes import DateType +import datetime +ser = SerDateType(DateType) +dt_now = datetime.datetime(2025, 4, 5, 12, 0, 0, 123456) +dt_epoch = datetime.datetime(1970, 1, 1, 0, 0, 1, 0) +d_only = datetime.date(2025, 4, 5) +ts_int = 1712318400000 +""" + bench("Cython serialize datetime (2025)", "ser.serialize(dt_now, 4)", setup_cy) + bench( + "Cython serialize datetime (epoch)", "ser.serialize(dt_epoch, 4)", setup_cy + ) + bench("Cython serialize date object", "ser.serialize(d_only, 4)", setup_cy) + bench("Cython serialize raw int", "ser.serialize(ts_int, 4)", setup_cy) + except ImportError: + print("\n(serializers.pyx not compiled — skipping Cython benchmark)") + + +# --------------------------------------------------------------------------- +# varint_pack / varint_unpack +# --------------------------------------------------------------------------- + + +def bench_varint(): + print("\n=== varint_pack ===") + setup = """\ +from cassandra.marshal import varint_pack, varint_unpack +small = 42 +medium = 2**62 +large = 2**127 +negative = -(2**62) +zero = 0 +""" + bench("varint_pack zero", "varint_pack(zero)", setup) + bench("varint_pack small", "varint_pack(small)", setup) + bench("varint_pack medium", "varint_pack(medium)", setup) + bench("varint_pack large", "varint_pack(large)", setup) + bench("varint_pack negative", "varint_pack(negative)", setup) + + print("\n=== varint_unpack ===") + setup_u = ( + setup + + """\ +packed_small = varint_pack(small) +packed_medium = varint_pack(medium) +packed_large = varint_pack(large) +packed_negative = varint_pack(negative) +packed_zero = varint_pack(zero) +""" + ) + bench("varint_unpack zero", "varint_unpack(packed_zero)", setup_u) + bench("varint_unpack small", "varint_unpack(packed_small)", setup_u) + bench("varint_unpack medium", "varint_unpack(packed_medium)", setup_u) + bench("varint_unpack large", "varint_unpack(packed_large)", setup_u) + bench("varint_unpack negative", "varint_unpack(packed_negative)", setup_u) + + +# --------------------------------------------------------------------------- +# MonotonicTimestampGenerator +# --------------------------------------------------------------------------- + + +def bench_timestamp_generator(): + print("\n=== MonotonicTimestampGenerator (single-thread) ===") + setup = """\ +from cassandra.timestamps import MonotonicTimestampGenerator +gen = MonotonicTimestampGenerator() +""" + bench("generator call", "gen()", setup) + + print("\n=== MonotonicTimestampGenerator (4-thread contention) ===") + from cassandra.timestamps import MonotonicTimestampGenerator + + gen = MonotonicTimestampGenerator() + n_threads = 4 + calls_per_thread = ITERATIONS // n_threads + barrier = threading.Barrier(n_threads + 1) + + elapsed = [] + + def worker(): + barrier.wait() + t0 = time.perf_counter_ns() + for _ in range(calls_per_thread): + gen() + elapsed.append(time.perf_counter_ns() - t0) + barrier.wait() + + threads = [threading.Thread(target=worker) for _ in range(n_threads)] + for t in threads: + t.start() + barrier.wait() # release all workers + barrier.wait() # wait for all to finish + for t in threads: + t.join() + + total_calls = n_threads * calls_per_thread + wall_ns = max(elapsed) + ns_per_call = wall_ns / calls_per_thread # per-thread throughput + print(f" {'contended (4 threads, per-thread)':.<60s} {ns_per_call:>9.1f} ns/call") + throughput = total_calls / (wall_ns / 1e9) + print(f" {'aggregate throughput':.<60s} {throughput:>9.0f} calls/sec") + + +# --------------------------------------------------------------------------- +# BoundStatement.bind() — typical time-series schema +# --------------------------------------------------------------------------- + + +def bench_bind(): + print("\n=== BoundStatement.bind (time-series schema) ===") + setup = """\ +import datetime +from cassandra.query import BoundStatement, PreparedStatement +from cassandra.cqltypes import ( + DateType, Int32Type, DoubleType, FloatType, UTF8Type, +) +from cassandra.protocol import ProtocolVersion +from unittest.mock import MagicMock + +# Build a mock PreparedStatement with 5 columns: +# (ts timestamp, sensor_id int, value double, quality float, tag text) +col_types = [DateType, Int32Type, DoubleType, FloatType, UTF8Type] +col_names = ['ts', 'sensor_id', 'value', 'quality', 'tag'] + +col_meta = [] +for name, ctype in zip(col_names, col_types): + cm = MagicMock() + cm.name = name + cm.keyspace_name = 'ks' + cm.table_name = 'metrics' + cm.type = ctype + col_meta.append(cm) + +ps = MagicMock(spec=PreparedStatement) +ps.column_metadata = col_meta +ps.routing_key_indexes = None +ps.protocol_version = 4 +ps.column_encryption_policy = None +ps.serial_consistency_level = None +ps.retry_policy = None +ps.consistency_level = None +ps.fetch_size = None +ps.custom_payload = None +ps.is_idempotent = False + +dt = datetime.datetime(2025, 4, 5, 12, 0, 0, 123456) +row = [dt, 42, 3.14159, 0.95, 'sensor-alpha-001'] +""" + bench( + "bind 5-col time-series row", + """\ +bs = BoundStatement(ps) +bs.bind(row) +""", + setup, + ) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print(f"Python {sys.version}") + print(f"Iterations per benchmark: {ITERATIONS:,}") + + bench_datetype() + bench_varint() + bench_timestamp_generator() + bench_bind() + + print("\nDone.") diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 547a13c979..e3a5f3105a 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -32,6 +32,7 @@ from binascii import unhexlify import calendar from collections import namedtuple +import datetime as _datetime_mod from decimal import Decimal import io from itertools import chain @@ -43,35 +44,58 @@ import sys from uuid import UUID -from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack, - uint16_pack, uint16_unpack, uint32_pack, uint32_unpack, - int32_pack, int32_unpack, int64_pack, int64_unpack, - float_pack, float_unpack, double_pack, double_unpack, - varint_pack, varint_unpack, point_be, point_le, - vints_pack, vints_unpack, uvint_unpack, uvint_pack) +from cassandra.marshal import ( + int8_pack, + int8_unpack, + int16_pack, + int16_unpack, + uint16_pack, + uint16_unpack, + uint32_pack, + uint32_unpack, + int32_pack, + int32_unpack, + int64_pack, + int64_unpack, + float_pack, + float_unpack, + double_pack, + double_unpack, + varint_pack, + varint_unpack, + point_be, + point_le, + vints_pack, + vints_unpack, + uvint_unpack, + uvint_pack, +) from cassandra import util _little_endian_flag = 1 # we always serialize LE import ipaddress -apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' +apache_cassandra_type_prefix = "org.apache.cassandra.db.marshal." -cassandra_empty_type = 'org.apache.cassandra.db.marshal.EmptyType' -cql_empty_type = 'empty' +cassandra_empty_type = "org.apache.cassandra.db.marshal.EmptyType" +cql_empty_type = "empty" log = logging.getLogger(__name__) _number_types = frozenset((int, float)) +_EPOCH_NAIVE = _datetime_mod.datetime(1970, 1, 1) +_EPOCH_DATE = _datetime_mod.date(1970, 1, 1) + def _name_from_hex_string(encoded_name): bin_str = unhexlify(encoded_name) - return bin_str.decode('ascii') + return bin_str.decode("ascii") def trim_if_startswith(s, prefix): if s.startswith(prefix): - return s[len(prefix):] + return s[len(prefix) :] return s @@ -79,11 +103,13 @@ def trim_if_startswith(s, prefix): _cqltypes = {} -cql_type_scanner = re.Scanner(( - ('frozen', None), - (r'[a-zA-Z0-9_]+', lambda s, t: t), - (r'[\s,<>]', None), -)) +cql_type_scanner = re.Scanner( + ( + ("frozen", None), + (r"[a-zA-Z0-9_]+", lambda s, t: t), + (r"[\s,<>]", None), + ) +) def cql_types_from_string(cql_type): @@ -102,20 +128,22 @@ class CassandraTypeType(type): """ def __new__(metacls, name, bases, dct): - dct.setdefault('cassname', name) + dct.setdefault("cassname", name) cls = type.__new__(metacls, name, bases, dct) - if not name.startswith('_'): + if not name.startswith("_"): _casstypes[name] = cls if not cls.typename.startswith(apache_cassandra_type_prefix): _cqltypes[cls.typename] = cls return cls -casstype_scanner = re.Scanner(( - (r'[()]', lambda s, t: t), - (r'[a-zA-Z0-9_.:=>]+', lambda s, t: t), - (r'[\s,]', None), -)) +casstype_scanner = re.Scanner( + ( + (r"[()]", lambda s, t: t), + (r"[a-zA-Z0-9_.:=>]+", lambda s, t: t), + (r"[\s,]", None), + ) +) def cqltype_to_python(cql_string): @@ -125,16 +153,18 @@ def cqltype_to_python(cql_string): int -> ['int'] frozen> -> ['frozen', ['tuple', ['text', 'int']]] """ - scanner = re.Scanner(( - (r'[a-zA-Z0-9_]+', lambda s, t: "'{}'".format(t)), - (r'<', lambda s, t: ', ['), - (r'>', lambda s, t: ']'), - (r'[, ]', lambda s, t: t), - (r'".*?"', lambda s, t: "'{}'".format(t)), - )) + scanner = re.Scanner( + ( + (r"[a-zA-Z0-9_]+", lambda s, t: "'{}'".format(t)), + (r"<", lambda s, t: ", ["), + (r">", lambda s, t: "]"), + (r"[, ]", lambda s, t: t), + (r'".*?"', lambda s, t: "'{}'".format(t)), + ) + ) scanned_tokens = scanner.scan(cql_string)[0] - hierarchy = ast.literal_eval(''.join(scanned_tokens)) + hierarchy = ast.literal_eval("".join(scanned_tokens)) return [hierarchy] if isinstance(hierarchy, str) else list(hierarchy) @@ -145,18 +175,20 @@ def python_to_cqltype(types): ['int'] -> int ['frozen', ['tuple', ['text', 'int']]] -> frozen> """ - scanner = re.Scanner(( - (r"'[a-zA-Z0-9_]+'", lambda s, t: t[1:-1]), - (r'^\[', lambda s, t: None), - (r'\]$', lambda s, t: None), - (r',\s*\[', lambda s, t: '<'), - (r'\]', lambda s, t: '>'), - (r'[, ]', lambda s, t: t), - (r'\'".*?"\'', lambda s, t: t[1:-1]), - )) + scanner = re.Scanner( + ( + (r"'[a-zA-Z0-9_]+'", lambda s, t: t[1:-1]), + (r"^\[", lambda s, t: None), + (r"\]$", lambda s, t: None), + (r",\s*\[", lambda s, t: "<"), + (r"\]", lambda s, t: ">"), + (r"[, ]", lambda s, t: t), + (r'\'".*?"\'', lambda s, t: t[1:-1]), + ) + ) scanned_tokens = scanner.scan(repr(types))[0] - cql = ''.join(scanned_tokens).replace('\\\\', '\\') + cql = "".join(scanned_tokens).replace("\\\\", "\\") return cql @@ -166,10 +198,13 @@ def _strip_frozen_from_python(types): Example: ['frozen', ['tuple', ['text', 'int']]] -> ['tuple', ['text', 'int']] """ - while 'frozen' in types: - index = types.index('frozen') - types = types[:index] + types[index + 1] + types[index + 2:] - new_types = [_strip_frozen_from_python(item) if isinstance(item, list) else item for item in types] + while "frozen" in types: + index = types.index("frozen") + types = types[:index] + types[index + 1] + types[index + 2 :] + new_types = [ + _strip_frozen_from_python(item) if isinstance(item, list) else item + for item in types + ] return new_types @@ -211,15 +246,15 @@ def parse_casstype_args(typestring): # use a stack of (types, names) lists args = [([], [])] for tok in tokens: - if tok == '(': + if tok == "(": args.append(([], [])) - elif tok == ')': + elif tok == ")": types, names = args.pop() prev_types, prev_names = args[-1] prev_types[-1] = prev_types[-1].apply_parameters(types, names) else: types, names = args[-1] - parts = re.split(':|=>', tok) + parts = re.split(":|=>", tok) tok = parts.pop() if parts: names.append(parts[0]) @@ -235,6 +270,7 @@ def parse_casstype_args(typestring): # return the first (outer) type, which will have all parameters applied return args[0][0][0] + def lookup_casstype(casstype): """ Given a Cassandra type as a string (possibly including parameters), hand @@ -249,7 +285,7 @@ def lookup_casstype(casstype): """ if isinstance(casstype, (CassandraType, CassandraTypeType)): return casstype - if '(' not in casstype: + if "(" not in casstype: return lookup_casstype_simple(casstype) try: return parse_casstype_args(casstype) @@ -262,12 +298,14 @@ def is_reversed_casstype(data_type): class EmptyValue(object): - """ See _CassandraType.support_empty_values """ + """See _CassandraType.support_empty_values""" def __str__(self): return "EMPTY" + __repr__ = __str__ + EMPTY = EmptyValue() @@ -276,6 +314,11 @@ class _CassandraType(object, metaclass=CassandraTypeType): num_subtypes = 0 empty_binary_ok = False + # Cached result of cql_parameterized_type(). Computed lazily on first + # access and stored as a class attribute. Safe because type classes are + # immutable after creation via apply_parameters(). + _cql_type_str = None + support_empty_values = False """ Back in the Thrift days, empty strings were used for "null" values of @@ -290,7 +333,7 @@ class _CassandraType(object, metaclass=CassandraTypeType): """ def __repr__(self): - return '<%s>' % (self.cql_parameterized_type()) + return "<%s>" % (self.cql_parameterized_type()) @classmethod def from_binary(cls, byts, protocol_version): @@ -312,7 +355,7 @@ def to_binary(cls, val, protocol_version): more information. This method differs in that if None is passed in, the result is the empty string. """ - return b'' if val is None else cls.serialize(val, protocol_version) + return b"" if val is None else cls.serialize(val, protocol_version) @staticmethod def deserialize(byts, protocol_version): @@ -351,12 +394,14 @@ def cass_parameterized_type_with(cls, subtypes, full=False): 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)' """ cname = cls.cassname - if full and '.' not in cname: + if full and "." not in cname: cname = apache_cassandra_type_prefix + cname if not subtypes: return cname - sublist = ', '.join(styp.cass_parameterized_type(full=full) for styp in subtypes) - return '%s(%s)' % (cname, sublist) + sublist = ", ".join( + styp.cass_parameterized_type(full=full) for styp in subtypes + ) + return "%s(%s)" % (cname, sublist) @classmethod def apply_parameters(cls, subtypes, names=None): @@ -370,11 +415,17 @@ def apply_parameters(cls, subtypes, names=None): `subtypes` will be a sequence of CassandraTypes. If provided, `names` will be an equally long sequence of column names or Nones. """ - if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes: - raise ValueError("%s types require %d subtypes (%d given)" - % (cls.typename, cls.num_subtypes, len(subtypes))) + if cls.num_subtypes != "UNKNOWN" and len(subtypes) != cls.num_subtypes: + raise ValueError( + "%s types require %d subtypes (%d given)" + % (cls.typename, cls.num_subtypes, len(subtypes)) + ) newname = cls.cass_parameterized_type_with(subtypes) - return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names}) + return type( + newname, + (cls,), + {"subtypes": subtypes, "cassname": cls.cassname, "fieldnames": names}, + ) @classmethod def cql_parameterized_type(cls): @@ -382,9 +433,18 @@ def cql_parameterized_type(cls): Return a CQL type specifier for this type. If this type has parameters, they are included in standard CQL <> notation. """ + result = cls._cql_type_str + if result is not None: + return result if not cls.subtypes: - return cls.typename - return '%s<%s>' % (cls.typename, ', '.join(styp.cql_parameterized_type() for styp in cls.subtypes)) + result = cls.typename + else: + result = "%s<%s>" % ( + cls.typename, + ", ".join(styp.cql_parameterized_type() for styp in cls.subtypes), + ) + cls._cql_type_str = result + return result @classmethod def cass_parameterized_type(cls, full=False): @@ -398,23 +458,24 @@ def cass_parameterized_type(cls, full=False): def serial_size(cls): return None + # it's initially named with a _ to avoid registering it as a real type, but # client programs may want to use the name still for isinstance(), etc CassandraType = _CassandraType class _UnrecognizedType(_CassandraType): - num_subtypes = 'UNKNOWN' + num_subtypes = "UNKNOWN" def mkUnrecognizedType(casstypename): - return CassandraTypeType(casstypename, - (_UnrecognizedType,), - {'typename': "'%s'" % casstypename}) + return CassandraTypeType( + casstypename, (_UnrecognizedType,), {"typename": "'%s'" % casstypename} + ) class BytesType(_CassandraType): - typename = 'blob' + typename = "blob" empty_binary_ok = True @staticmethod @@ -423,13 +484,13 @@ def serialize(val, protocol_version): class DecimalType(_CassandraType): - typename = 'decimal' + typename = "decimal" @staticmethod def deserialize(byts, protocol_version): scale = int32_unpack(byts[:4]) unscaled = varint_unpack(byts[4:]) - return Decimal('%de%d' % (unscaled, -scale)) + return Decimal("%de%d" % (unscaled, -scale)) @staticmethod def serialize(dec, protocol_version): @@ -440,7 +501,7 @@ def serialize(dec, protocol_version): sign, digits, exponent = Decimal(dec).as_tuple() except Exception: raise TypeError("Invalid type for Decimal value: %r", dec) - unscaled = int(''.join([str(digit) for digit in digits])) + unscaled = int("".join([str(digit) for digit in digits])) if sign: unscaled *= -1 scale = int32_pack(-exponent) @@ -449,7 +510,7 @@ def serialize(dec, protocol_version): class UUIDType(_CassandraType): - typename = 'uuid' + typename = "uuid" @staticmethod def deserialize(byts, protocol_version): @@ -466,8 +527,9 @@ def serialize(uuid, protocol_version): def serial_size(cls): return 16 + class BooleanType(_CassandraType): - typename = 'boolean' + typename = "boolean" @staticmethod def deserialize(byts, protocol_version): @@ -481,8 +543,9 @@ def serialize(truth, protocol_version): def serial_size(cls): return 1 + class ByteType(_CassandraType): - typename = 'tinyint' + typename = "tinyint" @staticmethod def deserialize(byts, protocol_version): @@ -494,23 +557,23 @@ def serialize(byts, protocol_version): class AsciiType(_CassandraType): - typename = 'ascii' + typename = "ascii" empty_binary_ok = True @staticmethod def deserialize(byts, protocol_version): - return byts.decode('ascii') + return byts.decode("ascii") @staticmethod def serialize(var, protocol_version): try: - return var.encode('ascii') + return var.encode("ascii") except UnicodeDecodeError: return var class FloatType(_CassandraType): - typename = 'float' + typename = "float" @staticmethod def deserialize(byts, protocol_version): @@ -524,8 +587,9 @@ def serialize(byts, protocol_version): def serial_size(cls): return 4 + class DoubleType(_CassandraType): - typename = 'double' + typename = "double" @staticmethod def deserialize(byts, protocol_version): @@ -539,8 +603,9 @@ def serialize(byts, protocol_version): def serial_size(cls): return 8 + class LongType(_CassandraType): - typename = 'bigint' + typename = "bigint" @staticmethod def deserialize(byts, protocol_version): @@ -554,8 +619,9 @@ def serialize(byts, protocol_version): def serial_size(cls): return 8 + class Int32Type(_CassandraType): - typename = 'int' + typename = "int" @staticmethod def deserialize(byts, protocol_version): @@ -569,8 +635,9 @@ def serialize(byts, protocol_version): def serial_size(cls): return 4 + class IntegerType(_CassandraType): - typename = 'varint' + typename = "varint" @staticmethod def deserialize(byts, protocol_version): @@ -582,7 +649,7 @@ def serialize(byts, protocol_version): class InetAddressType(_CassandraType): - typename = 'inet' + typename = "inet" @staticmethod def deserialize(byts, protocol_version): @@ -596,7 +663,7 @@ def deserialize(byts, protocol_version): @staticmethod def serialize(addr, protocol_version): try: - if ':' in addr: + if ":" in addr: return util.inet_pton(socket.AF_INET6, addr) else: # util.inet_pton could also handle, but this is faster @@ -609,26 +676,27 @@ def serialize(addr, protocol_version): class CounterColumnType(LongType): - typename = 'counter' + typename = "counter" + cql_timestamp_formats = ( - '%Y-%m-%d %H:%M', - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%dT%H:%M', - '%Y-%m-%dT%H:%M:%S', - '%Y-%m-%d' + "%Y-%m-%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%dT%H:%M", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%d", ) _have_warned_about_timestamps = False class DateType(_CassandraType): - typename = 'timestamp' + typename = "timestamp" @staticmethod def interpret_datestring(val): - if val[-5] in ('+', '-'): - offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + '1') + if val[-5] in ("+", "-"): + offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + "1") val = val[:-5] else: offset = -time.timezone @@ -650,16 +718,25 @@ def deserialize(byts, protocol_version): @staticmethod def serialize(v, protocol_version): try: - # v is datetime - timestamp_seconds = calendar.timegm(v.utctimetuple()) - timestamp = timestamp_seconds * 1000 + getattr(v, 'microsecond', 0) // 1000 + # v is a datetime; use integer arithmetic instead of + # calendar.timegm(v.utctimetuple()) to avoid allocating + # an intermediate struct_time object on every call. + utcoffset = v.utcoffset() + if utcoffset is not None: + v = v - utcoffset + v = v.replace(tzinfo=None) + td = v - _EPOCH_NAIVE + timestamp = (td.days * 86400 + td.seconds) * 1000 + td.microseconds // 1000 except AttributeError: try: - timestamp = calendar.timegm(v.timetuple()) * 1000 - except AttributeError: + td = v - _EPOCH_DATE + timestamp = td.days * 86400000 + except (AttributeError, TypeError): # Ints and floats are valid timestamps too if type(v) not in _number_types: - raise TypeError('DateType arguments must be a datetime, date, or timestamp') + raise TypeError( + "DateType arguments must be a datetime, date, or timestamp" + ) timestamp = v return int64_pack(int(timestamp)) @@ -668,12 +745,13 @@ def serialize(v, protocol_version): def serial_size(cls): return 8 + class TimestampType(DateType): pass class TimeUUIDType(DateType): - typename = 'timeuuid' + typename = "timeuuid" def my_timestamp(self): return util.unix_time_from_uuid1(self.val) @@ -693,14 +771,15 @@ def serialize(timeuuid, protocol_version): def serial_size(cls): return 16 + class SimpleDateType(_CassandraType): - typename = 'date' + typename = "date" date_format = "%Y-%m-%d" # Values of the 'date'` type are encoded as 32-bit unsigned integers # representing a number of days with epoch (January 1st, 1970) at the center of the # range (2^31). - EPOCH_OFFSET_DAYS = 2 ** 31 + EPOCH_OFFSET_DAYS = 2**31 @staticmethod def deserialize(byts, protocol_version): @@ -722,7 +801,7 @@ def serialize(val, protocol_version): class ShortType(_CassandraType): - typename = 'smallint' + typename = "smallint" @staticmethod def deserialize(byts, protocol_version): @@ -732,13 +811,14 @@ def deserialize(byts, protocol_version): def serialize(byts, protocol_version): return int16_pack(byts) + class TimeType(_CassandraType): - typename = 'time' + typename = "time" # Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as # variable size... and we have to match what the server expects since the server # uses that specification to encode data of that type. - #@classmethod - #def serial_size(cls): + # @classmethod + # def serial_size(cls): # return 8 @staticmethod @@ -755,7 +835,7 @@ def serialize(val, protocol_version): class DurationType(_CassandraType): - typename = 'duration' + typename = "duration" @staticmethod def deserialize(byts, protocol_version): @@ -767,65 +847,67 @@ def serialize(duration, protocol_version): try: m, d, n = duration.months, duration.days, duration.nanoseconds except AttributeError: - raise TypeError('DurationType arguments must be a Duration.') + raise TypeError("DurationType arguments must be a Duration.") return vints_pack([m, d, n]) class UTF8Type(_CassandraType): - typename = 'text' + typename = "text" empty_binary_ok = True @staticmethod def deserialize(byts, protocol_version): - return byts.decode('utf8') + return byts.decode("utf8") @staticmethod def serialize(ustr, protocol_version): try: - return ustr.encode('utf-8') + return ustr.encode("utf-8") except UnicodeDecodeError: # already utf-8 return ustr class VarcharType(UTF8Type): - typename = 'varchar' + typename = "varchar" class _ParameterizedType(_CassandraType): - num_subtypes = 'UNKNOWN' + num_subtypes = "UNKNOWN" @classmethod def deserialize(cls, byts, protocol_version): if not cls.subtypes: - raise NotImplementedError("can't deserialize unparameterized %s" - % cls.typename) + raise NotImplementedError( + "can't deserialize unparameterized %s" % cls.typename + ) return cls.deserialize_safe(byts, protocol_version) @classmethod def serialize(cls, val, protocol_version): if not cls.subtypes: - raise NotImplementedError("can't serialize unparameterized %s" - % cls.typename) + raise NotImplementedError( + "can't serialize unparameterized %s" % cls.typename + ) return cls.serialize_safe(val, protocol_version) class _SimpleParameterizedType(_ParameterizedType): @classmethod def deserialize_safe(cls, byts, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes length = 4 numelements = int32_unpack(byts[:length]) p = length result = [] inner_proto = max(3, protocol_version) for _ in range(numelements): - itemlen = int32_unpack(byts[p:p + length]) + itemlen = int32_unpack(byts[p : p + length]) p += length if itemlen < 0: result.append(None) else: - item = byts[p:p + itemlen] + item = byts[p : p + itemlen] p += itemlen result.append(subtype.from_binary(item, inner_proto)) return cls.adapter(result) @@ -835,7 +917,7 @@ def serialize_safe(cls, items, protocol_version): if isinstance(items, str): raise TypeError("Received a string for a type that expects a sequence") - subtype, = cls.subtypes + (subtype,) = cls.subtypes buf = io.BytesIO() buf.write(int32_pack(len(items))) inner_proto = max(3, protocol_version) @@ -850,19 +932,19 @@ def serialize_safe(cls, items, protocol_version): class ListType(_SimpleParameterizedType): - typename = 'list' + typename = "list" num_subtypes = 1 adapter = list class SetType(_SimpleParameterizedType): - typename = 'set' + typename = "set" num_subtypes = 1 adapter = util.sortedset class MapType(_ParameterizedType): - typename = 'map' + typename = "map" num_subtypes = 2 @classmethod @@ -874,22 +956,22 @@ def deserialize_safe(cls, byts, protocol_version): themap = util.OrderedMapSerializedKey(key_type, protocol_version) inner_proto = max(3, protocol_version) for _ in range(numelements): - key_len = int32_unpack(byts[p:p + length]) + key_len = int32_unpack(byts[p : p + length]) p += length if key_len < 0: keybytes = None key = None else: - keybytes = byts[p:p + key_len] + keybytes = byts[p : p + key_len] p += key_len key = key_type.from_binary(keybytes, inner_proto) - val_len = int32_unpack(byts[p:p + length]) + val_len = int32_unpack(byts[p : p + length]) p += length if val_len < 0: val = None else: - valbytes = byts[p:p + val_len] + valbytes = byts[p : p + val_len] p += val_len val = value_type.from_binary(valbytes, inner_proto) @@ -923,7 +1005,7 @@ def serialize_safe(cls, themap, protocol_version): class TupleType(_ParameterizedType): - typename = 'tuple' + typename = "tuple" @classmethod def deserialize_safe(cls, byts, protocol_version): @@ -933,10 +1015,10 @@ def deserialize_safe(cls, byts, protocol_version): for col_type in cls.subtypes: if p == len(byts): break - itemlen = int32_unpack(byts[p:p + 4]) + itemlen = int32_unpack(byts[p : p + 4]) p += 4 if itemlen >= 0: - item = byts[p:p + itemlen] + item = byts[p : p + itemlen] p += itemlen else: item = None @@ -953,8 +1035,10 @@ def deserialize_safe(cls, byts, protocol_version): @classmethod def serialize_safe(cls, val, protocol_version): if len(val) > len(cls.subtypes): - raise ValueError("Expected %d items in a tuple, but got %d: %s" % - (len(cls.subtypes), len(val), val)) + raise ValueError( + "Expected %d items in a tuple, but got %d: %s" + % (len(cls.subtypes), len(val), val) + ) proto_version = max(3, protocol_version) buf = io.BytesIO() @@ -969,8 +1053,15 @@ def serialize_safe(cls, val, protocol_version): @classmethod def cql_parameterized_type(cls): - subtypes_string = ', '.join(sub.cql_parameterized_type() for sub in cls.subtypes) - return 'frozen>' % (subtypes_string,) + result = cls._cql_type_str + if result is not None: + return result + subtypes_string = ", ".join( + sub.cql_parameterized_type() for sub in cls.subtypes + ) + result = "frozen>" % (subtypes_string,) + cls._cql_type_str = result + return result class UserType(TupleType): @@ -984,14 +1075,26 @@ def make_udt_class(cls, keyspace, udt_name, field_names, field_types): assert len(field_names) == len(field_types) instance = cls._cache.get((keyspace, udt_name)) - if not instance or instance.fieldnames != field_names or instance.subtypes != field_types: - instance = type(udt_name, (cls,), {'subtypes': field_types, - 'cassname': cls.cassname, - 'typename': udt_name, - 'fieldnames': field_names, - 'keyspace': keyspace, - 'mapped_class': None, - 'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)}) + if ( + not instance + or instance.fieldnames != field_names + or instance.subtypes != field_types + ): + instance = type( + udt_name, + (cls,), + { + "subtypes": field_types, + "cassname": cls.cassname, + "typename": udt_name, + "fieldnames": field_names, + "keyspace": keyspace, + "mapped_class": None, + "tuple_type": cls._make_registered_udt_namedtuple( + keyspace, udt_name, field_names + ), + }, + ) cls._cache[(keyspace, udt_name)] = instance return instance @@ -1004,14 +1107,23 @@ def evict_udt_class(cls, keyspace, udt_name): @classmethod def apply_parameters(cls, subtypes, names): - keyspace = subtypes[0].cass_parameterized_type() # when parsed from cassandra type, the keyspace is created as an unrecognized cass type; This gets the name back + keyspace = subtypes[ + 0 + ].cass_parameterized_type() # when parsed from cassandra type, the keyspace is created as an unrecognized cass type; This gets the name back udt_name = _name_from_hex_string(subtypes[1].cassname) - field_names = tuple(_name_from_hex_string(encoded_name) for encoded_name in names[2:]) # using tuple here to match what comes into make_udt_class from other sources (for caching equality test) + field_names = tuple( + _name_from_hex_string(encoded_name) for encoded_name in names[2:] + ) # using tuple here to match what comes into make_udt_class from other sources (for caching equality test) return cls.make_udt_class(keyspace, udt_name, field_names, tuple(subtypes[2:])) @classmethod def cql_parameterized_type(cls): - return "frozen<%s>" % (cls.typename,) + result = cls._cql_type_str + if result is not None: + return result + result = "frozen<%s>" % (cls.typename,) + cls._cql_type_str = result + return result @classmethod def deserialize_safe(cls, byts, protocol_version): @@ -1034,7 +1146,9 @@ def serialize_safe(cls, val, protocol_version): except TypeError: item = getattr(val, fieldname, None) if item is None and not hasattr(val, fieldname): - log.warning(f"field {fieldname} is part of the UDT {cls.typename} but is not present in the value {val}") + log.warning( + f"field {fieldname} is part of the UDT {cls.typename} but is not present in the value {val}" + ) if item is not None: packed_item = subtype.to_binary(item, proto_version) @@ -1063,15 +1177,21 @@ def _make_udt_tuple_type(cls, name, field_names): t = namedtuple(name, field_names) except ValueError: try: - t = namedtuple(name, util._positional_rename_invalid_identifiers(field_names)) - log.warning("could not create a namedtuple for '%s' because one or more " - "field names are not valid Python identifiers (%s); " - "returning positionally-named fields" % (name, field_names)) + t = namedtuple( + name, util._positional_rename_invalid_identifiers(field_names) + ) + log.warning( + "could not create a namedtuple for '%s' because one or more " + "field names are not valid Python identifiers (%s); " + "returning positionally-named fields" % (name, field_names) + ) except ValueError: t = None - log.warning("could not create a namedtuple for '%s' because the name is " - "not a valid Python identifier; will return tuples in " - "its place" % (name,)) + log.warning( + "could not create a namedtuple for '%s' because the name is " + "not a valid Python identifier; will return tuples in " + "its place" % (name,) + ) return t @@ -1083,8 +1203,13 @@ def cql_parameterized_type(cls): """ There is no CQL notation for Composites, so we override this. """ + result = cls._cql_type_str + if result is not None: + return result typestring = cls.cass_parameterized_type(full=True) - return "'%s'" % (typestring,) + result = "'%s'" % (typestring,) + cls._cql_type_str = result + return result @classmethod def deserialize_safe(cls, byts, protocol_version): @@ -1095,10 +1220,10 @@ def deserialize_safe(cls, byts, protocol_version): break element_length = uint16_unpack(byts[:2]) - element = byts[2:2 + element_length] + element = byts[2 : 2 + element_length] # skip element length, element, and the EOC (one byte) - byts = byts[2 + element_length + 1:] + byts = byts[2 + element_length + 1 :] result.append(subtype.from_binary(element, protocol_version)) return tuple(result) @@ -1109,8 +1234,16 @@ class DynamicCompositeType(_ParameterizedType): @classmethod def cql_parameterized_type(cls): - sublist = ', '.join('%s=>%s' % (alias, typ.cass_parameterized_type(full=True)) for alias, typ in zip(cls.fieldnames, cls.subtypes)) - return "'%s(%s)'" % (cls.typename, sublist) + result = cls._cql_type_str + if result is not None: + return result + sublist = ", ".join( + "%s=>%s" % (alias, typ.cass_parameterized_type(full=True)) + for alias, typ in zip(cls.fieldnames, cls.subtypes) + ) + result = "'%s(%s)'" % (cls.typename, sublist) + cls._cql_type_str = result + return result class ColumnToCollectionType(_ParameterizedType): @@ -1119,6 +1252,7 @@ class ColumnToCollectionType(_ParameterizedType): Cassandra includes this. We don't actually need or want the extra information. """ + typename = "org.apache.cassandra.db.marshal.ColumnToCollectionType" @@ -1128,12 +1262,12 @@ class ReversedType(_ParameterizedType): @classmethod def deserialize_safe(cls, byts, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes return subtype.from_binary(byts, protocol_version) @classmethod def serialize_safe(cls, val, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes return subtype.to_binary(val, protocol_version) @@ -1143,12 +1277,12 @@ class FrozenType(_ParameterizedType): @classmethod def deserialize_safe(cls, byts, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes return subtype.from_binary(byts, protocol_version) @classmethod def serialize_safe(cls, val, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes return subtype.to_binary(val, protocol_version) @@ -1179,9 +1313,9 @@ class WKBGeometryType(object): class PointType(CassandraType): - typename = 'PointType' + typename = "PointType" - _type = struct.pack('[[]] type_ = int8_unpack(byts[0:1]) - if type_ in (BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE), - BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN)): + if type_ in ( + BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE), + BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN), + ): time0 = precision0 = None else: time0 = int64_unpack(byts[1:9]) @@ -1350,32 +1517,34 @@ def deserialize(cls, byts, protocol_version): if time0 is not None: date_range_bound0 = util.DateRangeBound( - time0, - cls._decode_precision(precision0) + time0, cls._decode_precision(precision0) ) if time1 is not None: date_range_bound1 = util.DateRangeBound( - time1, - cls._decode_precision(precision1) + time1, cls._decode_precision(precision1) ) if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE): return util.DateRange(value=date_range_bound0) if type_ == BoundKind.to_int(BoundKind.CLOSED_RANGE): - return util.DateRange(lower_bound=date_range_bound0, - upper_bound=date_range_bound1) + return util.DateRange( + lower_bound=date_range_bound0, upper_bound=date_range_bound1 + ) if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_HIGH): - return util.DateRange(lower_bound=date_range_bound0, - upper_bound=util.OPEN_BOUND) + return util.DateRange( + lower_bound=date_range_bound0, upper_bound=util.OPEN_BOUND + ) if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_LOW): - return util.DateRange(lower_bound=util.OPEN_BOUND, - upper_bound=date_range_bound0) + return util.DateRange( + lower_bound=util.OPEN_BOUND, upper_bound=date_range_bound0 + ) if type_ == BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE): - return util.DateRange(lower_bound=util.OPEN_BOUND, - upper_bound=util.OPEN_BOUND) + return util.DateRange( + lower_bound=util.OPEN_BOUND, upper_bound=util.OPEN_BOUND + ) if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN): return util.DateRange(value=util.OPEN_BOUND) - raise ValueError('Could not deserialize %r' % (byts,)) + raise ValueError("Could not deserialize %r" % (byts,)) @classmethod def serialize(cls, v, protocol_version): @@ -1386,8 +1555,8 @@ def serialize(cls, v, protocol_version): value = v.value except AttributeError: raise ValueError( - '%s.serialize expects an object with a value attribute; got' - '%r' % (cls.__name__, v) + "%s.serialize expects an object with a value attribute; got" + "%r" % (cls.__name__, v) ) if value is None: @@ -1395,8 +1564,8 @@ def serialize(cls, v, protocol_version): lower_bound, upper_bound = v.lower_bound, v.upper_bound except AttributeError: raise ValueError( - '%s.serialize expects an object with lower_bound and ' - 'upper_bound attributes; got %r' % (cls.__name__, v) + "%s.serialize expects an object with lower_bound and " + "upper_bound attributes; got %r" % (cls.__name__, v) ) if lower_bound == util.OPEN_BOUND and upper_bound == util.OPEN_BOUND: bound_kind = BoundKind.BOTH_OPEN_RANGE @@ -1417,9 +1586,7 @@ def serialize(cls, v, protocol_version): bounds = (value,) if bound_kind is None: - raise ValueError( - 'Cannot serialize %r; could not find bound kind' % (v,) - ) + raise ValueError("Cannot serialize %r; could not find bound kind" % (v,)) buf.write(int8_pack(BoundKind.to_int(bound_kind))) for bound in bounds: @@ -1428,22 +1595,29 @@ def serialize(cls, v, protocol_version): return buf.getvalue() + class VectorType(_CassandraType): - typename = 'org.apache.cassandra.db.marshal.VectorType' + typename = "org.apache.cassandra.db.marshal.VectorType" vector_size = 0 subtype = None @classmethod def serial_size(cls): serialized_size = cls.subtype.serial_size() - return cls.vector_size * serialized_size if serialized_size is not None else None + return ( + cls.vector_size * serialized_size if serialized_size is not None else None + ) @classmethod def apply_parameters(cls, params, names): assert len(params) == 2 subtype = lookup_casstype(params[0]) vsize = params[1] - return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype}) + return type( + "%s(%s)" % (cls.cass_parameterized_type_with([]), vsize), + (cls,), + {"vector_size": vsize, "subtype": subtype}, + ) @classmethod def deserialize(cls, byts, protocol_version): @@ -1452,26 +1626,43 @@ def deserialize(cls, byts, protocol_version): expected_byte_size = serialized_size * cls.vector_size if len(byts) != expected_byte_size: raise ValueError( - "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\ - .format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts))) + "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead".format( + cls.subtype.typename, + cls.vector_size, + expected_byte_size, + len(byts), + ) + ) indexes = (serialized_size * x for x in range(0, cls.vector_size)) - return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes] + return [ + cls.subtype.deserialize( + byts[idx : idx + serialized_size], protocol_version + ) + for idx in indexes + ] idx = 0 rv = [] - while (len(rv) < cls.vector_size): + while len(rv) < cls.vector_size: try: size, bytes_read = uvint_unpack(byts[idx:]) idx += bytes_read - rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version)) + rv.append( + cls.subtype.deserialize(byts[idx : idx + size], protocol_version) + ) idx += size except: - raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\ - .format(len(rv))) + raise ValueError( + "Error reading additional data during vector deserialization after successfully adding {} elements".format( + len(rv) + ) + ) # If we have any additional data in the serialized vector treat that as an error as well if idx < len(byts): - raise ValueError("Additional bytes remaining after vector deserialization completed") + raise ValueError( + "Additional bytes remaining after vector deserialization completed" + ) return rv @classmethod @@ -1479,8 +1670,10 @@ def serialize(cls, v, protocol_version): v_length = len(v) if cls.vector_size != v_length: raise ValueError( - "Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}"\ - .format(cls.vector_size, cls.subtype.typename, v_length)) + "Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}".format( + cls.vector_size, cls.subtype.typename, v_length + ) + ) serialized_size = cls.subtype.serial_size() buf = io.BytesIO() @@ -1493,4 +1686,13 @@ def serialize(cls, v, protocol_version): @classmethod def cql_parameterized_type(cls): - return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size) + result = cls._cql_type_str + if result is not None: + return result + result = "%s<%s, %s>" % ( + cls.typename, + cls.subtype.cql_parameterized_type(), + cls.vector_size, + ) + cls._cql_type_str = result + return result diff --git a/cassandra/cython_marshal.pyx b/cassandra/cython_marshal.pyx index 0a926b6eef..2d8d22bd8c 100644 --- a/cassandra/cython_marshal.pyx +++ b/cassandra/cython_marshal.pyx @@ -55,16 +55,5 @@ cdef varint_unpack(Buffer *term): """Unpack a variable-sized integer""" return varint_unpack_py3(to_bytes(term)) -# TODO: Optimize these two functions cdef varint_unpack_py3(bytes term): - val = int(''.join(["%02x" % i for i in term]), 16) - if (term[0] & 128) != 0: - shift = len(term) * 8 # * Note below - val -= 1 << shift - return val - -# * Note * -# '1 << (len(term) * 8)' Cython tries to do native -# integer shifts, which overflows. We need this to -# emulate Python shifting, which will expand the long -# to accommodate + return int.from_bytes(term, byteorder='big', signed=True) diff --git a/cassandra/encoder.py b/cassandra/encoder.py index d803c087ba..a836dd899f 100644 --- a/cassandra/encoder.py +++ b/cassandra/encoder.py @@ -18,6 +18,7 @@ """ import logging + log = logging.getLogger(__name__) from binascii import hexlify @@ -26,12 +27,23 @@ import datetime import math import sys + +_EPOCH_NAIVE = datetime.datetime(1970, 1, 1) import types from uuid import UUID import ipaddress -from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey, - sortedset, Time, Date, Point, LineString, Polygon) +from cassandra.util import ( + OrderedDict, + OrderedMap, + OrderedMapSerializedKey, + sortedset, + Time, + Date, + Point, + LineString, + Polygon, +) def cql_quote(term): @@ -83,28 +95,30 @@ def __init__(self): ValueSequence: self.cql_encode_sequence, Point: self.cql_encode_str_quoted, LineString: self.cql_encode_str_quoted, - Polygon: self.cql_encode_str_quoted + Polygon: self.cql_encode_str_quoted, } - self.mapping.update({ - memoryview: self.cql_encode_bytes, - bytes: self.cql_encode_bytes, - type(None): self.cql_encode_none, - ipaddress.IPv4Address: self.cql_encode_ipaddress, - ipaddress.IPv6Address: self.cql_encode_ipaddress - }) + self.mapping.update( + { + memoryview: self.cql_encode_bytes, + bytes: self.cql_encode_bytes, + type(None): self.cql_encode_none, + ipaddress.IPv4Address: self.cql_encode_ipaddress, + ipaddress.IPv6Address: self.cql_encode_ipaddress, + } + ) def cql_encode_none(self, val): """ Converts :const:`None` to the string 'NULL'. """ - return 'NULL' + return "NULL" def cql_encode_unicode(self, val): """ Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping. """ - return cql_quote(val.encode('utf-8')) + return cql_quote(val.encode("utf-8")) def cql_encode_str(self, val): """ @@ -116,7 +130,7 @@ def cql_encode_str_quoted(self, val): return "'%s'" % val def cql_encode_bytes(self, val): - return (b'0x' + hexlify(val)).decode('utf-8') + return (b"0x" + hexlify(val)).decode("utf-8") def cql_encode_object(self, val): """ @@ -130,9 +144,9 @@ def cql_encode_float(self, val): Encode floats using repr to preserve precision """ if math.isinf(val): - return 'Infinity' if val > 0 else '-Infinity' + return "Infinity" if val > 0 else "-Infinity" elif math.isnan(val): - return 'NaN' + return "NaN" else: return repr(val) @@ -141,15 +155,19 @@ def cql_encode_datetime(self, val): Converts a :class:`datetime.datetime` object to a (string) integer timestamp with millisecond precision. """ - timestamp = calendar.timegm(val.utctimetuple()) - return str(timestamp * 1000 + getattr(val, 'microsecond', 0) // 1000) + utcoffset = val.utcoffset() + if utcoffset is not None: + val = val - utcoffset + val = val.replace(tzinfo=None) + td = val - _EPOCH_NAIVE + return str((td.days * 86400 + td.seconds) * 1000 + td.microseconds // 1000) def cql_encode_date(self, val): """ Converts a :class:`datetime.date` object to a string with format ``YYYY-MM-DD``. """ - return "'%s'" % val.strftime('%Y-%m-%d') + return "'%s'" % val.strftime("%Y-%m-%d") def cql_encode_time(self, val): """ @@ -163,15 +181,16 @@ def cql_encode_date_ext(self, val): Encodes a :class:`cassandra.util.Date` object as an integer """ # using the int form in case the Date exceeds datetime.[MIN|MAX]YEAR - return str(val.days_from_epoch + 2 ** 31) + return str(val.days_from_epoch + 2**31) def cql_encode_sequence(self, val): """ Converts a sequence to a string of the form ``(item1, item2, ...)``. This is suitable for ``IN`` value lists. """ - return '(%s)' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) - for v in val) + return "(%s)" % ", ".join( + self.mapping.get(type(v), self.cql_encode_object)(v) for v in val + ) cql_encode_tuple = cql_encode_sequence """ @@ -184,24 +203,32 @@ def cql_encode_map_collection(self, val): Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``. This is suitable for ``map`` type columns. """ - return '{%s}' % ', '.join('%s: %s' % ( - self.mapping.get(type(k), self.cql_encode_object)(k), - self.mapping.get(type(v), self.cql_encode_object)(v) - ) for k, v in val.items()) + return "{%s}" % ", ".join( + "%s: %s" + % ( + self.mapping.get(type(k), self.cql_encode_object)(k), + self.mapping.get(type(v), self.cql_encode_object)(v), + ) + for k, v in val.items() + ) def cql_encode_list_collection(self, val): """ Converts a sequence to a string of the form ``[item1, item2, ...]``. This is suitable for ``list`` type columns. """ - return '[%s]' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) + return "[%s]" % ", ".join( + self.mapping.get(type(v), self.cql_encode_object)(v) for v in val + ) def cql_encode_set_collection(self, val): """ Converts a sequence to a string of the form ``{item1, item2, ...}``. This is suitable for ``set`` type columns. """ - return '{%s}' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) + return "{%s}" % ", ".join( + self.mapping.get(type(v), self.cql_encode_object)(v) for v in val + ) def cql_encode_all_types(self, val, as_text_type=False): """ @@ -210,7 +237,7 @@ def cql_encode_all_types(self, val, as_text_type=False): """ encoded = self.mapping.get(type(val), self.cql_encode_object)(val) if as_text_type and not isinstance(encoded, str): - return encoded.decode('utf-8') + return encoded.decode("utf-8") return encoded def cql_encode_ipaddress(self, val): @@ -221,4 +248,4 @@ def cql_encode_ipaddress(self, val): return "'%s'" % val.compressed def cql_encode_decimal(self, val): - return self.cql_encode_float(float(val)) \ No newline at end of file + return self.cql_encode_float(float(val)) diff --git a/cassandra/marshal.py b/cassandra/marshal.py index 413e1831d4..ec1409c632 100644 --- a/cassandra/marshal.py +++ b/cassandra/marshal.py @@ -21,30 +21,27 @@ def _make_packer(format_string): unpack = lambda s: packer.unpack(s)[0] return pack, unpack -int64_pack, int64_unpack = _make_packer('>q') -int32_pack, int32_unpack = _make_packer('>i') -int16_pack, int16_unpack = _make_packer('>h') -int8_pack, int8_unpack = _make_packer('>b') -uint64_pack, uint64_unpack = _make_packer('>Q') -uint32_pack, uint32_unpack = _make_packer('>I') -uint32_le_pack, uint32_le_unpack = _make_packer('H') -uint8_pack, uint8_unpack = _make_packer('>B') -float_pack, float_unpack = _make_packer('>f') -double_pack, double_unpack = _make_packer('>d') + +int64_pack, int64_unpack = _make_packer(">q") +int32_pack, int32_unpack = _make_packer(">i") +int16_pack, int16_unpack = _make_packer(">h") +int8_pack, int8_unpack = _make_packer(">b") +uint64_pack, uint64_unpack = _make_packer(">Q") +uint32_pack, uint32_unpack = _make_packer(">I") +uint32_le_pack, uint32_le_unpack = _make_packer("H") +uint8_pack, uint8_unpack = _make_packer(">B") +float_pack, float_unpack = _make_packer(">f") +double_pack, double_unpack = _make_packer(">d") # in protocol version 3 and higher, the stream ID is two bytes -v3_header_struct = struct.Struct('>BBhB') +v3_header_struct = struct.Struct(">BBhB") v3_header_pack = v3_header_struct.pack v3_header_unpack = v3_header_struct.unpack def varint_unpack(term): - val = int(''.join("%02x" % i for i in term), 16) - if (term[0] & 128) != 0: - len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code - val -= 1 << (len_term * 8) - return val + return int.from_bytes(term, byteorder="big", signed=True) def bit_length(n): @@ -52,28 +49,20 @@ def bit_length(n): def varint_pack(big): - pos = True if big == 0: - return b'\x00' + return b"\x00" if big < 0: - bytelength = bit_length(abs(big) - 1) // 8 + 1 - big = (1 << bytelength * 8) + big - pos = False - revbytes = bytearray() - while big > 0: - revbytes.append(big & 0xff) - big >>= 8 - if pos and revbytes[-1] & 0x80: - revbytes.append(0) - revbytes.reverse() - return bytes(revbytes) + byte_length = (-big - 1).bit_length() // 8 + 1 + else: + byte_length = (big.bit_length() + 8) // 8 + return big.to_bytes(byte_length, byteorder="big", signed=True) -point_be = struct.Struct('>dd') -point_le = struct.Struct('dd") +point_le = struct.Struct("ddd') -circle_le = struct.Struct('ddd") +circle_le = struct.Struct("> num_extra_bytes) + num_extra_bytes = 8 - (~first_byte & 0xFF).bit_length() + val = first_byte & (0xFF >> num_extra_bytes) end = n + num_extra_bytes while n < end: n += 1 val <<= 8 - val |= term[n] & 0xff + val |= term[n] & 0xFF n += 1 values.append(decode_zig_zag(val)) return tuple(values) + def vints_pack(values): revbytes = bytearray() values = [int(v) for v in values[::-1]] @@ -120,39 +110,43 @@ def vints_pack(values): # ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved # ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved reserved_bits = num_extra_bytes + 1 - while num_bits > (8-(reserved_bits)): + while num_bits > (8 - (reserved_bits)): num_extra_bytes += 1 num_bits -= 8 reserved_bits = min(num_extra_bytes + 1, 8) - revbytes.append(v & 0xff) + revbytes.append(v & 0xFF) v >>= 8 if num_extra_bytes > 8: - raise ValueError('Value %d is too big and cannot be encoded as vint' % value) + raise ValueError( + "Value %d is too big and cannot be encoded as vint" % value + ) # We can now store the last bits in the first byte n = 8 - num_extra_bytes - v |= (0xff >> n << n) + v |= 0xFF >> n << n revbytes.append(abs(v)) revbytes.reverse() return bytes(revbytes) + def uvint_unpack(bytes): first_byte = bytes[0] if (first_byte & 128) == 0: - return (first_byte,1) + return (first_byte, 1) - num_extra_bytes = 8 - (~first_byte & 0xff).bit_length() - rv = first_byte & (0xff >> num_extra_bytes) - for idx in range(1,num_extra_bytes + 1): + num_extra_bytes = 8 - (~first_byte & 0xFF).bit_length() + rv = first_byte & (0xFF >> num_extra_bytes) + for idx in range(1, num_extra_bytes + 1): new_byte = bytes[idx] rv <<= 8 - rv |= new_byte & 0xff + rv |= new_byte & 0xFF return (rv, num_extra_bytes + 1) + def uvint_pack(val): rv = bytearray() if val < 128: @@ -165,19 +159,19 @@ def uvint_pack(val): # ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved # ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved reserved_bits = num_extra_bytes + 1 - while num_bits > (8-(reserved_bits)): + while num_bits > (8 - (reserved_bits)): num_extra_bytes += 1 num_bits -= 8 reserved_bits = min(num_extra_bytes + 1, 8) - rv.append(v & 0xff) + rv.append(v & 0xFF) v >>= 8 if num_extra_bytes > 8: - raise ValueError('Value %d is too big and cannot be encoded as vint' % val) + raise ValueError("Value %d is too big and cannot be encoded as vint" % val) # We can now store the last bits in the first byte n = 8 - num_extra_bytes - v |= (0xff >> n << n) + v |= 0xFF >> n << n rv.append(abs(v)) rv.reverse() diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..8e5f886eb6 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -636,28 +636,48 @@ def bind(self, values): self.raw_values = values self.values = [] - for value, col_spec in zip(values, col_meta): - if value is None: - self.values.append(None) - elif value is UNSET_VALUE: - if proto_version >= 4: - self._append_unset_value() + if ce_policy: + # Column encryption enabled — need ColDesc per column + for value, col_spec in zip(values, col_meta): + if value is None: + self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) - else: - try: - col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) - uses_ce = ce_policy and ce_policy.contains_column(col_desc) - col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type - col_bytes = col_type.serialize(value, proto_version) - if uses_ce: - col_bytes = ce_policy.encrypt(col_desc, col_bytes) - self.values.append(col_bytes) - except (TypeError, struct.error) as exc: - actual_type = type(value) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) - raise TypeError(message) + try: + col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + uses_ce = ce_policy.contains_column(col_desc) + col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type + col_bytes = col_type.serialize(value, proto_version) + if uses_ce: + col_bytes = ce_policy.encrypt(col_desc, col_bytes) + self.values.append(col_bytes) + except (TypeError, struct.error) as exc: + actual_type = type(value) + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + raise TypeError(message) + else: + # Fast path — no column encryption (common case) + for value, col_spec in zip(values, col_meta): + if value is None: + self.values.append(None) + elif value is UNSET_VALUE: + if proto_version >= 4: + self._append_unset_value() + else: + raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + else: + try: + self.values.append(col_spec.type.serialize(value, proto_version)) + except (TypeError, struct.error) as exc: + actual_type = type(value) + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + raise TypeError(message) if proto_version >= 4: diff = col_meta_len - len(self.values) diff --git a/cassandra/serializers.pxd b/cassandra/serializers.pxd new file mode 100644 index 0000000000..ad34f1d162 --- /dev/null +++ b/cassandra/serializers.pxd @@ -0,0 +1,20 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +cdef class Serializer: + # The cqltypes._CassandraType corresponding to this serializer + cdef object cqltype + + cpdef bytes serialize(self, object value, int protocol_version) diff --git a/cassandra/serializers.pyx b/cassandra/serializers.pyx new file mode 100644 index 0000000000..08d14f0046 --- /dev/null +++ b/cassandra/serializers.pyx @@ -0,0 +1,456 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cython-optimized serializers for CQL types. + +Mirrors the architecture of deserializers.pyx. Currently implements +optimized serialization for: +- FloatType (4-byte big-endian float) +- DoubleType (8-byte big-endian double) +- Int32Type (4-byte big-endian signed int) +- DateType (8-byte big-endian int64 ms timestamp) +- VectorType (type-specialized for float/double/int32, generic fallback) + +For all other types, GenericSerializer delegates to the Python-level +cqltype.serialize() classmethod. +""" + +from libc.stdint cimport int32_t, int64_t +from libc.string cimport memcpy +from libc.float cimport FLT_MAX +from libc.math cimport isinf, isnan +from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING + +from cassandra import cqltypes +import datetime as _datetime_mod +import io +from cassandra.marshal import uvint_pack + +cdef bint is_little_endian +from cassandra.util import is_little_endian + + +# --------------------------------------------------------------------------- +# Base class +# --------------------------------------------------------------------------- + +cdef class Serializer: + """Cython-based serializer class for a cqltype""" + + def __init__(self, cqltype): + self.cqltype = cqltype + + cpdef bytes serialize(self, object value, int protocol_version): + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Float range check +# --------------------------------------------------------------------------- + +cdef inline void _check_float_range(double value) except *: + """Raise OverflowError for finite values outside float32 range. + + Matches the behaviour of struct.pack('>f', value), which raises + OverflowError for values that cannot be represented as a 32-bit + IEEE 754 float. inf, -inf, and nan pass through unchanged. + """ + if not isinf(value) and not isnan(value): + if value > FLT_MAX or value < -FLT_MAX: + raise OverflowError( + "Value %r too large for float32 (max %r)" % (value, FLT_MAX) + ) + + +# --------------------------------------------------------------------------- +# Int32 range check +# --------------------------------------------------------------------------- + +cdef inline void _check_int32_range(object value) except *: + """Raise OverflowError for values outside the signed int32 range. + + Mirrors ``_check_float_range``: we intentionally raise OverflowError + (not struct.error) so callers only need to catch one exception type + for out-of-range values. The check must be done on the Python int + *before* the C-level cast, which would silently truncate. + """ + if value > 2147483647 or value < -2147483648: + raise OverflowError( + "'i' format requires -2147483648 <= number <= 2147483647" + ) + + +# --------------------------------------------------------------------------- +# Scalar serializers +# --------------------------------------------------------------------------- + +cdef class SerFloatType(Serializer): + """Serialize a Python float to 4-byte big-endian IEEE 754.""" + + cpdef bytes serialize(self, object value, int protocol_version): + _check_float_range(value) + cdef float val = value + cdef char out[4] + cdef char *src = &val + + if is_little_endian: + out[0] = src[3] + out[1] = src[2] + out[2] = src[1] + out[3] = src[0] + else: + memcpy(out, src, 4) + + return PyBytes_FromStringAndSize(out, 4) + + +cdef class SerDoubleType(Serializer): + """Serialize a Python float to 8-byte big-endian IEEE 754.""" + + cpdef bytes serialize(self, object value, int protocol_version): + cdef double val = value + cdef char out[8] + cdef char *src = &val + + if is_little_endian: + out[0] = src[7] + out[1] = src[6] + out[2] = src[5] + out[3] = src[4] + out[4] = src[3] + out[5] = src[2] + out[6] = src[1] + out[7] = src[0] + else: + memcpy(out, src, 8) + + return PyBytes_FromStringAndSize(out, 8) + + +cdef class SerInt32Type(Serializer): + """Serialize a Python int to 4-byte big-endian signed int32.""" + + cpdef bytes serialize(self, object value, int protocol_version): + _check_int32_range(value) + cdef int32_t val = value + cdef char out[4] + cdef char *src = &val + + if is_little_endian: + out[0] = src[3] + out[1] = src[2] + out[2] = src[1] + out[3] = src[0] + else: + memcpy(out, src, 4) + + return PyBytes_FromStringAndSize(out, 4) + + +# --------------------------------------------------------------------------- +# DateType (timestamp) serializer +# --------------------------------------------------------------------------- + +cdef object _EPOCH_NAIVE = _datetime_mod.datetime(1970, 1, 1) +cdef object _EPOCH_DATE = _datetime_mod.date(1970, 1, 1) +cdef frozenset _number_types = frozenset((int, float)) + +cdef class SerDateType(Serializer): + """Serialize a datetime/date/numeric to 8-byte big-endian int64 (ms timestamp). + + Mirrors cqltypes.DateType.serialize() using integer arithmetic, + but avoids the Python-level struct.pack('>q', ...) overhead by + doing the byte-swap in C. + """ + + cpdef bytes serialize(self, object value, int protocol_version): + cdef int64_t timestamp + cdef object td, utcoffset + + try: + utcoffset = value.utcoffset() + if utcoffset is not None: + value = value - utcoffset + value = value.replace(tzinfo=None) + td = value - _EPOCH_NAIVE + timestamp = (td.days * 86400 + td.seconds) * 1000 + td.microseconds // 1000 + except AttributeError: + try: + td = value - _EPOCH_DATE + timestamp = td.days * 86400000 + except (AttributeError, TypeError): + if type(value) not in _number_types: + raise TypeError( + "DateType arguments must be a datetime, date, or timestamp" + ) + timestamp = int(value) + + cdef char out[8] + cdef char *src = ×tamp + if is_little_endian: + out[0] = src[7] + out[1] = src[6] + out[2] = src[5] + out[3] = src[4] + out[4] = src[3] + out[5] = src[2] + out[6] = src[1] + out[7] = src[0] + else: + memcpy(out, src, 8) + return PyBytes_FromStringAndSize(out, 8) + + +# --------------------------------------------------------------------------- +# Type detection helpers +# --------------------------------------------------------------------------- + +cdef inline bint _is_float_type(object subtype): + return subtype is cqltypes.FloatType or issubclass(subtype, cqltypes.FloatType) + +cdef inline bint _is_double_type(object subtype): + return subtype is cqltypes.DoubleType or issubclass(subtype, cqltypes.DoubleType) + +cdef inline bint _is_int32_type(object subtype): + return subtype is cqltypes.Int32Type or issubclass(subtype, cqltypes.Int32Type) + + +# --------------------------------------------------------------------------- +# VectorType serializer +# --------------------------------------------------------------------------- + +cdef class SerVectorType(Serializer): + """ + Optimized Cython serializer for VectorType. + + For float, double, and int32 vectors, pre-allocates a contiguous buffer + and uses C-level byte swapping. For other subtypes, falls back to + per-element Python serialization. + """ + + cdef int vector_size + cdef object subtype + # 0 = generic, 1 = float, 2 = double, 3 = int32 + cdef int type_code + + def __init__(self, cqltype): + super().__init__(cqltype) + self.vector_size = cqltype.vector_size + self.subtype = cqltype.subtype + + if _is_float_type(self.subtype): + self.type_code = 1 + elif _is_double_type(self.subtype): + self.type_code = 2 + elif _is_int32_type(self.subtype): + self.type_code = 3 + else: + self.type_code = 0 + + cpdef bytes serialize(self, object value, int protocol_version): + cdef int v_length = len(value) + if v_length != self.vector_size: + raise ValueError( + "Expected sequence of size %d for vector of type %s and " + "dimension %d, observed sequence of length %d" % ( + self.vector_size, self.subtype.typename, + self.vector_size, v_length)) + + if self.type_code == 1: + return self._serialize_float(value) + elif self.type_code == 2: + return self._serialize_double(value) + elif self.type_code == 3: + return self._serialize_int32(value) + else: + return self._serialize_generic(value, protocol_version) + + cdef inline bytes _serialize_float(self, object values): + """Serialize a sequence of floats into a contiguous big-endian buffer. + + Uses index-based access (values[i]) rather than iteration for + performance — the input must support ``__getitem__`` (list, tuple, + etc.). This is intentional: index access lets Cython emit a single + ``PyObject_GetItem`` call per element instead of iterator protocol + overhead. + """ + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 4 + if buf_size == 0: + return b"" + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) + + cdef float val + cdef char *src + cdef char *dst + + for i in range(self.vector_size): + _check_float_range(values[i]) + val = values[i] + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return result + + cdef inline bytes _serialize_double(self, object values): + """Serialize a sequence of doubles into a contiguous big-endian buffer. + + Uses index-based access (values[i]) rather than iteration for + performance — the input must support ``__getitem__`` (list, tuple, + etc.). This is intentional: index access lets Cython emit a single + ``PyObject_GetItem`` call per element instead of iterator protocol + overhead. + """ + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 8 + if buf_size == 0: + return b"" + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) + + cdef double val + cdef char *src + cdef char *dst + + for i in range(self.vector_size): + val = values[i] + src = &val + dst = buf + i * 8 + + if is_little_endian: + dst[0] = src[7] + dst[1] = src[6] + dst[2] = src[5] + dst[3] = src[4] + dst[4] = src[3] + dst[5] = src[2] + dst[6] = src[1] + dst[7] = src[0] + else: + memcpy(dst, src, 8) + + return result + + cdef inline bytes _serialize_int32(self, object values): + """Serialize a sequence of int32 values into a contiguous big-endian buffer. + + Uses index-based access (values[i]) rather than iteration for + performance — the input must support ``__getitem__`` (list, tuple, + etc.). This is intentional: index access lets Cython emit a single + ``PyObject_GetItem`` call per element instead of iterator protocol + overhead. + """ + cdef Py_ssize_t i + cdef Py_ssize_t buf_size = self.vector_size * 4 + if buf_size == 0: + return b"" + + cdef object result = PyBytes_FromStringAndSize(NULL, buf_size) + cdef char *buf = PyBytes_AS_STRING(result) + + cdef int32_t val + cdef char *src + cdef char *dst + + for i in range(self.vector_size): + _check_int32_range(values[i]) + val = values[i] + src = &val + dst = buf + i * 4 + + if is_little_endian: + dst[0] = src[3] + dst[1] = src[2] + dst[2] = src[1] + dst[3] = src[0] + else: + memcpy(dst, src, 4) + + return result + + cdef inline bytes _serialize_generic(self, object values, int protocol_version): + """Fallback: element-by-element Python serialization for non-optimized types.""" + serialized_size = self.subtype.serial_size() + buf = io.BytesIO() + for item in values: + item_bytes = self.subtype.serialize(item, protocol_version) + if serialized_size is None: + buf.write(uvint_pack(len(item_bytes))) + buf.write(item_bytes) + return buf.getvalue() + + +# --------------------------------------------------------------------------- +# Generic serializer (fallback for all other types) +# --------------------------------------------------------------------------- + +cdef class GenericSerializer(Serializer): + """ + Wraps a generic cqltype for serialization, delegating to the Python-level + cqltype.serialize() classmethod. + """ + + cpdef bytes serialize(self, object value, int protocol_version): + return self.cqltype.serialize(value, protocol_version) + + def __repr__(self): + return "GenericSerializer(%s)" % (self.cqltype,) + + +# --------------------------------------------------------------------------- +# Lookup and factory +# --------------------------------------------------------------------------- + +cdef dict _ser_classes = {} + +cpdef Serializer find_serializer(cqltype): + """Find a serializer for a cqltype.""" + + # For VectorType, always use SerVectorType (it handles generic subtypes internally) + if issubclass(cqltype, cqltypes.VectorType): + return SerVectorType(cqltype) + + # For scalar types with dedicated serializers, look up by name + name = 'Ser' + cqltype.__name__ + cls = _ser_classes.get(name) + if cls is not None: + return cls(cqltype) + + # Fallback to generic + return GenericSerializer(cqltype) + + +def make_serializers(cqltypes_list): + """Create a list of Serializer objects for each given cqltype.""" + return [find_serializer(ct) for ct in cqltypes_list] + + +# Build the lookup dict for scalar serializers at module load time +_ser_classes['SerFloatType'] = SerFloatType +_ser_classes['SerDoubleType'] = SerDoubleType +_ser_classes['SerInt32Type'] = SerInt32Type +_ser_classes['SerDateType'] = SerDateType +_ser_classes['SerTimestampType'] = SerDateType diff --git a/cassandra/timestamps.py b/cassandra/timestamps.py index d11359cf13..6ce727b080 100644 --- a/cassandra/timestamps.py +++ b/cassandra/timestamps.py @@ -23,10 +23,11 @@ log = logging.getLogger(__name__) + class MonotonicTimestampGenerator(object): """ - An object that, when called, returns ``int(time.time() * 1e6)`` when - possible, but, if the value returned by ``time.time`` doesn't increase, + An object that, when called, returns ``time.time_ns() // 1000`` when + possible, but, if the value returned by ``time.time_ns`` doesn't increase, drifts into the future and logs warnings. Exposed configuration attributes can be configured with arguments to ``__init__`` or by changing attributes on an initialized object. @@ -55,9 +56,8 @@ class MonotonicTimestampGenerator(object): def __init__(self, warn_on_drift=True, warning_threshold=1, warning_interval=1): self.lock = Lock() - with self.lock: - self.last = 0 - self._last_warn = 0 + self.last = 0 + self._last_warn = 0 self.warn_on_drift = warn_on_drift self.warning_threshold = warning_threshold self.warning_interval = warning_interval @@ -88,22 +88,25 @@ def __call__(self): internally to _next_timestamp. """ with self.lock: - return self._next_timestamp(now=int(time.time() * 1e6), - last=self.last) + return self._next_timestamp(now=time.time_ns() // 1000, last=self.last) def _maybe_warn(self, now): # should be called from inside the self.lock. diff = self.last - now since_last_warn = now - self._last_warn - warn = (self.warn_on_drift and - (diff >= self.warning_threshold * 1e6) and - (since_last_warn >= self.warning_interval * 1e6)) + warn = ( + self.warn_on_drift + and (diff >= self.warning_threshold * 1_000_000) + and (since_last_warn >= self.warning_interval * 1_000_000) + ) if warn: log.warning( "Clock skew detected: current tick ({now}) was {diff} " "microseconds behind the last generated timestamp " "({last}), returned timestamps will be artificially " "incremented to guarantee monotonicity.".format( - now=now, diff=diff, last=self.last)) + now=now, diff=diff, last=self.last + ) + ) self._last_warn = now diff --git a/tests/unit/test_timestamps.py b/tests/unit/test_timestamps.py index 8ef747d515..335230a18d 100644 --- a/tests/unit/test_timestamps.py +++ b/tests/unit/test_timestamps.py @@ -22,23 +22,24 @@ class _TimestampTestMixin(object): - - @mock.patch('cassandra.timestamps.time') - def _call_and_check_results(self, - patched_time_module, - system_time_expected_stamp_pairs, - timestamp_generator=None): + @mock.patch("cassandra.timestamps.time") + def _call_and_check_results( + self, + patched_time_module, + system_time_expected_stamp_pairs, + timestamp_generator=None, + ): """ - For each element in an iterable of (system_time, expected_timestamp) + For each element in an iterable of (system_time_ns, expected_timestamp) pairs, call a :class:`cassandra.timestamps.MonotonicTimestampGenerator` - with system_times as the underlying time.time() result, then assert + with system_times as the underlying time.time_ns() result, then assert that the result is expected_timestamp. Skips the check if expected_timestamp is None. """ - patched_time_module.time = mock.Mock() + patched_time_module.time_ns = mock.Mock() system_times, expected_timestamps = zip(*system_time_expected_stamp_pairs) - patched_time_module.time.side_effect = system_times + patched_time_module.time_ns.side_effect = system_times tsg = timestamp_generator or timestamps.MonotonicTimestampGenerator() for expected in expected_timestamps: @@ -46,14 +47,14 @@ def _call_and_check_results(self, if expected is not None: assert actual == expected - # assert we patched timestamps.time.time correctly + # assert we patched timestamps.time.time_ns correctly with pytest.raises(StopIteration): tsg() class TestTimestampGeneratorOutput(unittest.TestCase, _TimestampTestMixin): """ - Mock time.time and test the output of MonotonicTimestampGenerator.__call__ + Mock time.time_ns and test the output of MonotonicTimestampGenerator.__call__ given different patterns of changing results. """ @@ -71,10 +72,11 @@ def test_timestamps_during_and_after_same_system_time(self): """ self._call_and_check_results( system_time_expected_stamp_pairs=( - (15.0, 15 * 1e6), - (15.0, 15 * 1e6 + 1), - (15.0, 15 * 1e6 + 2), - (15.01, 15.01 * 1e6)) + (15_000_000_000, 15_000_000), + (15_000_000_000, 15_000_001), + (15_000_000_000, 15_000_002), + (15_010_000_000, 15_010_000), + ) ) def test_timestamps_during_and_after_backwards_system_time(self): @@ -87,18 +89,18 @@ def test_timestamps_during_and_after_backwards_system_time(self): """ self._call_and_check_results( system_time_expected_stamp_pairs=( - (15.0, 15 * 1e6), - (13.0, 15 * 1e6 + 1), - (14.0, 15 * 1e6 + 2), - (13.5, 15 * 1e6 + 3), - (15.01, 15.01 * 1e6)) + (15_000_000_000, 15_000_000), + (13_000_000_000, 15_000_001), + (14_000_000_000, 15_000_002), + (13_500_000_000, 15_000_003), + (15_010_000_000, 15_010_000), + ) ) class TestTimestampGeneratorLogging(unittest.TestCase): - def setUp(self): - self.log_patcher = mock.patch('cassandra.timestamps.log') + self.log_patcher = mock.patch("cassandra.timestamps.log") self.addCleanup(self.log_patcher.stop) self.patched_timestamp_log = self.log_patcher.start() @@ -119,10 +121,9 @@ def test_basic_log_content(self): @test_category timing """ tsg = timestamps.MonotonicTimestampGenerator( - warning_threshold=1e-6, - warning_interval=1e-6 + warning_threshold=1e-6, warning_interval=1e-6 ) - #The units of _last_warn is seconds + # The units of _last_warn is seconds tsg._last_warn = 12 tsg._next_timestamp(20, tsg.last) @@ -132,7 +133,7 @@ def test_basic_log_content(self): assert len(self.patched_timestamp_log.warning.call_args_list) == 1 self.assertLastCallArgRegex( self.patched_timestamp_log.warning.call_args, - r'Clock skew detected:.*\b16\b.*\b4\b.*\b20\b' + r"Clock skew detected:.*\b16\b.*\b4\b.*\b20\b", ) def test_disable_logging(self): @@ -179,8 +180,7 @@ def test_warning_threshold_respected_logs(self): @test_category timing """ tsg = timestamps.MonotonicTimestampGenerator( - warning_threshold=1e-6, - warning_interval=1e-6 + warning_threshold=1e-6, warning_interval=1e-6 ) tsg.last, tsg._last_warn = 100, 97 tsg._next_timestamp(98, tsg.last) @@ -197,8 +197,7 @@ def test_warning_interval_respected_no_logging(self): @test_category timing """ tsg = timestamps.MonotonicTimestampGenerator( - warning_threshold=1e-6, - warning_interval=2e-6 + warning_threshold=1e-6, warning_interval=2e-6 ) tsg.last = 100 tsg._next_timestamp(70, tsg.last) @@ -231,7 +230,6 @@ def test_warning_interval_respected_logs(self): class TestTimestampGeneratorMultipleThreads(unittest.TestCase): - def test_should_generate_incrementing_timestamps_for_all_threads(self): """ Tests when time is "stopped", values are assigned incrementally @@ -251,13 +249,13 @@ def request_time(): generated_timestamps.append(timestamp) tsg = timestamps.MonotonicTimestampGenerator() - fixed_time = 1 + fixed_time_ns = 1_000_000_000 num_threads = 5 timestamp_to_generate = 1000 generated_timestamps = [] - with mock.patch('time.time', new=mock.Mock(return_value=fixed_time)): + with mock.patch.object(timestamps.time, "time_ns", return_value=fixed_time_ns): threads = [] for _ in range(num_threads): threads.append(Thread(target=request_time)) @@ -270,4 +268,4 @@ def request_time(): assert len(generated_timestamps) == num_threads * timestamp_to_generate for i, timestamp in enumerate(sorted(generated_timestamps)): - assert int(i + 1e6) == timestamp + assert i + 1_000_000 == timestamp