diff --git a/docs/integrations/falcon.md b/docs/integrations/falcon.md index f0c96da9..e801d636 100644 --- a/docs/integrations/falcon.md +++ b/docs/integrations/falcon.md @@ -10,11 +10,13 @@ The integration supports Falcon version 4. ## Middleware The Falcon API can be integrated using the `FalconOpenAPIMiddleware` middleware. +For explicit transport classes, use `FalconWSGIOpenAPIMiddleware` for +`falcon.App` and `FalconASGIOpenAPIMiddleware` for `falcon.asgi.App`. ``` python hl_lines="1 3 7" -from openapi_core.contrib.falcon.middlewares import FalconOpenAPIMiddleware +from openapi_core.contrib.falcon.middlewares import FalconWSGIOpenAPIMiddleware -openapi_middleware = FalconOpenAPIMiddleware.from_spec(spec) +openapi_middleware = FalconWSGIOpenAPIMiddleware.from_spec(spec) app = falcon.App( # ... @@ -22,6 +24,21 @@ app = falcon.App( ) ``` +`FalconOpenAPIMiddleware` supports both WSGI and ASGI Falcon apps. +For an explicit ASGI middleware class name, use +`FalconASGIOpenAPIMiddleware`. + +``` python hl_lines="1 3 7" +from openapi_core.contrib.falcon.middlewares import FalconASGIOpenAPIMiddleware + +openapi_middleware = FalconASGIOpenAPIMiddleware.from_spec(spec) + +app = falcon.asgi.App( + # ... + middleware=[openapi_middleware], +) +``` + Additional customization parameters can be passed to the middleware. ``` python hl_lines="5" diff --git a/openapi_core/app.py b/openapi_core/app.py index 854f3cc3..1a0df30a 100644 --- a/openapi_core/app.py +++ b/openapi_core/app.py @@ -2,6 +2,7 @@ from functools import cached_property from pathlib import Path +from typing import Any from typing import Optional from jsonschema._utils import Unset @@ -279,6 +280,51 @@ def from_file( sp = SchemaPath.from_file(fileobj, base_uri=base_uri) return cls(sp, config=config) + @classmethod + def build( + cls, + spec: Annotated[ + SchemaPath, + Doc(""" + OpenAPI specification schema path object. + """), + ], + request_unmarshaller_cls: Annotated[ + Optional[RequestUnmarshallerType], + Doc(""" + Custom request unmarshaller class. + """), + ] = None, + response_unmarshaller_cls: Annotated[ + Optional[ResponseUnmarshallerType], + Doc(""" + Custom response unmarshaller class. + """), + ] = None, + ) -> "OpenAPI": + """Builds an `OpenAPI` from a `SchemaPath` object with optional configuration parameters. + + Example: + ```python + from openapi_core import OpenAPI + app = OpenAPI.build(spec, request_unmarshaller_cls=CustomRequestUnmarshaller) + ``` + + Returns: + OpenAPI: An instance of the OpenAPI class. + """ + config_kwargs: dict[str, Any] = {} + if request_unmarshaller_cls is not None: + config_kwargs["request_unmarshaller_cls"] = ( + request_unmarshaller_cls + ) + if response_unmarshaller_cls is not None: + config_kwargs["response_unmarshaller_cls"] = ( + response_unmarshaller_cls + ) + config = Config(**config_kwargs) + return cls(spec, config=config) + def _get_version(self) -> SpecVersion: try: return get_spec_version(self.spec.read_value()) diff --git a/openapi_core/contrib/falcon/__init__.py b/openapi_core/contrib/falcon/__init__.py index 67c28a13..ef2a2411 100644 --- a/openapi_core/contrib/falcon/__init__.py +++ b/openapi_core/contrib/falcon/__init__.py @@ -1,7 +1,13 @@ +from openapi_core.contrib.falcon.middlewares import FalconASGIOpenAPIMiddleware +from openapi_core.contrib.falcon.middlewares import FalconOpenAPIMiddleware +from openapi_core.contrib.falcon.middlewares import FalconWSGIOpenAPIMiddleware from openapi_core.contrib.falcon.requests import FalconOpenAPIRequest from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse __all__ = [ + "FalconASGIOpenAPIMiddleware", + "FalconOpenAPIMiddleware", + "FalconWSGIOpenAPIMiddleware", "FalconOpenAPIRequest", "FalconOpenAPIResponse", ] diff --git a/openapi_core/contrib/falcon/integrations.py b/openapi_core/contrib/falcon/integrations.py index 8c3fa544..0b1ddc8f 100644 --- a/openapi_core/contrib/falcon/integrations.py +++ b/openapi_core/contrib/falcon/integrations.py @@ -1,8 +1,15 @@ +from typing import Optional +from typing import Type + from falcon.request import Request from falcon.response import Response +from openapi_core import OpenAPI +from openapi_core.contrib.falcon.requests import FalconAsgiOpenAPIRequest from openapi_core.contrib.falcon.requests import FalconOpenAPIRequest +from openapi_core.contrib.falcon.responses import FalconAsgiOpenAPIResponse from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse +from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor from openapi_core.unmarshalling.processors import UnmarshallingProcessor from openapi_core.unmarshalling.typing import ErrorsHandlerCallable @@ -32,3 +39,46 @@ def handle_response( if not self.should_validate_response(): return response return super().handle_response(request, response, errors_handler) + + +class AsyncFalconIntegration(AsyncUnmarshallingProcessor[Request, Response]): + request_cls: Type[FalconAsgiOpenAPIRequest] = FalconAsgiOpenAPIRequest + response_cls: Optional[Type[FalconAsgiOpenAPIResponse]] = ( + FalconAsgiOpenAPIResponse + ) + + def __init__( + self, + openapi: OpenAPI, + request_cls: Type[FalconAsgiOpenAPIRequest] = FalconAsgiOpenAPIRequest, + response_cls: Optional[Type[FalconAsgiOpenAPIResponse]] = ( + FalconAsgiOpenAPIResponse + ), + ): + super().__init__(openapi) + self.request_cls = request_cls or self.request_cls + self.response_cls = response_cls + + async def get_openapi_request( + self, request: Request + ) -> FalconAsgiOpenAPIRequest: + return await self.request_cls.from_request(request) + + async def get_openapi_response( + self, response: Response + ) -> FalconAsgiOpenAPIResponse: + assert self.response_cls is not None + return await self.response_cls.from_response(response) + + def should_validate_response(self) -> bool: + return self.response_cls is not None + + async def handle_response( + self, + request: Request, + response: Response, + errors_handler: ErrorsHandlerCallable[Response], + ) -> Response: + if not self.should_validate_response(): + return response + return await super().handle_response(request, response, errors_handler) diff --git a/openapi_core/contrib/falcon/middlewares.py b/openapi_core/contrib/falcon/middlewares.py index 13e4c5e8..89d9f81a 100644 --- a/openapi_core/contrib/falcon/middlewares.py +++ b/openapi_core/contrib/falcon/middlewares.py @@ -1,29 +1,39 @@ """OpenAPI core contrib falcon middlewares module""" from typing import Any +from typing import Optional from typing import Type -from typing import Union +from typing import cast from falcon.request import Request from falcon.response import Response -from jsonschema._utils import Unset -from jsonschema.validators import _UNSET from jsonschema_path import SchemaPath -from openapi_core import Config from openapi_core import OpenAPI from openapi_core.contrib.falcon.handlers import FalconOpenAPIErrorsHandler from openapi_core.contrib.falcon.handlers import ( FalconOpenAPIValidRequestHandler, ) +from openapi_core.contrib.falcon.integrations import AsyncFalconIntegration from openapi_core.contrib.falcon.integrations import FalconIntegration +from openapi_core.contrib.falcon.requests import FalconAsgiOpenAPIRequest from openapi_core.contrib.falcon.requests import FalconOpenAPIRequest +from openapi_core.contrib.falcon.responses import FalconAsgiOpenAPIResponse from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult from openapi_core.unmarshalling.request.types import RequestUnmarshallerType from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +_DEFAULT_ASYNC = object() + + +class FalconWSGIOpenAPIMiddleware(FalconIntegration): + """OpenAPI middleware for Falcon WSGI applications. + + This class wires Falcon's synchronous middleware hooks to the + synchronous OpenAPI integration. + """ -class FalconOpenAPIMiddleware(FalconIntegration): valid_request_handler_cls = FalconOpenAPIValidRequestHandler errors_handler_cls: Type[FalconOpenAPIErrorsHandler] = ( FalconOpenAPIErrorsHandler @@ -37,7 +47,7 @@ def __init__( errors_handler_cls: Type[ FalconOpenAPIErrorsHandler ] = FalconOpenAPIErrorsHandler, - **unmarshaller_kwargs: Any, + **kwargs: Any, ): super().__init__(openapi) self.request_cls = request_cls or self.request_cls @@ -48,32 +58,26 @@ def __init__( def from_spec( cls, spec: SchemaPath, - request_unmarshaller_cls: Union[ - RequestUnmarshallerType, Unset - ] = _UNSET, - response_unmarshaller_cls: Union[ - ResponseUnmarshallerType, Unset - ] = _UNSET, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler_cls: Type[ FalconOpenAPIErrorsHandler ] = FalconOpenAPIErrorsHandler, - **unmarshaller_kwargs: Any, - ) -> "FalconOpenAPIMiddleware": - config = Config( + **kwargs: Any, + ) -> "FalconWSGIOpenAPIMiddleware": + openapi = OpenAPI.build( + spec, request_unmarshaller_cls=request_unmarshaller_cls, response_unmarshaller_cls=response_unmarshaller_cls, ) - openapi = OpenAPI(spec, config=config) return cls( openapi, - request_unmarshaller_cls=request_unmarshaller_cls, - response_unmarshaller_cls=response_unmarshaller_cls, request_cls=request_cls, response_cls=response_cls, errors_handler_cls=errors_handler_cls, - **unmarshaller_kwargs, + **kwargs, ) def process_request(self, req: Request, resp: Response) -> None: @@ -86,3 +90,200 @@ def process_response( ) -> None: errors_handler = self.errors_handler_cls(req, resp) self.handle_response(req, resp, errors_handler) + + +class FalconASGIOpenAPIMiddleware(AsyncFalconIntegration): + """OpenAPI middleware for Falcon ASGI applications. + + This class wires Falcon's asynchronous middleware hooks to the + asynchronous OpenAPI integration. + """ + + valid_request_handler_cls = FalconOpenAPIValidRequestHandler + errors_handler_cls: Type[FalconOpenAPIErrorsHandler] = ( + FalconOpenAPIErrorsHandler + ) + + def __init__( + self, + openapi: OpenAPI, + request_cls: Type[FalconAsgiOpenAPIRequest] = FalconAsgiOpenAPIRequest, + response_cls: Optional[Type[FalconAsgiOpenAPIResponse]] = ( + FalconAsgiOpenAPIResponse + ), + errors_handler_cls: Type[ + FalconOpenAPIErrorsHandler + ] = FalconOpenAPIErrorsHandler, + **kwargs: Any, + ): + super().__init__( + openapi, + request_cls=request_cls, + response_cls=response_cls, + ) + self.errors_handler_cls = errors_handler_cls or self.errors_handler_cls + + @classmethod + def from_spec( + cls, + spec: SchemaPath, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, + request_cls: Type[FalconAsgiOpenAPIRequest] = FalconAsgiOpenAPIRequest, + response_cls: Optional[Type[FalconAsgiOpenAPIResponse]] = ( + FalconAsgiOpenAPIResponse + ), + errors_handler_cls: Type[ + FalconOpenAPIErrorsHandler + ] = FalconOpenAPIErrorsHandler, + **kwargs: Any, + ) -> "FalconASGIOpenAPIMiddleware": + openapi = OpenAPI.build( + spec, + request_unmarshaller_cls=request_unmarshaller_cls, + response_unmarshaller_cls=response_unmarshaller_cls, + ) + return cls( + openapi, + request_cls=request_cls, + response_cls=response_cls, + errors_handler_cls=errors_handler_cls, + **kwargs, + ) + + async def process_request_async( + self, req: Request, resp: Response + ) -> None: + errors_handler = self.errors_handler_cls(req, resp) + valid_request_handler = self.valid_request_handler_cls(req, resp) + + async def async_valid_request_handler( + request_unmarshal_result: RequestUnmarshalResult, + ) -> Response: + return valid_request_handler(request_unmarshal_result) + + await self.handle_request( + req, + async_valid_request_handler, + errors_handler, + ) + + async def process_response_async( + self, + req: Request, + resp: Response, + resource: Any, + req_succeeded: bool, + ) -> None: + errors_handler = self.errors_handler_cls(req, resp) + await self.handle_response(req, resp, errors_handler) + + +class FalconOpenAPIMiddleware: + """OpenAPI middleware compatible with both WSGI and ASGI Falcon apps. + + This class delegates to transport-specific middleware implementations: + :class:`FalconWSGIOpenAPIMiddleware` for sync hooks and + :class:`FalconASGIOpenAPIMiddleware` for async hooks. + """ + + def __init__( + self, + openapi: OpenAPI, + request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, + response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, + request_async_cls: Any = _DEFAULT_ASYNC, + response_async_cls: Any = _DEFAULT_ASYNC, + errors_handler_cls: Type[ + FalconOpenAPIErrorsHandler + ] = FalconOpenAPIErrorsHandler, + **kwargs: Any, + ): + if request_async_cls is _DEFAULT_ASYNC: + request_async_cls = FalconAsgiOpenAPIRequest + if response_async_cls is _DEFAULT_ASYNC: + response_async_cls = ( + FalconAsgiOpenAPIResponse if response_cls is not None else None + ) + + self.wsgi_middleware = FalconWSGIOpenAPIMiddleware( + openapi, + request_cls=request_cls, + response_cls=response_cls, + errors_handler_cls=errors_handler_cls, + **kwargs, + ) + self.asgi_middleware = FalconASGIOpenAPIMiddleware( + openapi, + request_cls=cast( + Type[FalconAsgiOpenAPIRequest], request_async_cls + ), + response_cls=cast( + Optional[Type[FalconAsgiOpenAPIResponse]], + response_async_cls, + ), + errors_handler_cls=errors_handler_cls, + **kwargs, + ) + + @classmethod + def from_spec( + cls, + spec: SchemaPath, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, + request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, + response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, + request_async_cls: Any = _DEFAULT_ASYNC, + response_async_cls: Any = _DEFAULT_ASYNC, + errors_handler_cls: Type[ + FalconOpenAPIErrorsHandler + ] = FalconOpenAPIErrorsHandler, + **kwargs: Any, + ) -> "FalconOpenAPIMiddleware": + openapi = OpenAPI.build( + spec, + request_unmarshaller_cls=request_unmarshaller_cls, + response_unmarshaller_cls=response_unmarshaller_cls, + ) + return cls( + openapi, + request_cls=request_cls, + response_cls=response_cls, + request_async_cls=request_async_cls, + response_async_cls=response_async_cls, + errors_handler_cls=errors_handler_cls, + **kwargs, + ) + + def process_request(self, req: Request, resp: Response) -> None: + self.wsgi_middleware.process_request(req, resp) + + def process_response( + self, req: Request, resp: Response, resource: Any, req_succeeded: bool + ) -> None: + self.wsgi_middleware.process_response( + req, + resp, + resource, + req_succeeded, + ) + + async def process_request_async( + self, req: Request, resp: Response + ) -> None: + await self.asgi_middleware.process_request_async(req, resp) + + async def process_response_async( + self, + req: Request, + resp: Response, + resource: Any, + req_succeeded: bool, + ) -> None: + await self.asgi_middleware.process_response_async( + req, + resp, + resource, + req_succeeded, + ) diff --git a/openapi_core/contrib/falcon/requests.py b/openapi_core/contrib/falcon/requests.py index 586bd82d..9c59af31 100644 --- a/openapi_core/contrib/falcon/requests.py +++ b/openapi_core/contrib/falcon/requests.py @@ -1,19 +1,22 @@ """OpenAPI core contrib falcon responses module""" -import warnings from json import dumps from typing import Any from typing import Dict from typing import Optional +from typing import cast from falcon.request import Request from falcon.request import RequestOptions from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict +from openapi_core.contrib.falcon.util import serialize_body from openapi_core.contrib.falcon.util import unpack_params from openapi_core.datatypes import RequestParameters +_BODY_NOT_SET = object() + class FalconOpenAPIRequest: def __init__( @@ -27,6 +30,7 @@ def __init__( if default_when_empty is None: default_when_empty = {} self.default_when_empty = default_when_empty + self._body: Any = _BODY_NOT_SET # Path gets deduced by path finder against spec self.parameters = RequestParameters( @@ -52,30 +56,23 @@ def method(self) -> str: @property def body(self) -> Optional[bytes]: + if self._body is not _BODY_NOT_SET: + return cast(Optional[bytes], self._body) + # Falcon doesn't store raw request stream. # That's why we need to revert deserialized data # Support falcon-jsonify. - if hasattr(self.request, "json"): - return dumps(self.request.json).encode("utf-8") + request_json = getattr(cast(Any, self.request), "json", None) + if request_json is not None: + self._body = dumps(request_json).encode("utf-8") + return cast(Optional[bytes], self._body) media = self.request.get_media( default_when_empty=self.default_when_empty, ) - handler, _, _ = self.request.options.media_handlers._resolve( - self.request.content_type, self.request.options.default_media_type - ) - try: - body = handler.serialize(media, content_type=self.content_type) - # multipart form serialization is not supported - except NotImplementedError: - warnings.warn( - f"body serialization for {self.request.content_type} not supported" - ) - return None - else: - assert isinstance(body, bytes) - return body + self._body = serialize_body(self.request, media, self.content_type) + return cast(Optional[bytes], self._body) @property def content_type(self) -> str: @@ -86,3 +83,21 @@ def content_type(self) -> str: assert isinstance(self.request.options, RequestOptions) assert isinstance(self.request.options.default_media_type, str) return self.request.options.default_media_type + + +class FalconAsgiOpenAPIRequest(FalconOpenAPIRequest): + @classmethod + async def from_request( + cls, + request: Request, + default_when_empty: Optional[Dict[Any, Any]] = None, + ) -> "FalconAsgiOpenAPIRequest": + instance = cls( + request, + default_when_empty=default_when_empty, + ) + media = await request.get_media( + default_when_empty=instance.default_when_empty + ) + instance._body = serialize_body(request, media, instance.content_type) + return instance diff --git a/openapi_core/contrib/falcon/responses.py b/openapi_core/contrib/falcon/responses.py index 22bdb81a..a6a74484 100644 --- a/openapi_core/contrib/falcon/responses.py +++ b/openapi_core/contrib/falcon/responses.py @@ -1,8 +1,11 @@ """OpenAPI core contrib falcon responses module""" +import inspect from io import BytesIO from itertools import tee +from typing import Any from typing import Iterable +from typing import List from falcon.response import Response from werkzeug.datastructures import Headers @@ -48,3 +51,82 @@ def content_type(self) -> str: @property def headers(self) -> Headers: return Headers(self.response.headers) + + +class FalconAsgiOpenAPIResponse(FalconOpenAPIResponse): + def __init__(self, response: Response, data: bytes): + super().__init__(response) + self._data = data + + @classmethod + async def from_response( + cls, + response: Any, + ) -> "FalconAsgiOpenAPIResponse": + data = await cls._get_asgi_response_data(response) + return cls(response, data=data) + + @classmethod + async def _get_asgi_response_data(cls, response: Any) -> bytes: + response_any = response + stream = response_any.stream + if stream is None: + data = await response_any.render_body() + if data is None: + return b"" + assert isinstance(data, bytes) + return data + + charset = getattr(response_any, "charset", None) or "utf-8" + chunks: List[bytes] = [] + stream_any = stream + + if hasattr(stream_any, "__aiter__"): + async for chunk in stream_any: + if chunk is None: + break + if not isinstance(chunk, bytes): + chunk = chunk.encode(charset) + chunks.append(chunk) + elif hasattr(stream_any, "read"): + while True: + chunk = stream_any.read() + if inspect.isawaitable(chunk): + chunk = await chunk + if not chunk: + break + if not isinstance(chunk, bytes): + chunk = chunk.encode(charset) + chunks.append(chunk) + elif isinstance(stream_any, Iterable): + response_iter1, response_iter2 = tee(stream_any) + response_any.stream = response_iter1 + for chunk in response_iter2: + if not isinstance(chunk, bytes): + chunk = chunk.encode(charset) + chunks.append(chunk) + return b"".join(chunks) + + response_any.stream = _AsyncChunksIterator(chunks) + return b"".join(chunks) + + @property + def data(self) -> bytes: + return self._data + + +class _AsyncChunksIterator: + def __init__(self, chunks: List[bytes]): + self._chunks = chunks + self._index = 0 + + def __aiter__(self) -> "_AsyncChunksIterator": + return self + + async def __anext__(self) -> bytes: + if self._index >= len(self._chunks): + raise StopAsyncIteration + + chunk = self._chunks[self._index] + self._index += 1 + return chunk diff --git a/openapi_core/contrib/falcon/util.py b/openapi_core/contrib/falcon/util.py index aa8725a0..83fa43e5 100644 --- a/openapi_core/contrib/falcon/util.py +++ b/openapi_core/contrib/falcon/util.py @@ -1,8 +1,32 @@ +import warnings from typing import Any from typing import Generator from typing import Mapping +from typing import Optional from typing import Tuple +from falcon.request import Request + + +def serialize_body( + request: Request, + media: Any, + content_type: str, +) -> Optional[bytes]: + """Serialize request body using media handlers.""" + handler, _, _ = request.options.media_handlers._resolve( + content_type, + request.options.default_media_type, + ) + try: + body = handler.serialize(media, content_type=content_type) + # multipart form serialization is not supported + except NotImplementedError: + warnings.warn(f"body serialization for {content_type} not supported") + return None + assert isinstance(body, bytes) + return body + def unpack_params( params: Mapping[str, Any], diff --git a/tests/integration/contrib/aiohttp/conftest.py b/tests/integration/contrib/aiohttp/conftest.py index ead341a5..56a0c9c2 100644 --- a/tests/integration/contrib/aiohttp/conftest.py +++ b/tests/integration/contrib/aiohttp/conftest.py @@ -1,9 +1,11 @@ import asyncio import pathlib +from collections.abc import AsyncGenerator from typing import Any from unittest import mock import pytest +import pytest_asyncio from aiohttp import web from aiohttp.test_utils import TestClient @@ -114,6 +116,10 @@ def app(router): return app -@pytest.fixture -async def client(app, aiohttp_client) -> TestClient: - return await aiohttp_client(app) +@pytest_asyncio.fixture +async def client(app, aiohttp_client) -> AsyncGenerator[TestClient, None]: + test_client = await aiohttp_client(app) + try: + yield test_client + finally: + await test_client.close() diff --git a/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/__main__.py b/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/__main__.py index 13109d64..f6e9d9f3 100644 --- a/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/__main__.py +++ b/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/__main__.py @@ -6,8 +6,8 @@ ] -def get_app(loop=None): - app = web.Application(loop=loop) +def get_app(): + app = web.Application() app.add_routes(routes) return app diff --git a/tests/integration/contrib/aiohttp/test_aiohttp_project.py b/tests/integration/contrib/aiohttp/test_aiohttp_project.py index 54f7297d..723b309a 100644 --- a/tests/integration/contrib/aiohttp/test_aiohttp_project.py +++ b/tests/integration/contrib/aiohttp/test_aiohttp_project.py @@ -1,9 +1,11 @@ import os import sys from base64 import b64encode +from collections.abc import AsyncGenerator from io import BytesIO import pytest +import pytest_asyncio @pytest.fixture(autouse=True, scope="session") @@ -22,9 +24,13 @@ def app(project_setup): return get_app() -@pytest.fixture -async def client(app, aiohttp_client): - return await aiohttp_client(app) +@pytest_asyncio.fixture +async def client(app, aiohttp_client) -> AsyncGenerator: + test_client = await aiohttp_client(app) + try: + yield test_client + finally: + await test_client.close() class BaseTestPetstore: diff --git a/tests/integration/contrib/falcon/test_falcon_asgi_middleware.py b/tests/integration/contrib/falcon/test_falcon_asgi_middleware.py new file mode 100644 index 00000000..eeca2e17 --- /dev/null +++ b/tests/integration/contrib/falcon/test_falcon_asgi_middleware.py @@ -0,0 +1,214 @@ +from json import dumps +from pathlib import Path +from typing import Any +from typing import cast + +import pytest +import yaml +from falcon import status_codes +from falcon.asgi import App +from falcon.asgi import Response +from falcon.constants import MEDIA_JSON +from falcon.testing import ASGIConductor +from jsonschema_path import SchemaPath + +from openapi_core.contrib.falcon.middlewares import FalconASGIOpenAPIMiddleware +from openapi_core.contrib.falcon.middlewares import FalconOpenAPIMiddleware +from openapi_core.contrib.falcon.requests import FalconAsgiOpenAPIRequest +from openapi_core.contrib.falcon.responses import FalconAsgiOpenAPIResponse +from openapi_core.contrib.falcon.util import serialize_body + + +@pytest.fixture +def spec(): + openapi_spec_path = Path("tests/integration/data/v3.0/petstore.yaml") + spec_dict = yaml.load(openapi_spec_path.read_text(), yaml.Loader) + return SchemaPath.from_dict(spec_dict) + + +class PetListResource: + async def on_get(self, req, resp): + assert req.context.openapi + assert not req.context.openapi.errors + resp.status = status_codes.HTTP_200 + resp.content_type = MEDIA_JSON + resp.text = dumps( + { + "data": [ + { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + } + ] + } + ) + resp.set_header("X-Rate-Limit", "12") + + +class InvalidPetListResource: + async def on_get(self, req, resp): + assert req.context.openapi + assert not req.context.openapi.errors + resp.status = status_codes.HTTP_200 + resp.content_type = MEDIA_JSON + resp.text = dumps({"data": [{"id": "12", "name": 13}]}) + resp.set_header("X-Rate-Limit", "12") + + +class _AsyncStream: + def __init__(self, chunks): + self._chunks = chunks + self._index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._index >= len(self._chunks): + raise StopAsyncIteration + + chunk = self._chunks[self._index] + self._index += 1 + return chunk + + +@pytest.mark.asyncio +async def test_dual_mode_sync_middleware_works_with_asgi_app(spec): + middleware = FalconOpenAPIMiddleware.from_spec(spec) + app = App(middleware=[middleware]) + app.add_route("/v1/pets", PetListResource()) + + async with ASGIConductor(app) as conductor: + with pytest.warns(DeprecationWarning): + response = await conductor.simulate_get( + "/v1/pets", + host="petstore.swagger.io", + query_string="limit=12", + ) + + assert response.status_code == 200 + assert response.json == { + "data": [ + { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + } + ] + } + + +@pytest.mark.asyncio +async def test_explicit_asgi_middleware_handles_request_validation(spec): + middleware = FalconASGIOpenAPIMiddleware.from_spec(spec) + app = App(middleware=[middleware]) + app.add_route("/v1/pets", PetListResource()) + + async with ASGIConductor(app) as conductor: + with pytest.warns(DeprecationWarning): + response = await conductor.simulate_get( + "/v1/pets", + host="petstore.swagger.io", + ) + + assert response.status_code == 400 + assert response.json == { + "errors": [ + { + "type": ( + "" + ), + "status": 400, + "title": "Missing required query parameter: limit", + } + ] + } + + +@pytest.mark.asyncio +async def test_explicit_asgi_middleware_validates_response(spec): + middleware = FalconASGIOpenAPIMiddleware.from_spec(spec) + app = App(middleware=[middleware]) + app.add_route("/v1/pets", InvalidPetListResource()) + + async with ASGIConductor(app) as conductor: + with pytest.warns(DeprecationWarning): + response = await conductor.simulate_get( + "/v1/pets", + host="petstore.swagger.io", + query_string="limit=12", + ) + + assert response.status_code == 400 + assert "errors" in response.json + + +@pytest.mark.asyncio +async def test_asgi_response_adapter_handles_stream_without_charset(): + chunks = [ + b'{"data": [', + b'{"id": 12, "name": "Cat", "ears": {"healthy": true}}', + b"]}", + ] + response = Response() + response.content_type = MEDIA_JSON + response.stream = _AsyncStream(chunks) + + openapi_response = await FalconAsgiOpenAPIResponse.from_response(response) + + assert openapi_response.data == b"".join(chunks) + assert response.stream is not None + + replayed_chunks = [] + async for chunk in response.stream: + replayed_chunks.append(chunk) + assert b"".join(replayed_chunks) == b"".join(chunks) + + +def test_asgi_request_body_cached_none_skips_media_deserialization(): + class _DummyRequest: + def get_media(self, *args, **kwargs): + raise AssertionError("get_media should not be called") + + openapi_request = object.__new__(FalconAsgiOpenAPIRequest) + openapi_request.request = cast(Any, _DummyRequest()) + openapi_request._body = None + + assert openapi_request.body is None + + +def test_multipart_unsupported_serialization_warns_and_returns_none(): + content_type = "multipart/form-data; boundary=test" + + class _DummyHandler: + def serialize(self, media, content_type): + raise NotImplementedError( + "multipart form serialization unsupported" + ) + + class _DummyMediaHandlers: + def _resolve(self, content_type, default_media_type): + return (_DummyHandler(), content_type, None) + + class _DummyOptions: + media_handlers = _DummyMediaHandlers() + default_media_type = MEDIA_JSON + + class _DummyRequest: + options = _DummyOptions() + + with pytest.warns( + UserWarning, + match="body serialization for multipart/form-data", + ): + body = serialize_body( + cast(Any, _DummyRequest()), {"name": "Cat"}, content_type + ) + + assert body is None diff --git a/tests/integration/contrib/falcon/test_falcon_wsgi_middleware.py b/tests/integration/contrib/falcon/test_falcon_wsgi_middleware.py new file mode 100644 index 00000000..d130ebc5 --- /dev/null +++ b/tests/integration/contrib/falcon/test_falcon_wsgi_middleware.py @@ -0,0 +1,68 @@ +from json import dumps +from pathlib import Path + +import pytest +import yaml +from falcon import App +from falcon.constants import MEDIA_JSON +from falcon.status_codes import HTTP_200 +from falcon.testing import TestClient +from jsonschema_path import SchemaPath + +from openapi_core.contrib.falcon.middlewares import FalconWSGIOpenAPIMiddleware + + +@pytest.fixture +def spec(): + openapi_spec_path = Path("tests/integration/data/v3.0/petstore.yaml") + spec_dict = yaml.load(openapi_spec_path.read_text(), yaml.Loader) + return SchemaPath.from_dict(spec_dict) + + +class PetListResource: + def on_get(self, req, resp): + assert req.context.openapi + assert not req.context.openapi.errors + resp.status = HTTP_200 + resp.content_type = MEDIA_JSON + resp.text = dumps( + { + "data": [ + { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + } + ] + } + ) + resp.set_header("X-Rate-Limit", "12") + + +def test_explicit_wsgi_middleware_works(spec): + middleware = FalconWSGIOpenAPIMiddleware.from_spec(spec) + app = App(middleware=[middleware]) + app.add_route("/v1/pets", PetListResource()) + client = TestClient(app) + + with pytest.warns(DeprecationWarning): + response = client.simulate_get( + "/v1/pets", + host="petstore.swagger.io", + query_string="limit=12", + ) + + assert response.status_code == 200 + assert response.json == { + "data": [ + { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + } + ] + } diff --git a/tests/integration/contrib/fastapi/test_fastapi_project.py b/tests/integration/contrib/fastapi/test_fastapi_project.py index 242613bc..c43e0950 100644 --- a/tests/integration/contrib/fastapi/test_fastapi_project.py +++ b/tests/integration/contrib/fastapi/test_fastapi_project.py @@ -24,7 +24,8 @@ def app(): @pytest.fixture def client(app): - return TestClient(app, base_url="http://petstore.swagger.io") + with TestClient(app, base_url="http://petstore.swagger.io") as test_client: + yield test_client class BaseTestPetstore: diff --git a/tests/integration/contrib/starlette/test_starlette_project.py b/tests/integration/contrib/starlette/test_starlette_project.py index 9ee65c06..0207316d 100644 --- a/tests/integration/contrib/starlette/test_starlette_project.py +++ b/tests/integration/contrib/starlette/test_starlette_project.py @@ -26,7 +26,11 @@ def app(self): @pytest.fixture def client(self, app): - return TestClient(app, base_url="http://petstore.swagger.io") + with TestClient( + app, + base_url="http://petstore.swagger.io", + ) as test_client: + yield test_client @property def api_key_encoded(self): @@ -45,7 +49,11 @@ def app(self): @pytest.fixture def client(self, app): - return TestClient(app, base_url="http://petstore.swagger.io") + with TestClient( + app, + base_url="http://petstore.swagger.io", + ) as test_client: + yield test_client class TestPetListEndpoint(BaseTestPetstore): diff --git a/tests/integration/contrib/starlette/test_starlette_validation.py b/tests/integration/contrib/starlette/test_starlette_validation.py index 6bebcfbb..03c42b63 100644 --- a/tests/integration/contrib/starlette/test_starlette_validation.py +++ b/tests/integration/contrib/starlette/test_starlette_validation.py @@ -43,7 +43,8 @@ async def test_route(scope, receive, send): @pytest.fixture def client(self, app): - return TestClient(app, base_url="http://localhost") + with TestClient(app, base_url="http://localhost") as test_client: + yield test_client def test_request_validator_path_pattern(self, client, schema_path): response_data = {"data": "data"} @@ -65,18 +66,18 @@ async def test_route(request): Route("/browse/12/", test_route, methods=["POST"]), ], ) - client = TestClient(app, base_url="http://localhost") - query_string = { - "q": "string", - } - headers = {"content-type": "application/json"} - data = {"param1": 1} - response = client.post( - "/browse/12/", - params=query_string, - json=data, - headers=headers, - ) + with TestClient(app, base_url="http://localhost") as client: + query_string = { + "q": "string", + } + headers = {"content-type": "application/json"} + data = {"param1": 1} + response = client.post( + "/browse/12/", + params=query_string, + json=data, + headers=headers, + ) assert response.status_code == 200 assert response.json() == response_data @@ -104,18 +105,18 @@ def test_route(request): Route("/browse/12/", test_route, methods=["POST"]), ], ) - client = TestClient(app, base_url="http://localhost") - query_string = { - "q": "string", - } - headers = {"content-type": "application/json"} - data = {"param1": 1} - response = client.post( - "/browse/12/", - params=query_string, - json=data, - headers=headers, - ) + with TestClient(app, base_url="http://localhost") as client: + query_string = { + "q": "string", + } + headers = {"content-type": "application/json"} + data = {"param1": 1} + response = client.post( + "/browse/12/", + params=query_string, + json=data, + headers=headers, + ) assert response.status_code == 200 assert response.json() == response_data