From 0d21d8d3f2d69ebfce1155edc76f17281b750377 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Apr 2025 18:50:32 -1000 Subject: [PATCH] Refactor WebSocket reader to avoid creating lists (#10740) --- CHANGES/10740.misc.rst | 1 + aiohttp/_websocket/reader_c.pxd | 36 ++-- aiohttp/_websocket/reader_py.py | 332 +++++++++++++++---------------- tests/test_websocket_parser.py | 342 +++++++++++++++----------------- 4 files changed, 340 insertions(+), 371 deletions(-) create mode 100644 CHANGES/10740.misc.rst diff --git a/CHANGES/10740.misc.rst b/CHANGES/10740.misc.rst new file mode 100644 index 00000000000..34ed19aebba --- /dev/null +++ b/CHANGES/10740.misc.rst @@ -0,0 +1 @@ +Improved performance of the WebSocket reader -- by :user:`bdraco`. diff --git a/aiohttp/_websocket/reader_c.pxd b/aiohttp/_websocket/reader_c.pxd index 07a7d979553..5c519961f82 100644 --- a/aiohttp/_websocket/reader_c.pxd +++ b/aiohttp/_websocket/reader_c.pxd @@ -8,12 +8,17 @@ cdef unsigned int READ_PAYLOAD_LENGTH cdef unsigned int READ_PAYLOAD_MASK cdef unsigned int READ_PAYLOAD -cdef unsigned int OP_CODE_CONTINUATION -cdef unsigned int OP_CODE_TEXT -cdef unsigned int OP_CODE_BINARY -cdef unsigned int OP_CODE_CLOSE -cdef unsigned int OP_CODE_PING -cdef unsigned int OP_CODE_PONG +cdef int OP_CODE_NOT_SET +cdef int OP_CODE_CONTINUATION +cdef int OP_CODE_TEXT +cdef int OP_CODE_BINARY +cdef int OP_CODE_CLOSE +cdef int OP_CODE_PING +cdef int OP_CODE_PONG + +cdef int COMPRESSED_NOT_SET +cdef int COMPRESSED_FALSE +cdef int COMPRESSED_TRUE cdef object UNPACK_LEN3 cdef object UNPACK_CLOSE_CODE @@ -66,9 +71,9 @@ cdef class WebSocketReader: cdef bytearray _partial cdef unsigned int _state - cdef object _opcode - cdef object _frame_fin - cdef object _frame_opcode + cdef int _opcode + cdef bint _frame_fin + cdef int _frame_opcode cdef object _frame_payload cdef unsigned long long _frame_payload_len @@ -77,7 +82,7 @@ cdef class WebSocketReader: cdef bytes _frame_mask cdef unsigned long long _payload_length cdef unsigned int _payload_length_flag - cdef object _compressed + cdef int _compressed cdef object _decompressobj cdef bint _compress @@ -88,22 +93,21 @@ cdef class WebSocketReader: fin=bint, has_partial=bint, payload_merged=bytes, - opcode="unsigned int", ) - cpdef void _feed_data(self, bytes data) + cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except * @cython.locals( start_pos="unsigned int", - buf_len="unsigned int", + data_len="unsigned int", length="unsigned int", chunk_size="unsigned int", chunk_len="unsigned int", - buf_length="unsigned int", - buf_cstr="const unsigned char *", + data_length="unsigned int", + data_cstr="const unsigned char *", first_byte="unsigned char", second_byte="unsigned char", end_pos="unsigned int", has_mask=bint, fin=bint, ) - cpdef list parse_frame(self, bytes buf) + cpdef void _feed_data(self, bytes data) except * diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index 5daf91d7140..aa30834f402 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -3,7 +3,7 @@ import asyncio import builtins from collections import deque -from typing import Deque, Final, List, Optional, Set, Tuple, Type, Union +from typing import Deque, Final, Optional, Set, Tuple, Type, Union from ..base_protocol import BaseProtocol from ..compression_utils import ZLibDecompressor @@ -36,6 +36,7 @@ WS_MSG_TYPE_TEXT = WSMsgType.TEXT # WSMsgType values unpacked so they can by cythonized to ints +OP_CODE_NOT_SET = -1 OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value OP_CODE_TEXT = WSMsgType.TEXT.value OP_CODE_BINARY = WSMsgType.BINARY.value @@ -46,8 +47,14 @@ EMPTY_FRAME_ERROR = (True, b"") EMPTY_FRAME = (False, b"") +COMPRESSED_NOT_SET = -1 +COMPRESSED_FALSE = 0 +COMPRESSED_TRUE = 1 + TUPLE_NEW = tuple.__new__ +cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd + class WebSocketDataQueue: """WebSocketDataQueue resumes and pauses an underlying stream. @@ -141,9 +148,9 @@ def __init__( self._partial = bytearray() self._state = READ_HEADER - self._opcode: Optional[int] = None + self._opcode: int = OP_CODE_NOT_SET self._frame_fin = False - self._frame_opcode: Optional[int] = None + self._frame_opcode: int = OP_CODE_NOT_SET self._frame_payload: Union[bytes, bytearray] = b"" self._frame_payload_len = 0 @@ -152,7 +159,7 @@ def __init__( self._frame_mask: Optional[bytes] = None self._payload_length = 0 self._payload_length_flag = 0 - self._compressed: Optional[bool] = None + self._compressed: int = COMPRESSED_NOT_SET self._decompressobj: Optional[ZLibDecompressor] = None self._compress = compress @@ -180,173 +187,165 @@ def feed_data( return EMPTY_FRAME - def _feed_data(self, data: bytes) -> None: + def _handle_frame( + self, + fin: bool, + opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int + payload: Union[bytes, bytearray], + compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int + ) -> None: msg: WSMessage - for frame in self.parse_frame(data): - fin = frame[0] - opcode = frame[1] - payload = frame[2] - compressed = frame[3] - - is_continuation = opcode == OP_CODE_CONTINUATION - if opcode == OP_CODE_TEXT or opcode == OP_CODE_BINARY or is_continuation: - # load text/binary - if not fin: - # got partial frame payload - if not is_continuation: - self._opcode = opcode - self._partial += payload - if self._max_msg_size and len(self._partial) >= self._max_msg_size: - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - f"Message size {len(self._partial)} " - f"exceeds limit {self._max_msg_size}", - ) - continue - - has_partial = bool(self._partial) - if is_continuation: - if self._opcode is None: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Continuation frame for non started message", - ) - opcode = self._opcode - self._opcode = None - # previous frame was non finished - # we should get continuation opcode - elif has_partial: + if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: + # load text/binary + if not fin: + # got partial frame payload + if opcode != OP_CODE_CONTINUATION: + self._opcode = opcode + self._partial += payload + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + f"Message size {len(self._partial)} " + f"exceeds limit {self._max_msg_size}", + ) + return + + has_partial = bool(self._partial) + if opcode == OP_CODE_CONTINUATION: + if self._opcode == OP_CODE_NOT_SET: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, - "The opcode in non-fin frame is expected " - f"to be zero, got {opcode!r}", + "Continuation frame for non started message", ) + opcode = self._opcode + self._opcode = OP_CODE_NOT_SET + # previous frame was non finished + # we should get continuation opcode + elif has_partial: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "The opcode in non-fin frame is expected " + f"to be zero, got {opcode!r}", + ) - assembled_payload: Union[bytes, bytearray] - if has_partial: - assembled_payload = self._partial + payload - self._partial.clear() - else: - assembled_payload = payload + assembled_payload: Union[bytes, bytearray] + if has_partial: + assembled_payload = self._partial + payload + self._partial.clear() + else: + assembled_payload = payload + + if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + f"Message size {len(assembled_payload)} " + f"exceeds limit {self._max_msg_size}", + ) - if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: + # Decompress process must to be done after all packets + # received. + if compressed: + if not self._decompressobj: + self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) + # XXX: It's possible that the zlib backend (isal is known to + # do this, maybe others too?) will return max_length bytes, + # but internally buffer more data such that the payload is + # >max_length, so we return one extra byte and if we're able + # to do that, then the message is too big. + payload_merged = self._decompressobj.decompress_sync( + assembled_payload + WS_DEFLATE_TRAILING, + ( + self._max_msg_size + 1 + if self._max_msg_size + else self._max_msg_size + ), + ) + if self._max_msg_size and len(payload_merged) > self._max_msg_size: raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, - f"Message size {len(assembled_payload)} " - f"exceeds limit {self._max_msg_size}", + f"Decompressed message exceeds size limit {self._max_msg_size}", ) + elif type(assembled_payload) is bytes: + payload_merged = assembled_payload + else: + payload_merged = bytes(assembled_payload) - # Decompress process must to be done after all packets - # received. - if compressed: - if not self._decompressobj: - self._decompressobj = ZLibDecompressor( - suppress_deflate_header=True - ) - # XXX: It's possible that the zlib backend (isal is known to - # do this, maybe others too?) will return max_length bytes, - # but internally buffer more data such that the payload is - # >max_length, so we return one extra byte and if we're able - # to do that, then the message is too big. - payload_merged = self._decompressobj.decompress_sync( - assembled_payload + WS_DEFLATE_TRAILING, - ( - self._max_msg_size + 1 - if self._max_msg_size - else self._max_msg_size - ), - ) - if self._max_msg_size and len(payload_merged) > self._max_msg_size: - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - f"Decompressed message exceeds size limit {self._max_msg_size}", - ) - elif type(assembled_payload) is bytes: - payload_merged = assembled_payload - else: - payload_merged = bytes(assembled_payload) - - size = len(payload_merged) - if opcode == OP_CODE_TEXT: - try: - text = payload_merged.decode("utf-8") - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" - ) from exc - - # XXX: The Text and Binary messages here can be a performance - # bottleneck, so we use tuple.__new__ to improve performance. - # This is not type safe, but many tests should fail in - # test_client_ws_functional.py if this is wrong. - msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) - else: - msg = TUPLE_NEW( - WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY) - ) + size = len(payload_merged) + if opcode == OP_CODE_TEXT: + try: + text = payload_merged.decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + + # XXX: The Text and Binary messages here can be a performance + # bottleneck, so we use tuple.__new__ to improve performance. + # This is not type safe, but many tests should fail in + # test_client_ws_functional.py if this is wrong. + msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) + else: + msg = TUPLE_NEW( + WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY) + ) - self.queue.feed_data(msg) - elif opcode == OP_CODE_CLOSE: - payload_len = len(payload) - if payload_len >= 2: - close_code = UNPACK_CLOSE_CODE(payload[:2])[0] - if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - f"Invalid close code: {close_code}", - ) - try: - close_message = payload[2:].decode("utf-8") - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" - ) from exc - msg = WSMessageClose( - data=close_code, size=payload_len, extra=close_message - ) - elif payload: + self.queue.feed_data(msg) + elif opcode == OP_CODE_CLOSE: + payload_len = len(payload) + if payload_len >= 2: + close_code = UNPACK_CLOSE_CODE(payload[:2])[0] + if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, - f"Invalid close frame: {fin} {opcode} {payload!r}", + f"Invalid close code: {close_code}", ) - else: - msg = WSMessageClose(data=0, size=payload_len, extra="") - - self.queue.feed_data(msg) - elif opcode == OP_CODE_PING: - self.queue.feed_data( - WSMessagePing(data=bytes(payload), size=len(payload), extra="") - ) - elif opcode == OP_CODE_PONG: - self.queue.feed_data( - WSMessagePong(data=bytes(payload), size=len(payload), extra="") + try: + close_message = payload[2:].decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + msg = WSMessageClose( + data=close_code, size=payload_len, extra=close_message ) - else: + elif payload: raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" + WSCloseCode.PROTOCOL_ERROR, + f"Invalid close frame: {fin} {opcode} {payload!r}", ) + else: + msg = WSMessageClose(data=0, size=payload_len, extra="") + + self.queue.feed_data(msg) + elif opcode == OP_CODE_PING: + self.queue.feed_data( + WSMessagePing(data=bytes(payload), size=len(payload), extra="") + ) + elif opcode == OP_CODE_PONG: + self.queue.feed_data( + WSMessagePong(data=bytes(payload), size=len(payload), extra="") + ) + else: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" + ) - def parse_frame( - self, buf: bytes - ) -> List[Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]]: + def _feed_data(self, data: bytes) -> None: """Return the next frame from the socket.""" - frames: List[ - Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]] - ] = [] if self._tail: - buf, self._tail = self._tail + buf, b"" + data, self._tail = self._tail + data, b"" start_pos: int = 0 - buf_length = len(buf) - buf_cstr = buf + data_length = len(data) + data_cstr = data while True: # read header if self._state == READ_HEADER: - if buf_length - start_pos < 2: + if data_length - start_pos < 2: break - first_byte = buf_cstr[start_pos] - second_byte = buf_cstr[start_pos + 1] + first_byte = data_cstr[start_pos] + second_byte = data_cstr[start_pos + 1] start_pos += 2 fin = (first_byte >> 7) & 1 @@ -391,8 +390,8 @@ def parse_frame( # Set compress status if last package is FIN # OR set compress status if this is first fragment # Raise error if not first fragment with rsv1 = 0x1 - if self._frame_fin or self._compressed is None: - self._compressed = True if rsv1 else False + if self._frame_fin or self._compressed == COMPRESSED_NOT_SET: + self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE elif rsv1: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, @@ -409,18 +408,17 @@ def parse_frame( if self._state == READ_PAYLOAD_LENGTH: length_flag = self._payload_length_flag if length_flag == 126: - if buf_length - start_pos < 2: + if data_length - start_pos < 2: break - first_byte = buf_cstr[start_pos] - second_byte = buf_cstr[start_pos + 1] + first_byte = data_cstr[start_pos] + second_byte = data_cstr[start_pos + 1] start_pos += 2 self._payload_length = first_byte << 8 | second_byte elif length_flag > 126: - if buf_length - start_pos < 8: + if data_length - start_pos < 8: break - data = buf_cstr[start_pos : start_pos + 8] + self._payload_length = UNPACK_LEN3(data, start_pos)[0] start_pos += 8 - self._payload_length = UNPACK_LEN3(data)[0] else: self._payload_length = length_flag @@ -428,16 +426,16 @@ def parse_frame( # read payload mask if self._state == READ_PAYLOAD_MASK: - if buf_length - start_pos < 4: + if data_length - start_pos < 4: break - self._frame_mask = buf_cstr[start_pos : start_pos + 4] + self._frame_mask = data_cstr[start_pos : start_pos + 4] start_pos += 4 self._state = READ_PAYLOAD if self._state == READ_PAYLOAD: - chunk_len = buf_length - start_pos + chunk_len = data_length - start_pos if self._payload_length >= chunk_len: - end_pos = buf_length + end_pos = data_length self._payload_length -= chunk_len else: end_pos = start_pos + self._payload_length @@ -446,10 +444,10 @@ def parse_frame( if self._frame_payload_len: if type(self._frame_payload) is not bytearray: self._frame_payload = bytearray(self._frame_payload) - self._frame_payload += buf_cstr[start_pos:end_pos] + self._frame_payload += data_cstr[start_pos:end_pos] else: # Fast path for the first frame - self._frame_payload = buf_cstr[start_pos:end_pos] + self._frame_payload = data_cstr[start_pos:end_pos] self._frame_payload_len += end_pos - start_pos start_pos = end_pos @@ -463,19 +461,17 @@ def parse_frame( self._frame_payload = bytearray(self._frame_payload) websocket_mask(self._frame_mask, self._frame_payload) - frames.append( - ( - self._frame_fin, - self._frame_opcode, - self._frame_payload, - self._compressed, - ) + self._handle_frame( + self._frame_fin, + self._frame_opcode, + self._frame_payload, + self._compressed, ) self._frame_payload = b"" self._frame_payload_len = 0 self._state = READ_HEADER # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. - self._tail = buf_cstr[start_pos:buf_length] if start_pos < buf_length else b"" - - return frames + self._tail = ( + data_cstr[start_pos:data_length] if start_pos < data_length else b"" + ) diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index a28c7e5ac9c..41da6b4e16e 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -26,6 +26,25 @@ class PatchableWebSocketReader(WebSocketReader): """WebSocketReader subclass that allows for patching parse_frame.""" + def parse_frame( + self, data: bytes + ) -> list[tuple[bool, int, Union[bytes, bytearray], int]]: + # This method is overridden to allow for patching in tests. + frames: list[tuple[bool, int, Union[bytes, bytearray], int]] = [] + + def _handle_frame( + fin: bool, + opcode: int, + payload: Union[bytes, bytearray], + compressed: int, + ) -> None: + # This method is overridden to allow for patching in tests. + frames.append((fin, opcode, payload, compressed)) + + with mock.patch.object(self, "_handle_frame", _handle_frame): + self._feed_data(data) + return frames + def build_frame( message: bytes, @@ -117,32 +136,32 @@ def test_feed_data_remembers_exception(parser: WebSocketReader) -> None: assert data == b"" -def test_parse_frame(parser: WebSocketReader) -> None: +def test_parse_frame(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b00000001, 0b00000001)) res = parser.parse_frame(b"1") fin, opcode, payload, compress = res[0] - assert (0, 1, b"1", False) == (fin, opcode, payload, not not compress) + assert (0, 1, b"1", 0) == (fin, opcode, payload, not not compress) -def test_parse_frame_length0(parser: WebSocketReader) -> None: +def test_parse_frame_length0(parser: PatchableWebSocketReader) -> None: fin, opcode, payload, compress = parser.parse_frame( struct.pack("!BB", 0b00000001, 0b00000000) )[0] - assert (0, 1, b"", False) == (fin, opcode, payload, not not compress) + assert (0, 1, b"", 0) == (fin, opcode, payload, not not compress) -def test_parse_frame_length2(parser: WebSocketReader) -> None: +def test_parse_frame_length2(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) parser.parse_frame(struct.pack("!H", 4)) res = parser.parse_frame(b"1234") fin, opcode, payload, compress = res[0] - assert (0, 1, b"1234", False) == (fin, opcode, payload, not not compress) + assert (0, 1, b"1234", 0) == (fin, opcode, payload, not not compress) -def test_parse_frame_length2_multi_byte(parser: WebSocketReader) -> None: +def test_parse_frame_length2_multi_byte(parser: PatchableWebSocketReader) -> None: """Ensure a multi-byte length is parsed correctly.""" expected_payload = b"1" * 32768 parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) @@ -150,10 +169,12 @@ def test_parse_frame_length2_multi_byte(parser: WebSocketReader) -> None: res = parser.parse_frame(b"1" * 32768) fin, opcode, payload, compress = res[0] - assert (0, 1, expected_payload, False) == (fin, opcode, payload, not not compress) + assert (0, 1, expected_payload, 0) == (fin, opcode, payload, not not compress) -def test_parse_frame_length2_multi_byte_multi_packet(parser: WebSocketReader) -> None: +def test_parse_frame_length2_multi_byte_multi_packet( + parser: PatchableWebSocketReader, +) -> None: """Ensure a multi-byte length with multiple packets is parsed correctly.""" expected_payload = b"1" * 32768 assert parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) == [] @@ -164,34 +185,34 @@ def test_parse_frame_length2_multi_byte_multi_packet(parser: WebSocketReader) -> res = parser.parse_frame(b"1" * 8192) fin, opcode, payload, compress = res[0] assert len(payload) == 32768 - assert (0, 1, expected_payload, False) == (fin, opcode, payload, not not compress) + assert (0, 1, expected_payload, 0) == (fin, opcode, payload, not not compress) -def test_parse_frame_length4(parser: WebSocketReader) -> None: +def test_parse_frame_length4(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b00000001, 127)) parser.parse_frame(struct.pack("!Q", 4)) fin, opcode, payload, compress = parser.parse_frame(b"1234")[0] - assert (0, 1, b"1234", False) == (fin, opcode, payload, not not compress) + assert (0, 1, b"1234", 0) == (fin, opcode, payload, compress) -def test_parse_frame_mask(parser: WebSocketReader) -> None: +def test_parse_frame_mask(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b00000001, 0b10000001)) parser.parse_frame(b"0001") fin, opcode, payload, compress = parser.parse_frame(b"1")[0] - assert (0, 1, b"\x01", False) == (fin, opcode, payload, not not compress) + assert (0, 1, b"\x01", 0) == (fin, opcode, payload, compress) def test_parse_frame_header_reversed_bits( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: with pytest.raises(WebSocketError): parser.parse_frame(struct.pack("!BB", 0b01100000, 0b00000000)) def test_parse_frame_header_control_frame( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: with pytest.raises(WebSocketError): parser.parse_frame(struct.pack("!BB", 0b00001000, 0b00000000)) @@ -199,14 +220,14 @@ def test_parse_frame_header_control_frame( @pytest.mark.xfail() def test_parse_frame_header_new_data_err( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: with pytest.raises(WebSocketError): parser.parse_frame(struct.pack("!BB", 0b000000000, 0b00000000)) def test_parse_frame_header_payload_size( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: with pytest.raises(WebSocketError): parser.parse_frame(struct.pack("!BB", 0b10001000, 0b01111110)) @@ -221,56 +242,44 @@ def test_parse_frame_header_payload_size( ) def test_ping_frame( out: WebSocketDataQueue, - parser: WebSocketReader, + parser: PatchableWebSocketReader, data: Union[bytes, bytearray, memoryview], ) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [(1, WSMsgType.PING, b"data", False)] - - parser.feed_data(data) - res = out._buffer[0] - assert res == WSMessagePing(data=b"data", size=4, extra="") - - -def test_pong_frame(out: WebSocketDataQueue, parser: WebSocketReader) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [(1, WSMsgType.PONG, b"data", False)] - - parser.feed_data(b"") - res = out._buffer[0] - assert res == WSMessagePong(data=b"data", size=4, extra="") - + parser._handle_frame(True, WSMsgType.PING, b"data", 0) + res = out._buffer[0] + assert res == WSMessagePing(data=b"data", size=4, extra="") -def test_close_frame(out: WebSocketDataQueue, parser: WebSocketReader) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [(1, WSMsgType.CLOSE, b"", False)] - parser.feed_data(b"") - res = out._buffer[0] - assert res == WSMessageClose(data=0, size=0, extra="") +def test_pong_frame(out: WebSocketDataQueue, parser: PatchableWebSocketReader) -> None: + parser._handle_frame(True, WSMsgType.PONG, b"data", 0) + res = out._buffer[0] + assert res == WSMessagePong(data=b"data", size=4, extra="") -def test_close_frame_info(out: WebSocketDataQueue, parser: WebSocketReader) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [(1, WSMsgType.CLOSE, b"0112345", False)] +def test_close_frame(out: WebSocketDataQueue, parser: PatchableWebSocketReader) -> None: + parser._handle_frame(True, WSMsgType.CLOSE, b"", 0) + res = out._buffer[0] + assert res == WSMessageClose(data=0, size=0, extra="") - parser.feed_data(b"") - res = out._buffer[0] - assert res == WSMessageClose(data=12337, size=7, extra="12345") +def test_close_frame_info( + out: WebSocketDataQueue, parser: PatchableWebSocketReader +) -> None: + parser._handle_frame(True, WSMsgType.CLOSE, b"0112345", 0) + res = out._buffer[0] + assert res == WSMessageClose(data=12337, size=7, extra="12345") -def test_close_frame_invalid(out: WebSocketDataQueue, parser: WebSocketReader) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [(1, WSMsgType.CLOSE, b"1", False)] - parser.feed_data(b"") - exc = out.exception() - assert isinstance(exc, WebSocketError) - assert exc.code == WSCloseCode.PROTOCOL_ERROR +def test_close_frame_invalid( + out: WebSocketDataQueue, parser: PatchableWebSocketReader +) -> None: + with pytest.raises(WebSocketError) as ctx: + parser._handle_frame(True, WSMsgType.CLOSE, b"1", 0) + assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR def test_close_frame_invalid_2( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: data = build_close_frame(code=1) @@ -280,7 +289,7 @@ def test_close_frame_invalid_2( assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR -def test_close_frame_unicode_err(parser: WebSocketReader) -> None: +def test_close_frame_unicode_err(parser: PatchableWebSocketReader) -> None: data = build_close_frame(code=1000, message=b"\xf4\x90\x80\x80") with pytest.raises(WebSocketError) as ctx: @@ -289,22 +298,21 @@ def test_close_frame_unicode_err(parser: WebSocketReader) -> None: assert ctx.value.code == WSCloseCode.INVALID_TEXT -def test_unknown_frame(out: WebSocketDataQueue, parser: WebSocketReader) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [(1, WSMsgType.CONTINUATION, b"", False)] - - parser.feed_data(b"") - assert isinstance(out.exception(), WebSocketError) +def test_unknown_frame( + out: WebSocketDataQueue, parser: PatchableWebSocketReader +) -> None: + with pytest.raises(WebSocketError): + parser._handle_frame(True, WSMsgType.CONTINUATION, b"", 0) -def test_simple_text(out: WebSocketDataQueue, parser: WebSocketReader) -> None: +def test_simple_text(out: WebSocketDataQueue, parser: PatchableWebSocketReader) -> None: data = build_frame(b"text", WSMsgType.TEXT) parser._feed_data(data) res = out._buffer[0] assert res == WSMessageText(data="text", size=4, extra="") -def test_simple_text_unicode_err(parser: WebSocketReader) -> None: +def test_simple_text_unicode_err(parser: PatchableWebSocketReader) -> None: data = build_frame(b"\xf4\x90\x80\x80", WSMsgType.TEXT) with pytest.raises(WebSocketError) as ctx: @@ -313,16 +321,18 @@ def test_simple_text_unicode_err(parser: WebSocketReader) -> None: assert ctx.value.code == WSCloseCode.INVALID_TEXT -def test_simple_binary(out: WebSocketDataQueue, parser: WebSocketReader) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [(1, WSMsgType.BINARY, b"binary", False)] - - parser.feed_data(b"") - res = out._buffer[0] - assert res == WSMessageBinary(data=b"binary", size=6, extra="") +def test_simple_binary( + out: WebSocketDataQueue, parser: PatchableWebSocketReader +) -> None: + data = build_frame(b"binary", WSMsgType.BINARY) + parser._feed_data(data) + res = out._buffer[0] + assert res == WSMessageBinary(data=b"binary", size=6, extra="") -def test_fragmentation_header(out: WebSocketDataQueue, parser: WebSocketReader) -> None: +def test_fragmentation_header( + out: WebSocketDataQueue, parser: PatchableWebSocketReader +) -> None: data = build_frame(b"a", WSMsgType.TEXT) parser._feed_data(data[:1]) parser._feed_data(data[1:]) @@ -331,7 +341,9 @@ def test_fragmentation_header(out: WebSocketDataQueue, parser: WebSocketReader) assert res == WSMessageText(data="a", size=1, extra="") -def test_continuation(out: WebSocketDataQueue, parser: WebSocketReader) -> None: +def test_continuation( + out: WebSocketDataQueue, parser: PatchableWebSocketReader +) -> None: data1 = build_frame(b"line1", WSMsgType.TEXT, is_fin=False) parser._feed_data(data1) @@ -343,131 +355,97 @@ def test_continuation(out: WebSocketDataQueue, parser: WebSocketReader) -> None: def test_continuation_with_ping( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [ - (0, WSMsgType.TEXT, b"line1", False), - (0, WSMsgType.PING, b"", False), - (1, WSMsgType.CONTINUATION, b"line2", False), - ] - - data1 = build_frame(b"line1", WSMsgType.TEXT, is_fin=False) - parser._feed_data(data1) - - data2 = build_frame(b"", WSMsgType.PING) - parser._feed_data(data2) + data1 = build_frame(b"line1", WSMsgType.TEXT, is_fin=False) + parser._feed_data(data1) - data3 = build_frame(b"line2", WSMsgType.CONTINUATION) - parser._feed_data(data3) + data2 = build_frame(b"", WSMsgType.PING) + parser._feed_data(data2) - res = out._buffer[0] - assert res == WSMessagePing(data=b"", size=0, extra="") - res = out._buffer[1] - assert res == WSMessageText(data="line1line2", size=10, extra="") + data3 = build_frame(b"line2", WSMsgType.CONTINUATION) + parser._feed_data(data3) + res = out._buffer[0] + assert res == WSMessagePing(data=b"", size=0, extra="") + res = out._buffer[1] + assert res == WSMessageText(data="line1line2", size=10, extra="") -def test_continuation_err(out: WebSocketDataQueue, parser: WebSocketReader) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [ - (0, WSMsgType.TEXT, b"line1", False), - (1, WSMsgType.TEXT, b"line2", False), - ] - with pytest.raises(WebSocketError): - parser._feed_data(b"") +def test_continuation_err( + out: WebSocketDataQueue, parser: PatchableWebSocketReader +) -> None: + parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) + with pytest.raises(WebSocketError): + parser._handle_frame(True, WSMsgType.TEXT, b"line2", 0) def test_continuation_with_close( out: WebSocketDataQueue, parser: WebSocketReader ) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [ - (0, WSMsgType.TEXT, b"line1", False), - ( - 0, - WSMsgType.CLOSE, - build_close_frame(1002, b"test", noheader=True), - False, - ), - (1, WSMsgType.CONTINUATION, b"line2", False), - ] - - parser.feed_data(b"") - res = out._buffer[0] - assert res == WSMessageClose(data=1002, size=6, extra="test") - res = out._buffer[1] - assert res == WSMessageText(data="line1line2", size=10, extra="") + parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) + parser._handle_frame( + False, + WSMsgType.CLOSE, + build_close_frame(1002, b"test", noheader=True), + False, + ) + parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) + res = out._buffer[0] + assert res == WSMessageClose(data=1002, size=6, extra="test") + res = out._buffer[1] + assert res == WSMessageText(data="line1line2", size=10, extra="") def test_continuation_with_close_unicode_err( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [ - (0, WSMsgType.TEXT, b"line1", False), - ( - 0, - WSMsgType.CLOSE, - build_close_frame(1000, b"\xf4\x90\x80\x80", noheader=True), - False, - ), - (1, WSMsgType.CONTINUATION, b"line2", False), - ] - - with pytest.raises(WebSocketError) as ctx: - parser._feed_data(b"") - - assert ctx.value.code == WSCloseCode.INVALID_TEXT + parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) + with pytest.raises(WebSocketError) as ctx: + parser._handle_frame( + False, + WSMsgType.CLOSE, + build_close_frame(1000, b"\xf4\x90\x80\x80", noheader=True), + 0, + ) + parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) + assert ctx.value.code == WSCloseCode.INVALID_TEXT def test_continuation_with_close_bad_code( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [ - (0, WSMsgType.TEXT, b"line1", False), - (0, WSMsgType.CLOSE, build_close_frame(1, b"test", noheader=True), False), - (1, WSMsgType.CONTINUATION, b"line2", False), - ] - - with pytest.raises(WebSocketError) as ctx: - parser._feed_data(b"") + parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) + with pytest.raises(WebSocketError) as ctx: - assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR + parser._handle_frame( + False, WSMsgType.CLOSE, build_close_frame(1, b"test", noheader=True), 0 + ) + assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR + parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) def test_continuation_with_close_bad_payload( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [ - (0, WSMsgType.TEXT, b"line1", False), - (0, WSMsgType.CLOSE, b"1", False), - (1, WSMsgType.CONTINUATION, b"line2", False), - ] - - with pytest.raises(WebSocketError) as ctx: - parser._feed_data(b"") - - assert ctx.value.code, WSCloseCode.PROTOCOL_ERROR + parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) + with pytest.raises(WebSocketError) as ctx: + parser._handle_frame(False, WSMsgType.CLOSE, b"1", 0) + assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR + parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) def test_continuation_with_close_empty( - out: WebSocketDataQueue, parser: WebSocketReader + out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: - with mock.patch.object(parser, "parse_frame", autospec=True) as m: - m.return_value = [ - (0, WSMsgType.TEXT, b"line1", False), - (0, WSMsgType.CLOSE, b"", False), - (1, WSMsgType.CONTINUATION, b"line2", False), - ] + parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) + parser._handle_frame(False, WSMsgType.CLOSE, b"", 0) + parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) - parser.feed_data(b"") - res = out._buffer[0] - assert res == WSMessageClose(data=0, size=0, extra="") - res = out._buffer[1] - assert res == WSMessageText(data="line1line2", size=10, extra="") + res = out._buffer[0] + assert res == WSMessageClose(data=0, size=0, extra="") + res = out._buffer[1] + assert res == WSMessageText(data="line1line2", size=10, extra="") websocket_mask_data: bytes = b"some very long data for masking by websocket" @@ -510,7 +488,7 @@ def test_websocket_mask_cython_empty() -> None: assert message == bytearray() -def test_parse_compress_frame_single(parser: WebSocketReader) -> None: +def test_parse_compress_frame_single(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) res = parser.parse_frame(b"1") fin, opcode, payload, compress = res[0] @@ -518,7 +496,7 @@ def test_parse_compress_frame_single(parser: WebSocketReader) -> None: assert (1, 1, b"1", True) == (fin, opcode, payload, not not compress) -def test_parse_compress_frame_multi(parser: WebSocketReader) -> None: +def test_parse_compress_frame_multi(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b01000001, 126)) parser.parse_frame(struct.pack("!H", 4)) res = parser.parse_frame(b"1234") @@ -538,7 +516,7 @@ def test_parse_compress_frame_multi(parser: WebSocketReader) -> None: assert (1, 1, b"1234", False) == (fin, opcode, payload, not not compress) -def test_parse_compress_error_frame(parser: WebSocketReader) -> None: +def test_parse_compress_error_frame(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b01000001, 0b00000001)) parser.parse_frame(b"1") @@ -550,7 +528,7 @@ def test_parse_compress_error_frame(parser: WebSocketReader) -> None: def test_parse_no_compress_frame_single(out: WebSocketDataQueue) -> None: - parser_no_compress = WebSocketReader(out, 0, compress=False) + parser_no_compress = PatchableWebSocketReader(out, 0, compress=False) with pytest.raises(WebSocketError) as ctx: parser_no_compress.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) parser_no_compress.parse_frame(b"1") @@ -603,16 +581,11 @@ def test_pickle(self) -> None: def test_flow_control_binary( protocol: BaseProtocol, out_low_limit: WebSocketDataQueue, - parser_low_limit: WebSocketReader, + parser_low_limit: PatchableWebSocketReader, ) -> None: large_payload = b"b" * (1 + 16 * 2) large_payload_size = len(large_payload) - - with mock.patch.object(parser_low_limit, "parse_frame", autospec=True) as m: - m.return_value = [(1, WSMsgType.BINARY, large_payload, False)] - - parser_low_limit.feed_data(b"") - + parser_low_limit._handle_frame(True, WSMsgType.BINARY, large_payload, 0) res = out_low_limit._buffer[0] assert res == WSMessageBinary(data=large_payload, size=large_payload_size, extra="") assert protocol._reading_paused is True @@ -621,17 +594,12 @@ def test_flow_control_binary( def test_flow_control_multi_byte_text( protocol: BaseProtocol, out_low_limit: WebSocketDataQueue, - parser_low_limit: WebSocketReader, + parser_low_limit: PatchableWebSocketReader, ) -> None: large_payload_text = "𒀁" * (1 + 16 * 2) large_payload = large_payload_text.encode("utf-8") large_payload_size = len(large_payload) - - with mock.patch.object(parser_low_limit, "parse_frame", autospec=True) as m: - m.return_value = [(1, WSMsgType.TEXT, large_payload, False)] - - parser_low_limit.feed_data(b"") - + parser_low_limit._handle_frame(True, WSMsgType.TEXT, large_payload, 0) res = out_low_limit._buffer[0] assert res == WSMessageText( data=large_payload_text, size=large_payload_size, extra=""