Skip to content

Commit 3b5b154

Browse files
committed
consolidate normalize field value functions
1 parent a1d2b0f commit 3b5b154

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import dataclasses
44
import json
55
import logging
6-
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast, get_origin
6+
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast
77
from urllib.parse import parse_qs
88

99
from pydantic import BaseModel
@@ -13,8 +13,8 @@
1313
_model_dump,
1414
_normalize_errors,
1515
_regenerate_error_with_loc,
16+
field_annotation_is_sequence,
1617
get_missing_field_error,
17-
is_sequence_field,
1818
lenient_issubclass,
1919
)
2020
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
@@ -369,7 +369,7 @@ def _request_body_to_args(
369369
_handle_missing_field_value(field, values, errors, loc)
370370
continue
371371

372-
value = _normalize_field_value(field, value)
372+
value = _normalize_field_value(value=value, field_info=field.field_info)
373373
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
374374

375375
return values, errors
@@ -412,10 +412,13 @@ def _handle_missing_field_value(
412412
values[field.name] = field.get_default()
413413

414414

415-
def _normalize_field_value(field: ModelField, value: Any) -> Any:
415+
def _normalize_field_value(value: Any, field_info: FieldInfo) -> Any:
416416
"""Normalize field value, converting lists to single values for non-sequence fields."""
417-
if isinstance(value, list) and not is_sequence_field(field):
417+
if field_annotation_is_sequence(field_info.annotation):
418+
return value
419+
elif isinstance(value, list) and value:
418420
return value[0]
421+
419422
return value
420423

421424

@@ -504,7 +507,7 @@ def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField
504507
value = _get_param_value(input_dict, field_alias, field_name, model_class)
505508

506509
if value is not None:
507-
model_data[field_alias] = _normalize_field_value_model_param(value, field_info)
510+
model_data[field_alias] = _normalize_field_value(value=value, field_info=field_info)
508511

509512
input_dict[param.alias] = model_data
510513

@@ -524,13 +527,3 @@ def _get_param_value(
524527
value = input_dict.get(field_name)
525528

526529
return value
527-
528-
529-
def _normalize_field_value_model_param(value: Any, field_info: FieldInfo) -> Any:
530-
"""Normalize field value based on its type annotation."""
531-
if get_origin(field_info.annotation) is list:
532-
return value
533-
elif isinstance(value, list) and value:
534-
return value[0]
535-
else:
536-
return value

0 commit comments

Comments
 (0)