From 95675496d8a34e5e8047ad4b52c7c0e56b3addd5 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 3 Mar 2026 18:23:46 +0000 Subject: [PATCH] fix: add middleware validation per route --- .../event_handler/api_gateway.py | 13 +++++++ .../_pydantic/test_per_route_validation.py | 39 +++++++++++++++---- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 02682f01c59..b1e0c9ff16d 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -546,6 +546,19 @@ def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]], self.enable_validation if self.enable_validation is not None else app._enable_validation ) + # If route needs validation but resolver didn't create the middlewares, create them now + if route_validation_enabled and not hasattr(app, "_request_validation_middleware"): + from aws_lambda_powertools.event_handler.middlewares.openapi_validation import ( + OpenAPIRequestValidationMiddleware, + OpenAPIResponseValidationMiddleware, + ) + + app._request_validation_middleware = OpenAPIRequestValidationMiddleware() + app._response_validation_middleware = OpenAPIResponseValidationMiddleware( + validation_serializer=app._serializer, + has_response_validation_error=app._has_response_validation_error, + ) + # Add request validation middleware first if validation is enabled if route_validation_enabled and hasattr(app, "_request_validation_middleware"): all_middlewares.append(app._request_validation_middleware) diff --git a/tests/functional/event_handler/_pydantic/test_per_route_validation.py b/tests/functional/event_handler/_pydantic/test_per_route_validation.py index bd5c33ae0b3..f6742b960ee 100644 --- a/tests/functional/event_handler/_pydantic/test_per_route_validation.py +++ b/tests/functional/event_handler/_pydantic/test_per_route_validation.py @@ -240,8 +240,8 @@ def invalid_response() -> TodoItem: def test_per_route_validation_with_pydantic_v2(): - """Test that per-route validation works correctly with Pydantic v2 models""" - # GIVEN APIGatewayRestResolver with mixed validation + """Test that per-route validation actually validates when resolver has validation disabled""" + # GIVEN APIGatewayRestResolver WITHOUT global validation app = APIGatewayRestResolver() class Task(BaseModel): @@ -250,7 +250,8 @@ class Task(BaseModel): @app.get("/task", enable_validation=True) def get_task() -> Task: - return Task(title="Important", priority=1) + # Return invalid data — missing 'title' and 'priority' + return cast(Task, {"wrong": "data"}) @app.get("/unvalidated-task") def get_unvalidated_task(): @@ -259,13 +260,12 @@ def get_unvalidated_task(): event = load_event("apiGatewayProxyEvent.json") event["httpMethod"] = "GET" - # WHEN calling validated route + # WHEN calling validated route with invalid data event["path"] = "/task" result = app(event, {}) - # THEN should validate and serialize correctly - assert result["statusCode"] == 200 - assert "Important" in result["body"] + # THEN validation must reject it with 422 + assert result["statusCode"] == 422 # WHEN calling unvalidated route event["path"] = "/unvalidated-task" @@ -274,3 +274,28 @@ def get_unvalidated_task(): # THEN should return as-is without validation assert result["statusCode"] == 200 assert "extra" in result["body"] + + +def test_per_route_opt_in_validation_with_valid_data(): + """Test that per-route opt-in validation passes valid data and serializes correctly""" + # GIVEN APIGatewayRestResolver WITHOUT global validation + app = APIGatewayRestResolver() + + class Task(BaseModel): + title: str + priority: int + + @app.get("/task", enable_validation=True) + def get_task() -> Task: + return Task(title="Important", priority=1) + + event = load_event("apiGatewayProxyEvent.json") + event["httpMethod"] = "GET" + event["path"] = "/task" + + # WHEN calling validated route with valid data + result = app(event, {}) + + # THEN validation passes and response is serialized + assert result["statusCode"] == 200 + assert "Important" in result["body"]