Skip to content

Commit 751d829

Browse files
committed
Avoid async-generator finalization side effects
1 parent 7fdd8d6 commit 751d829

File tree

3 files changed

+116
-13
lines changed

3 files changed

+116
-13
lines changed

openapi_core/contrib/falcon/requests.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from openapi_core.contrib.falcon.util import unpack_params
1616
from openapi_core.datatypes import RequestParameters
1717

18+
_BODY_NOT_SET = object()
19+
1820

1921
class FalconOpenAPIRequest:
2022
def __init__(
@@ -28,7 +30,7 @@ def __init__(
2830
if default_when_empty is None:
2931
default_when_empty = {}
3032
self.default_when_empty = default_when_empty
31-
self._body: Optional[bytes] = None
33+
self._body: Any = _BODY_NOT_SET
3234

3335
# Path gets deduced by path finder against spec
3436
self.parameters = RequestParameters(
@@ -54,21 +56,23 @@ def method(self) -> str:
5456

5557
@property
5658
def body(self) -> Optional[bytes]:
57-
if self._body is not None:
58-
return self._body
59+
if self._body is not _BODY_NOT_SET:
60+
return cast(Optional[bytes], self._body)
5961

6062
# Falcon doesn't store raw request stream.
6163
# That's why we need to revert deserialized data
6264

6365
# Support falcon-jsonify.
6466
request_json = getattr(cast(Any, self.request), "json", None)
6567
if request_json is not None:
66-
return dumps(request_json).encode("utf-8")
68+
self._body = dumps(request_json).encode("utf-8")
69+
return cast(Optional[bytes], self._body)
6770

6871
media = self.request.get_media(
6972
default_when_empty=self.default_when_empty,
7073
)
71-
return serialize_body(self.request, media, self.content_type)
74+
self._body = serialize_body(self.request, media, self.content_type)
75+
return cast(Optional[bytes], self._body)
7276

7377
@property
7478
def content_type(self) -> str:

openapi_core/contrib/falcon/responses.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from io import BytesIO
55
from itertools import tee
66
from typing import Any
7-
from typing import AsyncIterator
87
from typing import Iterable
98
from typing import List
109

@@ -78,7 +77,7 @@ async def _get_asgi_response_data(cls, response: Any) -> bytes:
7877
assert isinstance(data, bytes)
7978
return data
8079

81-
charset = response_any.charset or "utf-8"
80+
charset = getattr(response_any, "charset", None) or "utf-8"
8281
chunks: List[bytes] = []
8382
stream_any = stream
8483

@@ -108,14 +107,26 @@ async def _get_asgi_response_data(cls, response: Any) -> bytes:
108107
chunks.append(chunk)
109108
return b"".join(chunks)
110109

111-
response_any.stream = cls._iter_chunks(chunks)
110+
response_any.stream = _AsyncChunksIterator(chunks)
112111
return b"".join(chunks)
113112

114-
@staticmethod
115-
async def _iter_chunks(chunks: Iterable[bytes]) -> AsyncIterator[bytes]:
116-
for chunk in chunks:
117-
yield chunk
118-
119113
@property
120114
def data(self) -> bytes:
121115
return self._data
116+
117+
118+
class _AsyncChunksIterator:
119+
def __init__(self, chunks: List[bytes]):
120+
self._chunks = chunks
121+
self._index = 0
122+
123+
def __aiter__(self) -> "_AsyncChunksIterator":
124+
return self
125+
126+
async def __anext__(self) -> bytes:
127+
if self._index >= len(self._chunks):
128+
raise StopAsyncIteration
129+
130+
chunk = self._chunks[self._index]
131+
self._index += 1
132+
return chunk

tests/integration/contrib/falcon/test_falcon_asgi_middleware.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from json import dumps
22
from pathlib import Path
3+
from typing import Any
4+
from typing import cast
35

46
import pytest
57
import yaml
68
from falcon import status_codes
79
from falcon.asgi import App
10+
from falcon.asgi import Response
811
from falcon.constants import MEDIA_JSON
912
from falcon.testing import TestClient
1013
from jsonschema_path import SchemaPath
1114

1215
from openapi_core.contrib.falcon.middlewares import FalconASGIOpenAPIMiddleware
1316
from openapi_core.contrib.falcon.middlewares import FalconOpenAPIMiddleware
17+
from openapi_core.contrib.falcon.requests import FalconAsgiOpenAPIRequest
18+
from openapi_core.contrib.falcon.responses import FalconAsgiOpenAPIResponse
19+
from openapi_core.contrib.falcon.util import serialize_body
1420

1521

1622
@pytest.fixture
@@ -52,6 +58,23 @@ async def on_get(self, req, resp):
5258
resp.set_header("X-Rate-Limit", "12")
5359

5460

61+
class _AsyncStream:
62+
def __init__(self, chunks):
63+
self._chunks = chunks
64+
self._index = 0
65+
66+
def __aiter__(self):
67+
return self
68+
69+
async def __anext__(self):
70+
if self._index >= len(self._chunks):
71+
raise StopAsyncIteration
72+
73+
chunk = self._chunks[self._index]
74+
self._index += 1
75+
return chunk
76+
77+
5578
def test_dual_mode_sync_middleware_works_with_asgi_app(spec):
5679
middleware = FalconOpenAPIMiddleware.from_spec(spec)
5780
app = App(middleware=[middleware])
@@ -121,3 +144,68 @@ def test_explicit_asgi_middleware_validates_response(spec):
121144

122145
assert response.status_code == 400
123146
assert "errors" in response.json
147+
148+
149+
@pytest.mark.asyncio
150+
async def test_asgi_response_adapter_handles_stream_without_charset():
151+
chunks = [
152+
b'{"data": [',
153+
b'{"id": 12, "name": "Cat", "ears": {"healthy": true}}',
154+
b"]}",
155+
]
156+
response = Response()
157+
response.content_type = MEDIA_JSON
158+
response.stream = _AsyncStream(chunks)
159+
160+
openapi_response = await FalconAsgiOpenAPIResponse.from_response(response)
161+
162+
assert openapi_response.data == b"".join(chunks)
163+
assert response.stream is not None
164+
165+
replayed_chunks = []
166+
async for chunk in response.stream:
167+
replayed_chunks.append(chunk)
168+
assert b"".join(replayed_chunks) == b"".join(chunks)
169+
170+
171+
def test_asgi_request_body_cached_none_skips_media_deserialization():
172+
class _DummyRequest:
173+
def get_media(self, *args, **kwargs):
174+
raise AssertionError("get_media should not be called")
175+
176+
openapi_request = object.__new__(FalconAsgiOpenAPIRequest)
177+
openapi_request.request = cast(Any, _DummyRequest())
178+
openapi_request._body = None
179+
180+
assert openapi_request.body is None
181+
182+
183+
def test_multipart_unsupported_serialization_warns_and_returns_none():
184+
content_type = "multipart/form-data; boundary=test"
185+
186+
class _DummyHandler:
187+
def serialize(self, media, content_type):
188+
raise NotImplementedError(
189+
"multipart form serialization unsupported"
190+
)
191+
192+
class _DummyMediaHandlers:
193+
def _resolve(self, content_type, default_media_type):
194+
return (_DummyHandler(), content_type, None)
195+
196+
class _DummyOptions:
197+
media_handlers = _DummyMediaHandlers()
198+
default_media_type = MEDIA_JSON
199+
200+
class _DummyRequest:
201+
options = _DummyOptions()
202+
203+
with pytest.warns(
204+
UserWarning,
205+
match="body serialization for multipart/form-data",
206+
):
207+
body = serialize_body(
208+
cast(Any, _DummyRequest()), {"name": "Cat"}, content_type
209+
)
210+
211+
assert body is None

0 commit comments

Comments
 (0)