Skip to content

Commit 68efba2

Browse files
feat(event_handler): support any Pydantic Field annotation in parameter validation
1 parent 50541fe commit 68efba2

2 files changed

Lines changed: 117 additions & 44 deletions

File tree

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -953,54 +953,52 @@ def get_field_info_response_type(annotation, value) -> tuple[FieldInfo | None, A
953953
return get_field_info_and_type_annotation(inner_type, value, False, True)
954954

955955

956-
def _has_discriminator(field_info: FieldInfo) -> bool:
957-
"""Check if a FieldInfo has a discriminator."""
958-
return hasattr(field_info, "discriminator") and field_info.discriminator is not None
959-
960-
961-
def _handle_discriminator_with_param(
956+
def _split_location_marker_and_field(
962957
annotations: list[FieldInfo],
963-
annotation: Any,
964-
) -> tuple[FieldInfo | None, Any, bool]:
958+
) -> tuple[FieldInfo, FieldInfo]:
965959
"""
966-
Handle the special case of Field(discriminator) + Body() combination.
960+
Split a pair of FieldInfo into the Powertools location marker and the plain Pydantic Field.
961+
962+
A parameter like ``Annotated[str, Field(gt=0), Query()]`` flattens to two FieldInfo: the
963+
location marker (``Path``/``Query``/``Header``/``Body``, a ``Param`` or ``Body`` subclass) and
964+
a plain Pydantic ``Field`` carrying constraints. We keep them apart so the location marker
965+
drives where the value comes from, while the ``Field``'s constraints/metadata still apply.
967966
968967
Returns:
969-
tuple of (powertools_annotation, type_annotation, has_discriminator_with_body)
968+
tuple of (location_marker, plain_field)
970969
"""
971-
field_obj = None
972-
body_obj = None
970+
location_marker: FieldInfo | None = None
971+
plain_field: FieldInfo | None = None
973972

974973
for ann in annotations:
975-
if isinstance(ann, Body):
976-
body_obj = ann
977-
elif _has_discriminator(ann):
978-
field_obj = ann
974+
if isinstance(ann, (Param, Body)):
975+
location_marker = ann
976+
else:
977+
plain_field = ann
979978

980-
if field_obj and body_obj:
981-
# Use Body as the primary annotation, preserve full annotation for validation
982-
return body_obj, annotation, True
979+
if location_marker is None or plain_field is None:
980+
raise AssertionError("Only one FieldInfo can be used per parameter")
983981

984-
raise AssertionError("Only one FieldInfo can be used per parameter")
982+
return location_marker, plain_field
985983

986984

987985
def _create_field_info(
988986
powertools_annotation: FieldInfo,
989987
type_annotation: Any,
990-
has_discriminator_with_body: bool,
988+
preserve_full_annotation: bool,
991989
) -> FieldInfo:
992990
"""Create or copy FieldInfo based on the annotation type."""
993-
field_info: FieldInfo
994-
if has_discriminator_with_body:
995-
# For discriminator + Body case, create a new Body instance directly
996-
field_info = Body()
991+
# Copy field_info because we mutate field_info.default later.
992+
field_info = copy_field_info(
993+
field_info=powertools_annotation,
994+
annotation=type_annotation,
995+
)
996+
if preserve_full_annotation:
997+
# The location marker is paired with a plain Field (constraints or discriminator).
998+
# copy_field_info strips the Pydantic Field out of the metadata, so re-attach the full
999+
# Annotated type here, that's what makes the Field's metadata flow into the TypeAdapter
1000+
# that ModelField builds for validation.
9971001
field_info.annotation = type_annotation
998-
else:
999-
# Copy field_info because we mutate field_info.default later
1000-
field_info = copy_field_info(
1001-
field_info=powertools_annotation,
1002-
annotation=type_annotation,
1003-
)
10041002
return field_info
10051003

10061004

@@ -1043,32 +1041,36 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
10431041

10441042
# Determine which annotation to use
10451043
powertools_annotation: FieldInfo | None = None
1046-
has_discriminator_with_param = False
1044+
# When a plain Pydantic Field is paired with a location marker (e.g. Field(gt=0) + Query),
1045+
# we keep the full Annotated type as the annotation so the Field's metadata still validates.
1046+
preserve_full_annotation = False
10471047

10481048
if len(powertools_annotations) == 2:
1049-
powertools_annotation, type_annotation, has_discriminator_with_param = _handle_discriminator_with_param(
1050-
powertools_annotations,
1051-
annotation,
1052-
)
1049+
# A location marker (Path/Query/Header/Body) plus a plain Pydantic Field carrying
1050+
# constraints, a discriminator, or any other Field setting. The marker says where the
1051+
# value comes from; the Field says how to validate it.
1052+
powertools_annotation, _ = _split_location_marker_and_field(powertools_annotations)
1053+
preserve_full_annotation = True
10531054
elif len(powertools_annotations) > 1:
10541055
raise AssertionError("Only one FieldInfo can be used per parameter")
10551056
else:
10561057
powertools_annotation = next(iter(powertools_annotations), None)
10571058

10581059
# Reconstruct type_annotation with non-FieldInfo metadata if present
10591060
# This ensures constraints like Interval are preserved
1060-
if other_metadata and not has_discriminator_with_param:
1061+
if other_metadata and not preserve_full_annotation:
10611062
type_annotation = Annotated[(type_annotation, *other_metadata)]
10621063

10631064
# Process the annotation if it exists
10641065
field_info: FieldInfo | None = None
1065-
if isinstance(powertools_annotation, FieldInfo): # pragma: no cover
1066-
field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_param)
1066+
if isinstance(powertools_annotation, FieldInfo):
1067+
# When pairing a location marker with a plain Field, hand the full Annotated[...] to the
1068+
# field so the Field's constraints/discriminator flow into the validating TypeAdapter.
1069+
field_annotation = annotation if preserve_full_annotation else type_annotation
1070+
field_info = _create_field_info(powertools_annotation, field_annotation, preserve_full_annotation)
10671071
_set_field_default(field_info, value, is_path_param)
1068-
1069-
# Preserve full annotated type for discriminated unions
1070-
if _has_discriminator(powertools_annotation): # pragma: no cover
1071-
type_annotation = annotation # pragma: no cover
1072+
if preserve_full_annotation:
1073+
type_annotation = annotation
10721074

