diff --git a/CHANGES/10662.packaging.rst b/CHANGES/10662.packaging.rst new file mode 100644 index 00000000000..2ed3a69cb56 --- /dev/null +++ b/CHANGES/10662.packaging.rst @@ -0,0 +1 @@ +Removed non SPDX-license description from ``setup.cfg`` -- by :user:`devanshu-ziphq`. diff --git a/CHANGES/9732.feature.rst b/CHANGES/9732.feature.rst new file mode 100644 index 00000000000..bf6dd8ebde3 --- /dev/null +++ b/CHANGES/9732.feature.rst @@ -0,0 +1,6 @@ +Added client middleware support -- by :user:`bdraco` and :user:`Dreamsorcerer`. + +This change allows users to add middleware to the client session and requests, enabling features like +authentication, logging, and request/response modification without modifying the core +request logic. Additionally, the ``session`` attribute was added to ``ClientRequest``, +allowing middleware to access the session for making additional requests. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 7c36b570e87..89eb3ae621a 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -105,6 +105,7 @@ Denilson Amorim Denis Matiychuk Denis Moshensky Dennis Kliban +Devanshu Koyalkar Dima Veselov Dimitar Dimitrov Diogo Dutra da Mata diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index f23bf928f37..f2ada6bcf07 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -47,6 +47,7 @@ WSServerHandshakeError, request, ) +from .client_middlewares import ClientHandlerType, ClientMiddlewareType from .compression_utils import set_zlib_backend from .connector import AddrInfoType, SocketFactoryType from .cookiejar import CookieJar, DummyCookieJar @@ -157,6 +158,9 @@ "NamedPipeConnector", "WSServerHandshakeError", "request", + # client_middleware + "ClientMiddlewareType", + "ClientHandlerType", # cookiejar "CookieJar", "DummyCookieJar", diff --git a/aiohttp/client.py b/aiohttp/client.py index 04f03b710f0..30a29c36c01 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -72,6 +72,7 @@ WSMessageTypeError, WSServerHandshakeError, ) +from .client_middlewares import ClientMiddlewareType, build_client_middlewares from .client_reqrep import ( SSL_ALLOWED_TYPES, ClientRequest, @@ -193,6 +194,7 @@ class _RequestOptions(TypedDict, total=False): auto_decompress: Union[bool, None] max_line_size: Union[int, None] max_field_size: Union[int, None] + middlewares: Optional[Tuple[ClientMiddlewareType, ...]] @frozen_dataclass_decorator @@ -260,6 +262,7 @@ class ClientSession: "_default_proxy", "_default_proxy_auth", "_retry_connection", + "_middlewares", ) def __init__( @@ -292,6 +295,7 @@ def __init__( max_line_size: int = 8190, max_field_size: int = 8190, fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8", + middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None, ) -> None: # We initialise _connector to None immediately, as it's referenced in __del__() # and could cause issues if an exception occurs during initialisation. @@ -376,6 +380,7 @@ def __init__( self._default_proxy = proxy self._default_proxy_auth = proxy_auth self._retry_connection: bool = True + self._middlewares = middlewares def __init_subclass__(cls: Type["ClientSession"]) -> None: raise TypeError( @@ -450,6 +455,7 @@ async def _request( auto_decompress: Optional[bool] = None, max_line_size: Optional[int] = None, max_field_size: Optional[int] = None, + middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None, ) -> ClientResponse: # NOTE: timeout clamps existing connect and read timeouts. We cannot # set the default to None because we need to detect if the user wants @@ -642,32 +648,33 @@ async def _request( trust_env=self.trust_env, ) - # connection timeout - try: - conn = await self._connector.connect( - req, traces=traces, timeout=real_timeout + # Core request handler - now includes connection logic + async def _connect_and_send_request( + req: ClientRequest, + ) -> ClientResponse: + # connection timeout + assert self._connector is not None + try: + conn = await self._connector.connect( + req, traces=traces, timeout=real_timeout + ) + except asyncio.TimeoutError as exc: + raise ConnectionTimeoutError( + f"Connection timeout to host {req.url}" + ) from exc + + assert conn.protocol is not None + conn.protocol.set_response_params( + timer=timer, + skip_payload=req.method in EMPTY_BODY_METHODS, + read_until_eof=read_until_eof, + auto_decompress=auto_decompress, + read_timeout=real_timeout.sock_read, + read_bufsize=read_bufsize, + timeout_ceil_threshold=self._connector._timeout_ceil_threshold, + max_line_size=max_line_size, + max_field_size=max_field_size, ) - except asyncio.TimeoutError as exc: - raise ConnectionTimeoutError( - f"Connection timeout to host {url}" - ) from exc - - assert conn.transport is not None - - assert conn.protocol is not None - conn.protocol.set_response_params( - timer=timer, - skip_payload=method in EMPTY_BODY_METHODS, - read_until_eof=read_until_eof, - auto_decompress=auto_decompress, - read_timeout=real_timeout.sock_read, - read_bufsize=read_bufsize, - timeout_ceil_threshold=self._connector._timeout_ceil_threshold, - max_line_size=max_line_size, - max_field_size=max_field_size, - ) - - try: try: resp = await req.send(conn) try: @@ -678,6 +685,30 @@ async def _request( except BaseException: conn.close() raise + return resp + + # Apply middleware (if any) - per-request middleware overrides session middleware + effective_middlewares = ( + self._middlewares if middlewares is None else middlewares + ) + + if effective_middlewares: + handler = build_client_middlewares( + _connect_and_send_request, effective_middlewares + ) + else: + handler = _connect_and_send_request + + try: + resp = await handler(req) + # Client connector errors should not be retried + except ( + ConnectionTimeoutError, + ClientConnectorError, + ClientConnectorCertificateError, + ClientConnectorSSLError, + ): + raise except (ClientOSError, ServerDisconnectedError): if retry_persistent_connection: retry_persistent_connection = False diff --git a/aiohttp/client_middlewares.py b/aiohttp/client_middlewares.py new file mode 100644 index 00000000000..6be353c3a40 --- /dev/null +++ b/aiohttp/client_middlewares.py @@ -0,0 +1,58 @@ +"""Client middleware support.""" + +from collections.abc import Awaitable, Callable + +from .client_reqrep import ClientRequest, ClientResponse + +__all__ = ("ClientMiddlewareType", "ClientHandlerType", "build_client_middlewares") + +# Type alias for client request handlers - functions that process requests and return responses +ClientHandlerType = Callable[[ClientRequest], Awaitable[ClientResponse]] + +# Type for client middleware - similar to server but uses ClientRequest/ClientResponse +ClientMiddlewareType = Callable[ + [ClientRequest, ClientHandlerType], Awaitable[ClientResponse] +] + + +def build_client_middlewares( + handler: ClientHandlerType, + middlewares: tuple[ClientMiddlewareType, ...], +) -> ClientHandlerType: + """ + Apply middlewares to request handler. + + The middlewares are applied in reverse order, so the first middleware + in the list wraps all subsequent middlewares and the handler. + + This implementation avoids using partial/update_wrapper to minimize overhead + and doesn't cache to avoid holding references to stateful middleware. + """ + if not middlewares: + return handler + + # Optimize for single middleware case + if len(middlewares) == 1: + middleware = middlewares[0] + + async def single_middleware_handler(req: ClientRequest) -> ClientResponse: + return await middleware(req, handler) + + return single_middleware_handler + + # Build the chain for multiple middlewares + current_handler = handler + + for middleware in reversed(middlewares): + # Create a new closure that captures the current state + def make_wrapper( + mw: ClientMiddlewareType, next_h: ClientHandlerType + ) -> ClientHandlerType: + async def wrapped(req: ClientRequest) -> ClientResponse: + return await mw(req, next_h) + + return wrapped + + current_handler = make_wrapper(middleware, current_handler) + + return current_handler diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index d30e8704d3e..db4018efa1d 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -210,6 +210,11 @@ class ClientRequest: auth = None response = None + # These class defaults help create_autospec() work correctly. + # If autospec is improved in future, maybe these can be removed. + url = URL() + method = "GET" + __writer: Optional["asyncio.Task[None]"] = None # async task for streaming data _continue = None # waiter future for '100 Continue' response @@ -362,6 +367,16 @@ def request_info(self) -> RequestInfo: RequestInfo, (self.url, self.method, headers, self.original_url) ) + @property + def session(self) -> "ClientSession": + """Return the ClientSession instance. + + This property provides access to the ClientSession that initiated + this request, allowing middleware to make additional requests + using the same session. + """ + return self._session + def update_host(self, url: URL) -> None: """Update destination host, port and connection type (ssl).""" # get host/port diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 0f6eb99974b..107141e69be 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -105,6 +105,218 @@ background. Started keeping the ``Authorization`` header during HTTP → HTTPS redirects when the host remains the same. +.. _aiohttp-client-middleware: + +Client Middleware +----------------- + +aiohttp client supports middleware to intercept requests and responses. This can be +useful for authentication, logging, request/response modification, and retries. + +To create a middleware, you need to define an async function that accepts the request +and a handler function, and returns the response. The middleware must match the +:type:`ClientMiddlewareType` type signature:: + + import logging + from aiohttp import ClientSession, ClientRequest, ClientResponse, ClientHandlerType + + _LOGGER = logging.getLogger(__name__) + + async def my_middleware( + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + # Process request before sending + _LOGGER.debug(f"Request: {request.method} {request.url}") + + # Call the next handler + response = await handler(request) + + # Process response after receiving + _LOGGER.debug(f"Response: {response.status}") + + return response + +You can apply middleware to a client session or to individual requests:: + + # Apply to all requests in a session + async with ClientSession(middlewares=(my_middleware,)) as session: + resp = await session.get('http://example.com') + + # Apply to a specific request + async with ClientSession() as session: + resp = await session.get('http://example.com', middlewares=(my_middleware,)) + +Middleware Examples +^^^^^^^^^^^^^^^^^^^ + +Here's a simple example showing request modification:: + + async def add_api_key_middleware( + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + # Add API key to all requests + request.headers['X-API-Key'] = 'my-secret-key' + return await handler(request) + +.. _client-middleware-retry: + +Middleware Retry Pattern +^^^^^^^^^^^^^^^^^^^^^^^^ + +Client middleware can implement retry logic internally using a ``while`` loop. This allows the middleware to: + +- Retry requests based on response status codes or other conditions +- Modify the request between retries (e.g., refreshing tokens) +- Maintain state across retry attempts +- Control when to stop retrying and return the response + +This pattern is particularly useful for: + +- Refreshing authentication tokens after a 401 response +- Switching to fallback servers or authentication methods +- Adding or modifying headers based on error responses +- Implementing back-off strategies with increasing delays + +The middleware can maintain state between retries to track which strategies have been tried and modify the request accordingly for the next attempt. + +Example: Retrying requests with middleware +"""""""""""""""""""""""""""""""""""""""""" + +:: + + import logging + import aiohttp + + _LOGGER = logging.getLogger(__name__) + + class RetryMiddleware: + def __init__(self, max_retries: int = 3): + self.max_retries = max_retries + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + retry_count = 0 + use_fallback_auth = False + + while True: + # Modify request based on retry state + if use_fallback_auth: + request.headers['Authorization'] = 'Bearer fallback-token' + + response = await handler(request) + + # Retry on 401 errors with different authentication + if response.status == 401 and retry_count < self.max_retries: + retry_count += 1 + use_fallback_auth = True + _LOGGER.debug(f"Retrying with fallback auth (attempt {retry_count})") + continue + + # Retry on 5xx errors + if response.status >= 500 and retry_count < self.max_retries: + retry_count += 1 + _LOGGER.debug(f"Retrying request (attempt {retry_count})") + continue + + return response + +Middleware Chaining +^^^^^^^^^^^^^^^^^^^ + +Multiple middlewares are applied in the order they are listed:: + + import logging + + _LOGGER = logging.getLogger(__name__) + + async def logging_middleware( + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + _LOGGER.debug(f"[LOG] {request.method} {request.url}") + return await handler(request) + + async def auth_middleware( + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + request.headers['Authorization'] = 'Bearer token123' + return await handler(request) + + # Middlewares are applied in order: logging -> auth -> request + async with ClientSession(middlewares=(logging_middleware, auth_middleware)) as session: + resp = await session.get('http://example.com') + +.. note:: + + Client middleware is a powerful feature but should be used judiciously. + Each middleware adds overhead to request processing. For simple use cases + like adding static headers, you can often use request parameters + (e.g., ``headers``) or session configuration instead. + +.. warning:: + + Using the same session from within middleware can cause infinite recursion if + the middleware makes HTTP requests using the same session that has the middleware + applied. + + To avoid recursion, use one of these approaches: + + **Recommended:** Pass ``middlewares=()`` to requests made inside the middleware to + disable middleware for those specific requests:: + + async def log_middleware( + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + async with request.session.post( + "https://logapi.example/log", + json={"url": str(request.url)}, + middlewares=() # This prevents infinite recursion + ) as resp: + pass + + return await handler(request) + + **Alternative:** Check the request contents (URL, path, host) to avoid applying + middleware to certain requests:: + + async def log_middleware( + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + if request.url.host != "logapi.example": # Avoid infinite recursion + async with request.session.post( + "https://logapi.example/log", + json={"url": str(request.url)} + ) as resp: + pass + + return await handler(request) + +Middleware Type +^^^^^^^^^^^^^^^ + +.. type:: ClientMiddlewareType + + Type alias for client middleware functions. Middleware functions must have this signature:: + + Callable[ + [ClientRequest, ClientHandlerType], + Awaitable[ClientResponse] + ] + +.. type:: ClientHandlerType + + Type alias for client request handler functions:: + + Callable[ClientRequest, Awaitable[ClientResponse]] + Custom Cookies -------------- diff --git a/docs/client_reference.rst b/docs/client_reference.rst index a94e079b5f7..84e2f0c7014 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -53,6 +53,7 @@ The client session supports the context manager protocol for self closing. trust_env=False, \ requote_redirect_url=True, \ trace_configs=None, \ + middlewares=None, \ read_bufsize=2**16, \ max_line_size=8190, \ max_field_size=8190, \ @@ -213,6 +214,13 @@ The client session supports the context manager protocol for self closing. disabling. See :ref:`aiohttp-client-tracing-reference` for more information. + :param middlewares: A tuple of middleware instances to apply to all session requests. + Each middleware must match the :type:`ClientMiddlewareType` signature. + ``None`` (default) is used when no middleware is needed. + See :ref:`aiohttp-client-middleware` for more information. + + .. versionadded:: 3.12 + :param int read_bufsize: Size of the read buffer (:attr:`ClientResponse.content`). 64 KiB by default. @@ -371,6 +379,7 @@ The client session supports the context manager protocol for self closing. server_hostname=None, \ proxy_headers=None, \ trace_request_ctx=None, \ + middlewares=None, \ read_bufsize=None, \ auto_decompress=None, \ max_line_size=None, \ @@ -519,6 +528,13 @@ The client session supports the context manager protocol for self closing. .. versionadded:: 3.0 + :param middlewares: A tuple of middleware instances to apply to this request only. + Each middleware must match the :type:`ClientMiddlewareType` signature. + ``None`` by default which uses session middlewares. + See :ref:`aiohttp-client-middleware` for more information. + + .. versionadded:: 3.12 + :param int read_bufsize: Size of the read buffer (:attr:`ClientResponse.content`). ``None`` by default, it means that the session global value is used. @@ -1474,7 +1490,12 @@ Response object Returns value is ``'application/octet-stream'`` if no Content-Type header present in HTTP headers according to - :rfc:`2616`. To make sure Content-Type header is not present in + :rfc:`9110`. If the *Content-Type* header is invalid (e.g., ``jpg`` + instead of ``image/jpeg``), the value is ``text/plain`` by default + according to :rfc:`2045`. To see the original header check + ``resp.headers['CONTENT-TYPE']``. + + To make sure Content-Type header is not present in the server reply, use :attr:`headers` or :attr:`raw_headers`, e.g. ``'CONTENT-TYPE' not in resp.headers``. diff --git a/setup.cfg b/setup.cfg index 0e6fae807f4..30aa2e87838 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,8 +25,6 @@ classifiers = Intended Audience :: Developers - License :: OSI Approved :: Apache Software License - Operating System :: POSIX Operating System :: MacOS :: MacOS X Operating System :: Microsoft :: Windows diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index a7e229bfaa1..516e81a825c 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -1102,6 +1102,30 @@ async def handler(request: web.Request) -> web.StreamResponse: assert resp.status == 200 +async def test_connection_timeout_error( + aiohttp_client: AiohttpClient, mocker: MockerFixture +) -> None: + """Test that ConnectionTimeoutError is raised when connection times out.""" + + async def handler(request: web.Request) -> NoReturn: + assert False, "Handler should not be called" + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + # Mock the connector's connect method to raise asyncio.TimeoutError + mock_connect = mocker.patch.object( + client.session._connector, "connect", side_effect=asyncio.TimeoutError() + ) + + with pytest.raises(aiohttp.ConnectionTimeoutError) as exc_info: + await client.get("/", timeout=aiohttp.ClientTimeout(connect=0.01)) + + assert "Connection timeout to host" in str(exc_info.value) + mock_connect.assert_called_once() + + async def test_readline_error_on_conn_close(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() diff --git a/tests/test_client_middleware.py b/tests/test_client_middleware.py new file mode 100644 index 00000000000..2f79e4fd774 --- /dev/null +++ b/tests/test_client_middleware.py @@ -0,0 +1,1117 @@ +"""Tests for client middleware.""" + +import json +import socket +from typing import Dict, List, NoReturn, Optional, Union + +import pytest + +from aiohttp import ( + ClientError, + ClientHandlerType, + ClientRequest, + ClientResponse, + ClientSession, + ClientTimeout, + TCPConnector, + web, +) +from aiohttp.abc import ResolveResult +from aiohttp.client_middlewares import build_client_middlewares +from aiohttp.client_proto import ResponseHandler +from aiohttp.pytest_plugin import AiohttpServer +from aiohttp.resolver import ThreadedResolver +from aiohttp.tracing import Trace + + +class BlockedByMiddleware(ClientError): + """Custom exception for when middleware blocks a request.""" + + +async def test_client_middleware_called(aiohttp_server: AiohttpServer) -> None: + """Test that client middleware is called.""" + middleware_called = False + request_count = 0 + + async def handler(request: web.Request) -> web.Response: + nonlocal request_count + request_count += 1 + return web.Response(text=f"OK {request_count}") + + async def test_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + nonlocal middleware_called + middleware_called = True + response = await handler(request) + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(test_middleware,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "OK 1" + + assert middleware_called is True + assert request_count == 1 + + +async def test_client_middleware_retry(aiohttp_server: AiohttpServer) -> None: + """Test that middleware can trigger retries.""" + request_count = 0 + + async def handler(request: web.Request) -> web.Response: + nonlocal request_count + request_count += 1 + if request_count == 1: + return web.Response(status=503) + return web.Response(text=f"OK {request_count}") + + async def retry_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + retry_count = 0 + while True: + response = await handler(request) + if response.status == 503 and retry_count < 1: + retry_count += 1 + continue + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(retry_middleware,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "OK 2" + + assert request_count == 2 + + +async def test_client_middleware_per_request(aiohttp_server: AiohttpServer) -> None: + """Test that middleware can be specified per request.""" + session_middleware_called = False + request_middleware_called = False + + async def handler(request: web.Request) -> web.Response: + return web.Response(text="OK") + + async def session_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + nonlocal session_middleware_called + session_middleware_called = True + response = await handler(request) + return response + + async def request_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + nonlocal request_middleware_called + request_middleware_called = True + response = await handler(request) + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + # Request with session middleware + async with ClientSession(middlewares=(session_middleware,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + + assert session_middleware_called is True + assert request_middleware_called is False + + # Reset flags + session_middleware_called = False + + # Request with override middleware + async with ClientSession(middlewares=(session_middleware,)) as session: + async with session.get( + server.make_url("/"), middlewares=(request_middleware,) + ) as resp: + assert resp.status == 200 + + assert session_middleware_called is False + assert request_middleware_called is True + + +async def test_multiple_client_middlewares(aiohttp_server: AiohttpServer) -> None: + """Test that multiple middlewares are executed in order.""" + calls: list[str] = [] + + async def handler(request: web.Request) -> web.Response: + return web.Response(text="OK") + + async def middleware1( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + calls.append("before1") + response = await handler(request) + calls.append("after1") + return response + + async def middleware2( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + calls.append("before2") + response = await handler(request) + calls.append("after2") + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(middleware1, middleware2)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + + # Middlewares are applied in reverse order (like server middlewares) + # So middleware1 wraps middleware2 + assert calls == ["before1", "before2", "after2", "after1"] + + +async def test_client_middleware_auth_example(aiohttp_server: AiohttpServer) -> None: + """Test an authentication middleware example.""" + + async def handler(request: web.Request) -> web.Response: + auth_header = request.headers.get("Authorization") + if auth_header == "Bearer valid-token": + return web.Response(text="Authenticated") + return web.Response(status=401, text="Unauthorized") + + async def auth_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + # Add authentication header before request + request.headers["Authorization"] = "Bearer valid-token" + response = await handler(request) + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + # Without middleware - should fail + async with ClientSession() as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 401 + + # With middleware - should succeed + async with ClientSession(middlewares=(auth_middleware,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Authenticated" + + +async def test_client_middleware_challenge_auth(aiohttp_server: AiohttpServer) -> None: + """Test authentication middleware with challenge/response pattern like digest auth.""" + request_count = 0 + challenge_token = "challenge-123" + + async def handler(request: web.Request) -> web.Response: + nonlocal request_count + request_count += 1 + + auth_header = request.headers.get("Authorization") + + # First request - no auth header, return challenge + if request_count == 1 and not auth_header: + return web.Response( + status=401, + headers={ + "WWW-Authenticate": f'Custom realm="test", nonce="{challenge_token}"' + }, + ) + + # Subsequent requests - check for correct auth with challenge + if auth_header == f'Custom response="{challenge_token}-secret"': + return web.Response(text="Authenticated") + + assert False, "Should not reach here - invalid auth scenario" + + async def challenge_auth_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + challenge_data: Dict[str, Union[bool, str, None]] = { + "nonce": None, + "attempted": False, + } + + while True: + # If we have challenge data from previous attempt, add auth header + if challenge_data["nonce"] and challenge_data["attempted"]: + request.headers["Authorization"] = ( + f'Custom response="{challenge_data["nonce"]}-secret"' + ) + + response = await handler(request) + + # If we get a 401 with challenge, store it and retry + if response.status == 401 and not challenge_data["attempted"]: + www_auth = response.headers.get("WWW-Authenticate") + if www_auth and "nonce=" in www_auth: # pragma: no branch + # Extract nonce from authentication header + nonce_start = www_auth.find('nonce="') + 7 + nonce_end = www_auth.find('"', nonce_start) + challenge_data["nonce"] = www_auth[nonce_start:nonce_end] + challenge_data["attempted"] = True + continue + + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(challenge_auth_middleware,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Authenticated" + + # Should have made 2 requests: initial and retry with auth + assert request_count == 2 + + +async def test_client_middleware_multi_step_auth(aiohttp_server: AiohttpServer) -> None: + """Test middleware with multi-step authentication flow.""" + auth_state: dict[str, int] = {} + middleware_state: Dict[str, Optional[Union[int, str]]] = { + "step": 0, + "session": None, + "challenge": None, + } + + async def handler(request: web.Request) -> web.Response: + client_id = request.headers.get("X-Client-ID", "unknown") + auth_header = request.headers.get("Authorization") + step = auth_state.get(client_id, 0) + + # Step 0: No auth, request client ID + if step == 0 and not auth_header: + auth_state[client_id] = 1 + return web.Response( + status=401, headers={"X-Auth-Step": "1", "X-Session": "session-123"} + ) + + # Step 1: Has session, request credentials + if step == 1 and auth_header == "Bearer session-123": + auth_state[client_id] = 2 + return web.Response( + status=401, headers={"X-Auth-Step": "2", "X-Challenge": "challenge-456"} + ) + + # Step 2: Has challenge response, authenticate + if step == 2 and auth_header == "Bearer challenge-456-response": + return web.Response(text="Authenticated") + + assert False, "Should not reach here - invalid multi-step auth flow" + + async def multi_step_auth_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + request.headers["X-Client-ID"] = "test-client" + + while True: + # Apply auth based on current state + if middleware_state["step"] == 1 and middleware_state["session"]: + request.headers["Authorization"] = ( + f"Bearer {middleware_state['session']}" + ) + elif middleware_state["step"] == 2 and middleware_state["challenge"]: + request.headers["Authorization"] = ( + f"Bearer {middleware_state['challenge']}-response" + ) + + response = await handler(request) + + # Handle multi-step auth flow + if response.status == 401: + auth_step = response.headers.get("X-Auth-Step") + + if auth_step == "1": + # First step: store session token + middleware_state["session"] = response.headers.get("X-Session") + middleware_state["step"] = 1 + continue + + elif auth_step == "2": # pragma: no branch + # Second step: store challenge + middleware_state["challenge"] = response.headers.get("X-Challenge") + middleware_state["step"] = 2 + continue + + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(multi_step_auth_middleware,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Authenticated" + + +async def test_client_middleware_conditional_retry( + aiohttp_server: AiohttpServer, +) -> None: + """Test middleware with conditional retry based on response content.""" + request_count = 0 + token_state: Dict[str, Union[str, bool]] = { + "token": "old-token", + "refreshed": False, + } + + async def handler(request: web.Request) -> web.Response: + nonlocal request_count + request_count += 1 + + auth_token = request.headers.get("X-Auth-Token") + + if request_count == 1: + # First request returns expired token error + return web.json_response( + {"error": "token_expired", "refresh_required": True}, status=401 + ) + + if auth_token == "refreshed-token": + return web.json_response({"data": "success"}) + + assert False, "Should not reach here - invalid token refresh flow" + + async def token_refresh_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + while True: + # Add token to request + request.headers["X-Auth-Token"] = str(token_state["token"]) + + response = await handler(request) + + # Check if token needs refresh + if response.status == 401 and not token_state["refreshed"]: + data = await response.json() + if data.get("error") == "token_expired" and data.get( + "refresh_required" + ): # pragma: no branch + # Simulate token refresh + token_state["token"] = "refreshed-token" + token_state["refreshed"] = True + continue + + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + async with ClientSession(middlewares=(token_refresh_middleware,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + data = await resp.json() + assert data == {"data": "success"} + + assert request_count == 2 # Initial request + retry after refresh + + +async def test_build_client_middlewares_empty() -> None: + """Test build_client_middlewares with empty middlewares.""" + + async def handler(request: ClientRequest) -> NoReturn: + """Dummy handler.""" + assert False + + # Test empty case + result = build_client_middlewares(handler, ()) + assert result is handler # Should return handler unchanged + + +async def test_client_middleware_class_based_auth( + aiohttp_server: AiohttpServer, +) -> None: + """Test middleware using class-based pattern with instance state.""" + + class TokenAuthMiddleware: + """Middleware that handles token-based authentication.""" + + def __init__(self, token: str) -> None: + self.token = token + self.request_count = 0 + + async def __call__( + self, request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + self.request_count += 1 + request.headers["Authorization"] = f"Bearer {self.token}" + return await handler(request) + + async def handler(request: web.Request) -> web.Response: + auth_header = request.headers.get("Authorization") + if auth_header == "Bearer test-token": + return web.Response(text="Authenticated") + assert False, "Should not reach here - class auth should always have token" + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + # Create middleware instance + auth_middleware = TokenAuthMiddleware("test-token") + + async with ClientSession(middlewares=(auth_middleware,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Authenticated" + + # Verify the middleware was called + assert auth_middleware.request_count == 1 + + +async def test_client_middleware_stateful_retry(aiohttp_server: AiohttpServer) -> None: + """Test retry middleware using class with state management.""" + + class RetryMiddleware: + """Middleware that retries failed requests with backoff.""" + + def __init__(self, max_retries: int = 3) -> None: + self.max_retries = max_retries + self.retry_counts: Dict[int, int] = {} # Track retries per request + + async def __call__( + self, request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + retry_count = 0 + + while True: + response = await handler(request) + + if response.status >= 500 and retry_count < self.max_retries: + retry_count += 1 + continue + + return response + + request_count = 0 + + async def handler(request: web.Request) -> web.Response: + nonlocal request_count + request_count += 1 + + if request_count < 3: + return web.Response(status=503) + return web.Response(text="Success") + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + retry_middleware = RetryMiddleware(max_retries=2) + + async with ClientSession(middlewares=(retry_middleware,)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Success" + + assert request_count == 3 # Initial + 2 retries + + +async def test_client_middleware_multiple_instances( + aiohttp_server: AiohttpServer, +) -> None: + """Test using multiple instances of the same middleware class.""" + + class HeaderMiddleware: + """Middleware that adds a header with instance-specific value.""" + + def __init__(self, header_name: str, header_value: str) -> None: + self.header_name = header_name + self.header_value = header_value + self.applied = False + + async def __call__( + self, request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + self.applied = True + request.headers[self.header_name] = self.header_value + return await handler(request) + + headers_received = {} + + async def handler(request: web.Request) -> web.Response: + headers_received.update(dict(request.headers)) + return web.Response(text="OK") + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + # Create two instances with different headers + middleware1 = HeaderMiddleware("X-Custom-1", "value1") + middleware2 = HeaderMiddleware("X-Custom-2", "value2") + + async with ClientSession(middlewares=(middleware1, middleware2)) as session: + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + + # Both middlewares should have been applied + assert middleware1.applied is True + assert middleware2.applied is True + assert headers_received.get("X-Custom-1") == "value1" + assert headers_received.get("X-Custom-2") == "value2" + + +async def test_client_middleware_disable_with_empty_tuple( + aiohttp_server: AiohttpServer, +) -> None: + """Test that passing middlewares=() to a request disables session-level middlewares.""" + session_middleware_called = False + request_middleware_called = False + + async def handler(request: web.Request) -> web.Response: + auth_header = request.headers.get("Authorization") + if auth_header: + return web.Response(text=f"Auth: {auth_header}") + return web.Response(text="No auth") + + async def session_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + nonlocal session_middleware_called + session_middleware_called = True + request.headers["Authorization"] = "Bearer session-token" + response = await handler(request) + return response + + async def request_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + nonlocal request_middleware_called + request_middleware_called = True + request.headers["Authorization"] = "Bearer request-token" + response = await handler(request) + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + # Create session with middleware + async with ClientSession(middlewares=(session_middleware,)) as session: + # First request uses session middleware + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Auth: Bearer session-token" + assert session_middleware_called is True + assert request_middleware_called is False + + # Reset flags + session_middleware_called = False + request_middleware_called = False + + # Second request explicitly disables middlewares + async with session.get(server.make_url("/"), middlewares=()) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "No auth" + assert session_middleware_called is False + assert request_middleware_called is False + + # Reset flags + session_middleware_called = False + request_middleware_called = False + + # Third request uses request-specific middleware + async with session.get( + server.make_url("/"), middlewares=(request_middleware,) + ) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Auth: Bearer request-token" + assert session_middleware_called is False + assert request_middleware_called is True + + +@pytest.mark.parametrize( + "exception_class,match_text", + [ + (ValueError, "Middleware error"), + (ClientError, "Client error from middleware"), + (OSError, "OS error from middleware"), + ], +) +async def test_client_middleware_exception_closes_connection( + aiohttp_server: AiohttpServer, + exception_class: type[Exception], + match_text: str, +) -> None: + """Test that connections are closed when middleware raises an exception.""" + + async def handler(request: web.Request) -> NoReturn: + assert False, "Handler should not be reached" + + async def failing_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> NoReturn: + # Raise exception before the handler is called + raise exception_class(match_text) + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + # Create custom connector + connector = TCPConnector() + + async with ClientSession( + connector=connector, middlewares=(failing_middleware,) + ) as session: + # Make request that should fail in middleware + with pytest.raises(exception_class, match=match_text): + await session.get(server.make_url("/")) + + # Check that the connector has no active connections + # If connections were properly closed, _conns should be empty + assert len(connector._conns) == 0 + + await connector.close() + + +async def test_client_middleware_blocks_connection_before_established( + aiohttp_server: AiohttpServer, +) -> None: + """Test that middleware can block connections before they are established.""" + blocked_hosts = {"blocked.example.com", "evil.com"} + connection_attempts: List[str] = [] + + async def handler(request: web.Request) -> web.Response: + return web.Response(text="Reached") + + async def blocking_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + # Record the connection attempt + connection_attempts.append(str(request.url)) + + # Block requests to certain hosts + if request.url.host in blocked_hosts: + raise BlockedByMiddleware(f"Connection to {request.url.host} is blocked") + + # Allow the request to proceed + return await handler(request) + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + connector = TCPConnector() + async with ClientSession( + connector=connector, middlewares=(blocking_middleware,) + ) as session: + # Test allowed request + allowed_url = server.make_url("/") + async with session.get(allowed_url) as resp: + assert resp.status == 200 + assert await resp.text() == "Reached" + + # Test blocked request + with pytest.raises(BlockedByMiddleware) as exc_info: + # Use a fake URL that would fail DNS if connection was attempted + await session.get("https://blocked.example.com/") + + assert "Connection to blocked.example.com is blocked" in str(exc_info.value) + + # Test another blocked host + with pytest.raises(BlockedByMiddleware) as exc_info: + await session.get("https://evil.com/path") + + assert "Connection to evil.com is blocked" in str(exc_info.value) + + # Verify that connections were attempted in the correct order + assert len(connection_attempts) == 3 + assert allowed_url.host and allowed_url.host in connection_attempts[0] + assert "blocked.example.com" in connection_attempts[1] + assert "evil.com" in connection_attempts[2] + + # Check that no connections were leaked + assert len(connector._conns) == 0 + + +async def test_client_middleware_blocks_connection_without_dns_lookup( + aiohttp_server: AiohttpServer, +) -> None: + """Test that middleware prevents DNS lookups for blocked hosts.""" + blocked_hosts = {"blocked.domain.tld"} + dns_lookups_made: List[str] = [] + + # Create a simple server for the allowed request + async def handler(request: web.Request) -> web.Response: + return web.Response(text="OK") + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + class TrackingResolver(ThreadedResolver): + async def resolve( + self, + hostname: str, + port: int = 0, + family: socket.AddressFamily = socket.AF_INET, + ) -> List[ResolveResult]: + dns_lookups_made.append(hostname) + return await super().resolve(hostname, port, family) + + async def blocking_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + # Block requests to certain hosts before DNS lookup + if request.url.host in blocked_hosts: + raise BlockedByMiddleware(f"Blocked by policy: {request.url.host}") + + return await handler(request) + + resolver = TrackingResolver() + connector = TCPConnector(resolver=resolver) + async with ClientSession( + connector=connector, middlewares=(blocking_middleware,) + ) as session: + # Test blocked request to non-existent domain + with pytest.raises(BlockedByMiddleware) as exc_info: + await session.get("https://blocked.domain.tld/") + + assert "Blocked by policy: blocked.domain.tld" in str(exc_info.value) + + # Verify that no DNS lookup was made for the blocked domain + assert "blocked.domain.tld" not in dns_lookups_made + + # Test allowed request to existing server - this should trigger DNS lookup + async with session.get(f"http://localhost:{server.port}") as resp: + assert resp.status == 200 + + # Verify that DNS lookup was made for the allowed request + # The server might use a hostname that requires DNS resolution + assert len(dns_lookups_made) > 0 + + # Make sure blocked domain is still not in DNS lookups + assert "blocked.domain.tld" not in dns_lookups_made + + # Clean up + await connector.close() + + +async def test_client_middleware_retry_reuses_connection( + aiohttp_server: AiohttpServer, +) -> None: + """Test that connections are reused when middleware performs retries.""" + + async def handler(request: web.Request) -> web.Response: + return web.Response(text="OK") + + class TrackingConnector(TCPConnector): + """Connector that tracks connection attempts.""" + + connection_attempts = 0 + + async def _create_connection( + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + ) -> ResponseHandler: + self.connection_attempts += 1 + return await super()._create_connection(req, traces, timeout) + + class RetryOnceMiddleware: + """Middleware that retries exactly once.""" + + def __init__(self) -> None: + self.attempt_count = 0 + + async def __call__( + self, request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + retry_count = 0 + while True: + self.attempt_count += 1 + response = await handler(request) + if retry_count == 0: + retry_count += 1 + response.release() # Release the response to enable connection reuse + continue + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + connector = TrackingConnector() + middleware = RetryOnceMiddleware() + + async with ClientSession(connector=connector, middlewares=(middleware,)) as session: + # Make initial request + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "OK" + + # Should have made 2 request attempts (initial + 1 retry) + assert middleware.attempt_count == 2 + # Should have created only 1 connection (reused on retry) + assert connector.connection_attempts == 1 + + await connector.close() + + +async def test_middleware_uses_session_avoids_recursion_with_path_check( + aiohttp_server: AiohttpServer, +) -> None: + """Test that middleware can avoid infinite recursion using a path check.""" + log_collector: List[Dict[str, str]] = [] + + async def log_api_handler(request: web.Request) -> web.Response: + """Handle log API requests.""" + data: Dict[str, str] = await request.json() + log_collector.append(data) + return web.Response(text="OK") + + async def main_handler(request: web.Request) -> web.Response: + """Handle main server requests.""" + return web.Response(text=f"Hello from {request.path}") + + # Create log API server + log_app = web.Application() + log_app.router.add_post("/log", log_api_handler) + log_server = await aiohttp_server(log_app) + + # Create main server + main_app = web.Application() + main_app.router.add_get("/{path:.*}", main_handler) + main_server = await aiohttp_server(main_app) + + async def log_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + """Log requests to external API, avoiding recursion with path check.""" + # Avoid infinite recursion by not logging requests to the /log endpoint + if request.url.path != "/log": + # Use the session from the request to make the logging call + async with request.session.post( + f"http://localhost:{log_server.port}/log", + json={"method": str(request.method), "url": str(request.url)}, + ) as resp: + assert resp.status == 200 + + return await handler(request) + + # Create session with the middleware + async with ClientSession(middlewares=(log_middleware,)) as session: + # Make request to main server - should be logged + async with session.get(main_server.make_url("/test")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Hello from /test" + + # Make direct request to log API - should NOT be logged (avoid recursion) + async with session.post( + log_server.make_url("/log"), + json={"method": "DIRECT_POST", "url": "manual_test_entry"}, + ) as resp: + assert resp.status == 200 + + # Check logs + # The first request should be logged + # The second request (to /log) should also be logged but not the middleware's own log request + assert len(log_collector) == 2 + assert log_collector[0]["method"] == "GET" + assert log_collector[0]["url"] == str(main_server.make_url("/test")) + assert log_collector[1]["method"] == "DIRECT_POST" + assert log_collector[1]["url"] == "manual_test_entry" + + +async def test_middleware_uses_session_avoids_recursion_with_disabled_middleware( + aiohttp_server: AiohttpServer, +) -> None: + """Test that middleware can avoid infinite recursion by disabling middleware.""" + log_collector: List[Dict[str, str]] = [] + request_count = 0 + + async def log_api_handler(request: web.Request) -> web.Response: + """Handle log API requests.""" + nonlocal request_count + request_count += 1 + data: Dict[str, str] = await request.json() + log_collector.append(data) + return web.Response(text="OK") + + async def main_handler(request: web.Request) -> web.Response: + """Handle main server requests.""" + return web.Response(text=f"Hello from {request.path}") + + # Create log API server + log_app = web.Application() + log_app.router.add_post("/log", log_api_handler) + log_server = await aiohttp_server(log_app) + + # Create main server + main_app = web.Application() + main_app.router.add_get("/{path:.*}", main_handler) + main_server = await aiohttp_server(main_app) + + async def log_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + """Log all requests using session with disabled middleware.""" + # Use the session from the request to make the logging call + # Disable middleware to avoid infinite recursion + async with request.session.post( + f"http://localhost:{log_server.port}/log", + json={"method": str(request.method), "url": str(request.url)}, + middlewares=(), # This prevents infinite recursion + ) as resp: + assert resp.status == 200 + + return await handler(request) + + # Create session with the middleware + async with ClientSession(middlewares=(log_middleware,)) as session: + # Make request to main server - should be logged + async with session.get(main_server.make_url("/test")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Hello from /test" + + # Make another request - should also be logged + async with session.get(main_server.make_url("/another")) as resp: + assert resp.status == 200 + + # Check logs - both requests should be logged + assert len(log_collector) == 2 + assert log_collector[0]["method"] == "GET" + assert log_collector[0]["url"] == str(main_server.make_url("/test")) + assert log_collector[1]["method"] == "GET" + assert log_collector[1]["url"] == str(main_server.make_url("/another")) + + # Ensure that log requests were made without the middleware + # (request_count equals number of logged requests, not infinite) + assert request_count == 2 + + +async def test_middleware_can_check_request_body( + aiohttp_server: AiohttpServer, +) -> None: + """Test that middleware can check request body.""" + received_bodies: List[str] = [] + received_headers: List[Dict[str, str]] = [] + + async def handler(request: web.Request) -> web.Response: + """Server handler that receives requests.""" + body = await request.text() + received_bodies.append(body) + received_headers.append(dict(request.headers)) + return web.Response(text="OK") + + app = web.Application() + app.router.add_post("/api", handler) + app.router.add_get("/api", handler) # Add GET handler too + server = await aiohttp_server(app) + + class CustomAuth: + """Middleware that follows the GitHub discussion pattern for authentication.""" + + def __init__(self, secretkey: str) -> None: + self.secretkey = secretkey + + def get_hash(self, request: ClientRequest) -> str: + if request.body: + data = request.body.decode("utf-8") + else: + data = "{}" + + # Simulate authentication hash without using real crypto + signature = f"SIGNATURE-{self.secretkey}-{len(data)}-{data[:10]}" + return signature + + async def __call__( + self, request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + request.headers["CUSTOM-AUTH"] = self.get_hash(request) + return await handler(request) + + middleware = CustomAuth("test-secret-key") + + async with ClientSession(middlewares=(middleware,)) as session: + # Test 1: Send JSON data with user/action + data1 = {"user": "alice", "action": "login"} + json_str1 = json.dumps(data1) + async with session.post( + server.make_url("/api"), + data=json_str1, + headers={"Content-Type": "application/json"}, + ) as resp: + assert resp.status == 200 + + # Test 2: Send JSON data with different fields + data2 = {"user": "bob", "value": 42} + json_str2 = json.dumps(data2) + async with session.post( + server.make_url("/api"), + data=json_str2, + headers={"Content-Type": "application/json"}, + ) as resp: + assert resp.status == 200 + + # Test 3: Send GET request with no body + async with session.get(server.make_url("/api")) as resp: + assert resp.status == 200 # GET with empty body still should validate + + # Test 4: Send plain text (non-JSON) + text_data = "plain text body" + async with session.post( + server.make_url("/api"), + data=text_data, + headers={"Content-Type": "text/plain"}, + ) as resp: + assert resp.status == 200 + + # Verify server received the correct headers with authentication + headers1 = received_headers[0] + assert ( + headers1["CUSTOM-AUTH"] + == f"SIGNATURE-test-secret-key-{len(json_str1)}-{json_str1[:10]}" + ) + + headers2 = received_headers[1] + assert ( + headers2["CUSTOM-AUTH"] + == f"SIGNATURE-test-secret-key-{len(json_str2)}-{json_str2[:10]}" + ) + + headers3 = received_headers[2] + # GET request with no body should have empty JSON body + assert headers3["CUSTOM-AUTH"] == "SIGNATURE-test-secret-key-2-{}" + + headers4 = received_headers[3] + assert ( + headers4["CUSTOM-AUTH"] + == f"SIGNATURE-test-secret-key-{len(text_data)}-{text_data[:10]}" + ) + + # Verify all responses were successful + assert received_bodies[0] == json_str1 + assert received_bodies[1] == json_str2 + assert received_bodies[2] == "" # GET request has no body + assert received_bodies[3] == text_data diff --git a/tests/test_web_response.py b/tests/test_web_response.py index b6d9b3e2bd3..046bb89ac53 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -998,6 +998,13 @@ def test_ctor_content_type_with_extra() -> None: assert resp.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8" +def test_invalid_content_type_parses_to_text_plain() -> None: + resp = web.Response(text="test test", content_type="jpeg") + + assert resp.content_type == "text/plain" + assert resp.headers["content-type"] == "jpeg; charset=utf-8" + + def test_ctor_both_content_type_param_and_header_with_text() -> None: with pytest.raises(ValueError): web.Response(