Skip to content

Commit 43e2ac6

Browse files
author
Amin Farjadi
committed
fix(openapi): correct response validation for falsy objects
1 parent 6b91f70 commit 43e2ac6

File tree

3 files changed

+100
-37
lines changed

3 files changed

+100
-37
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -213,20 +213,27 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
213213
return self._handle_response(route=route, response=response)
214214

215215
def _handle_response(self, *, route: Route, response: Response):
216-
# Process the response body if it exists
217-
if response.body and response.is_json():
218-
response.body = self._serialize_response(
219-
field=route.dependant.return_param,
216+
field = route.dependant.return_param
217+
218+
if field is None:
219+
if not response.is_json():
220+
return response
221+
else:
222+
# JSON serialize the body without validation
223+
response.body = jsonable_encoder(response.body, custom_serializer=self._validation_serializer)
224+
else:
225+
response.body = self._serialize_response_with_validation(
226+
field=field,
220227
response_content=response.body,
221228
has_route_custom_response_validation=route.custom_response_validation_http_code is not None,
222229
)
223230

224231
return response
225232

226-
def _serialize_response(
233+
def _serialize_response_with_validation(
227234
self,
228235
*,
229-
field: ModelField | None = None,
236+
field: ModelField,
230237
response_content: Any,
231238
include: IncEx | None = None,
232239
exclude: IncEx | None = None,
@@ -239,45 +246,42 @@ def _serialize_response(
239246
"""
240247
Serialize the response content according to the field type.
241248
"""
242-
if field:
243-
errors: list[dict[str, Any]] = []
244-
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
245-
if errors:
246-
# route-level validation must take precedence over app-level
247-
if has_route_custom_response_validation:
248-
raise ResponseValidationError(
249-
errors=_normalize_errors(errors),
250-
body=response_content,
251-
source="route",
252-
)
253-
if self._has_response_validation_error:
254-
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")
255-
256-
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
257-
258-
if hasattr(field, "serialize"):
259-
return field.serialize(
260-
value,
261-
include=include,
262-
exclude=exclude,
263-
by_alias=by_alias,
264-
exclude_unset=exclude_unset,
265-
exclude_defaults=exclude_defaults,
266-
exclude_none=exclude_none,
249+
errors: list[dict[str, Any]] = []
250+
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
251+
if errors:
252+
# route-level validation must take precedence over app-level
253+
if has_route_custom_response_validation:
254+
raise ResponseValidationError(
255+
errors=_normalize_errors(errors),
256+
body=response_content,
257+
source="route",
267258
)
268-
return jsonable_encoder(
259+
if self._has_response_validation_error:
260+
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")
261+
262+
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
263+
264+
if hasattr(field, "serialize"):
265+
return field.serialize(
269266
value,
270267
include=include,
271268
exclude=exclude,
272269
by_alias=by_alias,
273270
exclude_unset=exclude_unset,
274271
exclude_defaults=exclude_defaults,
275272
exclude_none=exclude_none,
276-
custom_serializer=self._validation_serializer,
277273
)
278-
else:
279-
# Just serialize the response content returned from the handler.
280-
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)
274+
275+
return jsonable_encoder(
276+
value,
277+
include=include,
278+
exclude=exclude,
279+
by_alias=by_alias,
280+
exclude_unset=exclude_unset,
281+
exclude_defaults=exclude_defaults,
282+
exclude_none=exclude_none,
283+
custom_serializer=self._validation_serializer,
284+
)
281285

282286
def _prepare_response_content(
283287
self,

tests/functional/event_handler/_pydantic/test_http_resolver_pydantic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def search(
209209
# =============================================================================
210210

211211

212+
@pytest.mark.skip("Due to issue #7981.")
212213
@pytest.mark.asyncio
213214
async def test_async_handler_with_validation():
214215
# GIVEN an app with async handler and validation

tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1606,7 +1606,39 @@ def handler(user_id: int = 123):
16061606
assert result["statusCode"] == 200
16071607

16081608

1609-
@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed")
1609+
def test_validate_list_response(gw_event):
1610+
# GIVEN an APIGatewayRestResolver with validation enabled
1611+
app = APIGatewayRestResolver(enable_validation=True)
1612+
1613+
class Model(BaseModel):
1614+
name: str
1615+
age: int
1616+
1617+
response_before_validation = [
1618+
{
1619+
"name": "Joe",
1620+
"age": 20,
1621+
},
1622+
{
1623+
"name": "Jane",
1624+
"age": 20,
1625+
},
1626+
]
1627+
1628+
@app.get("/list_response_with_same_element_types")
1629+
def handler_different_list() -> List[Model]:
1630+
return response_before_validation
1631+
1632+
# WHEN returning list with the same element type as the non-Optional return type
1633+
gw_event["path"] = "/list_response_with_same_element_types"
1634+
result = app(gw_event, {})
1635+
body = json.loads(result["body"])
1636+
1637+
# THEN it should return a validation error
1638+
assert result["statusCode"] == 200
1639+
assert body == response_before_validation
1640+
1641+
16101642
def test_validation_error_none_returned_non_optional_type(gw_event):
16111643
# GIVEN an APIGatewayRestResolver with validation enabled
16121644
app = APIGatewayRestResolver(enable_validation=True)
@@ -1630,6 +1662,32 @@ def handler_none_not_allowed() -> Model:
16301662
assert body["detail"][0]["loc"] == ["response"]
16311663

16321664

1665+
def test_validation_error_different_list_returned_non_optional_type(gw_event):
1666+
# GIVEN an APIGatewayRestResolver with validation enabled
1667+
app = APIGatewayRestResolver(enable_validation=True)
1668+
1669+
class Model(BaseModel):
1670+
name: str
1671+
age: int
1672+
1673+
different_list_response = ["a", "b", "c"]
1674+
1675+
@app.get("/list_response_with_different_element_types")
1676+
def handler_different_list() -> List[Model]:
1677+
return different_list_response
1678+
1679+
# WHEN returning list with the different element type as the non-Optional return type
1680+
gw_event["path"] = "/list_response_with_different_element_types"
1681+
result = app(gw_event, {})
1682+
1683+
# THEN it should return a validation error
1684+
assert result["statusCode"] == 422
1685+
body = json.loads(result["body"])
1686+
assert len(body["detail"]) == len(different_list_response)
1687+
assert body["detail"][0]["type"] == "model_attributes_type"
1688+
assert body["detail"][0]["loc"] == ["response", 0]
1689+
1690+
16331691
def test_validation_error_incomplete_model_returned_non_optional_type(gw_event):
16341692
# GIVEN an APIGatewayRestResolver with validation enabled
16351693
app = APIGatewayRestResolver(enable_validation=True)

0 commit comments

Comments
 (0)