From 41f35ed4b000950fd42140c029cfa069607eae02 Mon Sep 17 00:00:00 2001 From: Antonio Aranda <102337110+arandito@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:38:01 -0400 Subject: [PATCH 1/3] smithy-aws-core: add support for awsQuery protocol --- ...ture-ec24e8dbe26b4ee58029f89da5b6526e.json | 4 + packages/smithy-aws-core/pyproject.toml | 3 + .../src/smithy_aws_core/_private/__init__.py | 2 + .../_private/query/__init__.py | 10 + .../smithy_aws_core/_private/query/errors.py | 106 +++++++ .../_private/query/serializers.py | 243 +++++++++++++++ .../src/smithy_aws_core/aio/protocols.py | 180 ++++++++++- .../src/smithy_aws_core/traits.py | 22 ++ .../tests/unit/aio/test_protocols.py | 175 ++++++++++- .../smithy-aws-core/tests/unit/test_query.py | 282 ++++++++++++++++++ .../smithy-aws-core/tests/unit/test_traits.py | 14 +- .../smithy-core/src/smithy_core/schemas.py | 3 + uv.lock | 6 +- 13 files changed, 1033 insertions(+), 17 deletions(-) create mode 100644 packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-ec24e8dbe26b4ee58029f89da5b6526e.json create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/__init__.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/query/__init__.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py create mode 100644 packages/smithy-aws-core/tests/unit/test_query.py diff --git a/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-ec24e8dbe26b4ee58029f89da5b6526e.json b/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-ec24e8dbe26b4ee58029f89da5b6526e.json new file mode 100644 index 000000000..377903034 --- /dev/null +++ b/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-ec24e8dbe26b4ee58029f89da5b6526e.json @@ -0,0 +1,4 @@ +{ + "type": "feature", + "description": "Add `awsQuery` protocol support for Smithy clients." +} \ No newline at end of file diff --git a/packages/smithy-aws-core/pyproject.toml b/packages/smithy-aws-core/pyproject.toml index 7dbd5ffbc..f9c0cec4c 100644 --- a/packages/smithy-aws-core/pyproject.toml +++ b/packages/smithy-aws-core/pyproject.toml @@ -50,6 +50,9 @@ eventstream = [ json = [ "smithy-json~=0.2.0" ] +xml = [ + "smithy-xml~=0.1.0" +] [tool.hatch.build] exclude = [ diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/__init__.py new file mode 100644 index 000000000..33cbe867a --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/__init__.py new file mode 100644 index 000000000..694350d1f --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from .errors import create_aws_query_error +from .serializers import QueryShapeSerializer + +__all__ = ( + "QueryShapeSerializer", + "create_aws_query_error", +) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py new file mode 100644 index 000000000..ed0dd97d4 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Any +from xml.etree.ElementTree import Element, ParseError, fromstring + +from smithy_core.documents import TypeRegistry +from smithy_core.exceptions import CallError, ExpectationNotMetError, ModeledError +from smithy_core.schemas import APIOperation +from smithy_core.shapes import ShapeID +from smithy_xml import XMLCodec + +from ...traits import AwsQueryErrorTrait + + +def _local_name(tag: str) -> str: + """Strip namespace URI from an element tag: {uri}local -> local.""" + if tag.startswith("{"): + return tag.split("}", 1)[1] + return tag + + +def _find_child(element: Element, name: str) -> Element | None: + """Return the first child element whose local name matches ``name``.""" + for child in element: + if _local_name(child.tag) == name: + return child + return None + + +def _parse_aws_query_error_code( + body: bytes, wrapper_elements: tuple[str, ...] +) -> str | None: + """Parse the ``Code`` field from a wrapped awsQuery error response.""" + try: + element = fromstring(body) # noqa: S314 + except ParseError: + return None + + if wrapper_elements: + if _local_name(element.tag) != wrapper_elements[0]: + return None + for wrapper in wrapper_elements[1:]: + next_element = _find_child(element, wrapper) + if next_element is None: + return None + element = next_element + + code_element = _find_child(element, "Code") + return code_element.text if code_element is not None else None + + +def _resolve_aws_query_error_shape_id( + *, + code: str, + operation: APIOperation[Any, Any], + error_registry: TypeRegistry, + default_namespace: str, +) -> ShapeID | None: + """Resolve an awsQuery error code to a modeled error shape ID.""" + for error_schema in operation.error_schemas: + trait = error_schema.get_trait(AwsQueryErrorTrait) + if trait is not None and trait.code == code: + if error_schema.id in error_registry: + return error_schema.id + break + + fallback_id = ShapeID.from_parts(namespace=default_namespace, name=code) + return fallback_id if fallback_id in error_registry else None + + +def create_aws_query_error( + *, + body: bytes, + operation: APIOperation[Any, Any], + error_registry: TypeRegistry, + default_namespace: str, + wrapper_elements: tuple[str, ...], + status: int, +) -> CallError: + """Create a modeled or generic CallError from an awsQuery error response.""" + code = _parse_aws_query_error_code(body, wrapper_elements) + if code is not None: + shape_id = _resolve_aws_query_error_shape_id( + code=code, + operation=operation, + error_registry=error_registry, + default_namespace=default_namespace, + ) + if shape_id is not None: + error_shape = error_registry.get(shape_id) + if not issubclass(error_shape, ModeledError): + raise ExpectationNotMetError( + "Modeled errors must be derived from 'ModeledError', " + f"but got {error_shape}" + ) + + deserializer = XMLCodec().create_deserializer( + body, wrapper_elements=wrapper_elements + ) + return error_shape.deserialize(deserializer) + + message = f"Unknown error for operation {operation.schema.id} - status: {status}" + if code is not None: + message += f", code: {code}" + fault = "client" if 400 <= status < 500 else "server" + return CallError(message=message, fault=fault) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py new file mode 100644 index 000000000..03b585321 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py @@ -0,0 +1,243 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from base64 import b64encode +from collections.abc import Callable +from contextlib import AbstractContextManager +from datetime import datetime +from decimal import Decimal +from types import TracebackType +from typing import Self +from urllib.parse import quote + +from smithy_core.documents import Document +from smithy_core.exceptions import SerializationError +from smithy_core.interfaces import BytesWriter +from smithy_core.schemas import Schema +from smithy_core.serializers import ( + InterceptingSerializer, + MapSerializer, + ShapeSerializer, +) +from smithy_core.traits import TimestampFormatTrait, XmlFlattenedTrait, XmlNameTrait +from smithy_core.types import TimestampFormat +from smithy_core.utils import serialize_float + + +def _percent_encode_query(value: str) -> str: + """Encode a query key or value using RFC 3986 percent-encoding.""" + return quote(value, safe="-_.~") + + +def _resolve_name(schema: Schema, default: str) -> str: + """Return ``@xmlName`` when present, otherwise ``default``.""" + if (xml_name := schema.get_trait(XmlNameTrait)) is not None: + return xml_name.value + return default + + +def _is_flattened(schema: Schema) -> bool: + """Return whether a collection is ``@xmlFlattened``.""" + return schema.get_trait(XmlFlattenedTrait) is not None + + +class QueryShapeSerializer(ShapeSerializer): + """Serializes Smithy shapes into AWS Query form parameters. + + Tracks a dotted key path and accumulates ``(key, value)`` pairs in a + shared buffer. Struct/list/map serializers create children that extend the + path, and primitives append terminal values at the current path. ``flush`` + emits the buffered pairs as the query payload. + """ + + def __init__( + self, + *, + sink: BytesWriter, + action: str | None = None, + version: str | None = None, + path: tuple[str, ...] = (), + params: list[tuple[str, str]] | None = None, + default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, + ) -> None: + self._sink = sink + self._action = action + self._version = version + self._path = path + self._params = [] if params is None else params + self._default_timestamp_format = default_timestamp_format + + def child(self, *segments: str) -> "QueryShapeSerializer": + return QueryShapeSerializer( + sink=self._sink, + path=(*self._path, *segments), + params=self._params, + default_timestamp_format=self._default_timestamp_format, + ) + + def append(self, value: str) -> None: + if not self._path: + raise SerializationError( + "Unable to serialize AWS Query value without a key path." + ) + self._params.append((".".join(self._path), value)) + + def begin_struct(self, schema: Schema) -> AbstractContextManager[ShapeSerializer]: + return QueryStructSerializer(self) + + def begin_list( + self, schema: Schema, size: int + ) -> AbstractContextManager[ShapeSerializer]: + if size == 0: + self.append("") + return QueryListSerializer(self, schema) + + def begin_map( + self, schema: Schema, size: int + ) -> AbstractContextManager[MapSerializer]: + return QueryMapSerializer(self, schema) + + def write_null(self, schema: Schema) -> None: + return None + + def write_boolean(self, schema: Schema, value: bool) -> None: + self.append("true" if value else "false") + + def write_integer(self, schema: Schema, value: int) -> None: + self.append(str(value)) + + def write_float(self, schema: Schema, value: float) -> None: + self.append(serialize_float(value)) + + def write_big_decimal(self, schema: Schema, value: Decimal) -> None: + self.append(serialize_float(value)) + + def write_string(self, schema: Schema, value: str) -> None: + self.append(value) + + def write_blob(self, schema: Schema, value: bytes) -> None: + self.append(b64encode(value).decode("utf-8")) + + def write_timestamp(self, schema: Schema, value: datetime) -> None: + format = self._default_timestamp_format + if (trait := schema.get_trait(TimestampFormatTrait)) is not None: + format = trait.format + self.append(str(format.serialize(value))) + + def write_document(self, schema: Schema, value: Document) -> None: + raise SerializationError("Query protocols do not support document types.") + + def flush(self) -> None: + serialized: list[tuple[str, str]] = [] + if self._action is not None and self._version is not None: + serialized.extend( + [ + ("Action", self._action), + ("Version", self._version), + ] + ) + serialized.extend(self._params) + body = "&".join( + f"{_percent_encode_query(key)}={_percent_encode_query(value)}" + for key, value in serialized + ).encode("utf-8") + self._sink.write(body) + + +class QueryStructSerializer(InterceptingSerializer): + """Serializes struct members as child query paths. + + ``before`` creates a child serializer rooted at the member name, honoring + ``@xmlName``. + """ + + def __init__(self, parent: QueryShapeSerializer) -> None: + self._parent = parent + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def before(self, schema: Schema) -> ShapeSerializer: + return self._parent.child(_resolve_name(schema, schema.expect_member_name())) + + def after(self, schema: Schema) -> None: + pass + + +class QueryListSerializer(InterceptingSerializer): + """Serializes list entries as indexed child query paths. + + ``before`` increments a 1-based index and creates the item path as either + ``.`` or ```` when the list is flattened. + """ + + def __init__(self, parent: QueryShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._is_flattened = _is_flattened(schema) + self._item_name = _resolve_name(schema.members["member"], "member") + self._index = 0 + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def before(self, schema: Schema) -> ShapeSerializer: + self._index += 1 + if self._is_flattened: + return self._parent.child(str(self._index)) + return self._parent.child(self._item_name, str(self._index)) + + def after(self, schema: Schema) -> None: + pass + + +class QueryMapSerializer(MapSerializer): + """Serializes map entries as indexed key and value query paths. + + Each entry increments a 1-based index, uses ``entry.`` (or + ```` when flattened), writes the key at ``...``, and + serializes the value at ``...``. + """ + + def __init__(self, parent: QueryShapeSerializer, schema: Schema) -> None: + self._parent = parent + self._is_flattened = _is_flattened(schema) + self._key_name = _resolve_name(schema.members["key"], "key") + self._value_name = _resolve_name(schema.members["value"], "value") + self._index = 0 + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def entry(self, key: str, value_writer: Callable[[ShapeSerializer], None]) -> None: + self._index += 1 + if self._is_flattened: + entry_path = (str(self._index),) + else: + entry_path = ("entry", str(self._index)) + + self._parent.child(*entry_path, self._key_name).append(key) + value_writer(self._parent.child(*entry_path, self._value_name)) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py index 709651a4a..ab7078b03 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py @@ -2,31 +2,55 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Callable from inspect import iscoroutinefunction +from io import BytesIO from typing import TYPE_CHECKING, Any, Final +from smithy_core import URI as _URI from smithy_core.aio.interfaces import AsyncWriter from smithy_core.aio.interfaces.auth import AuthScheme from smithy_core.aio.interfaces.eventstream import EventPublisher, EventReceiver from smithy_core.aio.types import AsyncBytesReader from smithy_core.codecs import Codec from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer +from smithy_core.documents import TypeRegistry from smithy_core.exceptions import ( + CallError, DiscriminatorError, MissingDependencyError, UnsupportedStreamError, ) -from smithy_core.interfaces import TypedProperties +from smithy_core.interfaces import TypedProperties, URI from smithy_core.schemas import APIOperation, Schema from smithy_core.serializers import SerializeableShape from smithy_core.shapes import ShapeID, ShapeType from smithy_core.types import TimestampFormat +from smithy_http import tuples_to_fields +from smithy_http.aio import HTTPRequest as _HTTPRequest from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPRequest, HTTPResponse -from smithy_http.aio.protocols import HttpBindingClientProtocol -from smithy_json import JSONCodec, JSONDocument +from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol +from smithy_http.deserializers import HTTPResponseDeserializer -from ..traits import RestJson1Trait +from .._private.query.errors import ( + create_aws_query_error, +) +from .._private.query.serializers import QueryShapeSerializer +from ..traits import AwsQueryTrait, RestJson1Trait from ..utils import parse_document_discriminator, parse_error_code +try: + from smithy_json import JSONCodec, JSONDocument + + _HAS_JSON = True +except ImportError: + _HAS_JSON = False # type: ignore + +try: + from smithy_xml import XMLCodec + + _HAS_XML = True +except ImportError: + _HAS_XML = False # type: ignore + try: from smithy_aws_event_stream.aio import ( AWSEventPublisher, @@ -44,10 +68,26 @@ AWSEventReceiver, SigningConfig, ) + from smithy_json import JSONCodec, JSONDocument + from smithy_xml import XMLCodec from typing_extensions import TypeForm -def _assert_event_stream_capable() -> None: +def _assert_json() -> None: + if not _HAS_JSON: + raise MissingDependencyError( + "Attempted to use JSON protocol support, but smithy-json is not installed." + ) + + +def _assert_xml() -> None: + if not _HAS_XML: + raise MissingDependencyError( + "Attempted to use XML protocol support, but smithy-xml is not installed." + ) + + +def _assert_event_stream() -> None: if not _HAS_EVENT_STREAM: raise MissingDependencyError( "Attempted to use event streams, but smithy-aws-event-stream " @@ -99,6 +139,7 @@ def __init__(self, service_schema: Schema) -> None: :param service: The schema for the service to interact with. """ + _assert_json() self._codec: Final = JSONCodec( document_class=AWSJSONDocument, default_namespace=service_schema.id.namespace, @@ -134,7 +175,7 @@ def create_event_publisher[ context: TypedProperties, auth_scheme: AuthScheme[Any, Any, Any, Any] | None = None, ) -> EventPublisher[Event]: - _assert_event_stream_capable() + _assert_event_stream() signing_config: SigningConfig | None = None if auth_scheme is not None: event_signer = auth_scheme.event_signer(request=request) @@ -177,9 +218,134 @@ def create_event_receiver[ event_deserializer: Callable[[ShapeDeserializer], Event], context: TypedProperties, ) -> EventReceiver[Event]: - _assert_event_stream_capable() + _assert_event_stream() return AWSEventReceiver( payload_codec=self.payload_codec, source=AsyncBytesReader(response.body), deserializer=event_deserializer, ) + + +class AwsQueryClientProtocol(HttpClientProtocol): + """An implementation of the aws.protocols#awsQuery protocol.""" + + _id: Final = AwsQueryTrait.id + _content_type: Final = "application/x-www-form-urlencoded" + + def __init__(self, service_schema: Schema, version: str) -> None: + _assert_xml() + self._default_namespace: Final = service_schema.id.namespace + self._version: Final = version + self._codec: Final = XMLCodec(default_namespace=self._default_namespace) + + @property + def id(self) -> ShapeID: + return self._id + + def serialize_request[ + OperationInput: SerializeableShape, + OperationOutput: DeserializeableShape, + ]( + self, + *, + operation: APIOperation[OperationInput, OperationOutput], + input: OperationInput, + endpoint: URI, + context: TypedProperties, + ) -> HTTPRequest: + sink = BytesIO() + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=sink, + action=self._action_name(operation), + version=self._version, + params=params, + ) + input.serialize(serializer) + serializer.flush() + content_length = sink.tell() + sink.seek(0) + body = AsyncBytesReader(sink) + return _HTTPRequest( + method="POST", + destination=_URI(host="", path="/"), + fields=tuples_to_fields( + [ + ("content-type", self._content_type), + ("content-length", str(content_length)), + ] + ), + body=body, + ) + + async def deserialize_response[ + OperationInput: SerializeableShape, + OperationOutput: DeserializeableShape, + ]( + self, + *, + operation: APIOperation[OperationInput, OperationOutput], + request: HTTPRequest, + response: HTTPResponse, + error_registry: TypeRegistry, + context: TypedProperties, + ) -> OperationOutput: + body = await response.consume_body_async() + + if response.status >= 300: + raise self._create_error( + operation=operation, + response=response, + response_body=body, + error_registry=error_registry, + ) + + if len(body) == 0: + return operation.output.deserialize( + HTTPResponseDeserializer( + payload_codec=self._codec, + response=response, + body=body, + ) + ) + + wrapper_elements = self._response_wrapper_elements(operation) + deserializer = self._codec.create_deserializer( + body, wrapper_elements=wrapper_elements + ) + return operation.output.deserialize(deserializer) + + def _create_error( + self, + *, + operation: APIOperation[Any, Any], + response: HTTPResponse, + response_body: bytes, + error_registry: TypeRegistry, + ) -> CallError: + return create_aws_query_error( + body=response_body, + operation=operation, + error_registry=error_registry, + default_namespace=self._default_namespace, + wrapper_elements=self._error_wrapper_elements(), + status=response.status, + ) + + def _action_name( + self, + operation: APIOperation[SerializeableShape, DeserializeableShape], + ) -> str: + return operation.schema.id.name + + def _response_wrapper_elements( + self, + operation: APIOperation[SerializeableShape, DeserializeableShape], + ) -> tuple[str, str]: + return ( + f"{operation.schema.id.name}Response", + f"{operation.schema.id.name}Result", + ) + + def _error_wrapper_elements(self) -> tuple[str, ...]: + return ("ErrorResponse", "Error") diff --git a/packages/smithy-aws-core/src/smithy_aws_core/traits.py b/packages/smithy-aws-core/src/smithy_aws_core/traits.py index 3902a55ff..2a789b96e 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/traits.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/traits.py @@ -46,6 +46,28 @@ def __init__(self, value: DocumentValue | DynamicTrait = None): ) +@dataclass(frozen=True) +class AwsQueryTrait(Trait, id=ShapeID("aws.protocols#awsQuery")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class AwsQueryErrorTrait(Trait, id=ShapeID("aws.protocols#awsQueryError")): + def __post_init__(self): + assert isinstance(self.document_value, Mapping) + assert isinstance(self.document_value.get("code"), str) + assert isinstance(self.document_value.get("httpResponseCode"), int | None) + + @property + def code(self) -> str: + return self.document_value["code"] # type: ignore + + @property + def http_response_code(self) -> int | None: + return self.document_value.get("httpResponseCode") # type: ignore + + @dataclass(init=False, frozen=True) class SigV4Trait(Trait, id=ShapeID("aws.auth#sigv4")): def __post_init__(self): diff --git a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py index 7b767a080..d02481948 100644 --- a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py +++ b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py @@ -1,13 +1,27 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import Any, cast from unittest.mock import Mock import pytest -from smithy_aws_core.aio.protocols import AWSErrorIdentifier, AWSJSONDocument -from smithy_core.exceptions import DiscriminatorError +from smithy_aws_core.aio.protocols import ( + AWSErrorIdentifier, + AWSJSONDocument, + AwsQueryClientProtocol, +) +from smithy_aws_core.traits import AwsQueryTrait +from smithy_core.deserializers import ShapeDeserializer +from smithy_core.documents import TypeRegistry +from smithy_core.exceptions import CallError, DiscriminatorError, ModeledError +from smithy_core.interfaces import URI +from smithy_core.prelude import STRING from smithy_core.schemas import APIOperation, Schema +from smithy_core.serializers import ShapeSerializer from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import Trait +from smithy_core.types import TypedProperties from smithy_http import Fields, tuples_to_fields from smithy_http.aio import HTTPResponse from smithy_json import JSONSettings @@ -36,13 +50,11 @@ def test_aws_error_identifier(header: str | None, expected: ShapeID | None) -> N fields = tuples_to_fields([("x-amzn-errortype", header)]) http_response = HTTPResponse(status=500, fields=fields) - operation = Mock(spec=APIOperation) - operation.schema = Schema( - id=ShapeID("com.test#TestOperation"), shape_type=ShapeType.OPERATION - ) - error_identifier = AWSErrorIdentifier() - actual = error_identifier.identify(operation=operation, response=http_response) + actual = error_identifier.identify( + operation=_mock_operation(_operation_schema("TestOperation")), + response=http_response, + ) assert actual == expected @@ -97,3 +109,150 @@ def test_aws_json_document_discriminator( else: discriminator = AWSJSONDocument(document, settings=settings).discriminator assert discriminator == expected + + +_INPUT_SCHEMA = Schema.collection( + id=ShapeID("com.test#TestInput"), + members={"name": {"target": STRING}}, +) +_SERVICE_SCHEMA = Schema.collection( + id=ShapeID("com.test#QueryService"), + shape_type=ShapeType.SERVICE, + traits=[AwsQueryTrait(None)], +) +_INVALID_ACTION_ERROR_SCHEMA = Schema.collection( + id=ShapeID("com.test#InvalidActionError"), + traits=[ + Trait.new(id=ShapeID("smithy.api#error"), value="client"), + Trait.new( + id=ShapeID("aws.protocols#awsQueryError"), + value={"code": "InvalidAction"}, + ), + ], + members={"message": {"target": STRING}}, +) + + +@dataclass +class _TestInput: + name: str | None = None + + def serialize(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(_INPUT_SCHEMA, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + if self.name is not None: + serializer.write_string(_INPUT_SCHEMA.members["name"], self.name) + + +class _ModeledQueryError(ModeledError): + message: str + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> "_ModeledQueryError": + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + if schema.expect_member_name() == "message": + kwargs["message"] = de.read_string(schema) + + deserializer.read_struct(_INVALID_ACTION_ERROR_SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +def _operation_schema(name: str) -> Schema: + return Schema( + id=ShapeID(f"com.test#{name}"), + shape_type=ShapeType.OPERATION, + ) + + +def _mock_operation( + schema: Schema, + *, + error_schemas: list[Schema] | None = None, +) -> APIOperation[Any, Any]: + operation = Mock(spec=APIOperation) + operation.schema = schema + operation.error_schemas = error_schemas or [] + return cast("APIOperation[Any, Any]", operation) + + +@pytest.mark.asyncio +async def test_aws_query_serializes_base_request_shape() -> None: + protocol = AwsQueryClientProtocol(_SERVICE_SCHEMA, "2020-01-08") + request = protocol.serialize_request( + operation=_mock_operation(_operation_schema("TestOperation")), + input=_TestInput(name="example"), + endpoint=cast(URI, Mock()), + context=TypedProperties(), + ) + + assert request.method == "POST" + assert request.destination.path == "/" + assert ( + request.fields["content-type"].as_string() + == "application/x-www-form-urlencoded" + ) + body = await request.consume_body_async() + assert request.fields["content-length"].as_string() == str(len(body)) + assert body == b"Action=TestOperation&Version=2020-01-08&name=example" + + +def test_aws_query_resolves_modeled_error_from_query_error_trait() -> None: + protocol = AwsQueryClientProtocol(_SERVICE_SCHEMA, "2020-01-08") + error = getattr(protocol, "_create_error")( + operation=_mock_operation( + _operation_schema("FailingOperation"), + error_schemas=[_INVALID_ACTION_ERROR_SCHEMA], + ), + response=HTTPResponse(status=400, fields=tuples_to_fields([])), + response_body=( + b"InvalidAction" + b"bad request" + ), + error_registry=TypeRegistry( + {ShapeID("com.test#InvalidActionError"): _ModeledQueryError} + ), + ) + + assert isinstance(error, _ModeledQueryError) + assert error.message == "bad request" + + +def test_aws_query_resolves_modeled_error_from_default_namespace_fallback() -> None: + protocol = AwsQueryClientProtocol(_SERVICE_SCHEMA, "2020-01-08") + error = getattr(protocol, "_create_error")( + operation=_mock_operation(_operation_schema("FailingOperation")), + response=HTTPResponse(status=503, fields=tuples_to_fields([])), + response_body=( + b"ServiceUnavailable" + b"try again" + ), + error_registry=TypeRegistry( + {ShapeID("com.test#ServiceUnavailable"): _ModeledQueryError} + ), + ) + + assert isinstance(error, _ModeledQueryError) + assert error.message == "try again" + + +def test_aws_query_returns_generic_error_for_unknown_code() -> None: + protocol = AwsQueryClientProtocol(_SERVICE_SCHEMA, "2020-01-08") + error = getattr(protocol, "_create_error")( + operation=_mock_operation(_operation_schema("FailingOperation")), + response=HTTPResponse(status=500, fields=tuples_to_fields([])), + response_body=( + b"UnknownThing" + b"bad request" + ), + error_registry=TypeRegistry({}), + ) + + assert isinstance(error, CallError) + assert not isinstance(error, ModeledError) + assert error.message == ( + "Unknown error for operation com.test#FailingOperation" + " - status: 500, code: UnknownThing" + ) diff --git a/packages/smithy-aws-core/tests/unit/test_query.py b/packages/smithy-aws-core/tests/unit/test_query.py new file mode 100644 index 000000000..ab85bf664 --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/test_query.py @@ -0,0 +1,282 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from io import BytesIO + +from smithy_aws_core._private.query.serializers import QueryShapeSerializer +from smithy_core.prelude import STRING +from smithy_core.schemas import Schema +from smithy_core.serializers import ShapeSerializer +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import XmlFlattenedTrait, XmlNameTrait + + +def test_query_list_serialization() -> None: + list_schema = Schema.collection( + id=ShapeID("com.test#StringList"), + shape_type=ShapeType.LIST, + members={"member": {"target": STRING}}, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Items",), + params=params, + ) + with serializer.begin_list(list_schema, 2) as list_serializer: + member_schema = list_schema.members["member"] + list_serializer.write_string(member_schema, "a") + list_serializer.write_string(member_schema, "b") + + assert params == [ + ("Items.member.1", "a"), + ("Items.member.2", "b"), + ] + + +def test_query_flattened_list_serialization() -> None: + list_schema = Schema.collection( + id=ShapeID("com.test#StringList"), + shape_type=ShapeType.LIST, + traits=[XmlFlattenedTrait()], + members={"member": {"target": STRING}}, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Items",), + params=params, + ) + with serializer.begin_list(list_schema, 2) as list_serializer: + member_schema = list_schema.members["member"] + list_serializer.write_string(member_schema, "a") + list_serializer.write_string(member_schema, "b") + + assert params == [("Items.1", "a"), ("Items.2", "b")] + + +def test_query_empty_list_serialization() -> None: + list_schema = Schema.collection( + id=ShapeID("com.test#StringList"), + shape_type=ShapeType.LIST, + members={"member": {"target": STRING}}, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Items",), + params=params, + ) + with serializer.begin_list(list_schema, 0): + pass + + assert params == [("Items", "")] + + +def test_query_flattened_list_uses_member_xml_name() -> None: + list_schema = Schema.collection( + id=ShapeID("com.test#StringList"), + shape_type=ShapeType.LIST, + members={"member": {"target": STRING, "traits": [XmlNameTrait("item")]}}, + ) + input_schema = Schema.collection( + id=ShapeID("com.test#Input"), + members={ + "values": { + "target": list_schema, + "traits": [XmlFlattenedTrait(), XmlNameTrait("Hi")], + } + }, + ) + + @dataclass + class Input: + values: list[str] + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(input_schema) as struct_serializer: + self.serialize_members(struct_serializer) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + schema = input_schema.members["values"] + member_schema = schema.expect_member_target().members["member"] + with serializer.begin_list(schema, len(self.values)) as list_serializer: + for value in self.values: + list_serializer.write_string(member_schema, value) + + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), action="TestOperation", version="2020-01-08", params=params + ) + Input(values=["a", "b"]).serialize(serializer) + + assert params == [("Hi.1", "a"), ("Hi.2", "b")] + + +def test_query_map_serialization_uses_xml_name_traits() -> None: + map_schema = Schema.collection( + id=ShapeID("com.test#StringMap"), + shape_type=ShapeType.MAP, + members={ + "key": {"target": STRING, "traits": [XmlNameTrait("K")]}, + "value": {"target": STRING, "traits": [XmlNameTrait("V")]}, + }, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Attributes",), + params=params, + ) + with serializer.begin_map(map_schema, 1) as map_serializer: + map_serializer.entry( + "one", lambda value_serializer: value_serializer.write_string(STRING, "1") + ) + + assert params == [ + ("Attributes.entry.1.K", "one"), + ("Attributes.entry.1.V", "1"), + ] + + +def test_query_flattened_map_serialization() -> None: + map_schema = Schema.collection( + id=ShapeID("com.test#StringMap"), + shape_type=ShapeType.MAP, + traits=[XmlFlattenedTrait()], + members={ + "key": {"target": STRING}, + "value": {"target": STRING}, + }, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Attributes",), + params=params, + ) + with serializer.begin_map(map_schema, 2) as map_serializer: + map_serializer.entry( + "one", lambda value_serializer: value_serializer.write_string(STRING, "1") + ) + map_serializer.entry( + "two", lambda value_serializer: value_serializer.write_string(STRING, "2") + ) + + assert params == [ + ("Attributes.1.key", "one"), + ("Attributes.1.value", "1"), + ("Attributes.2.key", "two"), + ("Attributes.2.value", "2"), + ] + + +def test_query_empty_map_is_omitted() -> None: + map_schema = Schema.collection( + id=ShapeID("com.test#StringMap"), + shape_type=ShapeType.MAP, + members={ + "key": {"target": STRING}, + "value": {"target": STRING}, + }, + ) + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Attributes",), + params=params, + ) + with serializer.begin_map(map_schema, 0): + pass + + assert params == [] + + +def test_query_null_member_is_omitted() -> None: + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), + action="TestOperation", + version="2020-01-08", + path=("Nullable",), + params=params, + ) + + serializer.write_null(STRING) + + assert params == [] + + +def test_query_serializer_flush_writes_body_to_sink() -> None: + sink = BytesIO() + serializer = QueryShapeSerializer( + sink=sink, + action="TestOperation", + version="2020-01-08", + path=("Member Name",), + ) + serializer.write_string(STRING, "hello world") + serializer.flush() + + expected = b"Action=TestOperation&Version=2020-01-08&Member%20Name=hello%20world" + assert sink.getvalue() == expected + + +def test_query_serializer_flush_omits_action_and_version_when_unset() -> None: + sink = BytesIO() + serializer = QueryShapeSerializer(sink=sink, path=("MemberName",)) + serializer.write_string(STRING, "hello world") + serializer.flush() + + assert sink.getvalue() == b"MemberName=hello%20world" + + +def test_query_nested_struct_serialization() -> None: + inner_schema = Schema.collection( + id=ShapeID("com.test#Inner"), + members={"value": {"target": STRING}}, + ) + outer_schema = Schema.collection( + id=ShapeID("com.test#Outer"), + members={"inner": {"target": inner_schema}}, + ) + + @dataclass + class Inner: + value: str + + def serialize(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(inner_schema, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_string(inner_schema.members["value"], self.value) + + @dataclass + class Outer: + inner: Inner + + def serialize(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(outer_schema, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(outer_schema.members["inner"], self.inner) + + params: list[tuple[str, str]] = [] + serializer = QueryShapeSerializer( + sink=BytesIO(), action="TestOperation", version="2020-01-08", params=params + ) + Outer(inner=Inner("x")).serialize(serializer) + + assert params == [("inner.value", "x")] diff --git a/packages/smithy-aws-core/tests/unit/test_traits.py b/packages/smithy-aws-core/tests/unit/test_traits.py index d4f04ebf1..bf4d7fe6c 100644 --- a/packages/smithy-aws-core/tests/unit/test_traits.py +++ b/packages/smithy-aws-core/tests/unit/test_traits.py @@ -1,9 +1,21 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from smithy_aws_core.traits import RestJson1Trait +from smithy_aws_core.traits import AwsQueryErrorTrait, AwsQueryTrait, RestJson1Trait def test_allows_empty_restjson1_value() -> None: trait = RestJson1Trait(None) assert trait.http == ("http/1.1",) + assert trait.event_stream_http == ("http/1.1",) + + +def test_allows_empty_aws_query_trait_value() -> None: + trait = AwsQueryTrait(None) + assert trait.document_value is None + + +def test_parses_aws_query_error_trait() -> None: + trait = AwsQueryErrorTrait({"code": "InvalidAction", "httpResponseCode": 400}) + assert trait.code == "InvalidAction" + assert trait.http_response_code == 400 diff --git a/packages/smithy-core/src/smithy_core/schemas.py b/packages/smithy-core/src/smithy_core/schemas.py index 925073b42..8c593fdfc 100644 --- a/packages/smithy-core/src/smithy_core/schemas.py +++ b/packages/smithy-core/src/smithy_core/schemas.py @@ -303,6 +303,9 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]: effective_auth_schemes: Sequence[ShapeID] """A list of effective auth schemes for the operation.""" + error_schemas: Sequence[Schema] = field(repr=False) + """A list of modeled error schemas for the operation.""" + @property def idempotency_token_member(self) -> Schema | None: """The input schema member that serves as the idempotency token.""" diff --git a/uv.lock b/uv.lock index 53f054dc5..2a844cde3 100644 --- a/uv.lock +++ b/uv.lock @@ -669,6 +669,9 @@ eventstream = [ json = [ { name = "smithy-json" }, ] +xml = [ + { name = "smithy-xml" }, +] [package.metadata] requires-dist = [ @@ -677,8 +680,9 @@ requires-dist = [ { name = "smithy-core", editable = "packages/smithy-core" }, { name = "smithy-http", editable = "packages/smithy-http" }, { name = "smithy-json", marker = "extra == 'json'", editable = "packages/smithy-json" }, + { name = "smithy-xml", marker = "extra == 'xml'", editable = "packages/smithy-xml" }, ] -provides-extras = ["eventstream", "json"] +provides-extras = ["eventstream", "json", "xml"] [[package]] name = "smithy-aws-event-stream" From 65c5ee7380cc3bb7d5f4167ce393db392675f34d Mon Sep 17 00:00:00 2001 From: Antonio Aranda <102337110+arandito@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:49:08 -0400 Subject: [PATCH 2/3] codegen: generate AwsQueryClientProtocol for awsQuery services and generate protocol tests --- Makefile | 10 +-- codegen/aws/core/build.gradle.kts | 1 + .../aws/codegen/AwsProtocolsIntegration.java | 2 +- .../codegen/AwsQueryProtocolGenerator.java | 69 +++++++++++++++++++ .../codegen/HttpProtocolTestGenerator.java | 36 +++++++++- .../generators/OperationGenerator.java | 13 +++- codegen/protocol-test/build.gradle.kts | 1 + codegen/protocol-test/smithy-build.json | 22 ++++++ 8 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsQueryProtocolGenerator.java diff --git a/Makefile b/Makefile index 5e931e1c5..ec2452598 100644 --- a/Makefile +++ b/Makefile @@ -14,10 +14,12 @@ build-java: ## Builds the Java code generation packages. cd codegen && ./gradlew clean build -test-protocols: ## Generates and runs the restJson1 protocol tests. - cd codegen && ./gradlew :protocol-test:build - uv pip install codegen/protocol-test/build/smithyprojections/protocol-test/rest-json-1/python-client-codegen - uv run pytest codegen/protocol-test/build/smithyprojections/protocol-test/rest-json-1/python-client-codegen +test-protocols: ## Generates and runs protocol tests for all supported protocols. + cd codegen && ./gradlew :protocol-test:clean :protocol-test:build + @set -e; for projection_dir in codegen/protocol-test/build/smithyprojections/protocol-test/*/python-client-codegen; do \ + uv pip install "$$projection_dir"; \ + uv run pytest "$$projection_dir"; \ + done lint-py: ## Runs linters and formatters on the python packages. diff --git a/codegen/aws/core/build.gradle.kts b/codegen/aws/core/build.gradle.kts index 3a81c5190..49d506d8d 100644 --- a/codegen/aws/core/build.gradle.kts +++ b/codegen/aws/core/build.gradle.kts @@ -12,4 +12,5 @@ extra["moduleName"] = "software.amazon.smithy.python.aws.codegen" dependencies { implementation(project(":core")) implementation(libs.smithy.aws.traits) + implementation(libs.smithy.protocol.test.traits) } diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java index 63601dd5d..d7d24a4af 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java @@ -16,6 +16,6 @@ public class AwsProtocolsIntegration implements PythonIntegration { @Override public List getProtocolGenerators() { - return List.of(); + return List.of(new AwsQueryProtocolGenerator()); } } diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsQueryProtocolGenerator.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsQueryProtocolGenerator.java new file mode 100644 index 000000000..f02c5e815 --- /dev/null +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsQueryProtocolGenerator.java @@ -0,0 +1,69 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.python.aws.codegen; + +import java.util.Set; +import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.python.codegen.ApplicationProtocol; +import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.HttpProtocolTestGenerator; +import software.amazon.smithy.python.codegen.SymbolProperties; +import software.amazon.smithy.python.codegen.generators.ProtocolGenerator; +import software.amazon.smithy.python.codegen.writer.PythonWriter; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class AwsQueryProtocolGenerator implements ProtocolGenerator { + private static final Set TESTS_TO_SKIP = Set.of( + // TODO: support the request compression trait + // https://smithy.io/2.0/spec/behavior-traits.html#smithy-api-requestcompression-trait + "SDKAppliedContentEncoding_awsQuery", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsQuery", + + // TODO: support idempotency token autofill + "QueryProtocolIdempotencyTokenAutoFill", + + // This test asserts nan == nan, which is never true. + // We should update the generator to make specific assertions for these. + "AwsQuerySupportsNaNFloatOutputs", + + // TODO: support of the endpoint trait + "AwsQueryEndpointTraitWithHostLabel", + "AwsQueryEndpointTrait"); + + @Override + public ShapeId getProtocol() { + return AwsQueryTrait.ID; + } + + @Override + public ApplicationProtocol getApplicationProtocol(GenerationContext context) { + return ApplicationProtocol.createDefaultHttpApplicationProtocol(); + } + + @Override + public void initializeProtocol(GenerationContext context, PythonWriter writer) { + writer.addDependency(AwsPythonDependency.SMITHY_AWS_CORE.withOptionalDependencies("xml")); + writer.addImport("smithy_aws_core.aio.protocols", "AwsQueryClientProtocol"); + var service = context.settings().service(context.model()); + var serviceSymbol = context.symbolProvider().toSymbol(service); + var serviceSchema = serviceSymbol.expectProperty(SymbolProperties.SCHEMA); + var version = service.getVersion(); + writer.write("AwsQueryClientProtocol($T, $S)", serviceSchema, version); + } + + @Override + public void generateProtocolTests(GenerationContext context) { + context.writerDelegator() + .useFileWriter("./tests/test_awsquery_protocol.py", "tests.test_awsquery_protocol", writer -> { + new HttpProtocolTestGenerator( + context, + getProtocol(), + writer, + (shape, testCase) -> TESTS_TO_SKIP.contains(testCase.getId())).run(); + }); + } +} diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java index 0a06f15bf..33f5191e3 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java @@ -16,6 +16,7 @@ import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.Stream; +import software.amazon.smithy.aws.traits.auth.SigV4Trait; import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.model.Model; @@ -188,12 +189,14 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t endpoint_uri="https://$L/$L", transport = $T(), retry_strategy=SimpleRetryStrategy(max_attempts=1), + ${C|} ) """, CodegenUtils.getConfigSymbol(context.settings()), host, path, - REQUEST_TEST_ASYNC_HTTP_CLIENT_SYMBOL); + REQUEST_TEST_ASYNC_HTTP_CLIENT_SYMBOL, + (Runnable) this::writeSigV4TestConfig); })); // Generate the input using the expected shape and params @@ -418,6 +421,16 @@ private void compareMediaBlob(HttpMessageTestCase testCase, PythonWriter writer) """); return; } + if (contentType.equals("application/x-www-form-urlencoded")) { + writer.addStdlibImport("urllib.parse", "parse_qsl"); + writer.write(""" + actual_params = sorted(parse_qsl(actual_body_content.decode())) + expected_params = sorted(parse_qsl(expected_body_content.decode())) + assert actual_params == expected_params + + """); + return; + } writer.write("assert actual_body_content == expected_body_content\n"); } @@ -437,13 +450,15 @@ private void generateResponseTest(OperationShape operation, HttpResponseTestCase headers=$J, body=b$S, ), + ${C|} ) """, CodegenUtils.getConfigSymbol(context.settings()), RESPONSE_TEST_ASYNC_HTTP_CLIENT_SYMBOL, testCase.getCode(), CodegenUtils.toTuples(testCase.getHeaders()), - testCase.getBody().filter(body -> !body.isEmpty()).orElse("")); + testCase.getBody().filter(body -> !body.isEmpty()).orElse(""), + (Runnable) this::writeSigV4TestConfig); })); // Create an empty input object to pass var inputShape = model.expectShape(operation.getInputShape(), StructureShape.class); @@ -490,13 +505,15 @@ private void generateErrorResponseTest( headers=$J, body=b$S, ), + ${C|} ) """, CodegenUtils.getConfigSymbol(context.settings()), RESPONSE_TEST_ASYNC_HTTP_CLIENT_SYMBOL, testCase.getCode(), CodegenUtils.toTuples(testCase.getHeaders()), - testCase.getBody().orElse("")); + testCase.getBody().orElse(""), + (Runnable) this::writeSigV4TestConfig); })); // Create an empty input object to pass var inputShape = model.expectShape(operation.getInputShape(), StructureShape.class); @@ -607,6 +624,19 @@ private void writeClientBlock( }); } + private void writeSigV4TestConfig() { + if (!service.hasTrait(SigV4Trait.class)) { + return; + } + writer.addImport("smithy_aws_core.identity", "StaticCredentialsResolver"); + writer.write(""" + region="us-east-1", + aws_access_key_id="test-access-key-id", + aws_secret_access_key="test-secret-access-key", + aws_credentials_identity_resolver=StaticCredentialsResolver(), + """); + } + private void writeUtilStubs(Symbol serviceSymbol) { LOGGER.fine(String.format("Writing utility stubs for %s : %s", serviceSymbol.getName(), protocol.getName())); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java index 9ada3907a..9557a93ae 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java @@ -58,6 +58,9 @@ public void run() { }), effective_auth_schemes = [ $8C + ], + error_schemas = [ + $9C ] ) """, @@ -68,7 +71,8 @@ public void run() { inSymbol.expectProperty(SymbolProperties.SCHEMA), outSymbol.expectProperty(SymbolProperties.SCHEMA), writer.consumer(this::writeErrorTypeRegistry), - writer.consumer(this::writeAuthSchemes)); + writer.consumer(this::writeAuthSchemes), + writer.consumer(this::writeErrorSchemas)); } private void writeErrorTypeRegistry(PythonWriter writer) { @@ -82,6 +86,13 @@ private void writeErrorTypeRegistry(PythonWriter writer) { } } + private void writeErrorSchemas(PythonWriter writer) { + for (var error : shape.getErrors()) { + var errSymbol = symbolProvider.toSymbol(model.expectShape(error)); + writer.write("$T,", errSymbol.expectProperty(SymbolProperties.SCHEMA)); + } + } + private void writeAuthSchemes(PythonWriter writer) { var authSchemes = ServiceIndex.of(model) .getEffectiveAuthSchemes(context.settings().service(), diff --git a/codegen/protocol-test/build.gradle.kts b/codegen/protocol-test/build.gradle.kts index 5c470b9c4..cddc35e75 100644 --- a/codegen/protocol-test/build.gradle.kts +++ b/codegen/protocol-test/build.gradle.kts @@ -30,5 +30,6 @@ repositories { dependencies { implementation(project(":core")) + implementation(project(":aws:core")) implementation(libs.smithy.aws.protocol.tests) } diff --git a/codegen/protocol-test/smithy-build.json b/codegen/protocol-test/smithy-build.json index cbaccad98..f5ec825f2 100644 --- a/codegen/protocol-test/smithy-build.json +++ b/codegen/protocol-test/smithy-build.json @@ -22,6 +22,28 @@ "moduleVersion": "0.0.1" } } + }, + "aws-query": { + "transforms": [ + { + "name": "includeServices", + "args": { + "services": [ + "aws.protocoltests.query#AwsQuery" + ] + } + }, + { + "name": "removeUnusedShapes" + } + ], + "plugins": { + "python-client-codegen": { + "service": "aws.protocoltests.query#AwsQuery", + "module": "awsquery", + "moduleVersion": "0.0.1" + } + } } } } From b9a4a2c60f4acf3a1ff71eb02994c20e78229ae4 Mon Sep 17 00:00:00 2001 From: Antonio Aranda <102337110+arandito@users.noreply.github.com> Date: Mon, 23 Mar 2026 15:30:51 -0400 Subject: [PATCH 3/3] smithy-aws-core: update casing for xml traits --- .../smithy_aws_core/_private/query/serializers.py | 6 +++--- packages/smithy-aws-core/tests/unit/test_query.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py index 03b585321..5f2eac127 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/serializers.py @@ -19,7 +19,7 @@ MapSerializer, ShapeSerializer, ) -from smithy_core.traits import TimestampFormatTrait, XmlFlattenedTrait, XmlNameTrait +from smithy_core.traits import TimestampFormatTrait, XMLFlattenedTrait, XMLNameTrait from smithy_core.types import TimestampFormat from smithy_core.utils import serialize_float @@ -31,14 +31,14 @@ def _percent_encode_query(value: str) -> str: def _resolve_name(schema: Schema, default: str) -> str: """Return ``@xmlName`` when present, otherwise ``default``.""" - if (xml_name := schema.get_trait(XmlNameTrait)) is not None: + if (xml_name := schema.get_trait(XMLNameTrait)) is not None: return xml_name.value return default def _is_flattened(schema: Schema) -> bool: """Return whether a collection is ``@xmlFlattened``.""" - return schema.get_trait(XmlFlattenedTrait) is not None + return schema.get_trait(XMLFlattenedTrait) is not None class QueryShapeSerializer(ShapeSerializer): diff --git a/packages/smithy-aws-core/tests/unit/test_query.py b/packages/smithy-aws-core/tests/unit/test_query.py index ab85bf664..a6a069d3f 100644 --- a/packages/smithy-aws-core/tests/unit/test_query.py +++ b/packages/smithy-aws-core/tests/unit/test_query.py @@ -8,7 +8,7 @@ from smithy_core.schemas import Schema from smithy_core.serializers import ShapeSerializer from smithy_core.shapes import ShapeID, ShapeType -from smithy_core.traits import XmlFlattenedTrait, XmlNameTrait +from smithy_core.traits import XMLFlattenedTrait, XMLNameTrait def test_query_list_serialization() -> None: @@ -40,7 +40,7 @@ def test_query_flattened_list_serialization() -> None: list_schema = Schema.collection( id=ShapeID("com.test#StringList"), shape_type=ShapeType.LIST, - traits=[XmlFlattenedTrait()], + traits=[XMLFlattenedTrait()], members={"member": {"target": STRING}}, ) params: list[tuple[str, str]] = [] @@ -83,14 +83,14 @@ def test_query_flattened_list_uses_member_xml_name() -> None: list_schema = Schema.collection( id=ShapeID("com.test#StringList"), shape_type=ShapeType.LIST, - members={"member": {"target": STRING, "traits": [XmlNameTrait("item")]}}, + members={"member": {"target": STRING, "traits": [XMLNameTrait("item")]}}, ) input_schema = Schema.collection( id=ShapeID("com.test#Input"), members={ "values": { "target": list_schema, - "traits": [XmlFlattenedTrait(), XmlNameTrait("Hi")], + "traits": [XMLFlattenedTrait(), XMLNameTrait("Hi")], } }, ) @@ -124,8 +124,8 @@ def test_query_map_serialization_uses_xml_name_traits() -> None: id=ShapeID("com.test#StringMap"), shape_type=ShapeType.MAP, members={ - "key": {"target": STRING, "traits": [XmlNameTrait("K")]}, - "value": {"target": STRING, "traits": [XmlNameTrait("V")]}, + "key": {"target": STRING, "traits": [XMLNameTrait("K")]}, + "value": {"target": STRING, "traits": [XMLNameTrait("V")]}, }, ) params: list[tuple[str, str]] = [] @@ -151,7 +151,7 @@ def test_query_flattened_map_serialization() -> None: map_schema = Schema.collection( id=ShapeID("com.test#StringMap"), shape_type=ShapeType.MAP, - traits=[XmlFlattenedTrait()], + traits=[XMLFlattenedTrait()], members={ "key": {"target": STRING}, "value": {"target": STRING},