Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pulsar/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@

from .schema import Schema, BytesSchema, StringSchema, JsonSchema
from .schema_avro import AvroSchema
from .schema_protobuf import ProtobufNativeSchema
127 changes: 127 additions & 0 deletions pulsar/schema/schema_protobuf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import base64
import _pulsar

from .schema import Schema

try:
from google.protobuf import descriptor_pb2
from google.protobuf.message import Message as ProtobufMessage
HAS_PROTOBUF = True
except ImportError:
HAS_PROTOBUF = False


def _collect_file_descriptors(file_descriptor, visited, file_descriptor_set):
"""Recursively collect all FileDescriptorProto objects into file_descriptor_set."""
if file_descriptor.name in visited:
return
for dep in file_descriptor.dependencies:
_collect_file_descriptors(dep, visited, file_descriptor_set)
visited.add(file_descriptor.name)
proto = descriptor_pb2.FileDescriptorProto()
file_descriptor.CopyToProto(proto)
file_descriptor_set.file.append(proto)


def _build_schema_definition(descriptor):
"""
Build the schema definition dict matching Java's ProtobufNativeSchemaData format:
{
"fileDescriptorSet": <base64-encoded FileDescriptorSet bytes>,
"rootMessageTypeName": <full name of the root message>,
"rootFileDescriptorName": <name of the root .proto file>
}
This mirrors ProtobufNativeSchemaUtils.serialize() in the Java client.
"""
file_descriptor_set = descriptor_pb2.FileDescriptorSet()
_collect_file_descriptors(descriptor.file, set(), file_descriptor_set)
file_descriptor_set_bytes = file_descriptor_set.SerializeToString()
return {
"fileDescriptorSet": base64.b64encode(file_descriptor_set_bytes).decode('utf-8'),
"rootMessageTypeName": descriptor.full_name,
"rootFileDescriptorName": descriptor.file.name,
}


if HAS_PROTOBUF:
class ProtobufNativeSchema(Schema):
"""
Schema for protobuf messages using the native protobuf binary encoding.

The schema definition is stored as a JSON-encoded ProtobufNativeSchemaData
(fileDescriptorSet, rootMessageTypeName, rootFileDescriptorName), which is
compatible with the Java client's ProtobufNativeSchema.

Parameters
----------
record_cls:
A generated protobuf message class (subclass of google.protobuf.message.Message).

Example
-------
.. code-block:: python

import pulsar
from pulsar.schema import ProtobufNativeSchema
from my_proto_pb2 import MyMessage

client = pulsar.Client('pulsar://localhost:6650')
producer = client.create_producer(
'my-topic',
schema=ProtobufNativeSchema(MyMessage)
)
producer.send(MyMessage(field='value'))
"""

def __init__(self, record_cls):
if not (isinstance(record_cls, type) and issubclass(record_cls, ProtobufMessage)):
raise TypeError(
f'record_cls must be a protobuf Message subclass, got {record_cls!r}'
)
schema_definition = _build_schema_definition(record_cls.DESCRIPTOR)
super(ProtobufNativeSchema, self).__init__(
record_cls, _pulsar.SchemaType.PROTOBUF_NATIVE, schema_definition, 'PROTOBUF_NATIVE'
)

def encode(self, obj):
self._validate_object_type(obj)
return obj.SerializeToString()

def decode(self, data):
return self._record_cls.FromString(data)

def __str__(self):
return f'ProtobufNativeSchema({self._record_cls.__name__})'

else:
class ProtobufNativeSchema(Schema):
def __init__(self, _record_cls=None):
raise Exception(
"protobuf library support was not found. "
"Install it with: pip install protobuf"
)

def encode(self, obj):
pass

