diff --git a/aws_lambda_powertools/event_handler/http_resolver.py b/aws_lambda_powertools/event_handler/http_resolver.py index 0be443bd200..93e2fdc932e 100644 --- a/aws_lambda_powertools/event_handler/http_resolver.py +++ b/aws_lambda_powertools/event_handler/http_resolver.py @@ -3,6 +3,7 @@ import asyncio import base64 import inspect +import threading import warnings from typing import TYPE_CHECKING, Any, Callable from urllib.parse import parse_qs @@ -324,36 +325,65 @@ async def final_handler(app): return await next_handler(self) def _wrap_middleware_async(self, middleware: Callable, next_handler: Callable) -> Callable: - """Wrap a middleware to work in async context.""" + """Wrap a middleware to work in async context. + + For sync middlewares, we split execution into pre/post phases around the + call to next(). The sync middleware runs its pre-processing (e.g. request + validation), then we intercept the next() call, await the async handler, + and resume the middleware with the real response so post-processing + (e.g. response validation) sees the actual data. + """ async def wrapped(app): - # Create a next_middleware that the sync middleware can call - def sync_next(app): - # This will be called by sync middleware - # We need to run the async next_handler - loop = asyncio.get_event_loop() - if loop.is_running(): - # We're in an async context, create a task - future = asyncio.ensure_future(next_handler(app)) - # Store for later await - app.context["_async_next_result"] = future - return Response(status_code=200, body="") # Placeholder - else: # pragma: no cover - return loop.run_until_complete(next_handler(app)) - - # Check if middleware is async if inspect.iscoroutinefunction(middleware): - result = await middleware(app, next_handler) - else: - # Sync middleware - need special handling - result = middleware(app, sync_next) + return await middleware(app, next_handler) - # Check if we stored an async result - if "_async_next_result" in app.context: - future = app.context.pop("_async_next_result") - result = await future + # We use an Event to coordinate: the sync middleware runs in a thread, + # calls sync_next which signals us to resolve the async handler, + # then waits for the real response. + middleware_called_next = asyncio.Event() + next_app_holder: list = [] + real_response_holder: list = [] + middleware_result_holder: list = [] + middleware_error_holder: list = [] - return result + def sync_next(app): + next_app_holder.append(app) + middleware_called_next.set() + # Block this thread until the real response is available + event = threading.Event() + next_app_holder.append(event) + event.wait() + return real_response_holder[0] + + def run_middleware(): + try: + result = middleware(app, sync_next) + middleware_result_holder.append(result) + except Exception as e: + middleware_error_holder.append(e) + + thread = threading.Thread(target=run_middleware, daemon=True) + thread.start() + + # Wait for the middleware to call next() + await middleware_called_next.wait() + + # Now resolve the async next_handler + real_response = await next_handler(next_app_holder[0]) + real_response_holder.append(real_response) + + # Signal the thread that the response is ready + threading_event = next_app_holder[1] + threading_event.set() + + # Wait for the middleware thread to finish + thread.join() + + if middleware_error_holder: + raise middleware_error_holder[0] + + return middleware_result_holder[0] return wrapped diff --git a/tests/functional/event_handler/_pydantic/test_http_resolver_pydantic.py b/tests/functional/event_handler/_pydantic/test_http_resolver_pydantic.py index d31185f3239..7cab58a1b70 100644 --- a/tests/functional/event_handler/_pydantic/test_http_resolver_pydantic.py +++ b/tests/functional/event_handler/_pydantic/test_http_resolver_pydantic.py @@ -209,7 +209,6 @@ def search( # ============================================================================= -@pytest.mark.skip("Due to issue #7981.") @pytest.mark.asyncio async def test_async_handler_with_validation(): # GIVEN an app with async handler and validation @@ -241,6 +240,91 @@ async def create_user(user: UserModel) -> UserResponse: assert body["user"]["name"] == "AsyncUser" +@pytest.mark.asyncio +async def test_async_handler_invalid_response_returns_422(): + # GIVEN an app with async handler and validation + app = HttpResolverLocal(enable_validation=True) + + @app.get("/user") + async def get_user() -> UserResponse: + await asyncio.sleep(0.001) + return {"name": "John"} # type: ignore # Missing required fields + + scope = { + "type": "http", + "method": "GET", + "path": "/user", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + } + + receive = make_asgi_receive() + send, captured = make_asgi_send() + + # WHEN called via ASGI interface + await app(scope, receive, send) + + # THEN it returns 422 for invalid response + assert captured["status_code"] == 422 + + +@pytest.mark.asyncio +async def test_sync_handler_with_validation_via_asgi(): + # GIVEN an app with a sync handler and validation, called via ASGI + app = HttpResolverLocal(enable_validation=True) + + @app.post("/users") + def create_user(user: UserModel) -> UserResponse: + return UserResponse(id="sync-123", user=user) + + scope = { + "type": "http", + "method": "POST", + "path": "/users", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + } + + receive = make_asgi_receive(b'{"name": "SyncUser", "age": 30}') + send, captured = make_asgi_send() + + # WHEN called via ASGI interface + await app(scope, receive, send) + + # THEN validation works with sync handler + assert captured["status_code"] == 200 + body = json.loads(captured["body"]) + assert body["id"] == "sync-123" + assert body["user"]["name"] == "SyncUser" + + +@pytest.mark.asyncio +async def test_sync_handler_invalid_response_returns_422_via_asgi(): + # GIVEN an app with a sync handler and validation, called via ASGI + app = HttpResolverLocal(enable_validation=True) + + @app.get("/user") + def get_user() -> UserResponse: + return {"name": "John"} # type: ignore # Missing required fields + + scope = { + "type": "http", + "method": "GET", + "path": "/user", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + } + + receive = make_asgi_receive() + send, captured = make_asgi_send() + + # WHEN called via ASGI interface + await app(scope, receive, send) + + # THEN it returns 422 for invalid response + assert captured["status_code"] == 422 + + # ============================================================================= # OpenAPI Tests # =============================================================================