From 78a70a502021e82878b456203a96642553771147 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen Date: Tue, 7 Apr 2026 11:08:39 -0400 Subject: [PATCH 1/3] fix(event-stream): Handle unknown event types gracefully instead of crashing --- .../codegen/generators/UnionGenerator.java | 5 ++- .../_private/deserializers.py | 38 ++++++++++++++++--- .../smithy_aws_event_stream/aio/__init__.py | 2 +- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java index b3beb2cf3..34b6275d9 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java @@ -149,6 +149,7 @@ private void generateDeserializer() { var symbol = symbolProvider.toSymbol(shape); var deserializerSymbol = symbol.expectProperty(SymbolProperties.DESERIALIZER); var schemaSymbol = symbol.expectProperty(SymbolProperties.SCHEMA); + var unknownSymbol = symbol.expectProperty(SymbolProperties.UNION_UNKNOWN); writer.putContext("schema", schemaSymbol); writer.write(""" class $1L: @@ -168,6 +169,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: ${4C|} case _: logger.debug("Unexpected member schema: %s", schema) + self._set_result($5L(tag=schema.member_name or "")) def _set_result(self, value: $2T) -> None: if self._result is not None: @@ -177,7 +179,8 @@ raise SerializationError("Unions must have exactly one value, but found more tha deserializerSymbol.getName(), symbol, schemaSymbol, - writer.consumer(w -> deserializeMembers())); + writer.consumer(w -> deserializeMembers()), + unknownSymbol.getName()); } private void deserializeMembers() { diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py index f52791450..cf0bf48cb 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py @@ -50,11 +50,39 @@ def read_struct( message_deserializer = self._create_deserializer(schema, headers) message_deserializer.read_struct(schema, consumer) else: - member_schema = schema.members[member_name] - message_deserializer = self._create_deserializer( - member_schema, headers - ) - consumer(member_schema, message_deserializer) + member_schema = schema.members.get(member_name) + if member_schema is None: + # Unknown event type. Call the consumer with a + # schema that carries the event type name as + # member_name and a member_index of -1 so the + # generated default branch constructs the unknown + # variant with the correct tag. + logger.debug( + "Unknown event type: %s", member_name + ) + from smithy_core.shapes import ShapeID + + _UNKNOWN_TARGET = Schema( + id=ShapeID("smithy.unknown#Unknown"), + shape_type=ShapeType.STRUCTURE, + ) + unknown_schema = Schema( + id=ShapeID( + f"smithy.unknown#Unknown${member_name}" + ), + shape_type=ShapeType.STRUCTURE, + member_target=_UNKNOWN_TARGET, + member_index=-1, + ) + consumer( + unknown_schema, + self._payload_codec.create_deserializer(b"{}"), + ) + else: + message_deserializer = self._create_deserializer( + member_schema, headers + ) + consumer(member_schema, message_deserializer) case "exception": member_name = expect_type(str, headers[":exception-type"]) member_schema = schema.members[member_name] diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py index 58ed7f184..a746fd61d 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py @@ -132,7 +132,7 @@ async def receive(self) -> E | None: ) result = self._deserializer(deserializer) logger.debug("Successfully deserialized event: %s", result) - if isinstance(getattr(result, "value"), Exception): + if isinstance(getattr(result, "value", None), Exception): raise result.value # type: ignore return result From e8f84ce8f8b07a7e8ad8310e320ee60a59f3c651 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen Date: Tue, 7 Apr 2026 14:53:00 -0400 Subject: [PATCH 2/3] Fix format issue and add unit tests for handling unknown events --- .../_private/deserializers.py | 11 +++-------- .../smithy_aws_event_stream/aio/__init__.py | 5 +++-- .../tests/unit/_private/__init__.py | 19 ++++++++++++++++--- .../tests/unit/_private/test_deserializers.py | 18 ++++++++++++++++++ 4 files changed, 40 insertions(+), 13 deletions(-) diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py index cf0bf48cb..6b8783916 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py @@ -10,7 +10,7 @@ SpecificShapeDeserializer, ) from smithy_core.schemas import Schema -from smithy_core.shapes import ShapeType +from smithy_core.shapes import ShapeID, ShapeType from smithy_core.traits import EventHeaderTrait from smithy_core.utils import expect_type @@ -57,19 +57,14 @@ def read_struct( # member_name and a member_index of -1 so the # generated default branch constructs the unknown # variant with the correct tag. - logger.debug( - "Unknown event type: %s", member_name - ) - from smithy_core.shapes import ShapeID + logger.debug("Unknown event type: %s", member_name) _UNKNOWN_TARGET = Schema( id=ShapeID("smithy.unknown#Unknown"), shape_type=ShapeType.STRUCTURE, ) unknown_schema = Schema( - id=ShapeID( - f"smithy.unknown#Unknown${member_name}" - ), + id=ShapeID(f"smithy.unknown#Unknown${member_name}"), shape_type=ShapeType.STRUCTURE, member_target=_UNKNOWN_TARGET, member_index=-1, diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py index a746fd61d..401442130 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py @@ -132,8 +132,9 @@ async def receive(self) -> E | None: ) result = self._deserializer(deserializer) logger.debug("Successfully deserialized event: %s", result) - if isinstance(getattr(result, "value", None), Exception): - raise result.value # type: ignore + value = getattr(result, "value", None) + if isinstance(value, Exception): + raise value return result async def close(self) -> None: diff --git a/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py b/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py index 09162ecce..2c12808aa 100644 --- a/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py +++ b/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py @@ -381,7 +381,7 @@ def serialize_members(self, serializer: ShapeSerializer): @dataclass -class EventStreamUnknownEvent: +class EventStreamUnknown: tag: str def serialize(self, serializer: ShapeSerializer): @@ -396,7 +396,7 @@ def serialize_members(self, serializer: ShapeSerializer): | EventStreamPayloadEvent | EventStreamBlobPayloadEvent | EventStreamErrorEvent - | EventStreamUnknownEvent + | EventStreamUnknown ) @@ -429,7 +429,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de))) case _: - raise SmithyError(f"Unexpected member schema: {schema}") + self._set_result(EventStreamUnknown(tag=schema.member_name or "")) def _set_result(self, value: EventStream) -> None: if self._result is not None: @@ -635,6 +635,19 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None: ] +UNKNOWN_EVENT_CASE = ( + EventStreamUnknown(tag="intermediateGroupEvent"), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "intermediateGroupEvent", + ":content-type": "application/json", + }, + payload=b"{}", + ), +) + + INITIAL_REQUEST_CASE = ( EventStreamOperationInputOutput(message="The initial request!"), EventMessage( diff --git a/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py b/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py index 34081bbf8..41c635e03 100644 --- a/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py @@ -20,6 +20,7 @@ EventStreamDeserializer, EventStreamErrorEvent, EventStreamOperationInputOutput, + EventStreamUnknown, ) @@ -126,3 +127,20 @@ async def test_read_closed_receiver_source() -> None: with pytest.raises(IOError): await receiver.receive() assert receiver.closed + + +def test_deserialize_unknown_event_type(): + message = EventMessage( + headers={ + ":message-type": "event", + ":event-type": "intermediateGroupEvent", + ":content-type": "application/json", + }, + payload=b"{}", + ) + source = Event.decode(BytesIO(message.encode())) + assert source is not None + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) + result = EventStreamDeserializer().deserialize(deserializer) + assert isinstance(result, EventStreamUnknown) + assert result.tag == "intermediateGroupEvent" From d40868a95d2caca4842eabc0ab12053845b0014c Mon Sep 17 00:00:00 2001 From: Yuxuan Chen Date: Fri, 10 Apr 2026 15:15:32 -0400 Subject: [PATCH 3/3] Address PR review: refactor unknown event handling --- .../codegen/generators/UnionGenerator.java | 4 +- .../_private/deserializers.py | 50 ++++++++----------- .../tests/unit/_private/__init__.py | 6 +-- .../tests/unit/_private/test_deserializers.py | 4 +- 4 files changed, 27 insertions(+), 37 deletions(-) diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java index 34b6275d9..4034c2732 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java @@ -144,8 +144,6 @@ private void generateDeserializer() { writer.addImport("smithy_core.deserializers", "ShapeDeserializer"); writer.addImport("smithy_core.exceptions", "SerializationError"); - // TODO: add in unknown handling - var symbol = symbolProvider.toSymbol(shape); var deserializerSymbol = symbol.expectProperty(SymbolProperties.DESERIALIZER); var schemaSymbol = symbol.expectProperty(SymbolProperties.SCHEMA); @@ -169,7 +167,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: ${4C|} case _: logger.debug("Unexpected member schema: %s", schema) - self._set_result($5L(tag=schema.member_name or "")) + self._set_result($5L(tag=schema.expect_member_name())) def _set_result(self, value: $2T) -> None: if self._result is not None: diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py index 6b8783916..3cb91d383 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py @@ -10,7 +10,7 @@ SpecificShapeDeserializer, ) from smithy_core.schemas import Schema -from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.shapes import ShapeType from smithy_core.traits import EventHeaderTrait from smithy_core.utils import expect_type @@ -50,34 +50,11 @@ def read_struct( message_deserializer = self._create_deserializer(schema, headers) message_deserializer.read_struct(schema, consumer) else: - member_schema = schema.members.get(member_name) - if member_schema is None: - # Unknown event type. Call the consumer with a - # schema that carries the event type name as - # member_name and a member_index of -1 so the - # generated default branch constructs the unknown - # variant with the correct tag. - logger.debug("Unknown event type: %s", member_name) - - _UNKNOWN_TARGET = Schema( - id=ShapeID("smithy.unknown#Unknown"), - shape_type=ShapeType.STRUCTURE, - ) - unknown_schema = Schema( - id=ShapeID(f"smithy.unknown#Unknown${member_name}"), - shape_type=ShapeType.STRUCTURE, - member_target=_UNKNOWN_TARGET, - member_index=-1, - ) - consumer( - unknown_schema, - self._payload_codec.create_deserializer(b"{}"), - ) - else: - message_deserializer = self._create_deserializer( - member_schema, headers - ) - consumer(member_schema, message_deserializer) + member_schema = self._resolve_member_schema(schema, member_name) + message_deserializer = self._create_deserializer( + member_schema, headers + ) + consumer(member_schema, message_deserializer) case "exception": member_name = expect_type(str, headers[":exception-type"]) member_schema = schema.members[member_name] @@ -94,6 +71,21 @@ def read_struct( case _: raise EventError(f"Unknown event structure: {self._event}") + def _resolve_member_schema(self, schema: Schema, member_name: str) -> Schema: + if member_schema := schema.members.get(member_name): + return member_schema + + logger.debug( + "Received unmodeled event stream member %s for union %s", + member_name, + schema.id, + ) + return Schema.member( + id=schema.id.with_member(member_name), + target=schema, + index=-1, + ) + def _create_deserializer( self, schema: Schema, headers: HEADERS_DICT ) -> ShapeDeserializer: diff --git a/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py b/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py index 2c12808aa..1fbb8cd02 100644 --- a/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py +++ b/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py @@ -429,7 +429,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de))) case _: - self._set_result(EventStreamUnknown(tag=schema.member_name or "")) + self._set_result(EventStreamUnknown(tag=schema.expect_member_name())) def _set_result(self, value: EventStream) -> None: if self._result is not None: @@ -636,11 +636,11 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None: UNKNOWN_EVENT_CASE = ( - EventStreamUnknown(tag="intermediateGroupEvent"), + EventStreamUnknown(tag="unmodeledEvent"), EventMessage( headers={ ":message-type": "event", - ":event-type": "intermediateGroupEvent", + ":event-type": "unmodeledEvent", ":content-type": "application/json", }, payload=b"{}", diff --git a/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py b/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py index 41c635e03..7ae9c77ce 100644 --- a/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py @@ -133,7 +133,7 @@ def test_deserialize_unknown_event_type(): message = EventMessage( headers={ ":message-type": "event", - ":event-type": "intermediateGroupEvent", + ":event-type": "unmodeledEvent", ":content-type": "application/json", }, payload=b"{}", @@ -143,4 +143,4 @@ def test_deserialize_unknown_event_type(): deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamDeserializer().deserialize(deserializer) assert isinstance(result, EventStreamUnknown) - assert result.tag == "intermediateGroupEvent" + assert result.tag == "unmodeledEvent"