Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,10 @@ private void generateDeserializer() {
writer.addImport("smithy_core.deserializers", "ShapeDeserializer");
writer.addImport("smithy_core.exceptions", "SerializationError");

// TODO: add in unknown handling

var symbol = symbolProvider.toSymbol(shape);
var deserializerSymbol = symbol.expectProperty(SymbolProperties.DESERIALIZER);
var schemaSymbol = symbol.expectProperty(SymbolProperties.SCHEMA);
var unknownSymbol = symbol.expectProperty(SymbolProperties.UNION_UNKNOWN);
writer.putContext("schema", schemaSymbol);
writer.write("""
class $1L:
Expand All @@ -168,6 +167,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
${4C|}
case _:
logger.debug("Unexpected member schema: %s", schema)
self._set_result($5L(tag=schema.expect_member_name()))

def _set_result(self, value: $2T) -> None:
if self._result is not None:
Expand All @@ -177,7 +177,8 @@ raise SerializationError("Unions must have exactly one value, but found more tha
deserializerSymbol.getName(),
symbol,
schemaSymbol,
writer.consumer(w -> deserializeMembers()));
writer.consumer(w -> deserializeMembers()),
unknownSymbol.getName());
}

private void deserializeMembers() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def read_struct(
message_deserializer = self._create_deserializer(schema, headers)
message_deserializer.read_struct(schema, consumer)
else:
member_schema = schema.members[member_name]
member_schema = self._resolve_member_schema(schema, member_name)
message_deserializer = self._create_deserializer(
member_schema, headers
)
Expand All @@ -71,6 +71,21 @@ def read_struct(
case _:
raise EventError(f"Unknown event structure: {self._event}")

def _resolve_member_schema(self, schema: Schema, member_name: str) -> Schema:
if member_schema := schema.members.get(member_name):
return member_schema

logger.debug(
"Received unmodeled event stream member %s for union %s",
member_name,
schema.id,
)
return Schema.member(
id=schema.id.with_member(member_name),
target=schema,
index=-1,
)

def _create_deserializer(
self, schema: Schema, headers: HEADERS_DICT
) -> ShapeDeserializer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ async def receive(self) -> E | None:
)
result = self._deserializer(deserializer)
logger.debug("Successfully deserialized event: %s", result)
if isinstance(getattr(result, "value"), Exception):
raise result.value # type: ignore
value = getattr(result, "value", None)
if isinstance(value, Exception):
raise value
return result

async def close(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def serialize_members(self, serializer: ShapeSerializer):


@dataclass
class EventStreamUnknownEvent:
class EventStreamUnknown:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the name change here?

Suggested change
class EventStreamUnknown:
class EventStreamUnknownEvent:

tag: str

def serialize(self, serializer: ShapeSerializer):
Expand All @@ -396,7 +396,7 @@ def serialize_members(self, serializer: ShapeSerializer):
| EventStreamPayloadEvent
| EventStreamBlobPayloadEvent
| EventStreamErrorEvent
| EventStreamUnknownEvent
| EventStreamUnknown
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
| EventStreamUnknown
| EventStreamUnknownEvent

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified this because the generated code does not have "Event" suffix after "Unknown".

)


Expand Down Expand Up @@ -429,7 +429,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de)))

case _:
raise SmithyError(f"Unexpected member schema: {schema}")
self._set_result(EventStreamUnknown(tag=schema.expect_member_name()))

def _set_result(self, value: EventStream) -> None:
if self._result is not None:
Expand Down Expand Up @@ -635,6 +635,19 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
]


UNKNOWN_EVENT_CASE = (
EventStreamUnknown(tag="unmodeledEvent"),
EventMessage(
headers={
":message-type": "event",
":event-type": "unmodeledEvent",
":content-type": "application/json",
},
payload=b"{}",
),
)


INITIAL_REQUEST_CASE = (
EventStreamOperationInputOutput(message="The initial request!"),
EventMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
EventStreamDeserializer,
EventStreamErrorEvent,
EventStreamOperationInputOutput,
EventStreamUnknown,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
EventStreamUnknown,
EventStreamUnknownEvent,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add "Event" at the end because the generated code does not have "Event" suffix after "Unknown".

)


Expand Down Expand Up @@ -126,3 +127,20 @@ async def test_read_closed_receiver_source() -> None:
with pytest.raises(IOError):
await receiver.receive()
assert receiver.closed


def test_deserialize_unknown_event_type():
message = EventMessage(
headers={
":message-type": "event",
":event-type": "unmodeledEvent",
":content-type": "application/json",
},
payload=b"{}",
)
source = Event.decode(BytesIO(message.encode()))
assert source is not None
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
result = EventStreamDeserializer().deserialize(deserializer)
assert isinstance(result, EventStreamUnknown)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert isinstance(result, EventStreamUnknown)
assert isinstance(result, EventStreamUnknownEvent)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add "Event" at the end because the generated code does not have "Event" suffix after "Unknown".

assert result.tag == "unmodeledEvent"
Loading