Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 45 additions & 43 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,54 +953,52 @@
return get_field_info_and_type_annotation(inner_type, value, False, True)


def _has_discriminator(field_info: FieldInfo) -> bool:
"""Check if a FieldInfo has a discriminator."""
return hasattr(field_info, "discriminator") and field_info.discriminator is not None


def _handle_discriminator_with_param(
def _split_location_marker_and_field(
annotations: list[FieldInfo],
annotation: Any,
) -> tuple[FieldInfo | None, Any, bool]:
) -> tuple[FieldInfo, FieldInfo]:
"""
Handle the special case of Field(discriminator) + Body() combination.
Split a pair of FieldInfo into the Powertools location marker and the plain Pydantic Field.

A parameter like ``Annotated[str, Field(gt=0), Query()]`` flattens to two FieldInfo: the
location marker (``Path``/``Query``/``Header``/``Body``, a ``Param`` or ``Body`` subclass) and
a plain Pydantic ``Field`` carrying constraints. We keep them apart so the location marker
drives where the value comes from, while the ``Field``'s constraints/metadata still apply.

Returns:
tuple of (powertools_annotation, type_annotation, has_discriminator_with_body)
tuple of (location_marker, plain_field)
"""
field_obj = None
body_obj = None
location_marker: FieldInfo | None = None
plain_field: FieldInfo | None = None

for ann in annotations:
if isinstance(ann, Body):
body_obj = ann
elif _has_discriminator(ann):
field_obj = ann
if isinstance(ann, (Param, Body)):
location_marker = ann
else:
plain_field = ann

if field_obj and body_obj:
# Use Body as the primary annotation, preserve full annotation for validation
return body_obj, annotation, True
if location_marker is None or plain_field is None:
raise AssertionError("Only one FieldInfo can be used per parameter")

raise AssertionError("Only one FieldInfo can be used per parameter")
return location_marker, plain_field


def _create_field_info(
powertools_annotation: FieldInfo,
type_annotation: Any,
has_discriminator_with_body: bool,
preserve_full_annotation: bool,
) -> FieldInfo:
"""Create or copy FieldInfo based on the annotation type."""
field_info: FieldInfo
if has_discriminator_with_body:
# For discriminator + Body case, create a new Body instance directly
field_info = Body()
# Copy field_info because we mutate field_info.default later.
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=type_annotation,
)
if preserve_full_annotation:
# The location marker is paired with a plain Field (constraints or discriminator).
# copy_field_info strips the Pydantic Field out of the metadata, so re-attach the full
# Annotated type here, that's what makes the Field's metadata flow into the TypeAdapter
# that ModelField builds for validation.
field_info.annotation = type_annotation
else:
# Copy field_info because we mutate field_info.default later
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=type_annotation,
)
return field_info


Expand All @@ -1017,7 +1015,7 @@
field_info.default = Required


def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tuple[FieldInfo | None, Any]:

Check failure on line 1018 in aws_lambda_powertools/event_handler/openapi/params.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 17 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=aws-powertools_powertools-lambda-python&issues=AZ8SspNVG0HmfgRSavIe&open=AZ8SspNVG0HmfgRSavIe&pullRequest=8305
"""
Get the FieldInfo and type annotation from an Annotated type.
"""
Expand All @@ -1043,32 +1041,36 @@

# Determine which annotation to use
powertools_annotation: FieldInfo | None = None
has_discriminator_with_param = False
# When a plain Pydantic Field is paired with a location marker (e.g. Field(gt=0) + Query),
# we keep the full Annotated type as the annotation so the Field's metadata still validates.
preserve_full_annotation = False

if len(powertools_annotations) == 2:
powertools_annotation, type_annotation, has_discriminator_with_param = _handle_discriminator_with_param(
powertools_annotations,
annotation,
)
# A location marker (Path/Query/Header/Body) plus a plain Pydantic Field carrying
# constraints, a discriminator, or any other Field setting. The marker says where the
# value comes from; the Field says how to validate it.
powertools_annotation, _ = _split_location_marker_and_field(powertools_annotations)
preserve_full_annotation = True
elif len(powertools_annotations) > 1:
raise AssertionError("Only one FieldInfo can be used per parameter")
else:
powertools_annotation = next(iter(powertools_annotations), None)

# Reconstruct type_annotation with non-FieldInfo metadata if present
# This ensures constraints like Interval are preserved
if other_metadata and not has_discriminator_with_param:
if other_metadata and not preserve_full_annotation:
type_annotation = Annotated[(type_annotation, *other_metadata)]

# Process the annotation if it exists
field_info: FieldInfo | None = None
if isinstance(powertools_annotation, FieldInfo): # pragma: no cover
field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_param)
if isinstance(powertools_annotation, FieldInfo):
# When pairing a location marker with a plain Field, hand the full Annotated[...] to the
# field so the Field's constraints/discriminator flow into the validating TypeAdapter.
field_annotation = annotation if preserve_full_annotation else type_annotation
field_info = _create_field_info(powertools_annotation, field_annotation, preserve_full_annotation)
_set_field_default(field_info, value, is_path_param)

# Preserve full annotated type for discriminated unions
if _has_discriminator(powertools_annotation): # pragma: no cover
type_annotation = annotation # pragma: no cover
if preserve_full_annotation:
type_annotation = annotation

return field_info, type_annotation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
VPCLatticeV2Resolver,
)
from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError
from aws_lambda_powertools.event_handler.openapi.params import Body, Form, Header, Query
from aws_lambda_powertools.event_handler.openapi.params import Body, Form, Header, Path, Query
from tests.functional.utils import load_event


Expand Down Expand Up @@ -2654,6 +2654,77 @@ def create_action(action: Annotated[action_type, Body()]):
assert result["statusCode"] == 422


def test_field_annotation_with_all_param_types(gw_event):
"""A reusable Annotated type carrying a Pydantic Field works with every parameter location."""
app = APIGatewayRestResolver(enable_validation=True)

# Reusable annotated type, the same kind you'd use inside a model
str_field = Annotated[str, Field()]

@app.get("/header")
def get_header(h: Annotated[str_field, Header()]):
return {"value": h}

@app.get("/path/<p>")
def get_path(p: Annotated[str_field, Path()]):
return {"value": p}

@app.get("/query")
def get_query(q: Annotated[str_field, Query()]):
return {"value": q}

@app.post("/body")
def post_body(b: Annotated[str_field, Body()]):
return {"value": b}

del gw_event["multiValueHeaders"]
del gw_event["multiValueQueryStringParameters"]

# Header
gw_event["path"] = "/header"
gw_event["httpMethod"] = "GET"
gw_event["headers"] = {"h": "test"}
assert app(gw_event, {})["statusCode"] == 200

# Path
gw_event["path"] = "/path/test"
gw_event["pathParameters"] = {"p": "test"}
assert app(gw_event, {})["statusCode"] == 200

# Query
gw_event["path"] = "/query"
gw_event["pathParameters"] = None
gw_event["queryStringParameters"] = {"q": "test"}
assert app(gw_event, {})["statusCode"] == 200

# Body
gw_event["path"] = "/body"
gw_event["httpMethod"] = "POST"
gw_event["headers"]["content-type"] = "application/json"
gw_event["body"] = '"test"'
assert app(gw_event, {})["statusCode"] == 200


def test_field_constraints_apply_with_param_type(gw_event):
"""Constraints declared on a Field are enforced when paired with a location marker."""
app = APIGatewayRestResolver(enable_validation=True)

@app.get("/items")
def get_items(quantity: Annotated[int, Field(gt=0), Query()]):
return {"quantity": quantity}

gw_event["path"] = "/items"
gw_event["httpMethod"] = "GET"

# Passes the gt=0 constraint
gw_event["queryStringParameters"] = {"quantity": "5"}
assert app(gw_event, {})["statusCode"] == 200

# Violates gt=0
gw_event["queryStringParameters"] = {"quantity": "-1"}
assert app(gw_event, {})["statusCode"] == 422


def test_validate_pydantic_query_params_with_config_dict_and_validators(gw_event):
"""Test that Pydantic models with ConfigDict, aliases, and validators work correctly"""

Expand Down