Skip to content

Commit 1a08022

Browse files
committed
fix(openapi): validate response serialization when falsy
1 parent 6cc7983 commit 1a08022

File tree

2 files changed

+55
-31
lines changed

2 files changed

+55
-31
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -136,67 +136,47 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
136136
return self._handle_response(route=route, response=response)
137137

138138
def _handle_response(self, *, route: Route, response: Response):
139-
# Process the response body if it exists
140-
if response.body:
141-
# Validate and serialize the response, if it's JSON
142-
if response.is_json():
139+
# Check if we have a return type defined
140+
if route.dependant.return_param:
141+
try:
142+
# Validate all responses, including None
143143
response.body = self._serialize_response(
144144
field=route.dependant.return_param,
145145
response_content=response.body,
146146
)
147+
except RequestValidationError as e:
148+
logger.error(f"Response validation failed: {str(e)}")
149+
response.status_code = 422
150+
response.body = {"detail": e.errors()}
147151

148152
return response
149153

150154
def _serialize_response(
151155
self,
152156
*,
153-
field: ModelField | None = None,
157+
field: Any = None,
154158
response_content: Any,
155159
include: IncEx | None = None,
156160
exclude: IncEx | None = None,
157-
by_alias: bool = True,
161+
by_alias: bool = False,
158162
exclude_unset: bool = False,
159163
exclude_defaults: bool = False,
160164
exclude_none: bool = False,
161165
) -> Any:
162-
"""
163-
Serialize the response content according to the field type.
164-
"""
165166
if field:
166167
errors: list[dict[str, Any]] = []
167-
# MAINTENANCE: remove this when we drop pydantic v1
168-
if not hasattr(field, "serializable"):
169-
response_content = self._prepare_response_content(
170-
response_content,
171-
exclude_unset=exclude_unset,
172-
exclude_defaults=exclude_defaults,
173-
exclude_none=exclude_none,
174-
)
175-
176168
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
177169
if errors:
178170
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
179171

180-
if hasattr(field, "serialize"):
181-
return field.serialize(
182-
value,
183-
include=include,
184-
exclude=exclude,
185-
by_alias=by_alias,
186-
exclude_unset=exclude_unset,
187-
exclude_defaults=exclude_defaults,
188-
exclude_none=exclude_none,
189-
)
190-
191-
return jsonable_encoder(
172+
return field.serialize(
192173
value,
193174
include=include,
194175
exclude=exclude,
195176
by_alias=by_alias,
196177
exclude_unset=exclude_unset,
197178
exclude_defaults=exclude_defaults,
198179
exclude_none=exclude_none,
199-
custom_serializer=self._validation_serializer,
200180
)
201181
else:
202182
# Just serialize the response content returned from the handler

tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,3 +1128,47 @@ def handler(user_id: int = 123):
11281128
# THEN the handler should be invoked and return 200
11291129
result = app(minimal_event, {})
11301130
assert result["statusCode"] == 200
1131+
1132+
1133+
def test_validate_optional_return_types(gw_event):
1134+
# GIVEN an APIGatewayRestResolver with validation enabled
1135+
app = APIGatewayRestResolver(enable_validation=True)
1136+
1137+
class Model(BaseModel):
1138+
name: str
1139+
age: int
1140+
1141+
# AND handlers defined with different Optional return types
1142+
@app.get("/none_not_allowed")
1143+
def handler_none_not_allowed() -> Model:
1144+
return None # type: ignore
1145+
1146+
@app.get("/none_allowed")
1147+
def handler_none_allowed() -> Optional[Model]:
1148+
return None
1149+
1150+
@app.get("/valid_optional")
1151+
def handler_valid_optional() -> Optional[Model]:
1152+
return Model(name="John", age=30)
1153+
1154+
# WHEN returning None for a non-Optional type
1155+
gw_event["path"] = "/none_not_allowed"
1156+
result = app(gw_event, {})
1157+
# THEN it should return a validation error
1158+
assert result["statusCode"] == 422
1159+
body = json.loads(result["body"])
1160+
assert "model_attributes_type" in body["detail"][0]["type"]
1161+
1162+
# WHEN returning None for an Optional type
1163+
gw_event["path"] = "/none_allowed"
1164+
result = app(gw_event, {})
1165+
# THEN it should succeed
1166+
assert result["statusCode"] == 200
1167+
assert result["body"] == "null"
1168+
1169+
# WHEN returning a valid model for an Optional type
1170+
gw_event["path"] = "/valid_optional"
1171+
result = app(gw_event, {})
1172+
# THEN it should succeed and return the serialized model
1173+
assert result["statusCode"] == 200
1174+
assert json.loads(result["body"]) == {"name": "John", "age": 30}

0 commit comments

Comments
 (0)