diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 0a0fc3c8075..dbfc0a6f9d7 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -3,10 +3,11 @@ import dataclasses import json import logging -from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, Union, cast from urllib.parse import parse_qs from pydantic import BaseModel +from typing_extensions import get_args, get_origin from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler from aws_lambda_powertools.event_handler.openapi.compat import ( @@ -25,6 +26,7 @@ ResponseValidationError, ) from aws_lambda_powertools.event_handler.openapi.params import Param +from aws_lambda_powertools.event_handler.openapi.types import UnionType if TYPE_CHECKING: from pydantic.fields import FieldInfo @@ -431,9 +433,41 @@ def _handle_missing_field_value( values[field.name] = field.get_default() +def _is_or_contains_sequence(annotation: Any) -> bool: + """ + Check if annotation is a sequence or Union/RootModel containing a sequence. + + This function handles complex type annotations like: + - List[Model] - direct sequence + - Union[Model, List[Model]] - checks if any Union member is a sequence + - Optional[List[Model]] - Union[List[Model], None] + - RootModel[List[Model]] - checks if the RootModel wraps a sequence + - Optional[RootModel[List[Model]]] - Union member that is a RootModel + - RootModel[Union[Model, List[Model]]] - RootModel wrapping a Union with a sequence + """ + # Direct sequence check + if field_annotation_is_sequence(annotation): + return True + + # Check Union members — recurse so we catch RootModel inside Union + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + for arg in get_args(annotation): + if _is_or_contains_sequence(arg): + return True + + # Check if it's a RootModel wrapping a sequence (or Union containing a sequence) + if lenient_issubclass(annotation, BaseModel) and getattr(annotation, "__pydantic_root_model__", False): + if hasattr(annotation, "model_fields") and "root" in annotation.model_fields: + root_annotation = annotation.model_fields["root"].annotation + return _is_or_contains_sequence(root_annotation) + + return False + + def _normalize_field_value(value: Any, field_info: FieldInfo) -> Any: """Normalize field value, converting lists to single values for non-sequence fields.""" - if field_annotation_is_sequence(field_info.annotation): + if _is_or_contains_sequence(field_info.annotation): return value elif isinstance(value, list) and value: return value[0] diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 93baef283ba..21bc9b26e0a 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -7,7 +7,16 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union import pytest -from pydantic import AfterValidator, Base64UrlStr, BaseModel, ConfigDict, Field, StringConstraints, alias_generators +from pydantic import ( + AfterValidator, + Base64UrlStr, + BaseModel, + ConfigDict, + Field, + RootModel, + StringConstraints, + alias_generators, +) from typing_extensions import Annotated from aws_lambda_powertools.event_handler import ( @@ -2833,3 +2842,488 @@ def handler(query_dt: datetime.datetime): # THEN validation should fail because the encoded string is not a valid datetime result = app(raw_event, {}) assert result["statusCode"] == 422 + + +def test_validate_union_single_or_list_body_with_list(gw_event): + """Test that Union[Model, List[Model]] correctly handles a list of items""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Item(BaseModel): + name: str + value: int + + # WHEN a handler is defined with Union[Model, List[Model]] body parameter + @app.post("/items") + def handler(items: Annotated[Union[Item, List[Item]], Body()]) -> Dict[str, Any]: + # Should receive the full list, not just the first element + if isinstance(items, list): + return {"count": len(items), "items": [item.model_dump() for item in items]} + else: + return {"count": 1, "items": [items.model_dump()]} + + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/items" + # Send a list of items + gw_event["body"] = json.dumps( + [ + {"name": "item1", "value": 10}, + {"name": "item2", "value": 20}, + {"name": "item3", "value": 30}, + ], + ) + + # THEN the handler should receive all items in the list, not just the first one + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["count"] == 3 + assert len(body["items"]) == 3 + assert body["items"][0]["name"] == "item1" + assert body["items"][1]["name"] == "item2" + assert body["items"][2]["name"] == "item3" + + +def test_validate_union_single_or_list_body_with_single(gw_event): + """Test that Union[Model, List[Model]] correctly handles a single item""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Item(BaseModel): + name: str + value: int + + # WHEN a handler is defined with Union[Model, List[Model]] body parameter + @app.post("/items") + def handler(items: Annotated[Union[Item, List[Item]], Body()]) -> Dict[str, Any]: + if isinstance(items, list): + return {"count": len(items), "items": [item.model_dump() for item in items]} + else: + return {"count": 1, "items": [items.model_dump()]} + + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/items" + # Send a single item + gw_event["body"] = json.dumps({"name": "single_item", "value": 42}) + + # THEN the handler should receive the single item + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["count"] == 1 + assert len(body["items"]) == 1 + assert body["items"][0]["name"] == "single_item" + assert body["items"][0]["value"] == 42 + + +def test_validate_rootmodel_list_body(gw_event): + """Test that RootModel[List[Model]] correctly handles a list of items""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Item(BaseModel): + name: str + value: int + + class ItemCollection(RootModel[List[Item]]): + root: List[Item] + + # WHEN a handler is defined with RootModel[List[Model]] body parameter + @app.post("/items") + def handler(collection: Annotated[ItemCollection, Body()]) -> Dict[str, Any]: + # collection.root should contain the full list + items = collection.root + return {"count": len(items), "items": [item.model_dump() for item in items]} + + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/items" + # Send a list of items + gw_event["body"] = json.dumps( + [ + {"name": "item1", "value": 100}, + {"name": "item2", "value": 200}, + ], + ) + + # THEN the handler should receive all items in the collection + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["count"] == 2 + assert len(body["items"]) == 2 + assert body["items"][0]["name"] == "item1" + assert body["items"][0]["value"] == 100 + assert body["items"][1]["name"] == "item2" + assert body["items"][1]["value"] == 200 + + +def test_validate_nested_union_with_sequence(gw_event): + """Test that nested Union types containing sequences are handled correctly""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Person(BaseModel): + name: str + age: int + + # WHEN a handler is defined with a complex Union including List + @app.post("/people") + def handler( + data: Annotated[Union[str, List[Person], Person], Body()], + ) -> Dict[str, Any]: + if isinstance(data, str): + return {"type": "string", "value": data} + elif isinstance(data, list): + return {"type": "list", "count": len(data)} + else: + return {"type": "person", "name": data.name} + + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/people" + # Send a list + gw_event["body"] = json.dumps( + [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + ], + ) + + # THEN the handler should receive the full list + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["type"] == "list" + assert body["count"] == 2 + + +# ──────────────────────────────────────────────────────────────────── +# Regression tests for Union / RootModel / Optional sequence body +# See: https://github.com/aws-powertools/powertools-lambda-python/issues/8057 +# ──────────────────────────────────────────────────────────────────── + + +class _Item(BaseModel): + name: str + value: int + + +class _ItemCollection(RootModel[List[_Item]]): + pass + + +_THREE_ITEMS = [ + {"name": "a", "value": 1}, + {"name": "b", "value": 2}, + {"name": "c", "value": 3}, +] + + +def _post_json(app, path, payload): + """Helper: build a minimal APIGW REST event, POST JSON, return parsed result.""" + from tests.functional.utils import load_event + + event = load_event("apiGatewayProxyEvent.json") + event["httpMethod"] = "POST" + event["path"] = path + event["body"] = json.dumps(payload) + result = app(event, {}) + return result["statusCode"], json.loads(result["body"]) + + +# ---------- Optional[List[Model]] ---------- + + +def test_optional_list_body_with_list(): + """Optional[List[Model]] must preserve the full list.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Optional[List[_Item]], Body()]) -> Dict[str, Any]: + assert isinstance(items, list) + return {"count": len(items)} + + status, body = _post_json(app, "/items", _THREE_ITEMS) + assert status == 200 + assert body["count"] == 3 + + +def test_optional_list_body_with_none(): + """Optional[List[Model]] must accept a null body gracefully.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Optional[List[_Item]], Body()] = None) -> Dict[str, Any]: + return {"received_none": items is None} + + status, body = _post_json(app, "/items", None) + assert status == 200 + assert body["received_none"] is True + + +# ---------- Optional[Union[Model, List[Model]]] ---------- + + +def test_optional_union_model_or_list_with_list(): + """Optional[Union[Model, List[Model]]] — send list, get full list.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Optional[Union[_Item, List[_Item]]], Body()]) -> Dict[str, Any]: + assert isinstance(items, list) + return {"count": len(items)} + + status, body = _post_json(app, "/items", _THREE_ITEMS) + assert status == 200 + assert body["count"] == 3 + + +def test_optional_union_model_or_list_with_single(): + """Optional[Union[Model, List[Model]]] — send single obj, get single obj.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Optional[Union[_Item, List[_Item]]], Body()]) -> Dict[str, Any]: + assert not isinstance(items, list) + return {"name": items.name} + + status, body = _post_json(app, "/items", {"name": "solo", "value": 99}) + assert status == 200 + assert body["name"] == "solo" + + +def test_optional_union_model_or_list_with_none(): + """Optional[Union[Model, List[Model]]] — send null, get None.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Optional[Union[_Item, List[_Item]]], Body()] = None) -> Dict[str, Any]: + return {"is_none": items is None} + + status, body = _post_json(app, "/items", None) + assert status == 200 + assert body["is_none"] is True + + +# ---------- List[Model] directly (no Union / Optional) ---------- + + +def test_plain_list_body_preserves_all_items(): + """List[Model] — baseline: must never truncate.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[List[_Item], Body()]) -> Dict[str, Any]: + return {"count": len(items)} + + status, body = _post_json(app, "/items", _THREE_ITEMS) + assert status == 200 + assert body["count"] == 3 + + +# ---------- Empty list ---------- + + +def test_union_model_or_list_with_empty_list(): + """Union[Model, List[Model]] with [] — must not crash on value[0].""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Union[_Item, List[_Item]], Body()]) -> Dict[str, Any]: + if isinstance(items, list): + return {"count": len(items)} + return {"count": 1} + + status, body = _post_json(app, "/items", []) + assert status == 200 + assert body["count"] == 0 + + +def test_plain_list_with_empty_list(): + """List[Model] with [] — must accept empty list.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[List[_Item], Body()]) -> Dict[str, Any]: + return {"count": len(items)} + + status, body = _post_json(app, "/items", []) + assert status == 200 + assert body["count"] == 0 + + +# ---------- Single-element list (boundary) ---------- + + +def test_union_model_or_list_with_single_element_list(): + """Union[Model, List[Model]] with [single_item] — must NOT unwrap to scalar.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Union[_Item, List[_Item]], Body()]) -> Dict[str, Any]: + if isinstance(items, list): + return {"type": "list", "count": len(items)} + return {"type": "single"} + + status, body = _post_json(app, "/items", [{"name": "only", "value": 1}]) + assert status == 200 + # Pydantic may match as single Item or list — either is valid, + # but it must NOT crash or lose data + assert body.get("count", 1) == 1 + + +# ---------- Union with primitive sequences ---------- + + +def test_union_str_or_list_dict(): + """Union[str, List[dict]] — list of dicts must arrive intact.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/data") + def handler(data: Annotated[Union[str, List[Dict[str, Any]]], Body()]) -> Dict[str, Any]: + if isinstance(data, list): + return {"type": "list", "count": len(data)} + return {"type": "str"} + + payload = [{"key": "v1"}, {"key": "v2"}] + status, body = _post_json(app, "/data", payload) + assert status == 200 + assert body["type"] == "list" + assert body["count"] == 2 + + +# ---------- RootModel edge cases ---------- + + +def test_optional_rootmodel_list_body(): + """Optional[RootModel[List[Model]]] — list must not be truncated.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Optional[_ItemCollection], Body()]) -> Dict[str, Any]: + return {"count": len(items.root)} + + status, body = _post_json(app, "/items", _THREE_ITEMS) + assert status == 200 + assert body["count"] == 3 + + +def test_union_rootmodel_and_model(): + """Union[RootModel[List[Model]], Model] — list must not be truncated.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Union[_ItemCollection, _Item], Body()]) -> Dict[str, Any]: + if isinstance(items, _ItemCollection): + return {"type": "collection", "count": len(items.root)} + return {"type": "single", "name": items.name} + + status, body = _post_json(app, "/items", _THREE_ITEMS) + assert status == 200 + assert body["type"] == "collection" + assert body["count"] == 3 + + +# ---------- Python 3.10+ pipe Union syntax ---------- + + +def test_pipe_union_syntax_model_or_list(): + """Model | List[Model] (PEP 604 syntax) — list must not be truncated.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[_Item | List[_Item], Body()]) -> Dict[str, Any]: # noqa: FA102 + if isinstance(items, list): + return {"count": len(items)} + return {"count": 1} + + status, body = _post_json(app, "/items", _THREE_ITEMS) + assert status == 200 + assert body["count"] == 3 + + +def test_pipe_union_optional_list(): + """List[Model] | None (PEP 604 Optional) — list must not be truncated.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[List[_Item] | None, Body()]) -> Dict[str, Any]: # noqa: FA102 + if items is None: + return {"count": 0} + return {"count": len(items)} + + status, body = _post_json(app, "/items", _THREE_ITEMS) + assert status == 200 + assert body["count"] == 3 + + +# ---------- Deeply nested: RootModel[Union[Model, List[Model]]] ---------- + + +def test_rootmodel_wrapping_union_with_sequence(): + """RootModel[Union[Model, List[Model]]] — inner Union sequence must be detected.""" + app = APIGatewayRestResolver(enable_validation=True) + + class FlexiblePayload(RootModel[Union[_Item, List[_Item]]]): + pass + + @app.post("/items") + def handler(payload: Annotated[FlexiblePayload, Body()]) -> Dict[str, Any]: + data = payload.root + if isinstance(data, list): + return {"type": "list", "count": len(data)} + return {"type": "single", "name": data.name} + + status, body = _post_json(app, "/items", _THREE_ITEMS) + assert status == 200 + assert body["type"] == "list" + assert body["count"] == 3 + + +# ---------- Multiple resolvers (ALB, HTTP API, etc.) ---------- + + +def test_union_list_body_works_across_resolvers(): + """Regression: ensure fix works for ALB and HTTP API resolvers too.""" + for ResolverClass in [APIGatewayHttpResolver, ALBResolver]: + app = ResolverClass(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Union[_Item, List[_Item]], Body()]) -> Dict[str, Any]: + if isinstance(items, list): + return {"count": len(items)} + return {"count": 1} + + # Build event appropriate for resolver + if ResolverClass is APIGatewayHttpResolver: + event = load_event("apiGatewayProxyV2Event.json") + event["requestContext"]["http"]["method"] = "POST" + event["requestContext"]["http"]["path"] = "/items" + event["rawPath"] = "/items" + else: + event = load_event("albEvent.json") + event["httpMethod"] = "POST" + event["path"] = "/items" + + event["body"] = json.dumps(_THREE_ITEMS) + result = app(event, {}) + assert result["statusCode"] == 200 + body_result = json.loads(result["body"]) + assert body_result["count"] == 3, f"Failed for {ResolverClass.__name__}" + + +# ---------- Large list (stress boundary) ---------- + + +def test_union_list_body_large_payload(): + """Union[Model, List[Model]] with 100 items — no truncation.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/items") + def handler(items: Annotated[Union[_Item, List[_Item]], Body()]) -> Dict[str, Any]: + assert isinstance(items, list) + return {"count": len(items)} + + big_payload = [{"name": f"item-{i}", "value": i} for i in range(100)] + status, body = _post_json(app, "/items", big_payload) + assert status == 200 + assert body["count"] == 100