def decode(self, data):
pass
3 changes: 2 additions & 1 deletion src/enums.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ void export_enums(py::module_& m) {
.value("AVRO", pulsar::AVRO)
.value("AUTO_CONSUME", pulsar::AUTO_CONSUME)
.value("AUTO_PUBLISH", pulsar::AUTO_PUBLISH)
.value("KEY_VALUE", pulsar::KEY_VALUE);
.value("KEY_VALUE", pulsar::KEY_VALUE)
.value("PROTOBUF_NATIVE", pulsar::PROTOBUF_NATIVE);

enum_<InitialPosition>(m, "InitialPosition", "Supported initial position")
.value("Latest", InitialPositionLatest)
Expand Down
92 changes: 92 additions & 0 deletions tests/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
# under the License.
#

import base64
import math
import os
import sys
import requests
from typing import List
from unittest import TestCase, main
Expand All @@ -30,6 +33,10 @@
import json
from fastavro.schema import load_schema

# Make generated protobuf test classes importable
sys.path.insert(0, os.path.dirname(__file__))
from test_schema_pb2 import TestMessage, TestMessageWithNested, TestInner

class ExampleRecord(Record):
str_field = String()
int_field = Integer()
Expand Down Expand Up @@ -1404,5 +1411,90 @@ def test_schema_type_promotion(self):
client.close()


class ProtobufNativeSchemaTest(TestCase):
"""Unit tests for ProtobufNativeSchema (no Pulsar broker required)."""

def test_schema_type(self):
"""Schema type must be PROTOBUF_NATIVE."""
import _pulsar
schema = ProtobufNativeSchema(TestMessage)
self.assertEqual(schema.schema_info().schema_type(), _pulsar.SchemaType.PROTOBUF_NATIVE)

def test_schema_definition_keys(self):
"""Schema definition JSON must contain the three required keys."""
schema = ProtobufNativeSchema(TestMessage)
schema_def = json.loads(schema.schema_info().schema())
self.assertIn('fileDescriptorSet', schema_def)
self.assertIn('rootMessageTypeName', schema_def)
self.assertIn('rootFileDescriptorName', schema_def)

def test_schema_definition_values(self):
"""rootMessageTypeName and rootFileDescriptorName must match the descriptor."""
schema = ProtobufNativeSchema(TestMessage)
schema_def = json.loads(schema.schema_info().schema())
self.assertEqual(schema_def['rootMessageTypeName'], 'test.TestMessage')
self.assertEqual(schema_def['rootFileDescriptorName'], 'test_schema.proto')

def test_file_descriptor_set_is_valid_base64_proto(self):
"""fileDescriptorSet must be valid base64-encoded FileDescriptorSet bytes."""
from google.protobuf import descriptor_pb2
schema = ProtobufNativeSchema(TestMessage)
schema_def = json.loads(schema.schema_info().schema())
raw = base64.b64decode(schema_def['fileDescriptorSet'])
fds = descriptor_pb2.FileDescriptorSet.FromString(raw)
file_names = [f.name for f in fds.file]
self.assertIn('test_schema.proto', file_names)

def test_encode_decode_roundtrip(self):
"""encode then decode must reproduce the original message."""
schema = ProtobufNativeSchema(TestMessage)
original = TestMessage(name='hello', value=42)
encoded = schema.encode(original)
decoded = schema.decode(encoded)
self.assertEqual(decoded.name, 'hello')
self.assertEqual(decoded.value, 42)

def test_encode_produces_protobuf_binary(self):
"""Encoded bytes must be valid protobuf binary (parseable by the class directly)."""
schema = ProtobufNativeSchema(TestMessage)
msg = TestMessage(name='pulsar', value=100)
encoded = schema.encode(msg)
# Verify with protobuf's own parser
reparsed = TestMessage.FromString(encoded)
self.assertEqual(reparsed, msg)

def test_encode_decode_nested_message(self):
"""encode/decode round-trip works for messages containing nested message fields."""
schema = ProtobufNativeSchema(TestMessageWithNested)
original = TestMessageWithNested(
str_field='test',
int_field=7,
double_field=3.14,
nested=TestInner(inner_str='inner', inner_int=999),
)
decoded = schema.decode(schema.encode(original))
self.assertEqual(decoded.str_field, 'test')
self.assertEqual(decoded.int_field, 7)
self.assertAlmostEqual(decoded.double_field, 3.14)
self.assertEqual(decoded.nested.inner_str, 'inner')
self.assertEqual(decoded.nested.inner_int, 999)

def test_wrong_type_raises(self):
"""Encoding an object of the wrong type must raise TypeError."""
schema = ProtobufNativeSchema(TestMessage)
with self.assertRaises(TypeError):
schema.encode("not a protobuf message")

def test_non_message_class_raises(self):
"""Constructing with a non-Message class must raise TypeError."""
with self.assertRaises(TypeError):
ProtobufNativeSchema(str)

def test_schema_name(self):
"""Schema name must be 'PROTOBUF_NATIVE'."""
schema = ProtobufNativeSchema(TestMessage)
self.assertEqual(schema.schema_info().name(), 'PROTOBUF_NATIVE')


if __name__ == '__main__':
main()
37 changes: 37 additions & 0 deletions tests/test_schema.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

syntax = "proto3";

package test;

message TestMessage {
string name = 1;
int32 value = 2;
}

message TestMessageWithNested {
string str_field = 1;
int32 int_field = 2;
double double_field = 3;
TestInner nested = 4;
}

message TestInner {
string inner_str = 1;
int64 inner_int = 2;
}
40 changes: 40 additions & 0 deletions tests/test_schema_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading