1111from typing import Any
1212
1313import orjson
14+ from asgi_tools import ResponseWebSocket
1415from asgiref import typing as asgi_types
1516from asgiref .compatibility import guarantee_single_callable
1617from servestatic import ServeStaticASGI
2627 AsgiHttpApp ,
2728 AsgiLifespanApp ,
2829 AsgiWebsocketApp ,
30+ AsgiWebsocketReceive ,
31+ AsgiWebsocketSend ,
2932 Connection ,
3033 Location ,
3134 ReactPyConfig ,
@@ -153,48 +156,52 @@ async def __call__(
153156 send : asgi_types .ASGISendCallable ,
154157 ) -> None :
155158 """ASGI app for rendering ReactPy Python components."""
156- dispatcher : asyncio .Task [Any ] | None = None
157- recv_queue : asyncio .Queue [dict [str , Any ]] = asyncio .Queue ()
158-
159159 # Start a loop that handles ASGI websocket events
160- while True :
161- event = await receive ()
162- if event ["type" ] == "websocket.connect" :
163- await send (
164- {"type" : "websocket.accept" , "subprotocol" : None , "headers" : []}
165- )
166- dispatcher = asyncio .create_task (
167- self .run_dispatcher (scope , receive , send , recv_queue )
168- )
169-
170- elif event ["type" ] == "websocket.disconnect" :
171- if dispatcher :
172- dispatcher .cancel ()
173- break
174-
175- elif event ["type" ] == "websocket.receive" and event ["text" ]:
176- msg = orjson .loads (event ["text" ])
177- msg_type = msg .get ("type" )
178- if msg_type == "layout-event" :
179- queue_put_func = recv_queue .put (msg )
180- await queue_put_func
181- else :
182- await asyncio .to_thread (
183- _logger .warning , f"Unknown message type: { msg_type } "
184- )
185-
186- async def run_dispatcher (
160+ async with ReactPyWebsocket (scope , receive , send , parent = self .parent ) as ws : # type: ignore
161+ while True :
162+ event : dict [str , Any ] = await ws .receive (raw = True ) # type: ignore
163+ if event ["type" ] == "websocket.receive" and event ["text" ]:
164+ msg : dict [str , str ] = orjson .loads (event ["text" ])
165+ if msg .get ("type" ) == "layout-event" :
166+ queue_put_func = ws .recv_queue .put (msg )
167+ await queue_put_func
168+ else :
169+ await asyncio .to_thread (
170+ _logger .warning , f"Unknown message type: { msg .get ('type' )} "
171+ )
172+ elif event ["type" ] == "websocket.disconnect" :
173+ break
174+
175+
176+ class ReactPyWebsocket (ResponseWebSocket ):
177+ def __init__ (
187178 self ,
188179 scope : asgi_types .WebSocketScope ,
189- receive : asgi_types . ASGIReceiveCallable ,
190- send : asgi_types . ASGISendCallable ,
191- recv_queue : asyncio . Queue [ dict [ str , Any ]] ,
180+ receive : AsgiWebsocketReceive ,
181+ send : AsgiWebsocketSend ,
182+ parent : ReactPyMiddleware ,
192183 ) -> None :
184+ super ().__init__ (scope = scope , receive = receive , send = send ) # type: ignore
185+ self .scope = scope
186+ self .parent = parent
187+ self .recv_queue : asyncio .Queue [dict [str , str ]] = asyncio .Queue ()
188+ self .dispatcher : asyncio .Task [Any ] | None = None
189+
190+ async def __aenter__ (self ) -> ReactPyWebsocket :
191+ self .dispatcher = asyncio .create_task (self .run_dispatcher ())
192+ return await super ().__aenter__ () # type: ignore
193+
194+ async def __aexit__ (self , * _ : Any ) -> None :
195+ if self .dispatcher :
196+ self .dispatcher .cancel ()
197+ await super ().__aexit__ () # type: ignore
198+
199+ async def run_dispatcher (self ) -> None :
193200 """Asyncio background task that renders and transmits layout updates of ReactPy components."""
194201 try :
195202 # Determine component to serve by analyzing the URL and/or class parameters.
196203 if self .parent .multiple_root_components :
197- url_match = re .match (self .parent .dispatcher_pattern , scope ["path" ])
204+ url_match = re .match (self .parent .dispatcher_pattern , self . scope ["path" ])
198205 if not url_match : # pragma: no cover
199206 raise RuntimeError ("Could not find component in URL path." )
200207 dotted_path = url_match ["dotted_path" ]
@@ -210,10 +217,10 @@ async def run_dispatcher(
210217
211218 # Create a connection object by analyzing the websocket's query string.
212219 ws_query_string = urllib .parse .parse_qs (
213- scope ["query_string" ].decode (), strict_parsing = True
220+ self . scope ["query_string" ].decode (), strict_parsing = True
214221 )
215222 connection = Connection (
216- scope = scope ,
223+ scope = self . scope ,
217224 location = Location (
218225 path = ws_query_string .get ("http_pathname" , ["" ])[0 ],
219226 query_string = ws_query_string .get ("http_query_string" , ["" ])[0 ],
@@ -224,20 +231,19 @@ async def run_dispatcher(
224231 # Start the ReactPy component rendering loop
225232 await serve_layout (
226233 Layout (ConnectionContext (component (), value = connection )),
227- lambda msg : send (
228- {
229- "type" : "websocket.send" ,
230- "text" : orjson .dumps (msg ).decode (),
231- "bytes" : None ,
232- }
233- ),
234- recv_queue .get , # type: ignore
234+ self .send_json ,
235+ self .recv_queue .get , # type: ignore
235236 )
236237
237238 # Manually log exceptions since this function is running in a separate asyncio task.
238239 except Exception as error :
239240 await asyncio .to_thread (_logger .error , f"{ error } \n { traceback .format_exc ()} " )
240241
242+ async def send_json (self , data : Any ) -> None :
243+ return await self ._send (
244+ {"type" : "websocket.send" , "text" : orjson .dumps (data ).decode ()}
245+ )
246+
241247
242248@dataclass
243249class StaticFileApp :
0 commit comments