diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java index b3beb2cf3..4034c2732 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java @@ -144,11 +144,10 @@ private void generateDeserializer() { writer.addImport("smithy_core.deserializers", "ShapeDeserializer"); writer.addImport("smithy_core.exceptions", "SerializationError"); - // TODO: add in unknown handling - var symbol = symbolProvider.toSymbol(shape); var deserializerSymbol = symbol.expectProperty(SymbolProperties.DESERIALIZER); var schemaSymbol = symbol.expectProperty(SymbolProperties.SCHEMA); + var unknownSymbol = symbol.expectProperty(SymbolProperties.UNION_UNKNOWN); writer.putContext("schema", schemaSymbol); writer.write(""" class $1L: @@ -168,6 +167,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: ${4C|} case _: logger.debug("Unexpected member schema: %s", schema) + self._set_result($5L(tag=schema.expect_member_name())) def _set_result(self, value: $2T) -> None: if self._result is not None: @@ -177,7 +177,8 @@ raise SerializationError("Unions must have exactly one value, but found more tha deserializerSymbol.getName(), symbol, schemaSymbol, - writer.consumer(w -> deserializeMembers())); + writer.consumer(w -> deserializeMembers()), + unknownSymbol.getName()); } private void deserializeMembers() { diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py index f52791450..3cb91d383 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py @@ -50,7 +50,7 @@ def read_struct( message_deserializer = self._create_deserializer(schema, headers) message_deserializer.read_struct(schema, consumer) else: - member_schema = schema.members[member_name] + member_schema = self._resolve_member_schema(schema, member_name) message_deserializer = self._create_deserializer( member_schema, headers ) @@ -71,6 +71,21 @@ def read_struct( case _: raise EventError(f"Unknown event structure: {self._event}") + def _resolve_member_schema(self, schema: Schema, member_name: str) -> Schema: + if member_schema := schema.members.get(member_name): + return member_schema + + logger.debug( + "Received unmodeled event stream member %s for union %s", + member_name, + schema.id, + ) + return Schema.member( + id=schema.id.with_member(member_name), + target=schema, + index=-1, + ) + def _create_deserializer( self, schema: Schema, headers: HEADERS_DICT ) -> ShapeDeserializer: diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py index 58ed7f184..401442130 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py @@ -132,8 +132,9 @@ async def receive(self) -> E | None: ) result = self._deserializer(deserializer) logger.debug("Successfully deserialized event: %s", result) - if isinstance(getattr(result, "value"), Exception): - raise result.value # type: ignore + value = getattr(result, "value", None) + if isinstance(value, Exception): + raise value return result async def close(self) -> None: diff --git a/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py b/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py index 09162ecce..1fbb8cd02 100644 --- a/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py +++ b/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py @@ -381,7 +381,7 @@ def serialize_members(self, serializer: ShapeSerializer): @dataclass -class EventStreamUnknownEvent: +class EventStreamUnknown: tag: str def serialize(self, serializer: ShapeSerializer): @@ -396,7 +396,7 @@ def serialize_members(self, serializer: ShapeSerializer): | EventStreamPayloadEvent | EventStreamBlobPayloadEvent | EventStreamErrorEvent - | EventStreamUnknownEvent + | EventStreamUnknown ) @@ -429,7 +429,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de))) case _: - raise SmithyError(f"Unexpected member schema: {schema}") + self._set_result(EventStreamUnknown(tag=schema.expect_member_name())) def _set_result(self, value: EventStream) -> None: if self._result is not None: @@ -635,6 +635,19 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None: ] +UNKNOWN_EVENT_CASE = ( + EventStreamUnknown(tag="unmodeledEvent"), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "unmodeledEvent", + ":content-type": "application/json", + }, + payload=b"{}", + ), +) + + INITIAL_REQUEST_CASE = ( EventStreamOperationInputOutput(message="The initial request!"), EventMessage( diff --git a/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py b/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py index 34081bbf8..7ae9c77ce 100644 --- a/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py @@ -20,6 +20,7 @@ EventStreamDeserializer, EventStreamErrorEvent, EventStreamOperationInputOutput, + EventStreamUnknown, ) @@ -126,3 +127,20 @@ async def test_read_closed_receiver_source() -> None: with pytest.raises(IOError): await receiver.receive() assert receiver.closed + + +def test_deserialize_unknown_event_type(): + message = EventMessage( + headers={ + ":message-type": "event", + ":event-type": "unmodeledEvent", + ":content-type": "application/json", + }, + payload=b"{}", + ) + source = Event.decode(BytesIO(message.encode())) + assert source is not None + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) + result = EventStreamDeserializer().deserialize(deserializer) + assert isinstance(result, EventStreamUnknown) + assert result.tag == "unmodeledEvent"