From 4152a0835b8667827f9740c37c28b83e1d9e9f67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Thu, 22 May 2025 16:21:46 +0200 Subject: [PATCH 1/6] Support using system llhttp library (#10760) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 🇺🇦 Sviatoslav Sydorenko (Святослав Сидоренко) --- CHANGES/10759.packaging.rst | 5 +++++ aiohttp/_cparser.pxd | 2 +- docs/glossary.rst | 11 +++++++++++ docs/spelling_wordlist.txt | 1 + pyproject.toml | 1 + requirements/test.in | 1 + setup.py | 37 +++++++++++++++++++++++++++++++------ 7 files changed, 51 insertions(+), 7 deletions(-) create mode 100644 CHANGES/10759.packaging.rst diff --git a/CHANGES/10759.packaging.rst b/CHANGES/10759.packaging.rst new file mode 100644 index 00000000000..6f41e873229 --- /dev/null +++ b/CHANGES/10759.packaging.rst @@ -0,0 +1,5 @@ +Added support for building against system ``llhttp`` library -- by :user:`mgorny`. + +This change adds support for :envvar:`AIOHTTP_USE_SYSTEM_DEPS` environment variable that +can be used to build aiohttp against the system install of the ``llhttp`` library rather +than the vendored one. diff --git a/aiohttp/_cparser.pxd b/aiohttp/_cparser.pxd index c2cd5a92fda..1b3be6d4efb 100644 --- a/aiohttp/_cparser.pxd +++ b/aiohttp/_cparser.pxd @@ -1,7 +1,7 @@ from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t -cdef extern from "../vendor/llhttp/build/llhttp.h": +cdef extern from "llhttp.h": struct llhttp__internal_s: int32_t _index diff --git a/docs/glossary.rst b/docs/glossary.rst index 392ef740cd1..996ea982d58 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -151,6 +151,17 @@ Environment Variables ===================== +.. envvar:: AIOHTTP_NO_EXTENSIONS + + If set to a non-empty value while building from source, aiohttp will be built without speedups + written as C extensions. This option is primarily useful for debugging. + +.. envvar:: AIOHTTP_USE_SYSTEM_DEPS + + If set to a non-empty value while building from source, aiohttp will be built against + the system installation of llhttp rather than the vendored library. This option is primarily + meant to be used by downstream redistributors. + .. envvar:: NETRC If set, HTTP Basic Auth will be read from the file pointed to by this environment variable, diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index db6c500d5f7..885e79d8466 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -176,6 +176,7 @@ kwargs latin lifecycle linux +llhttp localhost Locator login diff --git a/pyproject.toml b/pyproject.toml index e21ba283b11..a9b4200a06c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [build-system] requires = [ + "pkgconfig", "setuptools >= 46.4.0", ] build-backend = "setuptools.build_meta" diff --git a/requirements/test.in b/requirements/test.in index b8b82abd1ce..bf000f27443 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -5,6 +5,7 @@ coverage freezegun isal mypy; implementation_name == "cpython" +pkgconfig proxy.py >= 2.4.4rc5 pytest pytest-cov diff --git a/setup.py b/setup.py index c9a2c5c856c..fded89876f2 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,9 @@ raise RuntimeError("aiohttp 4.x requires Python 3.9+") +USE_SYSTEM_DEPS = bool( + os.environ.get("AIOHTTP_USE_SYSTEM_DEPS", os.environ.get("USE_SYSTEM_DEPS")) +) NO_EXTENSIONS: bool = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) HERE = pathlib.Path(__file__).parent IS_GIT_REPO = (HERE / ".git").exists() @@ -17,7 +20,11 @@ NO_EXTENSIONS = True -if IS_GIT_REPO and not (HERE / "vendor/llhttp/README.md").exists(): +if ( + not USE_SYSTEM_DEPS + and IS_GIT_REPO + and not (HERE / "vendor/llhttp/README.md").exists() +): print("Install submodules when building from git clone", file=sys.stderr) print("Hint:", file=sys.stderr) print(" git submodule update --init", file=sys.stderr) @@ -26,6 +33,27 @@ # NOTE: makefile cythonizes all Cython modules +if USE_SYSTEM_DEPS: + import shlex + + import pkgconfig + + llhttp_sources = [] + llhttp_kwargs = { + "extra_compile_args": shlex.split(pkgconfig.cflags("libllhttp")), + "extra_link_args": shlex.split(pkgconfig.libs("libllhttp")), + } +else: + llhttp_sources = [ + "vendor/llhttp/build/c/llhttp.c", + "vendor/llhttp/src/native/api.c", + "vendor/llhttp/src/native/http.c", + ] + llhttp_kwargs = { + "define_macros": [("LLHTTP_STRICT_MODE", 0)], + "include_dirs": ["vendor/llhttp/build"], + } + extensions = [ Extension("aiohttp._websocket.mask", ["aiohttp/_websocket/mask.c"]), Extension( @@ -33,12 +61,9 @@ [ "aiohttp/_http_parser.c", "aiohttp/_find_header.c", - "vendor/llhttp/build/c/llhttp.c", - "vendor/llhttp/src/native/api.c", - "vendor/llhttp/src/native/http.c", + *llhttp_sources, ], - define_macros=[("LLHTTP_STRICT_MODE", 0)], - include_dirs=["vendor/llhttp/build"], + **llhttp_kwargs, ), Extension("aiohttp._http_writer", ["aiohttp/_http_writer.c"]), Extension("aiohttp._websocket.reader_c", ["aiohttp/_websocket/reader_c.c"]), From 5fac5f1988d294ab811318face9dda67602bb923 Mon Sep 17 00:00:00 2001 From: Vizonex <114684698+Vizonex@users.noreply.github.com> Date: Thu, 22 May 2025 09:27:01 -0500 Subject: [PATCH 2/6] Add Winloop to test suite if User is using Windows (#10922) Co-authored-by: J. Nick Koston Co-authored-by: J. Nick Koston Co-authored-by: Sam Bull --- CHANGES/10922.contrib.rst | 1 + CONTRIBUTORS.txt | 1 + docs/spelling_wordlist.txt | 1 + requirements/base.in | 1 + requirements/base.txt | 1 + tests/conftest.py | 5 ++++- 6 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 CHANGES/10922.contrib.rst diff --git a/CHANGES/10922.contrib.rst b/CHANGES/10922.contrib.rst new file mode 100644 index 00000000000..e5e1cfd8af6 --- /dev/null +++ b/CHANGES/10922.contrib.rst @@ -0,0 +1 @@ +Added Winloop to test suite to support in the future -- by :user:`Vizonex`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 1e5f0da2684..ada385c74e3 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -369,6 +369,7 @@ Vincent Maillol Vitalik Verhovodov Vitaly Haritonsky Vitaly Magerya +Vizonex Vladimir Kamarzin Vladimir Kozlovski Vladimir Rutsky diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 885e79d8466..34642ec64da 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -369,6 +369,7 @@ websocket’s websockets Websockets wildcard +Winloop Workflow ws wsgi diff --git a/requirements/base.in b/requirements/base.in index 70493b6c83a..816a4e84026 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -2,3 +2,4 @@ gunicorn uvloop; platform_system != "Windows" and implementation_name == "cpython" # MagicStack/uvloop#14 +winloop; platform_system == "Windows" and implementation_name == "cpython" diff --git a/requirements/base.txt b/requirements/base.txt index f446aa95b2a..7c568a3b3e0 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -41,6 +41,7 @@ pycparser==2.22 typing-extensions==4.13.2 # via multidict uvloop==0.21.0 ; platform_system != "Windows" and implementation_name == "cpython" +winloop==0.1.8; platform_system == "Windows" and implementation_name == "cpython" # via -r requirements/base.in yarl==1.20.0 # via -r requirements/runtime-deps.in diff --git a/tests/conftest.py b/tests/conftest.py index 97b8c960a69..9b27519410e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,10 @@ try: - import uvloop + if sys.platform == "win32": + import winloop as uvloop + else: + import uvloop except ImportError: uvloop = None # type: ignore[assignment] From 545783b9e91c69d48ffd3b85c7eb13d1b19eb55e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 22 May 2025 09:58:04 -0500 Subject: [PATCH 3/6] Fix connection reuse for file-like data payloads (#10915) --- CHANGES/10325.bugfix.rst | 1 + CHANGES/10915.bugfix.rst | 3 + aiohttp/client_reqrep.py | 77 +++++- aiohttp/payload.py | 428 ++++++++++++++++++++++++++++---- tests/conftest.py | 16 ++ tests/test_client_functional.py | 159 +++++++++++- tests/test_client_request.py | 90 ++++++- tests/test_payload.py | 334 ++++++++++++++++++++++++- 8 files changed, 1044 insertions(+), 64 deletions(-) create mode 120000 CHANGES/10325.bugfix.rst create mode 100644 CHANGES/10915.bugfix.rst diff --git a/CHANGES/10325.bugfix.rst b/CHANGES/10325.bugfix.rst new file mode 120000 index 00000000000..aa085cc590d --- /dev/null +++ b/CHANGES/10325.bugfix.rst @@ -0,0 +1 @@ +10915.bugfix.rst \ No newline at end of file diff --git a/CHANGES/10915.bugfix.rst b/CHANGES/10915.bugfix.rst new file mode 100644 index 00000000000..f564603306b --- /dev/null +++ b/CHANGES/10915.bugfix.rst @@ -0,0 +1,3 @@ +Fixed connection reuse for file-like data payloads by ensuring buffer +truncation respects content-length boundaries and preventing premature +connection closure race -- by :user:`bdraco`. diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index db4018efa1d..cf3c48a1c16 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -304,6 +304,23 @@ def __init__( def __reset_writer(self, _: object = None) -> None: self.__writer = None + def _get_content_length(self) -> Optional[int]: + """Extract and validate Content-Length header value. + + Returns parsed Content-Length value or None if not set. + Raises ValueError if header exists but cannot be parsed as an integer. + """ + if hdrs.CONTENT_LENGTH not in self.headers: + return None + + content_length_hdr = self.headers[hdrs.CONTENT_LENGTH] + try: + return int(content_length_hdr) + except ValueError: + raise ValueError( + f"Invalid Content-Length header: {content_length_hdr}" + ) from None + @property def skip_auto_headers(self) -> CIMultiDict[None]: return self._skip_auto_headers or CIMultiDict() @@ -596,9 +613,37 @@ def update_proxy( self.proxy_headers = proxy_headers async def write_bytes( - self, writer: AbstractStreamWriter, conn: "Connection" + self, + writer: AbstractStreamWriter, + conn: "Connection", + content_length: Optional[int], ) -> None: - """Support coroutines that yields bytes objects.""" + """ + Write the request body to the connection stream. + + This method handles writing different types of request bodies: + 1. Payload objects (using their specialized write_with_length method) + 2. Bytes/bytearray objects + 3. Iterable body content + + Args: + writer: The stream writer to write the body to + conn: The connection being used for this request + content_length: Optional maximum number of bytes to write from the body + (None means write the entire body) + + The method properly handles: + - Waiting for 100-Continue responses if required + - Content length constraints for chunked encoding + - Error handling for network issues, cancellation, and other exceptions + - Signaling EOF and timeout management + + Raises: + ClientOSError: When there's an OS-level error writing the body + ClientConnectionError: When there's a general connection error + asyncio.CancelledError: When the operation is cancelled + + """ # 100 response if self._continue is not None: await writer.drain() @@ -608,16 +653,30 @@ async def write_bytes( assert protocol is not None try: if isinstance(self.body, payload.Payload): - await self.body.write(writer) + # Specialized handling for Payload objects that know how to write themselves + await self.body.write_with_length(writer, content_length) else: + # Handle bytes/bytearray by converting to an iterable for consistent handling if isinstance(self.body, (bytes, bytearray)): self.body = (self.body,) - for chunk in self.body: - await writer.write(chunk) + if content_length is None: + # Write the entire body without length constraint + for chunk in self.body: + await writer.write(chunk) + else: + # Write with length constraint, respecting content_length limit + # If the body is larger than content_length, we truncate it + remaining_bytes = content_length + for chunk in self.body: + await writer.write(chunk[:remaining_bytes]) + remaining_bytes -= len(chunk) + if remaining_bytes <= 0: + break except OSError as underlying_exc: reraised_exc = underlying_exc + # Distinguish between timeout and other OS errors for better error reporting exc_is_not_timeout = underlying_exc.errno is not None or not isinstance( underlying_exc, asyncio.TimeoutError ) @@ -629,18 +688,20 @@ async def write_bytes( set_exception(protocol, reraised_exc, underlying_exc) except asyncio.CancelledError: - # Body hasn't been fully sent, so connection can't be reused. + # Body hasn't been fully sent, so connection can't be reused conn.close() raise except Exception as underlying_exc: set_exception( protocol, ClientConnectionError( - f"Failed to send bytes into the underlying connection {conn !s}", + "Failed to send bytes into the underlying connection " + f"{conn !s}: {underlying_exc!r}", ), underlying_exc, ) else: + # Successfully wrote the body, signal EOF and start response timeout await writer.write_eof() protocol.start_timeout() @@ -705,7 +766,7 @@ async def send(self, conn: "Connection") -> "ClientResponse": await writer.write_headers(status_line, self.headers) task: Optional["asyncio.Task[None]"] if self.body or self._continue is not None or protocol.writing_paused: - coro = self.write_bytes(writer, conn) + coro = self.write_bytes(writer, conn, self._get_content_length()) if sys.version_info >= (3, 12): # Optimization for Python 3.12, try to write # bytes immediately to avoid having to schedule diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 55a7a677f49..7339e720fc9 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -16,6 +16,7 @@ Final, Iterable, Optional, + Set, TextIO, Tuple, Type, @@ -53,6 +54,9 @@ ) TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB +READ_SIZE: Final[int] = 2**16 # 64 KB +_CLOSE_FUTURES: Set[asyncio.Future[None]] = set() + if TYPE_CHECKING: from typing import List @@ -238,10 +242,46 @@ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: @abstractmethod async def write(self, writer: AbstractStreamWriter) -> None: - """Write payload. + """Write payload to the writer stream. + + Args: + writer: An AbstractStreamWriter instance that handles the actual writing + + This is a legacy method that writes the entire payload without length constraints. + + Important: + For new implementations, use write_with_length() instead of this method. + This method is maintained for backwards compatibility and will eventually + delegate to write_with_length(writer, None) in all implementations. + + All payload subclasses must override this method for backwards compatibility, + but new code should use write_with_length for more flexibility and control. + """ + + # write_with_length is new in aiohttp 3.12 + # it should be overridden by subclasses + async def write_with_length( + self, writer: AbstractStreamWriter, content_length: Optional[int] + ) -> None: + """ + Write payload with a specific content length constraint. + + Args: + writer: An AbstractStreamWriter instance that handles the actual writing + content_length: Maximum number of bytes to write (None for unlimited) + + This method allows writing payload content with a specific length constraint, + which is particularly useful for HTTP responses with Content-Length header. + + Note: + This is the base implementation that provides backwards compatibility + for subclasses that don't override this method. Specific payload types + should override this method to implement proper length-constrained writing. - writer is an AbstractStreamWriter instance: """ + # Backwards compatibility for subclasses that don't override this method + # and for the default implementation + await self.write(writer) class BytesPayload(Payload): @@ -275,8 +315,40 @@ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: return self._value.decode(encoding, errors) async def write(self, writer: AbstractStreamWriter) -> None: + """Write the entire bytes payload to the writer stream. + + Args: + writer: An AbstractStreamWriter instance that handles the actual writing + + This method writes the entire bytes content without any length constraint. + + Note: + For new implementations that need length control, use write_with_length(). + This method is maintained for backwards compatibility and is equivalent + to write_with_length(writer, None). + """ await writer.write(self._value) + async def write_with_length( + self, writer: AbstractStreamWriter, content_length: Optional[int] + ) -> None: + """ + Write bytes payload with a specific content length constraint. + + Args: + writer: An AbstractStreamWriter instance that handles the actual writing + content_length: Maximum number of bytes to write (None for unlimited) + + This method writes either the entire byte sequence or a slice of it + up to the specified content_length. For BytesPayload, this operation + is performed efficiently using array slicing. + + """ + if content_length is not None: + await writer.write(self._value[:content_length]) + else: + await writer.write(self._value) + class StringPayload(BytesPayload): def __init__( @@ -328,15 +400,165 @@ def __init__( if hdrs.CONTENT_DISPOSITION not in self.headers: self.set_content_disposition(disposition, filename=self._filename) + def _read_and_available_len( + self, remaining_content_len: Optional[int] + ) -> Tuple[Optional[int], bytes]: + """ + Read the file-like object and return both its total size and the first chunk. + + Args: + remaining_content_len: Optional limit on how many bytes to read in this operation. + If None, READ_SIZE will be used as the default chunk size. + + Returns: + A tuple containing: + - The total size of the remaining unread content (None if size cannot be determined) + - The first chunk of bytes read from the file object + + This method is optimized to perform both size calculation and initial read + in a single operation, which is executed in a single executor job to minimize + context switches and file operations when streaming content. + + """ + size = self.size # Call size only once since it does I/O + return size, self._value.read( + min(size or READ_SIZE, remaining_content_len or READ_SIZE) + ) + + def _read(self, remaining_content_len: Optional[int]) -> bytes: + """ + Read a chunk of data from the file-like object. + + Args: + remaining_content_len: Optional maximum number of bytes to read. + If None, READ_SIZE will be used as the default chunk size. + + Returns: + A chunk of bytes read from the file object, respecting the + remaining_content_len limit if specified. + + This method is used for subsequent reads during streaming after + the initial _read_and_available_len call has been made. + + """ + return self._value.read(remaining_content_len or READ_SIZE) # type: ignore[no-any-return] + + @property + def size(self) -> Optional[int]: + try: + return os.fstat(self._value.fileno()).st_size - self._value.tell() + except (AttributeError, OSError): + return None + async def write(self, writer: AbstractStreamWriter) -> None: - loop = asyncio.get_event_loop() + """ + Write the entire file-like payload to the writer stream. + + Args: + writer: An AbstractStreamWriter instance that handles the actual writing + + This method writes the entire file content without any length constraint. + It delegates to write_with_length() with no length limit for implementation + consistency. + + Note: + For new implementations that need length control, use write_with_length() directly. + This method is maintained for backwards compatibility with existing code. + + """ + await self.write_with_length(writer, None) + + async def write_with_length( + self, writer: AbstractStreamWriter, content_length: Optional[int] + ) -> None: + """ + Write file-like payload with a specific content length constraint. + + Args: + writer: An AbstractStreamWriter instance that handles the actual writing + content_length: Maximum number of bytes to write (None for unlimited) + + This method implements optimized streaming of file content with length constraints: + + 1. File reading is performed in a thread pool to avoid blocking the event loop + 2. Content is read and written in chunks to maintain memory efficiency + 3. Writing stops when either: + - All available file content has been written (when size is known) + - The specified content_length has been reached + 4. File resources are properly closed even if the operation is cancelled + + The implementation carefully handles both known-size and unknown-size payloads, + as well as constrained and unconstrained content lengths. + + """ + loop = asyncio.get_running_loop() + total_written_len = 0 + remaining_content_len = content_length + try: - chunk = await loop.run_in_executor(None, self._value.read, 2**16) + # Get initial data and available length + available_len, chunk = await loop.run_in_executor( + None, self._read_and_available_len, remaining_content_len + ) + # Process data chunks until done while chunk: - await writer.write(chunk) - chunk = await loop.run_in_executor(None, self._value.read, 2**16) + chunk_len = len(chunk) + + # Write data with or without length constraint + if remaining_content_len is None: + await writer.write(chunk) + else: + await writer.write(chunk[:remaining_content_len]) + remaining_content_len -= chunk_len + + total_written_len += chunk_len + + # Check if we're done writing + if self._should_stop_writing( + available_len, total_written_len, remaining_content_len + ): + return + + # Read next chunk + chunk = await loop.run_in_executor( + None, self._read, remaining_content_len + ) finally: - await loop.run_in_executor(None, self._value.close) + # Handle closing the file without awaiting to prevent cancellation issues + # when the StreamReader reaches EOF + self._schedule_file_close(loop) + + def _should_stop_writing( + self, + available_len: Optional[int], + total_written_len: int, + remaining_content_len: Optional[int], + ) -> bool: + """ + Determine if we should stop writing data. + + Args: + available_len: Known size of the payload if available (None if unknown) + total_written_len: Number of bytes already written + remaining_content_len: Remaining bytes to be written for content-length limited responses + + Returns: + True if we should stop writing data, based on either: + - Having written all available data (when size is known) + - Having written all requested content (when content-length is specified) + + """ + return (available_len is not None and total_written_len >= available_len) or ( + remaining_content_len is not None and remaining_content_len <= 0 + ) + + def _schedule_file_close(self, loop: asyncio.AbstractEventLoop) -> None: + """Schedule file closing without awaiting to prevent cancellation issues.""" + close_future = loop.run_in_executor(None, self._value.close) + # Hold a strong reference to the future to prevent it from being + # garbage collected before it completes. + _CLOSE_FUTURES.add(close_future) + close_future.add_done_callback(_CLOSE_FUTURES.remove) def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: return "".join(r.decode(encoding, errors) for r in self._value.readlines()) @@ -372,31 +594,60 @@ def __init__( **kwargs, ) - @property - def size(self) -> Optional[int]: - try: - return os.fstat(self._value.fileno()).st_size - self._value.tell() - except OSError: - return None + def _read_and_available_len( + self, remaining_content_len: Optional[int] + ) -> Tuple[Optional[int], bytes]: + """ + Read the text file-like object and return both its total size and the first chunk. + + Args: + remaining_content_len: Optional limit on how many bytes to read in this operation. + If None, READ_SIZE will be used as the default chunk size. + + Returns: + A tuple containing: + - The total size of the remaining unread content (None if size cannot be determined) + - The first chunk of bytes read from the file object, encoded using the payload's encoding + + This method is optimized to perform both size calculation and initial read + in a single operation, which is executed in a single executor job to minimize + context switches and file operations when streaming content. + + Note: + TextIOPayload handles encoding of the text content before writing it + to the stream. If no encoding is specified, UTF-8 is used as the default. + + """ + size = self.size + chunk = self._value.read( + min(size or READ_SIZE, remaining_content_len or READ_SIZE) + ) + return size, chunk.encode(self._encoding) if self._encoding else chunk.encode() + + def _read(self, remaining_content_len: Optional[int]) -> bytes: + """ + Read a chunk of data from the text file-like object. + + Args: + remaining_content_len: Optional maximum number of bytes to read. + If None, READ_SIZE will be used as the default chunk size. + + Returns: + A chunk of bytes read from the file object and encoded using the payload's + encoding. The data is automatically converted from text to bytes. + + This method is used for subsequent reads during streaming after + the initial _read_and_available_len call has been made. It properly + handles text encoding, converting the text content to bytes using + the specified encoding (or UTF-8 if none was provided). + + """ + chunk = self._value.read(remaining_content_len or READ_SIZE) + return chunk.encode(self._encoding) if self._encoding else chunk.encode() def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: return self._value.read() - async def write(self, writer: AbstractStreamWriter) -> None: - loop = asyncio.get_event_loop() - try: - chunk = await loop.run_in_executor(None, self._value.read, 2**16) - while chunk: - data = ( - chunk.encode(encoding=self._encoding) - if self._encoding - else chunk.encode() - ) - await writer.write(data) - chunk = await loop.run_in_executor(None, self._value.read, 2**16) - finally: - await loop.run_in_executor(None, self._value.close) - class BytesIOPayload(IOBasePayload): _value: io.BytesIO @@ -411,20 +662,55 @@ def size(self) -> int: def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: return self._value.read().decode(encoding, errors) + async def write(self, writer: AbstractStreamWriter) -> None: + return await self.write_with_length(writer, None) -class BufferedReaderPayload(IOBasePayload): - _value: io.BufferedIOBase + async def write_with_length( + self, writer: AbstractStreamWriter, content_length: Optional[int] + ) -> None: + """ + Write BytesIO payload with a specific content length constraint. - @property - def size(self) -> Optional[int]: + Args: + writer: An AbstractStreamWriter instance that handles the actual writing + content_length: Maximum number of bytes to write (None for unlimited) + + This implementation is specifically optimized for BytesIO objects: + + 1. Reads content in chunks to maintain memory efficiency + 2. Yields control back to the event loop periodically to prevent blocking + when dealing with large BytesIO objects + 3. Respects content_length constraints when specified + 4. Properly cleans up by closing the BytesIO object when done or on error + + The periodic yielding to the event loop is important for maintaining + responsiveness when processing large in-memory buffers. + + """ + loop_count = 0 + remaining_bytes = content_length try: - return os.fstat(self._value.fileno()).st_size - self._value.tell() - except (OSError, AttributeError): - # data.fileno() is not supported, e.g. - # io.BufferedReader(io.BytesIO(b'data')) - # For some file-like objects (e.g. tarfile), the fileno() attribute may - # not exist at all, and will instead raise an AttributeError. - return None + while chunk := self._value.read(READ_SIZE): + if loop_count > 0: + # Avoid blocking the event loop + # if they pass a large BytesIO object + # and we are not in the first iteration + # of the loop + await asyncio.sleep(0) + if remaining_bytes is None: + await writer.write(chunk) + else: + await writer.write(chunk[:remaining_bytes]) + remaining_bytes -= len(chunk) + if remaining_bytes <= 0: + return + loop_count += 1 + finally: + self._value.close() + + +class BufferedReaderPayload(IOBasePayload): + _value: io.BufferedIOBase def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: return self._value.read().decode(encoding, errors) @@ -481,15 +767,63 @@ def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: self._iter = value.__aiter__() async def write(self, writer: AbstractStreamWriter) -> None: - if self._iter: - try: - # iter is not None check prevents rare cases - # when the case iterable is used twice - while True: - chunk = await self._iter.__anext__() + """ + Write the entire async iterable payload to the writer stream. + + Args: + writer: An AbstractStreamWriter instance that handles the actual writing + + This method iterates through the async iterable and writes each chunk + to the writer without any length constraint. + + Note: + For new implementations that need length control, use write_with_length() directly. + This method is maintained for backwards compatibility with existing code. + + """ + await self.write_with_length(writer, None) + + async def write_with_length( + self, writer: AbstractStreamWriter, content_length: Optional[int] + ) -> None: + """ + Write async iterable payload with a specific content length constraint. + + Args: + writer: An AbstractStreamWriter instance that handles the actual writing + content_length: Maximum number of bytes to write (None for unlimited) + + This implementation handles streaming of async iterable content with length constraints: + + 1. Iterates through the async iterable one chunk at a time + 2. Respects content_length constraints when specified + 3. Handles the case when the iterable might be used twice + + Since async iterables are consumed as they're iterated, there is no way to + restart the iteration if it's already in progress or completed. + + """ + if self._iter is None: + return + + remaining_bytes = content_length + + try: + while True: + chunk = await self._iter.__anext__() + if remaining_bytes is None: await writer.write(chunk) - except StopAsyncIteration: - self._iter = None + # If we have a content length limit + elif remaining_bytes > 0: + await writer.write(chunk[:remaining_bytes]) + remaining_bytes -= len(chunk) + # We still want to exhaust the iterator even + # if we have reached the content length limit + # since the file handle may not get closed by + # the iterator if we don't do this + except StopAsyncIteration: + # Iterator is exhausted + self._iter = None def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: raise TypeError("Unable to decode.") diff --git a/tests/conftest.py b/tests/conftest.py index 9b27519410e..7e3f85fdd95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ import zlib_ng.zlib_ng from blockbuster import blockbuster_ctx +from aiohttp import payload from aiohttp.client_proto import ResponseHandler from aiohttp.compression_utils import ZLibBackend, ZLibBackendProtocol, set_zlib_backend from aiohttp.http import WS_KEY @@ -334,3 +335,18 @@ def parametrize_zlib_backend( yield set_zlib_backend(original_backend) + + +@pytest.fixture() +def cleanup_payload_pending_file_closes( + loop: asyncio.AbstractEventLoop, +) -> Generator[None, None, None]: + """Ensure all pending file close operations complete during test teardown.""" + yield + if payload._CLOSE_FUTURES: + # Only wait for futures from the current loop + loop_futures = [f for f in payload._CLOSE_FUTURES if f.get_loop() is loop] + if loop_futures: + loop.run_until_complete( + asyncio.gather(*loop_futures, return_exceptions=True) + ) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 15b1d48a686..1eb0fd28c9f 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -53,6 +53,13 @@ from aiohttp.typedefs import Handler, Query +@pytest.fixture(autouse=True) +def cleanup( + cleanup_payload_pending_file_closes: None, +) -> None: + """Ensure all pending file close operations complete during test teardown.""" + + @pytest.fixture def here() -> pathlib.Path: return pathlib.Path(__file__).parent @@ -1576,7 +1583,10 @@ async def handler(request: web.Request) -> web.Response: original_write_bytes = ClientRequest.write_bytes async def write_bytes( - self: ClientRequest, writer: StreamWriter, conn: Connection + self: ClientRequest, + writer: StreamWriter, + conn: Connection, + content_length: Optional[int] = None, ) -> None: nonlocal write_mock original_write = writer._write @@ -1584,7 +1594,7 @@ async def write_bytes( with mock.patch.object( writer, "_write", autospec=True, spec_set=True, side_effect=original_write ) as write_mock: - await original_write_bytes(self, writer, conn) + await original_write_bytes(self, writer, conn, content_length) with mock.patch.object(ClientRequest, "write_bytes", write_bytes): app = web.Application() @@ -1983,8 +1993,7 @@ async def handler(request: web.Request) -> web.Response: app.router.add_post("/", handler) client = await aiohttp_client(app) - with fname.open("rb") as f: - data_size = len(f.read()) + data_size = len(expected) async def gen(fname: pathlib.Path) -> AsyncIterator[bytes]: with fname.open("rb") as f: @@ -4226,3 +4235,145 @@ async def handler(request: web.Request) -> web.Response: with pytest.raises(RuntimeError, match="Connection closed"): await resp.read() + + +async def test_content_length_limit_enforced(aiohttp_server: AiohttpServer) -> None: + """Test that Content-Length header value limits the amount of data sent to the server.""" + received_data = bytearray() + + async def handler(request: web.Request) -> web.Response: + # Read all data from the request and store it + data = await request.read() + received_data.extend(data) + return web.Response(text="OK") + + app = web.Application() + app.router.add_post("/", handler) + + server = await aiohttp_server(app) + + # Create data larger than what we'll limit with Content-Length + data = b"X" * 1000 + # Only send 500 bytes even though data is 1000 bytes + headers = {"Content-Length": "500"} + + async with aiohttp.ClientSession() as session: + await session.post(server.make_url("/"), data=data, headers=headers) + + # Verify only 500 bytes (not the full 1000) were received by the server + assert len(received_data) == 500 + assert received_data == b"X" * 500 + + +async def test_content_length_limit_with_multiple_reads( + aiohttp_server: AiohttpServer, +) -> None: + """Test that Content-Length header value limits multi read data properly.""" + received_data = bytearray() + + async def handler(request: web.Request) -> web.Response: + # Read all data from the request and store it + data = await request.read() + received_data.extend(data) + return web.Response(text="OK") + + app = web.Application() + app.router.add_post("/", handler) + + server = await aiohttp_server(app) + + # Create an async generator of data + async def data_generator() -> AsyncIterator[bytes]: + yield b"Chunk1" * 100 # 600 bytes + yield b"Chunk2" * 100 # another 600 bytes + + # Limit to 800 bytes even though we'd generate 1200 bytes + headers = {"Content-Length": "800"} + + async with aiohttp.ClientSession() as session: + await session.post(server.make_url("/"), data=data_generator(), headers=headers) + + # Verify only 800 bytes (not the full 1200) were received by the server + assert len(received_data) == 800 + # First chunk fully sent (600 bytes) + assert received_data.startswith(b"Chunk1" * 100) + + # The rest should be from the second chunk (the exact split might vary by implementation) + assert b"Chunk2" in received_data # Some part of the second chunk was sent + # 200 bytes from the second chunk + assert len(received_data) - len(b"Chunk1" * 100) == 200 + + +async def test_post_connection_cleanup_with_bytesio( + aiohttp_client: AiohttpClient, +) -> None: + """Test that connections are properly cleaned up when using BytesIO data.""" + + async def handler(request: web.Request) -> web.Response: + return web.Response(body=b"") + + app = web.Application() + app.router.add_post("/hello", handler) + client = await aiohttp_client(app) + + # Test with direct bytes and BytesIO multiple times to ensure connection cleanup + for _ in range(10): + async with client.post( + "/hello", + data=b"x", + headers={"Content-Length": "1"}, + ) as response: + response.raise_for_status() + + assert client._session.connector is not None + assert len(client._session.connector._conns) == 1 + + x = io.BytesIO(b"x") + async with client.post( + "/hello", + data=x, + headers={"Content-Length": "1"}, + ) as response: + response.raise_for_status() + + assert len(client._session.connector._conns) == 1 + + +async def test_post_connection_cleanup_with_file( + aiohttp_client: AiohttpClient, here: pathlib.Path +) -> None: + """Test that connections are properly cleaned up when using file data.""" + + async def handler(request: web.Request) -> web.Response: + await request.read() + return web.Response(body=b"") + + app = web.Application() + app.router.add_post("/hello", handler) + client = await aiohttp_client(app) + + test_file = here / "data.unknown_mime_type" + + # Test with direct bytes and file multiple times to ensure connection cleanup + for _ in range(10): + async with client.post( + "/hello", + data=b"xx", + headers={"Content-Length": "2"}, + ) as response: + response.raise_for_status() + + assert client._session.connector is not None + assert len(client._session.connector._conns) == 1 + fh = await asyncio.get_running_loop().run_in_executor( + None, open, test_file, "rb" + ) + + async with client.post( + "/hello", + data=fh, + headers={"Content-Length": str(test_file.stat().st_size)}, + ) as response: + response.raise_for_status() + + assert len(client._session.connector._conns) == 1 diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 8458c376b78..6b094171012 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -12,7 +12,9 @@ Iterable, Iterator, List, + Optional, Protocol, + Union, ) from unittest import mock @@ -33,7 +35,7 @@ ) from aiohttp.compression_utils import ZLibBackend from aiohttp.connector import Connection -from aiohttp.http import HttpVersion10, HttpVersion11 +from aiohttp.http import HttpVersion10, HttpVersion11, StreamWriter from aiohttp.typedefs import LooseCookies @@ -1054,10 +1056,12 @@ async def gen() -> AsyncIterator[bytes]: assert req.headers["TRANSFER-ENCODING"] == "chunked" original_write_bytes = req.write_bytes - async def _mock_write_bytes(writer: AbstractStreamWriter, conn: mock.Mock) -> None: + async def _mock_write_bytes( + writer: AbstractStreamWriter, conn: mock.Mock, content_length: Optional[int] + ) -> None: # Ensure the task is scheduled await asyncio.sleep(0) - await original_write_bytes(writer, conn) + await original_write_bytes(writer, conn, content_length) with mock.patch.object(req, "write_bytes", _mock_write_bytes): resp = await req.send(conn) @@ -1260,7 +1264,7 @@ async def test_oserror_on_write_bytes( writer = WriterMock() writer.write.side_effect = OSError - await req.write_bytes(writer, conn) + await req.write_bytes(writer, conn, None) assert conn.protocol.set_exception.called exc = conn.protocol.set_exception.call_args[0][0] @@ -1576,3 +1580,81 @@ def test_request_info_tuple_new() -> None: ).real_url is url ) + + +def test_get_content_length(make_request: _RequestMaker) -> None: + """Test _get_content_length method extracts Content-Length correctly.""" + req = make_request("get", "http://python.org/") + + # No Content-Length header + assert req._get_content_length() is None + + # Valid Content-Length header + req.headers["Content-Length"] = "42" + assert req._get_content_length() == 42 + + # Invalid Content-Length header + req.headers["Content-Length"] = "invalid" + with pytest.raises(ValueError, match="Invalid Content-Length header: invalid"): + req._get_content_length() + + +async def test_write_bytes_with_content_length_limit( + loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock +) -> None: + """Test that write_bytes respects content_length limit for different body types.""" + # Test with bytes data + data = b"Hello World" + req = ClientRequest("post", URL("http://python.org/"), loop=loop) + + req.body = data + + writer = StreamWriter(protocol=conn.protocol, loop=loop) + # Use content_length=5 to truncate data + await req.write_bytes(writer, conn, 5) + + # Verify only the first 5 bytes were written + assert buf == b"Hello" + await req.close() + + +@pytest.mark.parametrize( + "data", + [ + [b"Part1", b"Part2", b"Part3"], + b"Part1Part2Part3", + ], +) +async def test_write_bytes_with_iterable_content_length_limit( + loop: asyncio.AbstractEventLoop, + buf: bytearray, + conn: mock.Mock, + data: Union[List[bytes], bytes], +) -> None: + """Test that write_bytes respects content_length limit for iterable data.""" + # Test with iterable data + req = ClientRequest("post", URL("http://python.org/"), loop=loop) + req.body = data + + writer = StreamWriter(protocol=conn.protocol, loop=loop) + # Use content_length=7 to truncate at the middle of Part2 + await req.write_bytes(writer, conn, 7) + assert len(buf) == 7 + assert buf == b"Part1Pa" + await req.close() + + +async def test_write_bytes_empty_iterable_with_content_length( + loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock +) -> None: + """Test that write_bytes handles empty iterable body with content_length.""" + req = ClientRequest("post", URL("http://python.org/"), loop=loop) + req.body = [] # Empty iterable + + writer = StreamWriter(protocol=conn.protocol, loop=loop) + # Use content_length=10 with empty body + await req.write_bytes(writer, conn, 10) + + # Verify nothing was written + assert len(buf) == 0 + await req.close() diff --git a/tests/test_payload.py b/tests/test_payload.py index 9594f808408..24dcbaeb819 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -1,13 +1,23 @@ import array +import io +import unittest.mock from io import StringIO -from typing import AsyncIterator, Iterator +from typing import AsyncIterator, Iterator, List, Optional, Union import pytest +from multidict import CIMultiDict from aiohttp import payload from aiohttp.abc import AbstractStreamWriter +@pytest.fixture(autouse=True) +def cleanup( + cleanup_payload_pending_file_closes: None, +) -> None: + """Ensure all pending file close operations complete during test teardown.""" + + @pytest.fixture def registry() -> Iterator[payload.PayloadRegistry]: old = payload.PAYLOAD_REGISTRY @@ -121,3 +131,325 @@ async def gen() -> AsyncIterator[bytes]: def test_async_iterable_payload_not_async_iterable() -> None: with pytest.raises(TypeError): payload.AsyncIterablePayload(object()) # type: ignore[arg-type] + + +class MockStreamWriter(AbstractStreamWriter): + """Mock stream writer for testing payload writes.""" + + def __init__(self) -> None: + self.written: List[bytes] = [] + + async def write( + self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + ) -> None: + """Store the chunk in the written list.""" + self.written.append(bytes(chunk)) + + async def write_eof(self, chunk: Optional[bytes] = None) -> None: + """write_eof implementation - no-op for tests.""" + + async def drain(self) -> None: + """Drain implementation - no-op for tests.""" + + def enable_compression( + self, encoding: str = "deflate", strategy: Optional[int] = None + ) -> None: + """Enable compression - no-op for tests.""" + + def enable_chunking(self) -> None: + """Enable chunking - no-op for tests.""" + + async def write_headers(self, status_line: str, headers: CIMultiDict[str]) -> None: + """Write headers - no-op for tests.""" + + def get_written_bytes(self) -> bytes: + """Return all written bytes as a single bytes object.""" + return b"".join(self.written) + + +async def test_bytes_payload_write_with_length_no_limit() -> None: + """Test BytesPayload writing with no content length limit.""" + data = b"0123456789" + p = payload.BytesPayload(data) + writer = MockStreamWriter() + + await p.write_with_length(writer, None) + assert writer.get_written_bytes() == data + assert len(writer.get_written_bytes()) == 10 + + +async def test_bytes_payload_write_with_length_exact() -> None: + """Test BytesPayload writing with exact content length.""" + data = b"0123456789" + p = payload.BytesPayload(data) + writer = MockStreamWriter() + + await p.write_with_length(writer, 10) + assert writer.get_written_bytes() == data + assert len(writer.get_written_bytes()) == 10 + + +async def test_bytes_payload_write_with_length_truncated() -> None: + """Test BytesPayload writing with truncated content length.""" + data = b"0123456789" + p = payload.BytesPayload(data) + writer = MockStreamWriter() + + await p.write_with_length(writer, 5) + assert writer.get_written_bytes() == b"01234" + assert len(writer.get_written_bytes()) == 5 + + +async def test_iobase_payload_write_with_length_no_limit() -> None: + """Test IOBasePayload writing with no content length limit.""" + data = b"0123456789" + p = payload.IOBasePayload(io.BytesIO(data)) + writer = MockStreamWriter() + + await p.write_with_length(writer, None) + assert writer.get_written_bytes() == data + assert len(writer.get_written_bytes()) == 10 + + +async def test_iobase_payload_write_with_length_exact() -> None: + """Test IOBasePayload writing with exact content length.""" + data = b"0123456789" + p = payload.IOBasePayload(io.BytesIO(data)) + writer = MockStreamWriter() + + await p.write_with_length(writer, 10) + assert writer.get_written_bytes() == data + assert len(writer.get_written_bytes()) == 10 + + +async def test_iobase_payload_write_with_length_truncated() -> None: + """Test IOBasePayload writing with truncated content length.""" + data = b"0123456789" + p = payload.IOBasePayload(io.BytesIO(data)) + writer = MockStreamWriter() + + await p.write_with_length(writer, 5) + assert writer.get_written_bytes() == b"01234" + assert len(writer.get_written_bytes()) == 5 + + +async def test_bytesio_payload_write_with_length_no_limit() -> None: + """Test BytesIOPayload writing with no content length limit.""" + data = b"0123456789" + p = payload.BytesIOPayload(io.BytesIO(data)) + writer = MockStreamWriter() + + await p.write_with_length(writer, None) + assert writer.get_written_bytes() == data + assert len(writer.get_written_bytes()) == 10 + + +async def test_bytesio_payload_write_with_length_exact() -> None: + """Test BytesIOPayload writing with exact content length.""" + data = b"0123456789" + p = payload.BytesIOPayload(io.BytesIO(data)) + writer = MockStreamWriter() + + await p.write_with_length(writer, 10) + assert writer.get_written_bytes() == data + assert len(writer.get_written_bytes()) == 10 + + +async def test_bytesio_payload_write_with_length_truncated() -> None: + """Test BytesIOPayload writing with truncated content length.""" + data = b"0123456789" + payload_bytesio = payload.BytesIOPayload(io.BytesIO(data)) + writer = MockStreamWriter() + + await payload_bytesio.write_with_length(writer, 5) + assert writer.get_written_bytes() == b"01234" + assert len(writer.get_written_bytes()) == 5 + + +async def test_bytesio_payload_write_with_length_remaining_zero() -> None: + """Test BytesIOPayload with content_length smaller than first read chunk.""" + data = b"0123456789" * 10 # 100 bytes + bio = io.BytesIO(data) + payload_bytesio = payload.BytesIOPayload(bio) + writer = MockStreamWriter() + + # Mock the read method to return smaller chunks + original_read = bio.read + read_calls = 0 + + def mock_read(size: Optional[int] = None) -> bytes: + nonlocal read_calls + read_calls += 1 + if read_calls == 1: + # First call: return 3 bytes (less than content_length=5) + return original_read(3) + else: + # Subsequent calls return remaining data normally + return original_read(size) + + with unittest.mock.patch.object(bio, "read", mock_read): + await payload_bytesio.write_with_length(writer, 5) + + assert len(writer.get_written_bytes()) == 5 + assert writer.get_written_bytes() == b"01234" + + +async def test_bytesio_payload_large_data_multiple_chunks() -> None: + """Test BytesIOPayload with large data requiring multiple read chunks.""" + chunk_size = 2**16 # 64KB (READ_SIZE) + data = b"x" * (chunk_size + 1000) # Slightly larger than READ_SIZE + payload_bytesio = payload.BytesIOPayload(io.BytesIO(data)) + writer = MockStreamWriter() + + await payload_bytesio.write_with_length(writer, None) + assert writer.get_written_bytes() == data + assert len(writer.get_written_bytes()) == chunk_size + 1000 + + +async def test_bytesio_payload_remaining_bytes_exhausted() -> None: + """Test BytesIOPayload when remaining_bytes becomes <= 0.""" + data = b"0123456789abcdef" * 1000 # 16000 bytes + payload_bytesio = payload.BytesIOPayload(io.BytesIO(data)) + writer = MockStreamWriter() + + await payload_bytesio.write_with_length(writer, 8000) # Exactly half the data + written = writer.get_written_bytes() + assert len(written) == 8000 + assert written == data[:8000] + + +async def test_iobase_payload_exact_chunk_size_limit() -> None: + """Test IOBasePayload with content length matching exactly one read chunk.""" + chunk_size = 2**16 # 65536 bytes (READ_SIZE) + data = b"x" * chunk_size + b"extra" # Slightly larger than one read chunk + p = payload.IOBasePayload(io.BytesIO(data)) + writer = MockStreamWriter() + + await p.write_with_length(writer, chunk_size) + written = writer.get_written_bytes() + assert len(written) == chunk_size + assert written == data[:chunk_size] + + +async def test_async_iterable_payload_write_with_length_no_limit() -> None: + """Test AsyncIterablePayload writing with no content length limit.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"0123" + yield b"4567" + yield b"89" + + p = payload.AsyncIterablePayload(gen()) + writer = MockStreamWriter() + + await p.write_with_length(writer, None) + assert writer.get_written_bytes() == b"0123456789" + assert len(writer.get_written_bytes()) == 10 + + +async def test_async_iterable_payload_write_with_length_exact() -> None: + """Test AsyncIterablePayload writing with exact content length.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"0123" + yield b"4567" + yield b"89" + + p = payload.AsyncIterablePayload(gen()) + writer = MockStreamWriter() + + await p.write_with_length(writer, 10) + assert writer.get_written_bytes() == b"0123456789" + assert len(writer.get_written_bytes()) == 10 + + +async def test_async_iterable_payload_write_with_length_truncated_mid_chunk() -> None: + """Test AsyncIterablePayload writing with content length truncating mid-chunk.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"0123" + yield b"4567" + yield b"89" # pragma: no cover + + p = payload.AsyncIterablePayload(gen()) + writer = MockStreamWriter() + + await p.write_with_length(writer, 6) + assert writer.get_written_bytes() == b"012345" + assert len(writer.get_written_bytes()) == 6 + + +async def test_async_iterable_payload_write_with_length_truncated_at_chunk() -> None: + """Test AsyncIterablePayload writing with content length truncating at chunk boundary.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"0123" + yield b"4567" # pragma: no cover + yield b"89" # pragma: no cover + + p = payload.AsyncIterablePayload(gen()) + writer = MockStreamWriter() + + await p.write_with_length(writer, 4) + assert writer.get_written_bytes() == b"0123" + assert len(writer.get_written_bytes()) == 4 + + +async def test_bytes_payload_backwards_compatibility() -> None: + """Test BytesPayload.write() backwards compatibility delegates to write_with_length().""" + p = payload.BytesPayload(b"1234567890") + writer = MockStreamWriter() + + await p.write(writer) + assert writer.get_written_bytes() == b"1234567890" + + +async def test_textio_payload_with_encoding() -> None: + """Test TextIOPayload reading with encoding and size constraints.""" + data = io.StringIO("hello world") + p = payload.TextIOPayload(data, encoding="utf-8") + writer = MockStreamWriter() + + await p.write_with_length(writer, 8) + # Should write exactly 8 bytes: "hello wo" + assert writer.get_written_bytes() == b"hello wo" + + +async def test_bytesio_payload_backwards_compatibility() -> None: + """Test BytesIOPayload.write() backwards compatibility delegates to write_with_length().""" + data = io.BytesIO(b"test data") + p = payload.BytesIOPayload(data) + writer = MockStreamWriter() + + await p.write(writer) + assert writer.get_written_bytes() == b"test data" + + +async def test_async_iterable_payload_backwards_compatibility() -> None: + """Test AsyncIterablePayload.write() backwards compatibility delegates to write_with_length().""" + + async def gen() -> AsyncIterator[bytes]: + yield b"chunk1" + yield b"chunk2" # pragma: no cover + + p = payload.AsyncIterablePayload(gen()) + writer = MockStreamWriter() + + await p.write(writer) + assert writer.get_written_bytes() == b"chunk1chunk2" + + +async def test_async_iterable_payload_with_none_iterator() -> None: + """Test AsyncIterablePayload with None iterator returns early without writing.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"test" # pragma: no cover + + p = payload.AsyncIterablePayload(gen()) + # Manually set _iter to None to test the guard clause + p._iter = None + writer = MockStreamWriter() + + # Should return early without writing anything + await p.write_with_length(writer, 10) + assert writer.get_written_bytes() == b"" From 6b3672f08aeb1575f84e81b4de8fb5a8caa07043 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 22 May 2025 10:46:05 -0500 Subject: [PATCH 4/6] Fix flakey test_normal_closure_while_client_sends_msg test (#10932) --- tests/test_web_websocket_functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 6bdd5808362..2ca95f71d9e 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1297,13 +1297,13 @@ async def handler(request: web.Request) -> web.WebSocketResponse: async def test_normal_closure_while_client_sends_msg( aiohttp_client: AiohttpClient, ) -> None: - """Test abnormal closure when the server closes and the client doesn't respond.""" + """Test normal closure when the server closes and the client responds properly.""" close_code: Optional[WSCloseCode] = None got_close_code = asyncio.Event() async def handler(request: web.Request) -> web.WebSocketResponse: - # Setting a short close timeout - ws = web.WebSocketResponse(timeout=0.2) + # Setting a longer close timeout to avoid race conditions + ws = web.WebSocketResponse(timeout=1.0) await ws.prepare(request) await ws.close() From 597161d1127d04e7842642d7548d4bd8de3115a6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 22 May 2025 10:53:03 -0500 Subject: [PATCH 5/6] Fix flakey client functional keep alive tests (#10933) --- tests/test_client_functional.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 1eb0fd28c9f..269248bd876 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -260,7 +260,7 @@ async def handler(request: web.Request) -> web.Response: assert 0 == len(client._session.connector._conns) -async def test_keepalive_timeout_async_sleep() -> None: +async def test_keepalive_timeout_async_sleep(unused_port_socket: socket.socket) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body @@ -272,17 +272,18 @@ async def handler(request: web.Request) -> web.Response: runner = web.AppRunner(app, tcp_keepalive=True, keepalive_timeout=0.001) await runner.setup() - port = unused_port() - site = web.TCPSite(runner, host="localhost", port=port) + site = web.SockSite(runner, unused_port_socket) await site.start() + host, port = unused_port_socket.getsockname()[:2] + try: async with aiohttp.ClientSession() as sess: - resp1 = await sess.get(f"http://localhost:{port}/") + resp1 = await sess.get(f"http://{host}:{port}/") await resp1.read() # wait for server keepalive_timeout await asyncio.sleep(0.01) - resp2 = await sess.get(f"http://localhost:{port}/") + resp2 = await sess.get(f"http://{host}:{port}/") await resp2.read() finally: await asyncio.gather(runner.shutdown(), site.stop()) @@ -292,7 +293,7 @@ async def handler(request: web.Request) -> web.Response: sys.version_info[:2] == (3, 11), reason="https://github.com/pytest-dev/pytest/issues/10763", ) -async def test_keepalive_timeout_sync_sleep() -> None: +async def test_keepalive_timeout_sync_sleep(unused_port_socket: socket.socket) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body @@ -304,18 +305,19 @@ async def handler(request: web.Request) -> web.Response: runner = web.AppRunner(app, tcp_keepalive=True, keepalive_timeout=0.001) await runner.setup() - port = unused_port() - site = web.TCPSite(runner, host="localhost", port=port) + site = web.SockSite(runner, unused_port_socket) await site.start() + host, port = unused_port_socket.getsockname()[:2] + try: async with aiohttp.ClientSession() as sess: - resp1 = await sess.get(f"http://localhost:{port}/") + resp1 = await sess.get(f"http://{host}:{port}/") await resp1.read() # wait for server keepalive_timeout # time.sleep is a more challenging scenario than asyncio.sleep time.sleep(0.01) - resp2 = await sess.get(f"http://localhost:{port}/") + resp2 = await sess.get(f"http://{host}:{port}/") await resp2.read() finally: await asyncio.gather(runner.shutdown(), site.stop()) From 77c0115eb878ad797cc6ac127001fdaa7f3efc40 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 22 May 2025 11:54:42 -0500 Subject: [PATCH 6/6] Fix flakey test_content_length_limit_with_multiple_reads test (#10938) --- tests/test_client_functional.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 269248bd876..2123b48ae14 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -4293,7 +4293,10 @@ async def data_generator() -> AsyncIterator[bytes]: headers = {"Content-Length": "800"} async with aiohttp.ClientSession() as session: - await session.post(server.make_url("/"), data=data_generator(), headers=headers) + async with session.post( + server.make_url("/"), data=data_generator(), headers=headers + ) as resp: + await resp.read() # Ensure response is fully read and connection cleaned up # Verify only 800 bytes (not the full 1200) were received by the server assert len(received_data) == 800