Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,19 @@ def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]],
self.enable_validation if self.enable_validation is not None else app._enable_validation
)

# If route needs validation but resolver didn't create the middlewares, create them now
if route_validation_enabled and not hasattr(app, "_request_validation_middleware"):
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import (
OpenAPIRequestValidationMiddleware,
OpenAPIResponseValidationMiddleware,
)

app._request_validation_middleware = OpenAPIRequestValidationMiddleware()
app._response_validation_middleware = OpenAPIResponseValidationMiddleware(
validation_serializer=app._serializer,
has_response_validation_error=app._has_response_validation_error,
)

# Add request validation middleware first if validation is enabled
if route_validation_enabled and hasattr(app, "_request_validation_middleware"):
all_middlewares.append(app._request_validation_middleware)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def invalid_response() -> TodoItem:


def test_per_route_validation_with_pydantic_v2():
"""Test that per-route validation works correctly with Pydantic v2 models"""
# GIVEN APIGatewayRestResolver with mixed validation
"""Test that per-route validation actually validates when resolver has validation disabled"""
# GIVEN APIGatewayRestResolver WITHOUT global validation
app = APIGatewayRestResolver()

class Task(BaseModel):
Expand All @@ -250,7 +250,8 @@ class Task(BaseModel):

@app.get("/task", enable_validation=True)
def get_task() -> Task:
return Task(title="Important", priority=1)
# Return invalid data — missing 'title' and 'priority'
return cast(Task, {"wrong": "data"})

@app.get("/unvalidated-task")
def get_unvalidated_task():
Expand All @@ -259,13 +260,12 @@ def get_unvalidated_task():
event = load_event("apiGatewayProxyEvent.json")
event["httpMethod"] = "GET"

# WHEN calling validated route
# WHEN calling validated route with invalid data
event["path"] = "/task"
result = app(event, {})

# THEN should validate and serialize correctly
assert result["statusCode"] == 200
assert "Important" in result["body"]
# THEN validation must reject it with 422
assert result["statusCode"] == 422

# WHEN calling unvalidated route
event["path"] = "/unvalidated-task"
Expand All @@ -274,3 +274,28 @@ def get_unvalidated_task():
# THEN should return as-is without validation
assert result["statusCode"] == 200
assert "extra" in result["body"]


def test_per_route_opt_in_validation_with_valid_data():
"""Test that per-route opt-in validation passes valid data and serializes correctly"""
# GIVEN APIGatewayRestResolver WITHOUT global validation
app = APIGatewayRestResolver()

class Task(BaseModel):
title: str
priority: int

@app.get("/task", enable_validation=True)
def get_task() -> Task:
return Task(title="Important", priority=1)

event = load_event("apiGatewayProxyEvent.json")
event["httpMethod"] = "GET"
event["path"] = "/task"

# WHEN calling validated route with valid data
result = app(event, {})

# THEN validation passes and response is serialized
assert result["statusCode"] == 200
assert "Important" in result["body"]