10731075
return field_info, type_annotation
10741076

tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
VPCLatticeV2Resolver,
3131
)
3232
from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError
33-
from aws_lambda_powertools.event_handler.openapi.params import Body, Form, Header, Query
33+
from aws_lambda_powertools.event_handler.openapi.params import Body, Form, Header, Path, Query
3434
from tests.functional.utils import load_event
3535

3636

@@ -2654,6 +2654,77 @@ def create_action(action: Annotated[action_type, Body()]):
26542654
assert result["statusCode"] == 422
26552655

26562656

2657+
def test_field_annotation_with_all_param_types(gw_event):
2658+
"""A reusable Annotated type carrying a Pydantic Field works with every parameter location."""
2659+
app = APIGatewayRestResolver(enable_validation=True)
2660+
2661+
# Reusable annotated type, the same kind you'd use inside a model
2662+
str_field = Annotated[str, Field()]
2663+
2664+
@app.get("/header")
2665+
def get_header(h: Annotated[str_field, Header()]):
2666+
return {"value": h}
2667+
2668+
@app.get("/path/<p>")
2669+
def get_path(p: Annotated[str_field, Path()]):
2670+
return {"value": p}
2671+
2672+
@app.get("/query")
2673+
def get_query(q: Annotated[str_field, Query()]):
2674+
return {"value": q}
2675+
2676+
@app.post("/body")
2677+
def post_body(b: Annotated[str_field, Body()]):
2678+
return {"value": b}
2679+
2680+
del gw_event["multiValueHeaders"]
2681+
del gw_event["multiValueQueryStringParameters"]
2682+
2683+
# Header
2684+
gw_event["path"] = "/header"
2685+
gw_event["httpMethod"] = "GET"
2686+
gw_event["headers"] = {"h": "test"}
2687+
assert app(gw_event, {})["statusCode"] == 200
2688+
2689+
# Path
2690+
gw_event["path"] = "/path/test"
2691+
gw_event["pathParameters"] = {"p": "test"}
2692+
assert app(gw_event, {})["statusCode"] == 200
2693+
2694+
# Query
2695+
gw_event["path"] = "/query"
2696+
gw_event["pathParameters"] = None
2697+
gw_event["queryStringParameters"] = {"q": "test"}
2698+
assert app(gw_event, {})["statusCode"] == 200
2699+
2700+
# Body
2701+
gw_event["path"] = "/body"
2702+
gw_event["httpMethod"] = "POST"
2703+
gw_event["headers"]["content-type"] = "application/json"
2704+
gw_event["body"] = '"test"'
2705+
assert app(gw_event, {})["statusCode"] == 200
2706+
2707+
2708+
def test_field_constraints_apply_with_param_type(gw_event):
2709+
"""Constraints declared on a Field are enforced when paired with a location marker."""
2710+
app = APIGatewayRestResolver(enable_validation=True)
2711+
2712+
@app.get("/items")
2713+
def get_items(quantity: Annotated[int, Field(gt=0), Query()]):
2714+
return {"quantity": quantity}
2715+
2716+
gw_event["path"] = "/items"
2717+
gw_event["httpMethod"] = "GET"
2718+
2719+
# Passes the gt=0 constraint
2720+
gw_event["queryStringParameters"] = {"quantity": "5"}
2721+
assert app(gw_event, {})["statusCode"] == 200
2722+
2723+
# Violates gt=0
2724+
gw_event["queryStringParameters"] = {"quantity": "-1"}
2725+
assert app(gw_event, {})["statusCode"] == 422
2726+
2727+
26572728
def test_validate_pydantic_query_params_with_config_dict_and_validators(gw_event):
26582729
"""Test that Pydantic models with ConfigDict, aliases, and validators work correctly"""
26592730

0 commit comments

Comments
 (0)