diff --git a/pulsar/schema/__init__.py b/pulsar/schema/__init__.py index efa68066..e3fa49e8 100644 --- a/pulsar/schema/__init__.py +++ b/pulsar/schema/__init__.py @@ -22,3 +22,4 @@ from .schema import Schema, BytesSchema, StringSchema, JsonSchema from .schema_avro import AvroSchema +from .schema_protobuf import ProtobufNativeSchema diff --git a/pulsar/schema/schema_protobuf.py b/pulsar/schema/schema_protobuf.py new file mode 100644 index 00000000..e691e777 --- /dev/null +++ b/pulsar/schema/schema_protobuf.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import base64 +import _pulsar + +from .schema import Schema + +try: + from google.protobuf import descriptor_pb2 + from google.protobuf.message import Message as ProtobufMessage + HAS_PROTOBUF = True +except ImportError: + HAS_PROTOBUF = False + + +def _collect_file_descriptors(file_descriptor, visited, file_descriptor_set): + """Recursively collect all FileDescriptorProto objects into file_descriptor_set.""" + if file_descriptor.name in visited: + return + for dep in file_descriptor.dependencies: + _collect_file_descriptors(dep, visited, file_descriptor_set) + visited.add(file_descriptor.name) + proto = descriptor_pb2.FileDescriptorProto() + file_descriptor.CopyToProto(proto) + file_descriptor_set.file.append(proto) + + +def _build_schema_definition(descriptor): + """ + Build the schema definition dict matching Java's ProtobufNativeSchemaData format: + { + "fileDescriptorSet": , + "rootMessageTypeName": , + "rootFileDescriptorName": + } + This mirrors ProtobufNativeSchemaUtils.serialize() in the Java client. + """ + file_descriptor_set = descriptor_pb2.FileDescriptorSet() + _collect_file_descriptors(descriptor.file, set(), file_descriptor_set) + file_descriptor_set_bytes = file_descriptor_set.SerializeToString() + return { + "fileDescriptorSet": base64.b64encode(file_descriptor_set_bytes).decode('utf-8'), + "rootMessageTypeName": descriptor.full_name, + "rootFileDescriptorName": descriptor.file.name, + } + + +if HAS_PROTOBUF: + class ProtobufNativeSchema(Schema): + """ + Schema for protobuf messages using the native protobuf binary encoding. + + The schema definition is stored as a JSON-encoded ProtobufNativeSchemaData + (fileDescriptorSet, rootMessageTypeName, rootFileDescriptorName), which is + compatible with the Java client's ProtobufNativeSchema. + + Parameters + ---------- + record_cls: + A generated protobuf message class (subclass of google.protobuf.message.Message). + + Example + ------- + .. code-block:: python + + import pulsar + from pulsar.schema import ProtobufNativeSchema + from my_proto_pb2 import MyMessage + + client = pulsar.Client('pulsar://localhost:6650') + producer = client.create_producer( + 'my-topic', + schema=ProtobufNativeSchema(MyMessage) + ) + producer.send(MyMessage(field='value')) + """ + + def __init__(self, record_cls): + if not (isinstance(record_cls, type) and issubclass(record_cls, ProtobufMessage)): + raise TypeError( + f'record_cls must be a protobuf Message subclass, got {record_cls!r}' + ) + schema_definition = _build_schema_definition(record_cls.DESCRIPTOR) + super(ProtobufNativeSchema, self).__init__( + record_cls, _pulsar.SchemaType.PROTOBUF_NATIVE, schema_definition, 'PROTOBUF_NATIVE' + ) + + def encode(self, obj): + self._validate_object_type(obj) + return obj.SerializeToString() + + def decode(self, data): + return self._record_cls.FromString(data) + + def __str__(self): + return f'ProtobufNativeSchema({self._record_cls.__name__})' + +else: + class ProtobufNativeSchema(Schema): + def __init__(self, _record_cls=None): + raise Exception( + "protobuf library support was not found. " + "Install it with: pip install protobuf" + ) + + def encode(self, obj): + pass + + def decode(self, data): + pass diff --git a/src/enums.cc b/src/enums.cc index 447d013c..7ee28ea1 100644 --- a/src/enums.cc +++ b/src/enums.cc @@ -115,7 +115,8 @@ void export_enums(py::module_& m) { .value("AVRO", pulsar::AVRO) .value("AUTO_CONSUME", pulsar::AUTO_CONSUME) .value("AUTO_PUBLISH", pulsar::AUTO_PUBLISH) - .value("KEY_VALUE", pulsar::KEY_VALUE); + .value("KEY_VALUE", pulsar::KEY_VALUE) + .value("PROTOBUF_NATIVE", pulsar::PROTOBUF_NATIVE); enum_(m, "InitialPosition", "Supported initial position") .value("Latest", InitialPositionLatest) diff --git a/tests/schema_test.py b/tests/schema_test.py index 9d031d15..575ec19c 100755 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -18,7 +18,10 @@ # under the License. # +import base64 import math +import os +import sys import requests from typing import List from unittest import TestCase, main @@ -30,6 +33,10 @@ import json from fastavro.schema import load_schema +# Make generated protobuf test classes importable +sys.path.insert(0, os.path.dirname(__file__)) +from test_schema_pb2 import TestMessage, TestMessageWithNested, TestInner + class ExampleRecord(Record): str_field = String() int_field = Integer() @@ -1404,5 +1411,90 @@ def test_schema_type_promotion(self): client.close() +class ProtobufNativeSchemaTest(TestCase): + """Unit tests for ProtobufNativeSchema (no Pulsar broker required).""" + + def test_schema_type(self): + """Schema type must be PROTOBUF_NATIVE.""" + import _pulsar + schema = ProtobufNativeSchema(TestMessage) + self.assertEqual(schema.schema_info().schema_type(), _pulsar.SchemaType.PROTOBUF_NATIVE) + + def test_schema_definition_keys(self): + """Schema definition JSON must contain the three required keys.""" + schema = ProtobufNativeSchema(TestMessage) + schema_def = json.loads(schema.schema_info().schema()) + self.assertIn('fileDescriptorSet', schema_def) + self.assertIn('rootMessageTypeName', schema_def) + self.assertIn('rootFileDescriptorName', schema_def) + + def test_schema_definition_values(self): + """rootMessageTypeName and rootFileDescriptorName must match the descriptor.""" + schema = ProtobufNativeSchema(TestMessage) + schema_def = json.loads(schema.schema_info().schema()) + self.assertEqual(schema_def['rootMessageTypeName'], 'test.TestMessage') + self.assertEqual(schema_def['rootFileDescriptorName'], 'test_schema.proto') + + def test_file_descriptor_set_is_valid_base64_proto(self): + """fileDescriptorSet must be valid base64-encoded FileDescriptorSet bytes.""" + from google.protobuf import descriptor_pb2 + schema = ProtobufNativeSchema(TestMessage) + schema_def = json.loads(schema.schema_info().schema()) + raw = base64.b64decode(schema_def['fileDescriptorSet']) + fds = descriptor_pb2.FileDescriptorSet.FromString(raw) + file_names = [f.name for f in fds.file] + self.assertIn('test_schema.proto', file_names) + + def test_encode_decode_roundtrip(self): + """encode then decode must reproduce the original message.""" + schema = ProtobufNativeSchema(TestMessage) + original = TestMessage(name='hello', value=42) + encoded = schema.encode(original) + decoded = schema.decode(encoded) + self.assertEqual(decoded.name, 'hello') + self.assertEqual(decoded.value, 42) + + def test_encode_produces_protobuf_binary(self): + """Encoded bytes must be valid protobuf binary (parseable by the class directly).""" + schema = ProtobufNativeSchema(TestMessage) + msg = TestMessage(name='pulsar', value=100) + encoded = schema.encode(msg) + # Verify with protobuf's own parser + reparsed = TestMessage.FromString(encoded) + self.assertEqual(reparsed, msg) + + def test_encode_decode_nested_message(self): + """encode/decode round-trip works for messages containing nested message fields.""" + schema = ProtobufNativeSchema(TestMessageWithNested) + original = TestMessageWithNested( + str_field='test', + int_field=7, + double_field=3.14, + nested=TestInner(inner_str='inner', inner_int=999), + ) + decoded = schema.decode(schema.encode(original)) + self.assertEqual(decoded.str_field, 'test') + self.assertEqual(decoded.int_field, 7) + self.assertAlmostEqual(decoded.double_field, 3.14) + self.assertEqual(decoded.nested.inner_str, 'inner') + self.assertEqual(decoded.nested.inner_int, 999) + + def test_wrong_type_raises(self): + """Encoding an object of the wrong type must raise TypeError.""" + schema = ProtobufNativeSchema(TestMessage) + with self.assertRaises(TypeError): + schema.encode("not a protobuf message") + + def test_non_message_class_raises(self): + """Constructing with a non-Message class must raise TypeError.""" + with self.assertRaises(TypeError): + ProtobufNativeSchema(str) + + def test_schema_name(self): + """Schema name must be 'PROTOBUF_NATIVE'.""" + schema = ProtobufNativeSchema(TestMessage) + self.assertEqual(schema.schema_info().name(), 'PROTOBUF_NATIVE') + + if __name__ == '__main__': main() diff --git a/tests/test_schema.proto b/tests/test_schema.proto new file mode 100644 index 00000000..42ef6595 --- /dev/null +++ b/tests/test_schema.proto @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +syntax = "proto3"; + +package test; + +message TestMessage { + string name = 1; + int32 value = 2; +} + +message TestMessageWithNested { + string str_field = 1; + int32 int_field = 2; + double double_field = 3; + TestInner nested = 4; +} + +message TestInner { + string inner_str = 1; + int64 inner_int = 2; +} diff --git a/tests/test_schema_pb2.py b/tests/test_schema_pb2.py new file mode 100644 index 00000000..0434b193 --- /dev/null +++ b/tests/test_schema_pb2.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: test_schema.proto +# Protobuf Python Version: 6.32.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 32, + 0, + '', + 'test_schema.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11test_schema.proto\x12\x04test\"*\n\x0bTestMessage\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05\"t\n\x15TestMessageWithNested\x12\x11\n\tstr_field\x18\x01 \x01(\t\x12\x11\n\tint_field\x18\x02 \x01(\x05\x12\x14\n\x0c\x64ouble_field\x18\x03 \x01(\x01\x12\x1f\n\x06nested\x18\x04 \x01(\x0b\x32\x0f.test.TestInner\"1\n\tTestInner\x12\x11\n\tinner_str\x18\x01 \x01(\t\x12\x11\n\tinner_int\x18\x02 \x01(\x03\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'test_schema_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TESTMESSAGE']._serialized_start=27 + _globals['_TESTMESSAGE']._serialized_end=69 + _globals['_TESTMESSAGEWITHNESTED']._serialized_start=71 + _globals['_TESTMESSAGEWITHNESTED']._serialized_end=187 + _globals['_TESTINNER']._serialized_start=189 + _globals['_TESTINNER']._serialized_end=238 +# @@protoc_insertion_point(module_scope)