Skip to content

Commit 64f2a11

Browse files
committed
use asgi_tools for WS responses
1 parent a85e635 commit 64f2a11

File tree

1 file changed

+50
-44
lines changed

1 file changed

+50
-44
lines changed

src/reactpy/asgi/middleware.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any
1212

1313
import orjson
14+
from asgi_tools import ResponseWebSocket
1415
from asgiref import typing as asgi_types
1516
from asgiref.compatibility import guarantee_single_callable
1617
from servestatic import ServeStaticASGI
@@ -26,6 +27,8 @@
2627
AsgiHttpApp,
2728
AsgiLifespanApp,
2829
AsgiWebsocketApp,
30+
AsgiWebsocketReceive,
31+
AsgiWebsocketSend,
2932
Connection,
3033
Location,
3134
ReactPyConfig,
@@ -153,48 +156,52 @@ async def __call__(
153156
send: asgi_types.ASGISendCallable,
154157
) -> None:
155158
"""ASGI app for rendering ReactPy Python components."""
156-
dispatcher: asyncio.Task[Any] | None = None
157-
recv_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
158-
159159
# Start a loop that handles ASGI websocket events
160-
while True:
161-
event = await receive()
162-
if event["type"] == "websocket.connect":
163-
await send(
164-
{"type": "websocket.accept", "subprotocol": None, "headers": []}
165-
)
166-
dispatcher = asyncio.create_task(
167-
self.run_dispatcher(scope, receive, send, recv_queue)
168-
)
169-
170-
elif event["type"] == "websocket.disconnect":
171-
if dispatcher:
172-
dispatcher.cancel()
173-
break
174-
175-
elif event["type"] == "websocket.receive" and event["text"]:
176-
msg = orjson.loads(event["text"])
177-
msg_type = msg.get("type")
178-
if msg_type == "layout-event":
179-
queue_put_func = recv_queue.put(msg)
180-
await queue_put_func
181-
else:
182-
await asyncio.to_thread(
183-
_logger.warning, f"Unknown message type: {msg_type}"
184-
)
185-
186-
async def run_dispatcher(
160+
async with ReactPyWebsocket(scope, receive, send, parent=self.parent) as ws: # type: ignore
161+
while True:
162+
event: dict[str, Any] = await ws.receive(raw=True) # type: ignore
163+
if event["type"] == "websocket.receive" and event["text"]:
164+
msg: dict[str, str] = orjson.loads(event["text"])
165+
if msg.get("type") == "layout-event":
166+
queue_put_func = ws.recv_queue.put(msg)
167+
await queue_put_func
168+
else:
169+
await asyncio.to_thread(
170+
_logger.warning, f"Unknown message type: {msg.get('type')}"
171+
)
172+
elif event["type"] == "websocket.disconnect":
173+
break
174+
175+
176+
class ReactPyWebsocket(ResponseWebSocket):
177+
def __init__(
187178
self,
188179
scope: asgi_types.WebSocketScope,
189-
receive: asgi_types.ASGIReceiveCallable,
190-
send: asgi_types.ASGISendCallable,
191-
recv_queue: asyncio.Queue[dict[str, Any]],
180+
receive: AsgiWebsocketReceive,
181+
send: AsgiWebsocketSend,
182+
parent: ReactPyMiddleware,
192183
) -> None:
184+
super().__init__(scope=scope, receive=receive, send=send) # type: ignore
185+
self.scope = scope
186+
self.parent = parent
187+
self.recv_queue: asyncio.Queue[dict[str, str]] = asyncio.Queue()
188+
self.dispatcher: asyncio.Task[Any] | None = None
189+
190+
async def __aenter__(self) -> ReactPyWebsocket:
191+
self.dispatcher = asyncio.create_task(self.run_dispatcher())
192+
return await super().__aenter__() # type: ignore
193+
194+
async def __aexit__(self, *_: Any) -> None:
195+
if self.dispatcher:
196+
self.dispatcher.cancel()
197+
await super().__aexit__() # type: ignore
198+
199+
async def run_dispatcher(self) -> None:
193200
"""Asyncio background task that renders and transmits layout updates of ReactPy components."""
194201
try:
195202
# Determine component to serve by analyzing the URL and/or class parameters.
196203
if self.parent.multiple_root_components:
197-
url_match = re.match(self.parent.dispatcher_pattern, scope["path"])
204+
url_match = re.match(self.parent.dispatcher_pattern, self.scope["path"])
198205
if not url_match: # pragma: no cover
199206
raise RuntimeError("Could not find component in URL path.")
200207
dotted_path = url_match["dotted_path"]
@@ -210,10 +217,10 @@ async def run_dispatcher(
210217

211218
# Create a connection object by analyzing the websocket's query string.
212219
ws_query_string = urllib.parse.parse_qs(
213-
scope["query_string"].decode(), strict_parsing=True
220+
self.scope["query_string"].decode(), strict_parsing=True
214221
)
215222
connection = Connection(
216-
scope=scope,
223+
scope=self.scope,
217224
location=Location(
218225
path=ws_query_string.get("http_pathname", [""])[0],
219226
query_string=ws_query_string.get("http_query_string", [""])[0],
@@ -224,20 +231,19 @@ async def run_dispatcher(
224231
# Start the ReactPy component rendering loop
225232
await serve_layout(
226233
Layout(ConnectionContext(component(), value=connection)),
227-
lambda msg: send(
228-
{
229-
"type": "websocket.send",
230-
"text": orjson.dumps(msg).decode(),
231-
"bytes": None,
232-
}
233-
),
234-
recv_queue.get, # type: ignore
234+
self.send_json,
235+
self.recv_queue.get, # type: ignore
235236
)
236237

237238
# Manually log exceptions since this function is running in a separate asyncio task.
238239
except Exception as error:
239240
await asyncio.to_thread(_logger.error, f"{error}\n{traceback.format_exc()}")
240241

242+
async def send_json(self, data: Any) -> None:
243+
return await self._send(
244+
{"type": "websocket.send", "text": orjson.dumps(data).decode()}
245+
)
246+
241247

242248
@dataclass
243249
class StaticFileApp:

0 commit comments

Comments
 (0)