From ceeca6a9b019b98d1dc3d5d1a622ba66768938d2 Mon Sep 17 00:00:00 2001 From: Tim Menninger Date: Sat, 12 Apr 2025 14:07:21 -0700 Subject: [PATCH] Add support for switching the zlib implementation (#10700) --- CHANGES/9798.feature.rst | 5 + aiohttp/__init__.py | 2 + aiohttp/_websocket/reader_py.py | 18 +- aiohttp/_websocket/writer.py | 9 +- aiohttp/abc.py | 3 +- aiohttp/compression_utils.py | 139 ++++++++++-- aiohttp/http_writer.py | 3 +- aiohttp/multipart.py | 3 +- aiohttp/web_response.py | 5 +- docs/client_reference.rst | 24 +++ docs/conf.py | 2 + docs/spelling_wordlist.txt | 1 + docs/web_reference.rst | 7 +- requirements/dev.txt | 4 + requirements/lint.in | 2 + requirements/lint.txt | 4 + requirements/test.in | 2 + requirements/test.txt | 4 + tests/conftest.py | 16 ++ tests/test_client_functional.py | 25 ++- tests/test_client_request.py | 17 +- tests/test_client_ws_functional.py | 2 + tests/test_compression_utils.py | 18 +- tests/test_http_writer.py | 291 +++++++++++++++++++++++++- tests/test_multipart.py | 9 +- tests/test_web_functional.py | 50 +++-- tests/test_web_response.py | 28 ++- tests/test_web_sendfile_functional.py | 9 +- tests/test_websocket_parser.py | 17 +- tests/test_websocket_writer.py | 1 + 30 files changed, 632 insertions(+), 88 deletions(-) create mode 100644 CHANGES/9798.feature.rst diff --git a/CHANGES/9798.feature.rst b/CHANGES/9798.feature.rst new file mode 100644 index 00000000000..c1584b04491 --- /dev/null +++ b/CHANGES/9798.feature.rst @@ -0,0 +1,5 @@ +Allow user setting zlib compression backend -- by :user:`TimMenninger` + +This change allows the user to call :func:`aiohttp.set_zlib_backend()` with the +zlib compression module of their choice. Default behavior continues to use +the builtin ``zlib`` library. diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 7759a997cb9..f23bf928f37 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -47,6 +47,7 @@ WSServerHandshakeError, request, ) +from .compression_utils import set_zlib_backend from .connector import AddrInfoType, SocketFactoryType from .cookiejar import CookieJar, DummyCookieJar from .formdata import FormData @@ -165,6 +166,7 @@ "BasicAuth", "ChainMapProxy", "ETag", + "set_zlib_backend", # http "HttpVersion", "HttpVersion10", diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index 52d6b83925f..5daf91d7140 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -243,15 +243,23 @@ def _feed_data(self, data: bytes) -> None: self._decompressobj = ZLibDecompressor( suppress_deflate_header=True ) + # XXX: It's possible that the zlib backend (isal is known to + # do this, maybe others too?) will return max_length bytes, + # but internally buffer more data such that the payload is + # >max_length, so we return one extra byte and if we're able + # to do that, then the message is too big. payload_merged = self._decompressobj.decompress_sync( - assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size + assembled_payload + WS_DEFLATE_TRAILING, + ( + self._max_msg_size + 1 + if self._max_msg_size + else self._max_msg_size + ), ) - if self._decompressobj.unconsumed_tail: - left = len(self._decompressobj.unconsumed_tail) + if self._max_msg_size and len(payload_merged) > self._max_msg_size: raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, - f"Decompressed message size {self._max_msg_size + left}" - f" exceeds limit {self._max_msg_size}", + f"Decompressed message exceeds size limit {self._max_msg_size}", ) elif type(assembled_payload) is bytes: payload_merged = assembled_payload diff --git a/aiohttp/_websocket/writer.py b/aiohttp/_websocket/writer.py index ea4962e86fa..8ba98280304 100644 --- a/aiohttp/_websocket/writer.py +++ b/aiohttp/_websocket/writer.py @@ -2,13 +2,12 @@ import asyncio import random -import zlib from functools import partial from typing import Any, Final, Optional, Union from ..base_protocol import BaseProtocol from ..client_exceptions import ClientConnectionResetError -from ..compression_utils import ZLibCompressor +from ..compression_utils import ZLibBackend, ZLibCompressor from .helpers import ( MASK_LEN, MSG_SIZE, @@ -95,7 +94,9 @@ async def send_frame( message = ( await compressobj.compress(message) + compressobj.flush( - zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH + ZLibBackend.Z_FULL_FLUSH + if self.notakeover + else ZLibBackend.Z_SYNC_FLUSH ) ).removesuffix(WS_DEFLATE_TRAILING) # Its critical that we do not return control to the event @@ -160,7 +161,7 @@ async def send_frame( def _make_compress_obj(self, compress: int) -> ZLibCompressor: return ZLibCompressor( - level=zlib.Z_BEST_SPEED, + level=ZLibBackend.Z_BEST_SPEED, wbits=-compress, max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, ) diff --git a/aiohttp/abc.py b/aiohttp/abc.py index 7ff3fee73e8..53ce6ec495e 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -1,6 +1,5 @@ import logging import socket -import zlib from abc import ABC, abstractmethod from collections.abc import Sized from http.cookies import BaseCookie, Morsel @@ -217,7 +216,7 @@ async def drain(self) -> None: @abstractmethod def enable_compression( - self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY + self, encoding: str = "deflate", strategy: Optional[int] = None ) -> None: """Enable HTTP body compression""" diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 43460af5a27..918b764baf5 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -2,7 +2,7 @@ import sys import zlib from concurrent.futures import Executor -from typing import Optional, cast +from typing import Any, Final, Optional, Protocol, TypedDict, cast if sys.version_info >= (3, 12): from collections.abc import Buffer @@ -24,14 +24,113 @@ MAX_SYNC_CHUNK_SIZE = 1024 +class ZLibCompressObjProtocol(Protocol): + def compress(self, data: Buffer) -> bytes: ... + def flush(self, mode: int = ..., /) -> bytes: ... + + +class ZLibDecompressObjProtocol(Protocol): + def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ... + def flush(self, length: int = ..., /) -> bytes: ... + + @property + def eof(self) -> bool: ... + + +class ZLibBackendProtocol(Protocol): + MAX_WBITS: int + Z_FULL_FLUSH: int + Z_SYNC_FLUSH: int + Z_BEST_SPEED: int + Z_FINISH: int + + def compressobj( + self, + level: int = ..., + method: int = ..., + wbits: int = ..., + memLevel: int = ..., + strategy: int = ..., + zdict: Optional[Buffer] = ..., + ) -> ZLibCompressObjProtocol: ... + def decompressobj( + self, wbits: int = ..., zdict: Buffer = ... + ) -> ZLibDecompressObjProtocol: ... + + def compress( + self, data: Buffer, /, level: int = ..., wbits: int = ... + ) -> bytes: ... + def decompress( + self, data: Buffer, /, wbits: int = ..., bufsize: int = ... + ) -> bytes: ... + + +class CompressObjArgs(TypedDict, total=False): + wbits: int + strategy: int + level: int + + +class ZLibBackendWrapper: + def __init__(self, _zlib_backend: ZLibBackendProtocol): + self._zlib_backend: ZLibBackendProtocol = _zlib_backend + + @property + def name(self) -> str: + return getattr(self._zlib_backend, "__name__", "undefined") + + @property + def MAX_WBITS(self) -> int: + return self._zlib_backend.MAX_WBITS + + @property + def Z_FULL_FLUSH(self) -> int: + return self._zlib_backend.Z_FULL_FLUSH + + @property + def Z_SYNC_FLUSH(self) -> int: + return self._zlib_backend.Z_SYNC_FLUSH + + @property + def Z_BEST_SPEED(self) -> int: + return self._zlib_backend.Z_BEST_SPEED + + @property + def Z_FINISH(self) -> int: + return self._zlib_backend.Z_FINISH + + def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol: + return self._zlib_backend.compressobj(*args, **kwargs) + + def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol: + return self._zlib_backend.decompressobj(*args, **kwargs) + + def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: + return self._zlib_backend.compress(data, *args, **kwargs) + + def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: + return self._zlib_backend.decompress(data, *args, **kwargs) + + # Everything not explicitly listed in the Protocol we just pass through + def __getattr__(self, attrname: str) -> Any: + return getattr(self._zlib_backend, attrname) + + +ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib) + + +def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None: + ZLibBackend._zlib_backend = new_zlib_backend + + def encoding_to_mode( encoding: Optional[str] = None, suppress_deflate_header: bool = False, ) -> int: if encoding == "gzip": - return 16 + zlib.MAX_WBITS + return 16 + ZLibBackend.MAX_WBITS - return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS + return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS class ZlibBaseHandler: @@ -53,7 +152,7 @@ def __init__( suppress_deflate_header: bool = False, level: Optional[int] = None, wbits: Optional[int] = None, - strategy: int = zlib.Z_DEFAULT_STRATEGY, + strategy: Optional[int] = None, executor: Optional[Executor] = None, max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, ): @@ -66,12 +165,15 @@ def __init__( executor=executor, max_sync_chunk_size=max_sync_chunk_size, ) - if level is None: - self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy) - else: - self._compressor = zlib.compressobj( - wbits=self._mode, strategy=strategy, level=level - ) + self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) + + kwargs: CompressObjArgs = {} + kwargs["wbits"] = self._mode + if strategy is not None: + kwargs["strategy"] = strategy + if level is not None: + kwargs["level"] = level + self._compressor = self._zlib_backend.compressobj(**kwargs) self._compress_lock = asyncio.Lock() def compress_sync(self, data: Buffer) -> bytes: @@ -100,8 +202,10 @@ async def compress(self, data: Buffer) -> bytes: ) return self.compress_sync(data) - def flush(self, mode: int = zlib.Z_FINISH) -> bytes: - return self._compressor.flush(mode) + def flush(self, mode: Optional[int] = None) -> bytes: + return self._compressor.flush( + mode if mode is not None else self._zlib_backend.Z_FINISH + ) class ZLibDecompressor(ZlibBaseHandler): @@ -117,7 +221,8 @@ def __init__( executor=executor, max_sync_chunk_size=max_sync_chunk_size, ) - self._decompressor = zlib.decompressobj(wbits=self._mode) + self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) + self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode) def decompress_sync(self, data: Buffer, max_length: int = 0) -> bytes: return self._decompressor.decompress(data, max_length) @@ -149,14 +254,6 @@ def flush(self, length: int = 0) -> bytes: def eof(self) -> bool: return self._decompressor.eof - @property - def unconsumed_tail(self) -> bytes: - return self._decompressor.unconsumed_tail - - @property - def unused_data(self) -> bytes: - return self._decompressor.unused_data - class BrotliDecompressor: # Supports both 'brotlipy' and 'Brotli' packages diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index f9b0e9b2268..6b13e3cdd1d 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -2,7 +2,6 @@ import asyncio import sys -import zlib from typing import ( # noqa Any, Awaitable, @@ -85,7 +84,7 @@ def enable_chunking(self) -> None: self.chunked = True def enable_compression( - self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY + self, encoding: str = "deflate", strategy: Optional[int] = None ) -> None: self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 052e420a302..5c437248ee4 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -5,7 +5,6 @@ import sys import uuid import warnings -import zlib from collections import deque from types import TracebackType from typing import ( @@ -1032,7 +1031,7 @@ def enable_encoding(self, encoding: str) -> None: self._encoding = "quoted-printable" def enable_compression( - self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY + self, encoding: str = "deflate", strategy: Optional[int] = None ) -> None: self._compress = ZLibCompressor( encoding=encoding, diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 56596905a35..b637543b29c 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -6,7 +6,6 @@ import math import time import warnings -import zlib from concurrent.futures import Executor from http import HTTPStatus from typing import ( @@ -83,7 +82,7 @@ class StreamResponse(BaseClass, HeadersMixin, CookieMixin): _keep_alive: Optional[bool] = None _chunked: bool = False _compression: bool = False - _compression_strategy: int = zlib.Z_DEFAULT_STRATEGY + _compression_strategy: Optional[int] = None _compression_force: Optional[ContentCoding] = None _req: Optional["BaseRequest"] = None _payload_writer: Optional[AbstractStreamWriter] = None @@ -184,7 +183,7 @@ def enable_chunked_encoding(self) -> None: def enable_compression( self, force: Optional[ContentCoding] = None, - strategy: int = zlib.Z_DEFAULT_STRATEGY, + strategy: Optional[int] = None, ) -> None: """Enables response compression encoding.""" self._compression = True diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 43e02ebfeaa..a94e079b5f7 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -2145,6 +2145,30 @@ Utilities .. versionadded:: 3.0 +.. function:: set_zlib_backend(lib) + + Sets the compression backend for zlib-based operations. + + This function allows you to override the default zlib backend + used internally by passing a module that implements the standard + compression interface. + + The module should implement at minimum the exact interface offered by the + latest version of zlib. + + :param types.ModuleType lib: A module that implements the zlib-compatible compression API. + + Example usage:: + + import zlib_ng.zlib_ng as zng + import aiohttp + + aiohttp.set_zlib_backend(zng) + + .. note:: aiohttp has been tested internally with :mod:`zlib`, :mod:`zlib_ng.zlib_ng`, and :mod:`isal.isal_zlib`. + + .. versionadded:: 3.12 + FormData ^^^^^^^^ diff --git a/docs/conf.py b/docs/conf.py index 15de3598c7e..0be0e21eaef 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -84,6 +84,8 @@ "aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None), "aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None), "aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/latest/", None), + "isal": ("https://python-isal.readthedocs.io/en/stable/", None), + "zlib_ng": ("https://python-zlib-ng.readthedocs.io/en/stable/", None), } # Add any paths that contain templates here, relative to this directory. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 5eabd185d05..16c8aa789e9 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -374,3 +374,4 @@ wss www xxx yarl +zlib diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 7f661f44f71..d1c14409a28 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -649,7 +649,7 @@ and :ref:`aiohttp-web-signals` handlers:: .. seealso:: :meth:`enable_compression` - .. method:: enable_compression(force=None, strategy=zlib.Z_DEFAULT_STRATEGY) + .. method:: enable_compression(force=None, strategy=None) Enable compression. @@ -660,7 +660,10 @@ and :ref:`aiohttp-web-signals` handlers:: :class:`ContentCoding`. *strategy* accepts a :mod:`zlib` compression strategy. - See :func:`zlib.compressobj` for possible values. + See :func:`zlib.compressobj` for possible values, or refer to the + docs for the zlib of your using, should you use :func:`aiohttp.set_zlib_backend` + to change zlib backend. If ``None``, the default value adopted by + your zlib backend will be used where applicable. .. seealso:: :attr:`compression` diff --git a/requirements/dev.txt b/requirements/dev.txt index 31a410990ea..47f0b92bc17 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -97,6 +97,8 @@ incremental==24.7.2 # via towncrier iniconfig==2.1.0 # via pytest +isal==1.7.2 + # via -r requirements/test.in jinja2==3.1.6 # via # sphinx @@ -275,6 +277,8 @@ wheel==0.46.0 # via pip-tools yarl==1.19.0 # via -r requirements/runtime-deps.in +zlib_ng==0.5.1 + # via -r requirements/test.in # The following packages are considered to be unsafe in a requirements file: pip==25.0.1 diff --git a/requirements/lint.in b/requirements/lint.in index 64b34df92a9..21a9fb4e0f4 100644 --- a/requirements/lint.in +++ b/requirements/lint.in @@ -1,6 +1,7 @@ aiodns blockbuster freezegun +isal mypy; implementation_name == "cpython" pre-commit proxy.py @@ -12,3 +13,4 @@ slotscheck trustme uvloop; platform_system != "Windows" valkey +zlib_ng diff --git a/requirements/lint.txt b/requirements/lint.txt index 3a68f752556..1b9c8849163 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -39,6 +39,8 @@ idna==3.7 # via trustme iniconfig==2.1.0 # via pytest +isal==1.7.2 + # via -r requirements/lint.in markdown-it-py==3.0.0 # via rich mdurl==0.1.2 @@ -113,3 +115,5 @@ valkey==6.1.0 # via -r requirements/lint.in virtualenv==20.30.0 # via pre-commit +zlib-ng==0.5.1 + # via -r requirements/lint.in diff --git a/requirements/test.in b/requirements/test.in index 25813a963b7..b8b82abd1ce 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -3,6 +3,7 @@ blockbuster coverage freezegun +isal mypy; implementation_name == "cpython" proxy.py >= 2.4.4rc5 pytest @@ -14,3 +15,4 @@ python-on-whales setuptools-git trustme; platform_machine != "i686" # no 32-bit wheels wait-for-it +zlib_ng diff --git a/requirements/test.txt b/requirements/test.txt index b708feb7f59..3d4372221d9 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -51,6 +51,8 @@ idna==3.6 # yarl iniconfig==2.1.0 # via pytest +isal==1.7.2 + # via -r requirements/test.in markdown-it-py==3.0.0 # via rich mdurl==0.1.2 @@ -134,3 +136,5 @@ wait-for-it==2.3.0 # via -r requirements/test.in yarl==1.19.0 # via -r requirements/runtime-deps.in +zlib_ng==0.5.1 + # via -r requirements/test.in diff --git a/tests/conftest.py b/tests/conftest.py index f1f4569e3a7..e8e11835393 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import socket import ssl import sys +import zlib from hashlib import md5, sha1, sha256 from pathlib import Path from tempfile import TemporaryDirectory @@ -11,10 +12,13 @@ from unittest import mock from uuid import uuid4 +import isal.isal_zlib import pytest +import zlib_ng.zlib_ng from blockbuster import blockbuster_ctx from aiohttp.client_proto import ResponseHandler +from aiohttp.compression_utils import ZLibBackend, ZLibBackendProtocol, set_zlib_backend from aiohttp.http import WS_KEY from aiohttp.test_utils import get_unused_port_socket, loop_context @@ -296,3 +300,15 @@ def unused_port_socket() -> Generator[socket.socket, None, None]: yield s finally: s.close() + + +@pytest.fixture(params=[zlib, zlib_ng.zlib_ng, isal.isal_zlib]) +def parametrize_zlib_backend( + request: pytest.FixtureRequest, +) -> Generator[None, None, None]: + original_backend: ZLibBackendProtocol = ZLibBackend._zlib_backend + set_zlib_backend(request.param) + + yield + + set_zlib_backend(original_backend) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 04538966062..a7e229bfaa1 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -2071,7 +2071,10 @@ async def expect_handler(request: web.Request) -> None: assert expect_called -async def test_encoding_deflate(aiohttp_client: AiohttpClient) -> None: +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_encoding_deflate( + aiohttp_client: AiohttpClient, +) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.enable_chunked_encoding() @@ -2089,7 +2092,10 @@ async def handler(request: web.Request) -> web.Response: resp.close() -async def test_encoding_deflate_nochunk(aiohttp_client: AiohttpClient) -> None: +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_encoding_deflate_nochunk( + aiohttp_client: AiohttpClient, +) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.enable_compression(web.ContentCoding.deflate) @@ -2106,7 +2112,10 @@ async def handler(request: web.Request) -> web.Response: resp.close() -async def test_encoding_gzip(aiohttp_client: AiohttpClient) -> None: +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_encoding_gzip( + aiohttp_client: AiohttpClient, +) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.enable_chunked_encoding() @@ -2124,7 +2133,10 @@ async def handler(request: web.Request) -> web.Response: resp.close() -async def test_encoding_gzip_write_by_chunks(aiohttp_client: AiohttpClient) -> None: +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_encoding_gzip_write_by_chunks( + aiohttp_client: AiohttpClient, +) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() resp.enable_compression(web.ContentCoding.gzip) @@ -2144,7 +2156,10 @@ async def handler(request: web.Request) -> web.StreamResponse: resp.close() -async def test_encoding_gzip_nochunk(aiohttp_client: AiohttpClient) -> None: +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_encoding_gzip_nochunk( + aiohttp_client: AiohttpClient, +) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.enable_compression(web.ContentCoding.gzip) diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 2b5e2725c49..e1e8e3d9992 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -3,7 +3,6 @@ import io import pathlib import sys -import zlib from http.cookies import BaseCookie, Morsel, SimpleCookie from typing import ( Any, @@ -32,6 +31,7 @@ Fingerprint, _gen_default_accept_encoding, ) +from aiohttp.compression_utils import ZLibBackend from aiohttp.connector import Connection from aiohttp.http import HttpVersion10, HttpVersion11 from aiohttp.test_utils import make_mocked_coro @@ -822,8 +822,10 @@ async def test_bytes_data(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> N resp.close() +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_content_encoding( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, ) -> None: req = ClientRequest( "post", URL("http://python.org/"), data="foo", compress="deflate", loop=loop @@ -852,8 +854,10 @@ async def test_content_encoding_dont_set_headers_if_no_body( resp.close() +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_content_encoding_header( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, ) -> None: req = ClientRequest( "post", @@ -978,8 +982,11 @@ async def test_file_upload_not_chunked(loop: asyncio.AbstractEventLoop) -> None: await req.close() -async def test_precompressed_data_stays_intact(loop: asyncio.AbstractEventLoop) -> None: - data = zlib.compress(b"foobar") +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_precompressed_data_stays_intact( + loop: asyncio.AbstractEventLoop, +) -> None: + data = ZLibBackend.compress(b"foobar") req = ClientRequest( "post", URL("http://python.org/"), diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 5ec6d251388..3e871d8d29a 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -977,6 +977,7 @@ async def delayed_send_frame( assert cancelled is True +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_send_recv_compress(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() @@ -1002,6 +1003,7 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert resp.get_extra_info("socket") is None +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_send_recv_compress_wbits(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() diff --git a/tests/test_compression_utils.py b/tests/test_compression_utils.py index 047a4ff7cf0..fdaf91b36a0 100644 --- a/tests/test_compression_utils.py +++ b/tests/test_compression_utils.py @@ -1,22 +1,34 @@ """Tests for compression utils.""" -from aiohttp.compression_utils import ZLibCompressor, ZLibDecompressor +import pytest +from aiohttp.compression_utils import ZLibBackend, ZLibCompressor, ZLibDecompressor + +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_compression_round_trip_in_executor() -> None: """Ensure that compression and decompression work correctly in the executor.""" - compressor = ZLibCompressor(max_sync_chunk_size=1) + compressor = ZLibCompressor( + strategy=ZLibBackend.Z_DEFAULT_STRATEGY, max_sync_chunk_size=1 + ) + assert type(compressor._compressor) is type(ZLibBackend.compressobj()) decompressor = ZLibDecompressor(max_sync_chunk_size=1) + assert type(decompressor._decompressor) is type(ZLibBackend.decompressobj()) data = b"Hi" * 100 compressed_data = await compressor.compress(data) + compressor.flush() decompressed_data = await decompressor.decompress(compressed_data) assert data == decompressed_data +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_compression_round_trip_in_event_loop() -> None: """Ensure that compression and decompression work correctly in the event loop.""" - compressor = ZLibCompressor(max_sync_chunk_size=10000) + compressor = ZLibCompressor( + strategy=ZLibBackend.Z_DEFAULT_STRATEGY, max_sync_chunk_size=10000 + ) + assert type(compressor._compressor) is type(ZLibBackend.compressobj()) decompressor = ZLibDecompressor(max_sync_chunk_size=10000) + assert type(decompressor._decompressor) is type(ZLibBackend.decompressobj()) data = b"Hi" * 100 compressed_data = await compressor.compress(data) + compressor.flush() decompressed_data = await decompressor.decompress(compressed_data) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 7032e1417b5..4e0ca4b13ea 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -2,7 +2,7 @@ import array import asyncio import zlib -from typing import Any, Generator, Iterable +from typing import Any, Generator, Iterable, Union from unittest import mock import pytest @@ -10,6 +10,7 @@ from aiohttp import ClientConnectionResetError, hdrs, http from aiohttp.base_protocol import BaseProtocol +from aiohttp.compression_utils import ZLibBackend from aiohttp.http_writer import _serialize_headers from aiohttp.test_utils import make_mocked_coro @@ -61,6 +62,26 @@ def protocol(loop: asyncio.AbstractEventLoop, transport: asyncio.Transport) -> A ) +def decompress(data: bytes) -> bytes: + d = ZLibBackend.decompressobj() + return d.decompress(data) + + +def decode_chunked(chunked: Union[bytes, bytearray]) -> bytes: + i = 0 + out = b"" + while i < len(chunked): + j = chunked.find(b"\r\n", i) + assert j != -1, "Malformed chunk" + size = int(chunked[i:j], 16) + if size == 0: + break + i = j + 2 + out += chunked[i : i + size] + i += size + 2 # skip \r\n after the chunk + return out + + def test_payloadwriter_properties( transport: asyncio.Transport, protocol: BaseProtocol, @@ -131,6 +152,7 @@ async def test_write_payload_length( @pytest.mark.usefixtures("disable_writelines") +@pytest.mark.internal # Used for performance benchmarking async def test_write_large_payload_deflate_compression_data_in_eof( protocol: BaseProtocol, transport: asyncio.Transport, @@ -156,7 +178,42 @@ async def test_write_large_payload_deflate_compression_data_in_eof( assert zlib.decompress(content) == (b"data" * 4096) + payload +@pytest.mark.usefixtures("disable_writelines") +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_large_payload_deflate_compression_data_in_eof_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + + await msg.write(b"data" * 4096) + # Behavior depends on zlib backend, isal compress() returns b'' initially + # and the entire compressed bytes at flush() for this data + backend_to_write_called = { + "isal.isal_zlib": False, + "zlib": True, + "zlib_ng.zlib_ng": True, + } + assert transport.write.called == backend_to_write_called[ZLibBackend.name] # type: ignore[attr-defined] + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + transport.write.reset_mock() # type: ignore[attr-defined] + + # This payload compresses to 20447 bytes + payload = b"".join( + [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] + ) + await msg.write_eof(payload) + chunks.extend([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] + + assert all(chunks) + content = b"".join(chunks) + assert ZLibBackend.decompress(content) == (b"data" * 4096) + payload + + @pytest.mark.usefixtures("enable_writelines") +@pytest.mark.internal # Used for performance benchmarking async def test_write_large_payload_deflate_compression_data_in_eof_writelines( protocol: BaseProtocol, transport: asyncio.Transport, @@ -183,6 +240,43 @@ async def test_write_large_payload_deflate_compression_data_in_eof_writelines( assert zlib.decompress(content) == (b"data" * 4096) + payload +@pytest.mark.usefixtures("enable_writelines") +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_large_payload_deflate_compression_data_in_eof_writelines_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + + await msg.write(b"data" * 4096) + # Behavior depends on zlib backend, isal compress() returns b'' initially + # and the entire compressed bytes at flush() for this data + backend_to_write_called = { + "isal.isal_zlib": False, + "zlib": True, + "zlib_ng.zlib_ng": True, + } + assert transport.write.called == backend_to_write_called[ZLibBackend.name] # type: ignore[attr-defined] + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + transport.write.reset_mock() # type: ignore[attr-defined] + assert not transport.writelines.called # type: ignore[attr-defined] + + # This payload compresses to 20447 bytes + payload = b"".join( + [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] + ) + await msg.write_eof(payload) + assert transport.writelines.called != transport.write.called # type: ignore[attr-defined] + if transport.writelines.called: # type: ignore[attr-defined] + chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined] + else: # transport.write.called: # type: ignore[attr-defined] + chunks.extend([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] + content = b"".join(chunks) + assert ZLibBackend.decompress(content) == (b"data" * 4096) + payload + + async def test_write_payload_chunked_filter( protocol: BaseProtocol, transport: asyncio.Transport, @@ -219,6 +313,7 @@ async def test_write_payload_chunked_filter_multiple_chunks( ) +@pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression( protocol: BaseProtocol, transport: asyncio.Transport, @@ -236,6 +331,24 @@ async def test_write_payload_deflate_compression( assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1] +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_payload_deflate_compression_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + await msg.write(b"data") + await msg.write_eof() + + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert b"data" == decompress(content) + + +@pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked( protocol: BaseProtocol, transport: asyncio.Transport, @@ -254,8 +367,27 @@ async def test_write_payload_deflate_compression_chunked( assert content == expected +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_payload_deflate_compression_chunked_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof() + + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert b"data" == decompress(decode_chunked(content)) + + @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") +@pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked_writelines( protocol: BaseProtocol, transport: asyncio.Transport, @@ -274,6 +406,27 @@ async def test_write_payload_deflate_compression_chunked_writelines( assert content == expected +@pytest.mark.usefixtures("enable_writelines") +@pytest.mark.usefixtures("force_writelines_small_payloads") +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_payload_deflate_compression_chunked_writelines_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof() + + chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert b"data" == decompress(decode_chunked(content)) + + +@pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_and_chunked( buf: bytearray, protocol: BaseProtocol, @@ -292,6 +445,25 @@ async def test_write_payload_deflate_and_chunked( assert thing == buf +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_payload_deflate_and_chunked_all_zlib( + buf: bytearray, + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + + await msg.write(b"da") + await msg.write(b"ta") + await msg.write_eof() + + assert b"data" == decompress(decode_chunked(buf)) + + +@pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked_data_in_eof( protocol: BaseProtocol, transport: asyncio.Transport, @@ -310,8 +482,27 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof( assert content == expected +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_payload_deflate_compression_chunked_data_in_eof_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof(b"end") + + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert b"dataend" == decompress(decode_chunked(content)) + + @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") +@pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines( protocol: BaseProtocol, transport: asyncio.Transport, @@ -330,6 +521,27 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines( assert content == expected +@pytest.mark.usefixtures("enable_writelines") +@pytest.mark.usefixtures("force_writelines_small_payloads") +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof(b"end") + + chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert b"dataend" == decompress(decode_chunked(content)) + + +@pytest.mark.internal # Used for performance benchmarking async def test_write_large_payload_deflate_compression_chunked_data_in_eof( protocol: BaseProtocol, transport: asyncio.Transport, @@ -356,8 +568,36 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof( assert zlib.decompress(content) == (b"data" * 4096) + payload +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_large_payload_deflate_compression_chunked_data_in_eof_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + + await msg.write(b"data" * 4096) + # This payload compresses to 1111 bytes + payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) + await msg.write_eof(payload) + + compressed = [] + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + chunked_body = b"".join(chunks) + split_body = chunked_body.split(b"\r\n") + while split_body: + if split_body.pop(0): + compressed.append(split_body.pop(0)) + + content = b"".join(compressed) + assert ZLibBackend.decompress(content) == (b"data" * 4096) + payload + + @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") +@pytest.mark.internal # Used for performance benchmarking async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines( protocol: BaseProtocol, transport: asyncio.Transport, @@ -384,6 +624,36 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof_write assert zlib.decompress(content) == (b"data" * 4096) + payload +@pytest.mark.usefixtures("enable_writelines") +@pytest.mark.usefixtures("force_writelines_small_payloads") +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + + await msg.write(b"data" * 4096) + # This payload compresses to 1111 bytes + payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) + await msg.write_eof(payload) + assert not transport.write.called # type: ignore[attr-defined] + + chunks = [] + for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined] + chunked_payload = list(write_lines_call[1][0])[1:] + chunked_payload.pop() + chunks.extend(chunked_payload) + + assert all(chunks) + content = b"".join(chunks) + assert ZLibBackend.decompress(content) == (b"data" * 4096) + payload + + +@pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked_connection_lost( protocol: BaseProtocol, transport: asyncio.Transport, @@ -402,6 +672,25 @@ async def test_write_payload_deflate_compression_chunked_connection_lost( await msg.write_eof(b"end") +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_write_payload_deflate_compression_chunked_connection_lost_all_zlib( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + with ( + pytest.raises( + ClientConnectionResetError, match="Cannot write to closing transport" + ), + mock.patch.object(transport, "is_closing", return_value=True), + ): + await msg.write_eof(b"end") + + async def test_write_payload_bytes_memoryview( buf: bytearray, protocol: BaseProtocol, diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 55befdbb60f..6d9707f4a5a 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -3,7 +3,6 @@ import json import pathlib import sys -import zlib from types import TracebackType from typing import Dict, Optional, Tuple, Type from unittest import mock @@ -13,6 +12,7 @@ import aiohttp from aiohttp import payload +from aiohttp.compression_utils import ZLibBackend from aiohttp.hdrs import ( CONTENT_DISPOSITION, CONTENT_ENCODING, @@ -1104,8 +1104,11 @@ async def test_writer_write_no_parts( assert b"--:--\r\n" == bytes(buf) +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_writer_serialize_with_content_encoding_gzip( - buf: bytearray, stream: Stream, writer: aiohttp.MultipartWriter + buf: bytearray, + stream: Stream, + writer: aiohttp.MultipartWriter, ) -> None: writer.append("Time to Relax!", {CONTENT_ENCODING: "gzip"}) await writer.write(stream) @@ -1116,7 +1119,7 @@ async def test_writer_serialize_with_content_encoding_gzip( b"Content-Encoding: gzip" == headers ) - decompressor = zlib.decompressobj(wbits=16 + zlib.MAX_WBITS) + decompressor = ZLibBackend.decompressobj(wbits=16 + ZLibBackend.MAX_WBITS) data = decompressor.decompress(message.split(b"\r\n")[0]) data += decompressor.flush() assert b"Time to Relax!" == data diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 332e5cba6bb..ffa27ec8acf 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -4,8 +4,17 @@ import pathlib import socket import sys -import zlib -from typing import AsyncIterator, Awaitable, Callable, Dict, List, NoReturn, Optional +from typing import ( + AsyncIterator, + Awaitable, + Callable, + Dict, + Generator, + List, + NoReturn, + Optional, + Tuple, +) from unittest import mock import pytest @@ -24,6 +33,7 @@ web, ) from aiohttp.abc import AbstractResolver, ResolveResult +from aiohttp.compression_utils import ZLibBackend, ZLibCompressObjProtocol from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.test_utils import make_mocked_coro @@ -1090,19 +1100,30 @@ async def handler(request: web.Request) -> web.Response: resp.release() -@pytest.mark.parametrize( - "compressor,encoding", - [ - (zlib.compressobj(wbits=16 + zlib.MAX_WBITS), "gzip"), - (zlib.compressobj(wbits=zlib.MAX_WBITS), "deflate"), - # Actually, wrong compression format, but - # should be supported for some legacy cases. - (zlib.compressobj(wbits=-zlib.MAX_WBITS), "deflate"), - ], -) +@pytest.fixture(params=["gzip", "deflate", "deflate-raw"]) +def compressor_case( + request: pytest.FixtureRequest, + parametrize_zlib_backend: None, +) -> Generator[Tuple[ZLibCompressObjProtocol, str], None, None]: + encoding: str = request.param + max_wbits: int = ZLibBackend.MAX_WBITS + + encoding_to_wbits: Dict[str, int] = { + "deflate": max_wbits, + "deflate-raw": -max_wbits, + "gzip": 16 + max_wbits, + } + + compressor = ZLibBackend.compressobj(wbits=encoding_to_wbits[encoding]) + yield (compressor, "deflate" if encoding.startswith("deflate") else encoding) + + async def test_response_with_precompressed_body( - aiohttp_client: AiohttpClient, compressor: "zlib._Compress", encoding: str + aiohttp_client: AiohttpClient, + compressor_case: Tuple[ZLibCompressObjProtocol, str], ) -> None: + compressor, encoding = compressor_case + async def handler(request: web.Request) -> web.Response: headers = {"Content-Encoding": encoding} data = compressor.compress(b"mydata") + compressor.flush() @@ -2179,6 +2200,7 @@ async def handler(request: web.Request) -> web.Response: @pytest.mark.parametrize( "auto_decompress,len_of", [(True, "uncompressed"), (False, "compressed")] ) +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_auto_decompress( aiohttp_client: AiohttpClient, auto_decompress: bool, @@ -2193,7 +2215,7 @@ async def handler(request: web.Request) -> web.Response: client = await aiohttp_client(app) uncompressed = b"dataaaaaaaaaaaaaaaaaaaaaaaaa" - compressor = zlib.compressobj(wbits=16 + zlib.MAX_WBITS) + compressor = ZLibBackend.compressobj(wbits=16 + ZLibBackend.MAX_WBITS) compressed = compressor.compress(uncompressed) + compressor.flush() assert len(compressed) != len(uncompressed) headers = {"content-encoding": "gzip"} diff --git a/tests/test_web_response.py b/tests/test_web_response.py index e5dd4dab7fb..b6d9b3e2bd3 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -5,7 +5,6 @@ import json import re import weakref -import zlib from concurrent.futures import ThreadPoolExecutor from typing import AsyncIterator, Optional, Union from unittest import mock @@ -389,6 +388,7 @@ async def test_chunked_encoding_forbidden_for_http_10() -> None: assert str(ctx.value) == "Using chunked encoding is forbidden for HTTP/1.0" +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_compression_no_accept() -> None: req = make_request("GET", "/") resp = web.StreamResponse() @@ -402,6 +402,7 @@ async def test_compression_no_accept() -> None: assert not msg.enable_compression.called +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_compression_default_coding() -> None: req = make_request( "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) @@ -415,11 +416,12 @@ async def test_compression_default_coding() -> None: msg = await resp.prepare(req) # type: ignore[unreachable] - msg.enable_compression.assert_called_with("deflate", zlib.Z_DEFAULT_STRATEGY) + msg.enable_compression.assert_called_with("deflate", None) assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) assert msg.filter is not None +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_deflate() -> None: req = make_request( "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) @@ -431,10 +433,11 @@ async def test_force_compression_deflate() -> None: msg = await resp.prepare(req) assert msg is not None - msg.enable_compression.assert_called_with("deflate", zlib.Z_DEFAULT_STRATEGY) # type: ignore[attr-defined] + msg.enable_compression.assert_called_with("deflate", None) # type: ignore[attr-defined] assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_deflate_large_payload() -> None: """Make sure a warning is thrown for large payloads compressed in the event loop.""" req = make_request( @@ -454,6 +457,7 @@ async def test_force_compression_deflate_large_payload() -> None: assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_no_accept_deflate() -> None: req = make_request("GET", "/") resp = web.StreamResponse() @@ -463,10 +467,11 @@ async def test_force_compression_no_accept_deflate() -> None: msg = await resp.prepare(req) assert msg is not None - msg.enable_compression.assert_called_with("deflate", zlib.Z_DEFAULT_STRATEGY) # type: ignore[attr-defined] + msg.enable_compression.assert_called_with("deflate", None) # type: ignore[attr-defined] assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_gzip() -> None: req = make_request( "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) @@ -478,10 +483,11 @@ async def test_force_compression_gzip() -> None: msg = await resp.prepare(req) assert msg is not None - msg.enable_compression.assert_called_with("gzip", zlib.Z_DEFAULT_STRATEGY) # type: ignore[attr-defined] + msg.enable_compression.assert_called_with("gzip", None) # type: ignore[attr-defined] assert "gzip" == resp.headers.get(hdrs.CONTENT_ENCODING) +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_no_accept_gzip() -> None: req = make_request("GET", "/") resp = web.StreamResponse() @@ -491,10 +497,11 @@ async def test_force_compression_no_accept_gzip() -> None: msg = await resp.prepare(req) assert msg is not None - msg.enable_compression.assert_called_with("gzip", zlib.Z_DEFAULT_STRATEGY) # type: ignore[attr-defined] + msg.enable_compression.assert_called_with("gzip", None) # type: ignore[attr-defined] assert "gzip" == resp.headers.get(hdrs.CONTENT_ENCODING) +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_change_content_threaded_compression_enabled() -> None: req = make_request("GET", "/") body_thread_size = 1024 @@ -507,6 +514,7 @@ async def test_change_content_threaded_compression_enabled() -> None: assert gzip.decompress(resp._compressed_body) == body +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_change_content_threaded_compression_enabled_explicit() -> None: req = make_request("GET", "/") body_thread_size = 1024 @@ -522,6 +530,7 @@ async def test_change_content_threaded_compression_enabled_explicit() -> None: assert gzip.decompress(resp._compressed_body) == body +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_change_content_length_if_compression_enabled() -> None: req = make_request("GET", "/") resp = web.Response(body=b"answer") @@ -531,6 +540,7 @@ async def test_change_content_length_if_compression_enabled() -> None: assert resp.content_length is not None and resp.content_length != len(b"answer") +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_set_content_length_if_compression_enabled() -> None: writer = mock.Mock() @@ -550,6 +560,7 @@ async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert resp.content_length == 26 +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_remove_content_length_if_compression_enabled_http11() -> None: writer = mock.Mock() @@ -566,6 +577,7 @@ async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert resp.content_length is None +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_remove_content_length_if_compression_enabled_http10() -> None: writer = mock.Mock() @@ -582,6 +594,7 @@ async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert resp.content_length is None +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_identity() -> None: writer = mock.Mock() @@ -598,6 +611,7 @@ async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert resp.content_length == 123 +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_identity_response() -> None: writer = mock.Mock() @@ -613,6 +627,7 @@ async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert resp.content_length == 6 +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_rm_content_length_if_compression_http11() -> None: writer = mock.Mock() @@ -630,6 +645,7 @@ async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert resp.content_length is None +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_rm_content_length_if_compression_http10() -> None: writer = mock.Mock() diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index a093b1c01a7..64b97787616 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -3,7 +3,6 @@ import gzip import pathlib import socket -import zlib from typing import Iterable, Iterator, NoReturn, Optional, Protocol, Tuple from unittest import mock @@ -12,6 +11,7 @@ import aiohttp from aiohttp import web +from aiohttp.compression_utils import ZLibBackend from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.typedefs import PathLike @@ -313,6 +313,7 @@ async def handler(request: web.Request) -> web.FileResponse: [("gzip, deflate", "gzip"), ("gzip, deflate, br", "br")], ) @pytest.mark.parametrize("forced_compression", [None, web.ContentCoding.gzip]) +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_static_file_with_encoding_and_enable_compression( hello_txt: pathlib.Path, aiohttp_client: AiohttpClient, @@ -1062,8 +1063,10 @@ async def test_static_file_if_range_invalid_date( await client.close() +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_static_file_compression( - aiohttp_client: AiohttpClient, sender: _Sender + aiohttp_client: AiohttpClient, + sender: _Sender, ) -> None: filepath = pathlib.Path(__file__).parent / "data.unknown_mime_type" @@ -1078,7 +1081,7 @@ async def handler(request: web.Request) -> web.FileResponse: resp = await client.get("/") assert resp.status == 200 - zcomp = zlib.compressobj(wbits=zlib.MAX_WBITS) + zcomp = ZLibBackend.compressobj(wbits=ZLibBackend.MAX_WBITS) expected_body = zcomp.compress(b"file content\n") + zcomp.flush() assert expected_body == await resp.read() assert "application/octet-stream" == resp.headers["Content-Type"] diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 7e187b897c4..a28c7e5ac9c 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -1,8 +1,7 @@ import asyncio import pickle import struct -import zlib -from typing import Union +from typing import Optional, Union from unittest import mock import pytest @@ -12,6 +11,7 @@ from aiohttp._websocket.models import WS_DEFLATE_TRAILING from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol +from aiohttp.compression_utils import ZLibBackend, ZLibBackendWrapper from aiohttp.http import WebSocketError, WSCloseCode, WSMsgType from aiohttp.http_websocket import ( WebSocketReader, @@ -32,13 +32,15 @@ def build_frame( opcode: int, noheader: bool = False, is_fin: bool = True, - compress: bool = False, + ZLibBackend: Optional[ZLibBackendWrapper] = None, ) -> bytes: # Send a frame over the websocket with message as its payload. - if compress: - compressobj = zlib.compressobj(wbits=-9) + compress = False + if ZLibBackend: + compress = True + compressobj = ZLibBackend.compressobj(wbits=-9) message = compressobj.compress(message) - message = message + compressobj.flush(zlib.Z_SYNC_FLUSH) + message = message + compressobj.flush(ZLibBackend.Z_SYNC_FLUSH) if message.endswith(WS_DEFLATE_TRAILING): message = message[:-4] msg_length = len(message) @@ -572,9 +574,10 @@ def test_msg_too_large_not_fin(out: WebSocketDataQueue) -> None: assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG +@pytest.mark.usefixtures("parametrize_zlib_backend") def test_compressed_msg_too_large(out: WebSocketDataQueue) -> None: parser = WebSocketReader(out, 256, compress=True) - data = build_frame(b"aaa" * 256, WSMsgType.TEXT, compress=True) + data = build_frame(b"aaa" * 256, WSMsgType.TEXT, ZLibBackend=ZLibBackend) with pytest.raises(WebSocketError) as ctx: parser._feed_data(data) assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index fa49410c283..ec22ac0b5eb 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -127,6 +127,7 @@ async def test_send_compress_text_per_message( (32, lambda count: 64 + count if count % 2 else count), ), ) +@pytest.mark.usefixtures("parametrize_zlib_backend") async def test_concurrent_messages( protocol: BaseProtocol, transport: asyncio.Transport,