From 23d3ee06ecf66428529818cd85d6abfd5fe1b3a5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Apr 2025 21:17:16 -1000 Subject: [PATCH 1/2] Refactor WebSocket reader to avoid frequent realloc when frames are fragmented (#10744) --- CHANGES/10744.misc.rst | 1 + aiohttp/_websocket/reader_c.pxd | 25 ++++---- aiohttp/_websocket/reader_py.py | 103 ++++++++++++++++++-------------- 3 files changed, 72 insertions(+), 57 deletions(-) create mode 100644 CHANGES/10744.misc.rst diff --git a/CHANGES/10744.misc.rst b/CHANGES/10744.misc.rst new file mode 100644 index 00000000000..da0d379475d --- /dev/null +++ b/CHANGES/10744.misc.rst @@ -0,0 +1 @@ +Improved performance of the WebSocket reader with large messages -- by :user:`bdraco`. diff --git a/aiohttp/_websocket/reader_c.pxd b/aiohttp/_websocket/reader_c.pxd index 5c519961f82..9a6fdae3e97 100644 --- a/aiohttp/_websocket/reader_c.pxd +++ b/aiohttp/_websocket/reader_c.pxd @@ -74,14 +74,14 @@ cdef class WebSocketReader: cdef int _opcode cdef bint _frame_fin cdef int _frame_opcode - cdef object _frame_payload - cdef unsigned long long _frame_payload_len + cdef list _payload_fragments + cdef Py_ssize_t _frame_payload_len cdef bytes _tail cdef bint _has_mask cdef bytes _frame_mask - cdef unsigned long long _payload_length - cdef unsigned int _payload_length_flag + cdef Py_ssize_t _payload_bytes_to_read + cdef unsigned int _payload_len_flag cdef int _compressed cdef object _decompressobj cdef bint _compress @@ -97,17 +97,20 @@ cdef class WebSocketReader: cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except * @cython.locals( - start_pos="unsigned int", - data_len="unsigned int", - length="unsigned int", - chunk_size="unsigned int", - chunk_len="unsigned int", - data_length="unsigned int", + start_pos=Py_ssize_t, + data_len=Py_ssize_t, + length=Py_ssize_t, + chunk_size=Py_ssize_t, + chunk_len=Py_ssize_t, + data_len=Py_ssize_t, data_cstr="const unsigned char *", first_byte="unsigned char", second_byte="unsigned char", - end_pos="unsigned int", + f_start_pos=Py_ssize_t, + f_end_pos=Py_ssize_t, has_mask=bint, fin=bint, + had_fragments=Py_ssize_t, + payload_bytearray=bytearray, ) cpdef void _feed_data(self, bytes data) except * diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index aa30834f402..df322a436dd 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -151,14 +151,14 @@ def __init__( self._opcode: int = OP_CODE_NOT_SET self._frame_fin = False self._frame_opcode: int = OP_CODE_NOT_SET - self._frame_payload: Union[bytes, bytearray] = b"" + self._payload_fragments: list[bytes] = [] self._frame_payload_len = 0 self._tail: bytes = b"" self._has_mask = False self._frame_mask: Optional[bytes] = None - self._payload_length = 0 - self._payload_length_flag = 0 + self._payload_bytes_to_read = 0 + self._payload_len_flag = 0 self._compressed: int = COMPRESSED_NOT_SET self._decompressobj: Optional[ZLibDecompressor] = None self._compress = compress @@ -336,13 +336,13 @@ def _feed_data(self, data: bytes) -> None: data, self._tail = self._tail + data, b"" start_pos: int = 0 - data_length = len(data) + data_len = len(data) data_cstr = data while True: # read header if self._state == READ_HEADER: - if data_length - start_pos < 2: + if data_len - start_pos < 2: break first_byte = data_cstr[start_pos] second_byte = data_cstr[start_pos + 1] @@ -401,77 +401,88 @@ def _feed_data(self, data: bytes) -> None: self._frame_fin = bool(fin) self._frame_opcode = opcode self._has_mask = bool(has_mask) - self._payload_length_flag = length + self._payload_len_flag = length self._state = READ_PAYLOAD_LENGTH # read payload length if self._state == READ_PAYLOAD_LENGTH: - length_flag = self._payload_length_flag - if length_flag == 126: - if data_length - start_pos < 2: + len_flag = self._payload_len_flag + if len_flag == 126: + if data_len - start_pos < 2: break first_byte = data_cstr[start_pos] second_byte = data_cstr[start_pos + 1] start_pos += 2 - self._payload_length = first_byte << 8 | second_byte - elif length_flag > 126: - if data_length - start_pos < 8: + self._payload_bytes_to_read = first_byte << 8 | second_byte + elif len_flag > 126: + if data_len - start_pos < 8: break - self._payload_length = UNPACK_LEN3(data, start_pos)[0] + self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0] start_pos += 8 else: - self._payload_length = length_flag + self._payload_bytes_to_read = len_flag self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD # read payload mask if self._state == READ_PAYLOAD_MASK: - if data_length - start_pos < 4: + if data_len - start_pos < 4: break self._frame_mask = data_cstr[start_pos : start_pos + 4] start_pos += 4 self._state = READ_PAYLOAD if self._state == READ_PAYLOAD: - chunk_len = data_length - start_pos - if self._payload_length >= chunk_len: - end_pos = data_length - self._payload_length -= chunk_len + chunk_len = data_len - start_pos + if self._payload_bytes_to_read >= chunk_len: + f_end_pos = data_len + self._payload_bytes_to_read -= chunk_len else: - end_pos = start_pos + self._payload_length - self._payload_length = 0 - - if self._frame_payload_len: - if type(self._frame_payload) is not bytearray: - self._frame_payload = bytearray(self._frame_payload) - self._frame_payload += data_cstr[start_pos:end_pos] - else: - # Fast path for the first frame - self._frame_payload = data_cstr[start_pos:end_pos] - - self._frame_payload_len += end_pos - start_pos - start_pos = end_pos - - if self._payload_length != 0: + f_end_pos = start_pos + self._payload_bytes_to_read + self._payload_bytes_to_read = 0 + + had_fragments = self._frame_payload_len + self._frame_payload_len += f_end_pos - start_pos + f_start_pos = start_pos + start_pos = f_end_pos + + if self._payload_bytes_to_read != 0: + # If we don't have a complete frame, we need to save the + # data for the next call to feed_data. + self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) break - if self._has_mask: + payload: Union[bytes, bytearray] + if had_fragments: + # We have to join the payload fragments get the payload + self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) + if self._has_mask: + assert self._frame_mask is not None + payload_bytearray = bytearray() + payload_bytearray.join(self._payload_fragments) + websocket_mask(self._frame_mask, payload_bytearray) + payload = payload_bytearray + else: + payload = b"".join(self._payload_fragments) + self._payload_fragments.clear() + elif self._has_mask: assert self._frame_mask is not None - if type(self._frame_payload) is not bytearray: - self._frame_payload = bytearray(self._frame_payload) - websocket_mask(self._frame_mask, self._frame_payload) + payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment] + if type(payload_bytearray) is not bytearray: # pragma: no branch + # Cython will do the conversion for us + # but we need to do it for Python and we + # will always get here in Python + payload_bytearray = bytearray(payload_bytearray) + websocket_mask(self._frame_mask, payload_bytearray) + payload = payload_bytearray + else: + payload = data_cstr[f_start_pos:f_end_pos] self._handle_frame( - self._frame_fin, - self._frame_opcode, - self._frame_payload, - self._compressed, + self._frame_fin, self._frame_opcode, payload, self._compressed ) - self._frame_payload = b"" self._frame_payload_len = 0 self._state = READ_HEADER # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. - self._tail = ( - data_cstr[start_pos:data_length] if start_pos < data_length else b"" - ) + self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b"" From d702fb30a8bce4aec18a48bea68ae74d373eaedc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Apr 2025 22:46:16 -1000 Subject: [PATCH 2/2] Add compressed binary WebSocket roundtrip benchmark (#10749) --- tests/test_benchmarks_client_ws.py | 66 ++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/test_benchmarks_client_ws.py b/tests/test_benchmarks_client_ws.py index c244d33f6bd..0338b52fb9d 100644 --- a/tests/test_benchmarks_client_ws.py +++ b/tests/test_benchmarks_client_ws.py @@ -105,3 +105,69 @@ async def run_websocket_benchmark() -> None: @benchmark def _run() -> None: loop.run_until_complete(run_websocket_benchmark()) + + +@pytest.mark.usefixtures("parametrize_zlib_backend") +def test_client_send_large_websocket_compressed_messages( + loop: asyncio.AbstractEventLoop, + aiohttp_client: AiohttpClient, + benchmark: BenchmarkFixture, +) -> None: + """Benchmark send of compressed WebSocket binary messages.""" + message_count = 10 + raw_message = b"x" * 2**19 # 512 KiB + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + for _ in range(message_count): + await ws.receive() + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + + async def run_websocket_benchmark() -> None: + client = await aiohttp_client(app) + resp = await client.ws_connect("/", compress=15) + for _ in range(message_count): + await resp.send_bytes(raw_message) + await resp.close() + + @benchmark + def _run() -> None: + loop.run_until_complete(run_websocket_benchmark()) + + +@pytest.mark.usefixtures("parametrize_zlib_backend") +def test_client_receive_large_websocket_compressed_messages( + loop: asyncio.AbstractEventLoop, + aiohttp_client: AiohttpClient, + benchmark: BenchmarkFixture, +) -> None: + """Benchmark receive of compressed WebSocket binary messages.""" + message_count = 10 + raw_message = b"x" * 2**19 # 512 KiB + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + for _ in range(message_count): + await ws.send_bytes(raw_message) + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + + async def run_websocket_benchmark() -> None: + client = await aiohttp_client(app) + resp = await client.ws_connect("/", compress=15) + for _ in range(message_count): + await resp.receive() + await resp.close() + + @benchmark + def _run() -> None: + loop.run_until_complete(run_websocket_benchmark())