From 52c104a8aaa2fab82c8d7398d87479f899c3b07c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 23 May 2025 11:10:51 +0000 Subject: [PATCH 1/7] Bump pydantic from 2.11.4 to 2.11.5 (#10964) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [pydantic](https://github.com/pydantic/pydantic) from 2.11.4 to 2.11.5.
Release notes

Sourced from pydantic's releases.

v2.11.5 2025-05-22

What's Changed

Fixes

Full Changelog: https://github.com/pydantic/pydantic/compare/v2.11.4...v2.11.5

Changelog

Sourced from pydantic's changelog.

v2.11.5 (2025-05-22)

GitHub release

What's Changed

Fixes

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=pydantic&package-manager=pip&previous-version=2.11.4&new-version=2.11.5)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/constraints.txt | 4 +++- requirements/dev.txt | 4 +++- requirements/lint.txt | 2 +- requirements/test.txt | 4 +++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/requirements/constraints.txt b/requirements/constraints.txt index 9d1ddf20173..fd1b65a1beb 100644 --- a/requirements/constraints.txt +++ b/requirements/constraints.txt @@ -135,6 +135,8 @@ packaging==25.0 # wheel pip-tools==7.4.1 # via -r requirements/dev.in +pkgconfig==1.5.5 + # via -r requirements/test.in platformdirs==4.3.8 # via virtualenv pluggy==1.6.0 @@ -153,7 +155,7 @@ pycares==4.8.0 # via aiodns pycparser==2.22 # via cffi -pydantic==2.11.4 +pydantic==2.11.5 # via python-on-whales pydantic-core==2.33.2 # via pydantic diff --git a/requirements/dev.txt b/requirements/dev.txt index 181d0cf8f93..d63205d4af5 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -132,6 +132,8 @@ packaging==25.0 # wheel pip-tools==7.4.1 # via -r requirements/dev.in +pkgconfig==1.5.5 + # via -r requirements/test.in platformdirs==4.3.8 # via virtualenv pluggy==1.6.0 @@ -150,7 +152,7 @@ pycares==4.8.0 # via aiodns pycparser==2.22 # via cffi -pydantic==2.11.4 +pydantic==2.11.5 # via python-on-whales pydantic-core==2.33.2 # via pydantic diff --git a/requirements/lint.txt b/requirements/lint.txt index be29f603556..19a589f58dd 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -65,7 +65,7 @@ pycares==4.8.0 # via aiodns pycparser==2.22 # via cffi -pydantic==2.11.4 +pydantic==2.11.5 # via python-on-whales pydantic-core==2.33.2 # via pydantic diff --git a/requirements/test.txt b/requirements/test.txt index 24e677a09e5..f8eb4d6b193 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -69,6 +69,8 @@ packaging==25.0 # via # gunicorn # pytest +pkgconfig==1.5.5 + # via -r requirements/test.in pluggy==1.6.0 # via pytest propcache==0.3.1 @@ -81,7 +83,7 @@ pycares==4.8.0 # via aiodns pycparser==2.22 # via cffi -pydantic==2.11.4 +pydantic==2.11.5 # via python-on-whales pydantic-core==2.33.2 # via pydantic From 84decfe5fce79a4ca29fdc5c7317fc85f3d908f5 Mon Sep 17 00:00:00 2001 From: Cycloctane Date: Fri, 23 May 2025 21:36:26 +0800 Subject: [PATCH 2/7] add example of setting network interface in custom socket creation (#10962) Co-authored-by: J. Nick Koston --- CHANGES/10962.feature.rst | 1 + docs/client_advanced.rst | 13 +++++++++++++ 2 files changed, 14 insertions(+) create mode 120000 CHANGES/10962.feature.rst diff --git a/CHANGES/10962.feature.rst b/CHANGES/10962.feature.rst new file mode 120000 index 00000000000..7c4f9a7b83b --- /dev/null +++ b/CHANGES/10962.feature.rst @@ -0,0 +1 @@ +10520.feature.rst \ No newline at end of file diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 10105249a6a..e1f556062a7 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -721,6 +721,19 @@ make all sockets respect 9*7200 = 18 hours:: return sock conn = aiohttp.TCPConnector(socket_factory=socket_factory) +``socket_factory`` may also be used for binding to the specific network +interface on supported platforms:: + + def socket_factory(addr_info): + family, type_, proto, _, _ = addr_info + sock = socket.socket(family=family, type=type_, proto=proto) + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_BINDTODEVICE, b'eth0' + ) + return sock + + conn = aiohttp.TCPConnector(socket_factory=socket_factory) + Named pipes in Windows ^^^^^^^^^^^^^^^^^^^^^^ From 5e68276ca359860e8943b98bb4d2f40da6f4157e Mon Sep 17 00:00:00 2001 From: Cycloctane Date: Fri, 23 May 2025 21:40:18 +0800 Subject: [PATCH 3/7] fix example in socket_factory docs (#10961) Co-authored-by: J. Nick Koston --- CHANGES/10961.feature.rst | 1 + docs/client_advanced.rst | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 120000 CHANGES/10961.feature.rst diff --git a/CHANGES/10961.feature.rst b/CHANGES/10961.feature.rst new file mode 120000 index 00000000000..7c4f9a7b83b --- /dev/null +++ b/CHANGES/10961.feature.rst @@ -0,0 +1 @@ +10520.feature.rst \ No newline at end of file diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index e1f556062a7..c3d800cdd68 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -713,12 +713,13 @@ make all sockets respect 9*7200 = 18 hours:: import socket def socket_factory(addr_info): - family, type_, proto, _, _, _ = addr_info + family, type_, proto, _, _ = addr_info sock = socket.socket(family=family, type=type_, proto=proto) sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) return sock + conn = aiohttp.TCPConnector(socket_factory=socket_factory) ``socket_factory`` may also be used for binding to the specific network From 18785096b3628f682f950c612d8acb68333629ba Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 23 May 2025 10:20:56 -0500 Subject: [PATCH 4/7] Add Client Middleware Cookbook (#10945) --- CHANGES/10945.feature.rst | 1 + docs/client.rst | 1 + docs/client_advanced.rst | 2 + docs/client_middleware_cookbook.rst | 358 +++++++++++++++++++++++++++ docs/spelling_wordlist.txt | 1 + examples/basic_auth_middleware.py | 190 ++++++++++++++ examples/combined_middleware.py | 320 ++++++++++++++++++++++++ examples/logging_middleware.py | 169 +++++++++++++ examples/retry_middleware.py | 245 ++++++++++++++++++ examples/token_refresh_middleware.py | 336 +++++++++++++++++++++++++ 10 files changed, 1623 insertions(+) create mode 120000 CHANGES/10945.feature.rst create mode 100644 docs/client_middleware_cookbook.rst create mode 100644 examples/basic_auth_middleware.py create mode 100644 examples/combined_middleware.py create mode 100644 examples/logging_middleware.py create mode 100644 examples/retry_middleware.py create mode 100644 examples/token_refresh_middleware.py diff --git a/CHANGES/10945.feature.rst b/CHANGES/10945.feature.rst new file mode 120000 index 00000000000..b565aa68ee0 --- /dev/null +++ b/CHANGES/10945.feature.rst @@ -0,0 +1 @@ +9732.feature.rst \ No newline at end of file diff --git a/docs/client.rst b/docs/client.rst index 78fbeae4ded..9109c3772da 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -14,6 +14,7 @@ The page contains all information about aiohttp Client API: Quickstart Advanced Usage + Client Middleware Cookbook Reference Tracing Reference The aiohttp Request Lifecycle diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index c3d800cdd68..dcbd743d6fb 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -133,6 +133,8 @@ Client Middleware The client supports middleware to intercept requests and responses. This can be useful for authentication, logging, request/response modification, and retries. +For practical examples and common middleware patterns, see the :ref:`aiohttp-client-middleware-cookbook`. + Creating Middleware ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/client_middleware_cookbook.rst b/docs/client_middleware_cookbook.rst new file mode 100644 index 00000000000..4b8d6ddd5f8 --- /dev/null +++ b/docs/client_middleware_cookbook.rst @@ -0,0 +1,358 @@ +.. currentmodule:: aiohttp + +.. _aiohttp-client-middleware-cookbook: + +Client Middleware Cookbook +========================== + +This cookbook provides practical examples of implementing client middleware for common use cases. + +.. note:: + + All examples in this cookbook are also available as complete, runnable scripts in the + ``examples/`` directory of the aiohttp repository. Look for files named ``*_middleware.py``. + +.. _cookbook-basic-auth-middleware: + +Basic Authentication Middleware +------------------------------- + +Basic authentication is a simple authentication scheme built into the HTTP protocol. +Here's a middleware that automatically adds Basic Auth headers to all requests: + +.. code-block:: python + + import base64 + from aiohttp import ClientRequest, ClientResponse, ClientHandlerType, hdrs + + class BasicAuthMiddleware: + """Middleware that adds Basic Authentication to all requests.""" + + def __init__(self, username: str, password: str) -> None: + self.username = username + self.password = password + self._auth_header = self._encode_credentials() + + def _encode_credentials(self) -> str: + """Encode username and password to base64.""" + credentials = f"{self.username}:{self.password}" + encoded = base64.b64encode(credentials.encode()).decode() + return f"Basic {encoded}" + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + """Add Basic Auth header to the request.""" + # Only add auth if not already present + if hdrs.AUTHORIZATION not in request.headers: + request.headers[hdrs.AUTHORIZATION] = self._auth_header + + # Proceed with the request + return await handler(request) + +Usage example: + +.. code-block:: python + + import aiohttp + import asyncio + import logging + + _LOGGER = logging.getLogger(__name__) + + async def main(): + # Create middleware instance + auth_middleware = BasicAuthMiddleware("user", "pass") + + # Use middleware in session + async with aiohttp.ClientSession(middlewares=(auth_middleware,)) as session: + async with session.get("https://httpbin.org/basic-auth/user/pass") as resp: + _LOGGER.debug("Status: %s", resp.status) + data = await resp.json() + _LOGGER.debug("Response: %s", data) + + asyncio.run(main()) + +.. _cookbook-retry-middleware: + +Simple Retry Middleware +----------------------- + +A retry middleware that automatically retries failed requests with exponential backoff: + +.. code-block:: python + + import asyncio + import logging + from http import HTTPStatus + from typing import Union, Set + from aiohttp import ClientRequest, ClientResponse, ClientHandlerType + + _LOGGER = logging.getLogger(__name__) + + DEFAULT_RETRY_STATUSES = { + HTTPStatus.TOO_MANY_REQUESTS, + HTTPStatus.INTERNAL_SERVER_ERROR, + HTTPStatus.BAD_GATEWAY, + HTTPStatus.SERVICE_UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT + } + + class RetryMiddleware: + """Middleware that retries failed requests with exponential backoff.""" + + def __init__( + self, + max_retries: int = 3, + retry_statuses: Union[Set[int], None] = None, + initial_delay: float = 1.0, + backoff_factor: float = 2.0 + ) -> None: + self.max_retries = max_retries + self.retry_statuses = retry_statuses or DEFAULT_RETRY_STATUSES + self.initial_delay = initial_delay + self.backoff_factor = backoff_factor + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + """Execute request with retry logic.""" + last_response = None + delay = self.initial_delay + + for attempt in range(self.max_retries + 1): + if attempt > 0: + _LOGGER.info( + "Retrying request to %s (attempt %s/%s)", + request.url, + attempt + 1, + self.max_retries + 1 + ) + + # Execute the request + response = await handler(request) + last_response = response + + # Check if we should retry + if response.status not in self.retry_statuses: + return response + + # Don't retry if we've exhausted attempts + if attempt >= self.max_retries: + _LOGGER.warning( + "Max retries (%s) exceeded for %s", + self.max_retries, + request.url + ) + return response + + # Wait before retrying + _LOGGER.debug("Waiting %ss before retry...", delay) + await asyncio.sleep(delay) + delay *= self.backoff_factor + + # Return the last response + return last_response + +Usage example: + +.. code-block:: python + + import aiohttp + import asyncio + import logging + from http import HTTPStatus + + _LOGGER = logging.getLogger(__name__) + + RETRY_STATUSES = { + HTTPStatus.TOO_MANY_REQUESTS, + HTTPStatus.INTERNAL_SERVER_ERROR, + HTTPStatus.BAD_GATEWAY, + HTTPStatus.SERVICE_UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT + } + + async def main(): + # Create retry middleware with custom settings + retry_middleware = RetryMiddleware( + max_retries=3, + retry_statuses=RETRY_STATUSES, + initial_delay=0.5, + backoff_factor=2.0 + ) + + async with aiohttp.ClientSession(middlewares=(retry_middleware,)) as session: + # This will automatically retry on server errors + async with session.get("https://httpbin.org/status/500") as resp: + _LOGGER.debug("Final status: %s", resp.status) + + asyncio.run(main()) + +.. _cookbook-combining-middleware: + +Combining Multiple Middleware +----------------------------- + +You can combine multiple middleware to create powerful request pipelines: + +.. code-block:: python + + import time + import logging + from aiohttp import ClientRequest, ClientResponse, ClientHandlerType + + _LOGGER = logging.getLogger(__name__) + + class LoggingMiddleware: + """Middleware that logs request timing and response status.""" + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + start_time = time.monotonic() + + # Log request + _LOGGER.debug("[REQUEST] %s %s", request.method, request.url) + + # Execute request + response = await handler(request) + + # Log response + duration = time.monotonic() - start_time + _LOGGER.debug("[RESPONSE] %s in %.2fs", response.status, duration) + + return response + + # Combine multiple middleware + async def main(): + # Middleware are applied in order: logging -> auth -> retry -> request + logging_middleware = LoggingMiddleware() + auth_middleware = BasicAuthMiddleware("user", "pass") + retry_middleware = RetryMiddleware(max_retries=2) + + async with aiohttp.ClientSession( + middlewares=(logging_middleware, auth_middleware, retry_middleware) + ) as session: + async with session.get("https://httpbin.org/basic-auth/user/pass") as resp: + text = await resp.text() + _LOGGER.debug("Response text: %s", text) + +.. _cookbook-token-refresh-middleware: + +Token Refresh Middleware +------------------------ + +A more advanced example showing JWT token refresh: + +.. code-block:: python + + import asyncio + import time + from http import HTTPStatus + from typing import Union + from aiohttp import ClientRequest, ClientResponse, ClientHandlerType, hdrs + + class TokenRefreshMiddleware: + """Middleware that handles JWT token refresh automatically.""" + + def __init__(self, token_endpoint: str, refresh_token: str) -> None: + self.token_endpoint = token_endpoint + self.refresh_token = refresh_token + self.access_token: Union[str, None] = None + self.token_expires_at: Union[float, None] = None + self._refresh_lock = asyncio.Lock() + + async def _refresh_access_token(self, session) -> str: + """Refresh the access token using the refresh token.""" + async with self._refresh_lock: + # Check if another coroutine already refreshed the token + if self.token_expires_at and time.time() < self.token_expires_at: + return self.access_token + + # Make refresh request without middleware to avoid recursion + async with session.post( + self.token_endpoint, + json={"refresh_token": self.refresh_token}, + middlewares=() # Disable middleware for this request + ) as resp: + resp.raise_for_status() + data = await resp.json() + + if "access_token" not in data: + raise ValueError("No access_token in refresh response") + + self.access_token = data["access_token"] + # Token expires in 1 hour for demo, refresh 5 min early + expires_in = data.get("expires_in", 3600) + self.token_expires_at = time.time() + expires_in - 300 + return self.access_token + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + """Add auth token to request, refreshing if needed.""" + # Skip token for refresh endpoint + if str(request.url).endswith('/token/refresh'): + return await handler(request) + + # Refresh token if needed + if not self.access_token or ( + self.token_expires_at and time.time() >= self.token_expires_at + ): + await self._refresh_access_token(request.session) + + # Add token to request + request.headers[hdrs.AUTHORIZATION] = f"Bearer {self.access_token}" + + # Execute request + response = await handler(request) + + # If we get 401, try refreshing token once + if response.status == HTTPStatus.UNAUTHORIZED: + await self._refresh_access_token(request.session) + request.headers[hdrs.AUTHORIZATION] = f"Bearer {self.access_token}" + response = await handler(request) + + return response + +Best Practices +-------------- + +1. **Keep middleware focused**: Each middleware should have a single responsibility. + +2. **Order matters**: Middleware execute in the order they're listed. Place logging first, + authentication before retry, etc. + +3. **Avoid infinite recursion**: When making HTTP requests inside middleware, either: + + - Use ``middlewares=()`` to disable middleware for internal requests + - Check the request URL/host to skip middleware for specific endpoints + - Use a separate session for internal requests + +4. **Handle errors gracefully**: Don't let middleware errors break the request flow unless + absolutely necessary. + +5. **Use bounded loops**: Always use ``for`` loops with a maximum iteration count instead + of unbounded ``while`` loops to prevent infinite retries. + +6. **Consider performance**: Each middleware adds overhead. For simple cases like adding + static headers, consider using session or request parameters instead. + +7. **Test thoroughly**: Middleware can affect all requests in subtle ways. Test edge cases + like network errors, timeouts, and concurrent requests. + +See Also +-------- + +- :ref:`aiohttp-client-middleware` - Core middleware documentation +- :ref:`aiohttp-client-advanced` - Advanced client usage +- :class:`DigestAuthMiddleware` - Built-in digest authentication middleware diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 34642ec64da..48266e67486 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -29,6 +29,7 @@ autoformatters autogenerates autogeneration awaitable +backoff backend backends backport diff --git a/examples/basic_auth_middleware.py b/examples/basic_auth_middleware.py new file mode 100644 index 00000000000..4c30f477505 --- /dev/null +++ b/examples/basic_auth_middleware.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +""" +Example of using basic authentication middleware with aiohttp client. + +This example shows how to implement a middleware that automatically adds +Basic Authentication headers to all requests. The middleware encodes the +username and password in base64 format as required by the HTTP Basic Auth +specification. + +This example includes a test server that validates basic auth credentials. +""" + +import asyncio +import base64 +import binascii +import logging + +from aiohttp import ( + ClientHandlerType, + ClientRequest, + ClientResponse, + ClientSession, + hdrs, + web, +) + +logging.basicConfig(level=logging.DEBUG) +_LOGGER = logging.getLogger(__name__) + + +class BasicAuthMiddleware: + """Middleware that adds Basic Authentication to all requests.""" + + def __init__(self, username: str, password: str) -> None: + self.username = username + self.password = password + self._auth_header = self._encode_credentials() + + def _encode_credentials(self) -> str: + """Encode username and password to base64.""" + credentials = f"{self.username}:{self.password}" + encoded = base64.b64encode(credentials.encode()).decode() + return f"Basic {encoded}" + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType, + ) -> ClientResponse: + """Add Basic Auth header to the request.""" + # Only add auth if not already present + if hdrs.AUTHORIZATION not in request.headers: + request.headers[hdrs.AUTHORIZATION] = self._auth_header + + # Proceed with the request + return await handler(request) + + +class TestServer: + """Test server for basic auth endpoints.""" + + async def handle_basic_auth(self, request: web.Request) -> web.Response: + """Handle basic auth validation.""" + # Get expected credentials from path + expected_user = request.match_info["user"] + expected_pass = request.match_info["pass"] + + # Check if Authorization header is present + auth_header = request.headers.get(hdrs.AUTHORIZATION, "") + + if not auth_header.startswith("Basic "): + return web.Response( + status=401, + text="Unauthorized", + headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, + ) + + # Decode the credentials + encoded_creds = auth_header[6:] # Remove "Basic " + try: + decoded = base64.b64decode(encoded_creds).decode() + username, password = decoded.split(":", 1) + except (ValueError, binascii.Error): + return web.Response( + status=401, + text="Invalid credentials format", + headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, + ) + + # Validate credentials + if username != expected_user or password != expected_pass: + return web.Response( + status=401, + text="Invalid username or password", + headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, + ) + + return web.json_response({"authenticated": True, "user": username}) + + async def handle_protected_resource(self, request: web.Request) -> web.Response: + """A protected resource that requires any valid auth.""" + auth_header = request.headers.get(hdrs.AUTHORIZATION, "") + + if not auth_header.startswith("Basic "): + return web.Response( + status=401, + text="Authentication required", + headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="protected"'}, + ) + + return web.json_response( + { + "message": "Access granted to protected resource", + "auth_provided": True, + } + ) + + +async def run_test_server() -> web.AppRunner: + """Run a simple test server with basic auth endpoints.""" + app = web.Application() + server = TestServer() + + app.router.add_get("/basic-auth/{user}/{pass}", server.handle_basic_auth) + app.router.add_get("/protected", server.handle_protected_resource) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8080) + await site.start() + return runner + + +async def run_tests() -> None: + """Run all basic auth middleware tests.""" + # Create middleware instance + auth_middleware = BasicAuthMiddleware("user", "pass") + + # Use middleware in session + async with ClientSession(middlewares=(auth_middleware,)) as session: + # Test 1: Correct credentials endpoint + print("=== Test 1: Correct credentials ===") + async with session.get("http://localhost:8080/basic-auth/user/pass") as resp: + _LOGGER.info("Status: %s", resp.status) + + if resp.status == 200: + data = await resp.json() + _LOGGER.info("Response: %s", data) + print("Authentication successful!") + print(f"Authenticated: {data.get('authenticated')}") + print(f"User: {data.get('user')}") + else: + print("Authentication failed!") + print(f"Status: {resp.status}") + text = await resp.text() + print(f"Response: {text}") + + # Test 2: Wrong credentials endpoint + print("\n=== Test 2: Wrong credentials endpoint ===") + async with session.get("http://localhost:8080/basic-auth/other/secret") as resp: + if resp.status == 401: + print("Authentication failed as expected (wrong credentials)") + text = await resp.text() + print(f"Response: {text}") + else: + print(f"Unexpected status: {resp.status}") + + # Test 3: Protected resource + print("\n=== Test 3: Access protected resource ===") + async with session.get("http://localhost:8080/protected") as resp: + if resp.status == 200: + data = await resp.json() + print("Successfully accessed protected resource!") + print(f"Response: {data}") + else: + print(f"Failed to access protected resource: {resp.status}") + + +async def main() -> None: + # Start test server + server = await run_test_server() + + try: + await run_tests() + finally: + await server.cleanup() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/combined_middleware.py b/examples/combined_middleware.py new file mode 100644 index 00000000000..8646a182b98 --- /dev/null +++ b/examples/combined_middleware.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +""" +Example of combining multiple middleware with aiohttp client. + +This example shows how to chain multiple middleware together to create +a powerful request pipeline. Middleware are applied in order, demonstrating +how logging, authentication, and retry logic can work together. + +The order of middleware matters: +1. Logging (outermost) - logs all attempts including retries +2. Authentication - adds auth headers before retry logic +3. Retry (innermost) - retries requests on failure +""" + +import asyncio +import base64 +import binascii +import logging +import time +from http import HTTPStatus +from typing import TYPE_CHECKING, Set, Union + +from aiohttp import ( + ClientHandlerType, + ClientRequest, + ClientResponse, + ClientSession, + hdrs, + web, +) + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +_LOGGER = logging.getLogger(__name__) + + +class LoggingMiddleware: + """Middleware that logs request timing and response status.""" + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType, + ) -> ClientResponse: + start_time = time.monotonic() + + # Log request + _LOGGER.info("[REQUEST] %s %s", request.method, request.url) + + # Execute request + response = await handler(request) + + # Log response + duration = time.monotonic() - start_time + _LOGGER.info( + "[RESPONSE] %s in %.2fs - Status: %s", + request.url.path, + duration, + response.status, + ) + + return response + + +class BasicAuthMiddleware: + """Middleware that adds Basic Authentication to all requests.""" + + def __init__(self, username: str, password: str) -> None: + self.username = username + self.password = password + self._auth_header = self._encode_credentials() + + def _encode_credentials(self) -> str: + """Encode username and password to base64.""" + credentials = f"{self.username}:{self.password}" + encoded = base64.b64encode(credentials.encode()).decode() + return f"Basic {encoded}" + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType, + ) -> ClientResponse: + """Add Basic Auth header to the request.""" + # Only add auth if not already present + if hdrs.AUTHORIZATION not in request.headers: + request.headers[hdrs.AUTHORIZATION] = self._auth_header + _LOGGER.debug("Added Basic Auth header") + + # Proceed with the request + return await handler(request) + + +DEFAULT_RETRY_STATUSES: Set[HTTPStatus] = { + HTTPStatus.TOO_MANY_REQUESTS, + HTTPStatus.INTERNAL_SERVER_ERROR, + HTTPStatus.BAD_GATEWAY, + HTTPStatus.SERVICE_UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT, +} + + +class RetryMiddleware: + """Middleware that retries failed requests with exponential backoff.""" + + def __init__( + self, + max_retries: int = 3, + retry_statuses: Union[Set[HTTPStatus], None] = None, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, + ) -> None: + self.max_retries = max_retries + self.retry_statuses = retry_statuses or DEFAULT_RETRY_STATUSES + self.initial_delay = initial_delay + self.backoff_factor = backoff_factor + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType, + ) -> ClientResponse: + """Execute request with retry logic.""" + last_response: Union[ClientResponse, None] = None + delay = self.initial_delay + + for attempt in range(self.max_retries + 1): + if attempt > 0: + _LOGGER.info( + "Retrying request (attempt %s/%s)", + attempt + 1, + self.max_retries + 1, + ) + + # Execute the request + response = await handler(request) + last_response = response + + # Check if we should retry + if response.status not in self.retry_statuses: + return response + + # Don't retry if we've exhausted attempts + if attempt >= self.max_retries: + _LOGGER.warning("Max retries exceeded") + return response + + # Wait before retrying + _LOGGER.debug("Waiting %ss before retry...", delay) + await asyncio.sleep(delay) + delay *= self.backoff_factor + + if TYPE_CHECKING: + assert last_response is not None # Always set since we loop at least once + return last_response + + +class TestServer: + """Test server with stateful endpoints for middleware testing.""" + + def __init__(self) -> None: + self.flaky_counter = 0 + self.protected_counter = 0 + + async def handle_protected(self, request: web.Request) -> web.Response: + """Protected endpoint that requires authentication and is flaky on first attempt.""" + auth_header = request.headers.get(hdrs.AUTHORIZATION, "") + + if not auth_header.startswith("Basic "): + return web.Response( + status=401, + text="Unauthorized", + headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, + ) + + # Decode the credentials + encoded_creds = auth_header[6:] # Remove "Basic " + try: + decoded = base64.b64decode(encoded_creds).decode() + username, password = decoded.split(":", 1) + except (ValueError, binascii.Error): + return web.Response( + status=401, + text="Invalid credentials format", + headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, + ) + + # Validate credentials + if username != "user" or password != "pass": + return web.Response(status=401, text="Invalid credentials") + + # Fail with 500 on first attempt to test retry + auth combination + self.protected_counter += 1 + if self.protected_counter == 1: + return web.Response( + status=500, text="Internal server error (first attempt)" + ) + + return web.json_response( + { + "message": "Access granted", + "user": username, + "resource": "protected data", + } + ) + + async def handle_flaky(self, request: web.Request) -> web.Response: + """Endpoint that fails a few times before succeeding.""" + self.flaky_counter += 1 + + # Fail the first 2 requests, succeed on the 3rd + if self.flaky_counter <= 2: + return web.Response( + status=503, + text=f"Service temporarily unavailable (attempt {self.flaky_counter})", + ) + + # Reset counter and return success + self.flaky_counter = 0 + return web.json_response( + { + "message": "Success after retries!", + "data": "Important information retrieved", + } + ) + + async def handle_always_fail(self, request: web.Request) -> web.Response: + """Endpoint that always returns an error.""" + return web.Response(status=500, text="Internal server error") + + async def handle_status(self, request: web.Request) -> web.Response: + """Return the status code specified in the path.""" + status = int(request.match_info["status"]) + return web.Response(status=status, text=f"Status: {status}") + + +async def run_test_server() -> web.AppRunner: + """Run a test server with various endpoints.""" + app = web.Application() + server = TestServer() + + app.router.add_get("/protected", server.handle_protected) + app.router.add_get("/flaky", server.handle_flaky) + app.router.add_get("/always-fail", server.handle_always_fail) + app.router.add_get("/status/{status}", server.handle_status) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8080) + await site.start() + return runner + + +async def run_tests() -> None: + """Run all the middleware tests.""" + # Create middleware instances + logging_middleware = LoggingMiddleware() + auth_middleware = BasicAuthMiddleware("user", "pass") + retry_middleware = RetryMiddleware(max_retries=2, initial_delay=0.5) + + # Combine middleware - order matters! + # Applied in order: logging -> auth -> retry -> request + async with ClientSession( + middlewares=(logging_middleware, auth_middleware, retry_middleware) + ) as session: + + print( + "=== Test 1: Protected endpoint with auth (fails once, then succeeds) ===" + ) + print("This tests retry + auth working together...") + async with session.get("http://localhost:8080/protected") as resp: + if resp.status == 200: + data = await resp.json() + print(f"Success after retry! Response: {data}") + else: + print(f"Failed with status: {resp.status}") + + print("\n=== Test 2: Flaky endpoint (fails twice, then succeeds) ===") + print("Watch the logs to see retries in action...") + async with session.get("http://localhost:8080/flaky") as resp: + if resp.status == 200: + data = await resp.json() + print(f"Success after retries! Response: {data}") + else: + text = await resp.text() + print(f"Failed with status {resp.status}: {text}") + + print("\n=== Test 3: Always failing endpoint ===") + async with session.get("http://localhost:8080/always-fail") as resp: + print(f"Final status after retries: {resp.status}") + + print("\n=== Test 4: Non-retryable status (404) ===") + async with session.get("http://localhost:8080/status/404") as resp: + print(f"Status: {resp.status} (no retries for 404)") + + # Test without middleware for comparison + print("\n=== Test 5: Request without middleware ===") + print("Making a request to protected endpoint without middleware...") + async with session.get( + "http://localhost:8080/protected", middlewares=() + ) as resp: + print(f"Status without middleware: {resp.status}") + if resp.status == 401: + print("Failed as expected - no auth header added") + + +async def main() -> None: + # Start test server + server = await run_test_server() + + try: + await run_tests() + + finally: + await server.cleanup() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/logging_middleware.py b/examples/logging_middleware.py new file mode 100644 index 00000000000..b6345953db2 --- /dev/null +++ b/examples/logging_middleware.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +""" +Example of using logging middleware with aiohttp client. + +This example shows how to implement a middleware that logs request timing +and response status. This is useful for debugging, monitoring, and +understanding the flow of HTTP requests in your application. + +This example includes a test server with various endpoints. +""" + +import asyncio +import json +import logging +import time +from typing import Any, Coroutine, List + +from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web + +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +_LOGGER = logging.getLogger(__name__) + + +class LoggingMiddleware: + """Middleware that logs request timing and response status.""" + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType, + ) -> ClientResponse: + start_time = time.monotonic() + + # Log request + _LOGGER.info("[REQUEST] %s %s", request.method, request.url) + if request.headers: + _LOGGER.debug("[REQUEST HEADERS] %s", request.headers) + + # Execute request + response = await handler(request) + + # Log response + duration = time.monotonic() - start_time + _LOGGER.info( + "[RESPONSE] %s %s - Status: %s - Duration: %.3fs", + request.method, + request.url, + response.status, + duration, + ) + _LOGGER.debug("[RESPONSE HEADERS] %s", response.headers) + + return response + + +class TestServer: + """Test server for logging middleware demo.""" + + async def handle_hello(self, request: web.Request) -> web.Response: + """Simple hello endpoint.""" + name = request.match_info.get("name", "World") + return web.json_response({"message": f"Hello, {name}!"}) + + async def handle_slow(self, request: web.Request) -> web.Response: + """Endpoint that simulates slow response.""" + delay = float(request.match_info.get("delay", 1)) + await asyncio.sleep(delay) + return web.json_response({"message": "Slow response completed", "delay": delay}) + + async def handle_error(self, request: web.Request) -> web.Response: + """Endpoint that returns an error.""" + status = int(request.match_info.get("status", 500)) + return web.Response(status=status, text=f"Error response with status {status}") + + async def handle_json_data(self, request: web.Request) -> web.Response: + """Endpoint that echoes JSON data.""" + try: + data = await request.json() + return web.json_response({"echo": data, "received_at": time.time()}) + except json.JSONDecodeError: + return web.json_response({"error": "Invalid JSON"}, status=400) + + +async def run_test_server() -> web.AppRunner: + """Run a simple test server.""" + app = web.Application() + server = TestServer() + + app.router.add_get("/hello", server.handle_hello) + app.router.add_get("/hello/{name}", server.handle_hello) + app.router.add_get("/slow/{delay}", server.handle_slow) + app.router.add_get("/error/{status}", server.handle_error) + app.router.add_post("/echo", server.handle_json_data) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8080) + await site.start() + return runner + + +async def run_tests() -> None: + """Run all the middleware tests.""" + # Create logging middleware + logging_middleware = LoggingMiddleware() + + # Use middleware in session + async with ClientSession(middlewares=(logging_middleware,)) as session: + # Test 1: Simple GET request + print("\n=== Test 1: Simple GET request ===") + async with session.get("http://localhost:8080/hello") as resp: + data = await resp.json() + print(f"Response: {data}") + + # Test 2: GET with parameter + print("\n=== Test 2: GET with parameter ===") + async with session.get("http://localhost:8080/hello/Alice") as resp: + data = await resp.json() + print(f"Response: {data}") + + # Test 3: Slow request + print("\n=== Test 3: Slow request (2 seconds) ===") + async with session.get("http://localhost:8080/slow/2") as resp: + data = await resp.json() + print(f"Response: {data}") + + # Test 4: Error response + print("\n=== Test 4: Error response ===") + async with session.get("http://localhost:8080/error/404") as resp: + text = await resp.text() + print(f"Response: {text}") + + # Test 5: POST with JSON data + print("\n=== Test 5: POST with JSON data ===") + payload = {"name": "Bob", "age": 30, "city": "New York"} + async with session.post("http://localhost:8080/echo", json=payload) as resp: + data = await resp.json() + print(f"Response: {data}") + + # Test 6: Multiple concurrent requests + print("\n=== Test 6: Multiple concurrent requests ===") + coros: List[Coroutine[Any, Any, ClientResponse]] = [] + for i in range(3): + coro = session.get(f"http://localhost:8080/hello/User{i}") + coros.append(coro) + + responses = await asyncio.gather(*coros) + for i, resp in enumerate(responses): + async with resp: + data = await resp.json() + print(f"Concurrent request {i}: {data}") + + +async def main() -> None: + # Start test server + server = await run_test_server() + + try: + await run_tests() + + finally: + # Cleanup server + await server.cleanup() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/retry_middleware.py b/examples/retry_middleware.py new file mode 100644 index 00000000000..c8fa829455a --- /dev/null +++ b/examples/retry_middleware.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +""" +Example of using retry middleware with aiohttp client. + +This example shows how to implement a middleware that automatically retries +failed requests with exponential backoff. The middleware can be configured +with custom retry statuses, maximum retries, and backoff parameters. + +This example includes a test server that simulates various HTTP responses +and can return different status codes on sequential requests. +""" + +import asyncio +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Dict, List, Set, Union + +from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web + +logging.basicConfig(level=logging.INFO) +_LOGGER = logging.getLogger(__name__) + +DEFAULT_RETRY_STATUSES: Set[HTTPStatus] = { + HTTPStatus.TOO_MANY_REQUESTS, + HTTPStatus.INTERNAL_SERVER_ERROR, + HTTPStatus.BAD_GATEWAY, + HTTPStatus.SERVICE_UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT, +} + + +class RetryMiddleware: + """Middleware that retries failed requests with exponential backoff.""" + + def __init__( + self, + max_retries: int = 3, + retry_statuses: Union[Set[HTTPStatus], None] = None, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, + ) -> None: + self.max_retries = max_retries + self.retry_statuses = retry_statuses or DEFAULT_RETRY_STATUSES + self.initial_delay = initial_delay + self.backoff_factor = backoff_factor + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType, + ) -> ClientResponse: + """Execute request with retry logic.""" + last_response: Union[ClientResponse, None] = None + delay = self.initial_delay + + for attempt in range(self.max_retries + 1): + if attempt > 0: + _LOGGER.info( + "Retrying request to %s (attempt %s/%s)", + request.url, + attempt + 1, + self.max_retries + 1, + ) + + # Execute the request + response = await handler(request) + last_response = response + + # Check if we should retry + if response.status not in self.retry_statuses: + return response + + # Don't retry if we've exhausted attempts + if attempt >= self.max_retries: + _LOGGER.warning( + "Max retries (%s) exceeded for %s", self.max_retries, request.url + ) + return response + + # Wait before retrying + _LOGGER.debug("Waiting %ss before retry...", delay) + await asyncio.sleep(delay) + delay *= self.backoff_factor + + # Return the last response + if TYPE_CHECKING: + assert last_response is not None # Always set since we loop at least once + return last_response + + +class TestServer: + """Test server with stateful endpoints for retry testing.""" + + def __init__(self) -> None: + self.request_counters: Dict[str, int] = {} + self.status_sequences: Dict[str, List[int]] = { + "eventually-ok": [500, 503, 502, 200], # Fails 3 times, then succeeds + "always-error": [500, 500, 500, 500], # Always fails + "immediate-ok": [200], # Succeeds immediately + "flaky": [503, 200], # Fails once, then succeeds + } + + async def handle_status(self, request: web.Request) -> web.Response: + """Return the status code specified in the path.""" + status = int(request.match_info["status"]) + return web.Response(status=status, text=f"Status: {status}") + + async def handle_status_sequence(self, request: web.Request) -> web.Response: + """Return different status codes on sequential requests.""" + path = request.path + + # Initialize counter for this path if needed + if path not in self.request_counters: + self.request_counters[path] = 0 + + # Get the status sequence for this path + sequence_name = request.match_info["name"] + if sequence_name not in self.status_sequences: + return web.Response(status=404, text="Sequence not found") + + sequence = self.status_sequences[sequence_name] + + # Get the current status based on request count + count = self.request_counters[path] + if count < len(sequence): + status = sequence[count] + else: + # After sequence ends, always return the last status + status = sequence[-1] + + # Increment counter for next request + self.request_counters[path] += 1 + + return web.Response( + status=status, text=f"Request #{count + 1}: Status {status}" + ) + + async def handle_delay(self, request: web.Request) -> web.Response: + """Delay response by specified seconds.""" + delay = float(request.match_info["delay"]) + await asyncio.sleep(delay) + return web.json_response({"delay": delay, "message": "Response after delay"}) + + async def handle_reset(self, request: web.Request) -> web.Response: + """Reset request counters.""" + self.request_counters = {} + return web.Response(text="Counters reset") + + +async def run_test_server() -> web.AppRunner: + """Run a simple test server.""" + app = web.Application() + server = TestServer() + + app.router.add_get("/status/{status}", server.handle_status) + app.router.add_get("/sequence/{name}", server.handle_status_sequence) + app.router.add_get("/delay/{delay}", server.handle_delay) + app.router.add_post("/reset", server.handle_reset) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8080) + await site.start() + return runner + + +async def run_tests() -> None: + """Run all retry middleware tests.""" + # Create retry middleware with custom settings + retry_middleware = RetryMiddleware( + max_retries=3, + retry_statuses=DEFAULT_RETRY_STATUSES, + initial_delay=0.5, + backoff_factor=2.0, + ) + + async with ClientSession(middlewares=(retry_middleware,)) as session: + # Reset counters before tests + await session.post("http://localhost:8080/reset") + + # Test 1: Request that succeeds immediately + print("=== Test 1: Immediate success ===") + async with session.get("http://localhost:8080/sequence/immediate-ok") as resp: + text = await resp.text() + print(f"Final status: {resp.status}") + print(f"Response: {text}") + print("Success - no retries needed\n") + + # Test 2: Request that eventually succeeds after retries + print("=== Test 2: Eventually succeeds (500->503->502->200) ===") + async with session.get("http://localhost:8080/sequence/eventually-ok") as resp: + text = await resp.text() + print(f"Final status: {resp.status}") + print(f"Response: {text}") + if resp.status == 200: + print("Success after retries!\n") + else: + print("Failed after retries\n") + + # Test 3: Request that always fails + print("=== Test 3: Always fails (500->500->500->500) ===") + async with session.get("http://localhost:8080/sequence/always-error") as resp: + text = await resp.text() + print(f"Final status: {resp.status}") + print(f"Response: {text}") + print("Failed after exhausting all retries\n") + + # Test 4: Flaky service (fails once then succeeds) + print("=== Test 4: Flaky service (503->200) ===") + await session.post("http://localhost:8080/reset") # Reset counters + async with session.get("http://localhost:8080/sequence/flaky") as resp: + text = await resp.text() + print(f"Final status: {resp.status}") + print(f"Response: {text}") + print("Success after one retry!\n") + + # Test 5: Non-retryable status + print("=== Test 5: Non-retryable status (404) ===") + async with session.get("http://localhost:8080/status/404") as resp: + print(f"Final status: {resp.status}") + print("Failed immediately - not a retryable status\n") + + # Test 6: Delayed response + print("=== Test 6: Testing with delay endpoint ===") + try: + async with session.get("http://localhost:8080/delay/0.5") as resp: + print(f"Status: {resp.status}") + data = await resp.json() + print(f"Response received after delay: {data}\n") + except asyncio.TimeoutError: + print("Request timed out\n") + + +async def main() -> None: + # Start test server + server = await run_test_server() + + try: + await run_tests() + finally: + await server.cleanup() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/token_refresh_middleware.py b/examples/token_refresh_middleware.py new file mode 100644 index 00000000000..8a7ff963850 --- /dev/null +++ b/examples/token_refresh_middleware.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +""" +Example of using token refresh middleware with aiohttp client. + +This example shows how to implement a middleware that handles JWT token +refresh automatically. The middleware: +- Adds bearer tokens to requests +- Detects when tokens are expired +- Automatically refreshes tokens when needed +- Handles concurrent requests during token refresh + +This example includes a test server that simulates a JWT auth system. +Note: This is a simplified example for demonstration purposes. +In production, use proper JWT libraries and secure token storage. +""" + +import asyncio +import hashlib +import json +import logging +import secrets +import time +from http import HTTPStatus +from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Union + +from aiohttp import ( + ClientHandlerType, + ClientRequest, + ClientResponse, + ClientSession, + hdrs, + web, +) + +logging.basicConfig(level=logging.INFO) +_LOGGER = logging.getLogger(__name__) + + +class TokenRefreshMiddleware: + """Middleware that handles JWT token refresh automatically.""" + + def __init__(self, token_endpoint: str, refresh_token: str) -> None: + self.token_endpoint = token_endpoint + self.refresh_token = refresh_token + self.access_token: Union[str, None] = None + self.token_expires_at: Union[float, None] = None + self._refresh_lock = asyncio.Lock() + + async def _refresh_access_token(self, session: ClientSession) -> str: + """Refresh the access token using the refresh token.""" + async with self._refresh_lock: + # Check if another coroutine already refreshed the token + if ( + self.token_expires_at + and time.time() < self.token_expires_at + and self.access_token + ): + _LOGGER.debug("Token already refreshed by another request") + return self.access_token + + _LOGGER.info("Refreshing access token...") + + # Make refresh request without middleware to avoid recursion + async with session.post( + self.token_endpoint, + json={"refresh_token": self.refresh_token}, + middlewares=(), # Disable middleware for this request + ) as resp: + resp.raise_for_status() + data = await resp.json() + + if "access_token" not in data: + raise ValueError("No access_token in refresh response") + + self.access_token = data["access_token"] + # Token expires in 5 minutes for demo, refresh 30 seconds early + expires_in = data.get("expires_in", 300) + self.token_expires_at = time.time() + expires_in - 30 + + _LOGGER.info( + "Token refreshed successfully, expires in %s seconds", expires_in + ) + if TYPE_CHECKING: + assert self.access_token is not None # Just assigned above + return self.access_token + + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType, + ) -> ClientResponse: + """Add auth token to request, refreshing if needed.""" + # Skip token for refresh endpoint to avoid recursion + if str(request.url).endswith("/token/refresh"): + return await handler(request) + + # Refresh token if needed + if not self.access_token or ( + self.token_expires_at and time.time() >= self.token_expires_at + ): + await self._refresh_access_token(request.session) + + # Add token to request + request.headers[hdrs.AUTHORIZATION] = f"Bearer {self.access_token}" + _LOGGER.debug("Added Bearer token to request") + + # Execute request + response = await handler(request) + + # If we get 401, try refreshing token once + if response.status == HTTPStatus.UNAUTHORIZED: + _LOGGER.info("Got 401, attempting token refresh...") + await self._refresh_access_token(request.session) + request.headers[hdrs.AUTHORIZATION] = f"Bearer {self.access_token}" + response = await handler(request) + + return response + + +class TestServer: + """Test server with JWT-like token authentication.""" + + def __init__(self) -> None: + self.tokens_db: Dict[str, Dict[str, Union[str, float]]] = {} + self.refresh_tokens_db: Dict[str, Dict[str, Union[str, float]]] = { + # Hash of refresh token -> user data + hashlib.sha256(b"demo_refresh_token_12345").hexdigest(): { + "user_id": "user123", + "username": "testuser", + "issued_at": time.time(), + } + } + + def generate_access_token(self) -> str: + """Generate a secure random access token.""" + return secrets.token_urlsafe(32) + + async def _process_token_refresh(self, data: Dict[str, str]) -> web.Response: + """Process the token refresh request.""" + refresh_token = data.get("refresh_token") + + if not refresh_token: + return web.json_response({"error": "refresh_token required"}, status=400) + + # Hash the refresh token to look it up + refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest() + + if refresh_token_hash not in self.refresh_tokens_db: + return web.json_response({"error": "Invalid refresh token"}, status=401) + + user_data = self.refresh_tokens_db[refresh_token_hash] + + # Generate new access token + access_token = self.generate_access_token() + expires_in = 300 # 5 minutes for demo + + # Store the access token with expiry + token_hash = hashlib.sha256(access_token.encode()).hexdigest() + self.tokens_db[token_hash] = { + "user_id": user_data["user_id"], + "username": user_data["username"], + "expires_at": time.time() + expires_in, + "issued_at": time.time(), + } + + # Clean up expired tokens periodically + current_time = time.time() + self.tokens_db = { + k: v + for k, v in self.tokens_db.items() + if isinstance(v["expires_at"], float) and v["expires_at"] > current_time + } + + return web.json_response( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": expires_in, + } + ) + + async def handle_token_refresh(self, request: web.Request) -> web.Response: + """Handle token refresh requests.""" + try: + data = await request.json() + return await self._process_token_refresh(data) + except json.JSONDecodeError: + return web.json_response({"error": "Invalid request"}, status=400) + + async def verify_bearer_token( + self, request: web.Request + ) -> Union[Dict[str, Union[str, float]], None]: + """Verify bearer token and return user data if valid.""" + auth_header = request.headers.get(hdrs.AUTHORIZATION, "") + + if not auth_header.startswith("Bearer "): + return None + + token = auth_header[7:] # Remove "Bearer " + token_hash = hashlib.sha256(token.encode()).hexdigest() + + # Check if token exists and is not expired + if token_hash in self.tokens_db: + token_data = self.tokens_db[token_hash] + if ( + isinstance(token_data["expires_at"], float) + and token_data["expires_at"] > time.time() + ): + return token_data + + return None + + async def handle_protected_resource(self, request: web.Request) -> web.Response: + """Protected endpoint that requires valid bearer token.""" + user_data = await self.verify_bearer_token(request) + + if not user_data: + return web.json_response({"error": "Invalid or expired token"}, status=401) + + return web.json_response( + { + "message": "Access granted to protected resource", + "user": user_data["username"], + "data": "Secret information", + } + ) + + async def handle_user_info(self, request: web.Request) -> web.Response: + """Another protected endpoint.""" + user_data = await self.verify_bearer_token(request) + + if not user_data: + return web.json_response({"error": "Invalid or expired token"}, status=401) + + return web.json_response( + { + "user_id": user_data["user_id"], + "username": user_data["username"], + "email": f"{user_data['username']}@example.com", + "roles": ["user", "admin"], + } + ) + + +async def run_test_server() -> web.AppRunner: + """Run a test server with JWT auth endpoints.""" + test_server = TestServer() + app = web.Application() + app.router.add_post("/token/refresh", test_server.handle_token_refresh) + app.router.add_get("/api/protected", test_server.handle_protected_resource) + app.router.add_get("/api/user", test_server.handle_user_info) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8080) + await site.start() + return runner + + +async def run_tests() -> None: + """Run all token refresh middleware tests.""" + # Create token refresh middleware + # In a real app, this refresh token would be securely stored + token_middleware = TokenRefreshMiddleware( + token_endpoint="http://localhost:8080/token/refresh", + refresh_token="demo_refresh_token_12345", + ) + + async with ClientSession(middlewares=(token_middleware,)) as session: + print("=== Test 1: First request (will trigger token refresh) ===") + async with session.get("http://localhost:8080/api/protected") as resp: + if resp.status == 200: + data = await resp.json() + print(f"Success! Response: {data}") + else: + print(f"Failed with status: {resp.status}") + + print("\n=== Test 2: Second request (uses cached token) ===") + async with session.get("http://localhost:8080/api/user") as resp: + if resp.status == 200: + data = await resp.json() + print(f"User info: {data}") + else: + print(f"Failed with status: {resp.status}") + + print("\n=== Test 3: Multiple concurrent requests ===") + print("(Should only refresh token once)") + coros: List[Coroutine[Any, Any, ClientResponse]] = [] + for i in range(3): + coro = session.get("http://localhost:8080/api/protected") + coros.append(coro) + + responses = await asyncio.gather(*coros) + for i, resp in enumerate(responses): + async with resp: + if resp.status == 200: + print(f"Request {i + 1}: Success") + else: + print(f"Request {i + 1}: Failed with {resp.status}") + + print("\n=== Test 4: Simulate token expiry ===") + # For demo purposes, force token expiry + token_middleware.token_expires_at = time.time() - 1 + + print("Token expired, next request should trigger refresh...") + async with session.get("http://localhost:8080/api/protected") as resp: + if resp.status == 200: + data = await resp.json() + print(f"Success after token refresh! Response: {data}") + else: + print(f"Failed with status: {resp.status}") + + print("\n=== Test 5: Request without middleware (no auth) ===") + # Make a request without any middleware to show the difference + async with session.get( + "http://localhost:8080/api/protected", + middlewares=(), # Bypass all middleware for this request + ) as resp: + print(f"Status: {resp.status}") + if resp.status == 401: + error = await resp.json() + print(f"Failed as expected without auth: {error}") + + +async def main() -> None: + # Start test server + server = await run_test_server() + + try: + await run_tests() + finally: + await server.cleanup() + + +if __name__ == "__main__": + asyncio.run(main()) From a023a245f675b77c746d4cac37ac5289e4196070 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 23 May 2025 17:20:58 +0100 Subject: [PATCH 5/7] Upgrade to llhttp 3.9 (#10972) --- CHANGES/10972.feature.rst | 1 + vendor/llhttp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 CHANGES/10972.feature.rst diff --git a/CHANGES/10972.feature.rst b/CHANGES/10972.feature.rst new file mode 100644 index 00000000000..1d3779a3969 --- /dev/null +++ b/CHANGES/10972.feature.rst @@ -0,0 +1 @@ +Upgraded to LLHTTP 9.3.0 -- by :user:`Dreamsorcerer`. diff --git a/vendor/llhttp b/vendor/llhttp index b0b279fb5a6..36151b9a7d6 160000 --- a/vendor/llhttp +++ b/vendor/llhttp @@ -1 +1 @@ -Subproject commit b0b279fb5a617ab3bc2fc11c5f8bd937aac687c1 +Subproject commit 36151b9a7d6320072e24e472a769a5e09f9e969d From ff7feaf4327ccbf07ffc316db697e7bd4dfdbdda Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 23 May 2025 11:41:26 -0500 Subject: [PATCH 6/7] Update Key Features to mention client middleware (#10968) --- CHANGES/10968.feature.rst | 1 + docs/client_reference.rst | 1 - docs/index.rst | 2 ++ 3 files changed, 3 insertions(+), 1 deletion(-) create mode 120000 CHANGES/10968.feature.rst diff --git a/CHANGES/10968.feature.rst b/CHANGES/10968.feature.rst new file mode 120000 index 00000000000..b565aa68ee0 --- /dev/null +++ b/CHANGES/10968.feature.rst @@ -0,0 +1 @@ +9732.feature.rst \ No newline at end of file diff --git a/docs/client_reference.rst b/docs/client_reference.rst index cfb0e191196..40b4f7bcbf9 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -2033,7 +2033,6 @@ Utilities :return: encoded authentication data, :class:`str`. - .. class:: DigestAuthMiddleware(login, password) HTTP digest authentication client middleware. diff --git a/docs/index.rst b/docs/index.rst index 347bb034289..39b95cb1b61 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -23,6 +23,8 @@ Key Features without the Callback Hell. - Web-server has :ref:`aiohttp-web-middlewares`, :ref:`aiohttp-web-signals` and pluggable routing. +- Client supports :ref:`middleware ` for + customizing request/response processing. .. _aiohttp-installation: From 9c6da05b4233f07d41dd60bf68b8f7a4fe425d9b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 23 May 2025 11:54:52 -0500 Subject: [PATCH 7/7] Fix double compression issue when enable_compression() is called on pre-encoded responses (#10974) --- CHANGES/10968.bugfix.rst | 1 + aiohttp/web_response.py | 7 +++++++ tests/test_web_response.py | 25 +++++++++++++++++++++++++ 3 files changed, 33 insertions(+) create mode 100644 CHANGES/10968.bugfix.rst diff --git a/CHANGES/10968.bugfix.rst b/CHANGES/10968.bugfix.rst new file mode 100644 index 00000000000..052a7d2b8f9 --- /dev/null +++ b/CHANGES/10968.bugfix.rst @@ -0,0 +1 @@ +Fixed double compression issue when :py:meth:`~aiohttp.web.StreamResponse.enable_compression` is called on a response with pre-existing Content-Encoding header -- by :user:`bdraco`. diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index b637543b29c..85389963616 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -186,6 +186,13 @@ def enable_compression( strategy: Optional[int] = None, ) -> None: """Enables response compression encoding.""" + # Don't enable compression if content is already encoded. + # This prevents double compression and provides a safe, predictable behavior + # without breaking existing code that may call enable_compression() on + # responses that already have Content-Encoding set (e.g., FileResponse + # serving pre-compressed files). + if hdrs.CONTENT_ENCODING in self._headers: + return self._compression = True self._compression_force = force self._compression_strategy = strategy diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 98384d7eabf..d84201afa2d 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -627,6 +627,31 @@ async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert resp.content_length == 6 +async def test_enable_compression_with_existing_encoding() -> None: + """Test that enable_compression does not override existing content encoding.""" + writer = mock.Mock() + + async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: + # Should preserve the existing content encoding + assert headers[hdrs.CONTENT_ENCODING] == "gzip" + # Should not have double encoding + assert headers.get(hdrs.CONTENT_ENCODING) != "gzip, deflate" + + writer.write_headers.side_effect = write_headers + req = make_request("GET", "/", writer=writer) + resp = web.Response(body=b"answer") + + # Manually set content encoding (simulating FileResponse with pre-compressed file) + resp.headers[hdrs.CONTENT_ENCODING] = "gzip" + + # Try to enable compression - should be ignored + resp.enable_compression(web.ContentCoding.deflate) + + await resp.prepare(req) + # Verify compression was not enabled due to existing encoding + assert not resp.compression + + @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_rm_content_length_if_compression_http11() -> None: writer = mock.Mock()