diff --git a/CHANGES/10474.feature.rst b/CHANGES/10474.feature.rst deleted file mode 100644 index d5d6e4b40b9..00000000000 --- a/CHANGES/10474.feature.rst +++ /dev/null @@ -1,2 +0,0 @@ -Added ``tcp_sockopts`` to ``TCPConnector`` to allow specifying custom socket options --- by :user:`TimMenninger`. diff --git a/CHANGES/10520.feature.rst b/CHANGES/10520.feature.rst new file mode 100644 index 00000000000..3d2877b5c09 --- /dev/null +++ b/CHANGES/10520.feature.rst @@ -0,0 +1,2 @@ +Added ``socket_factory`` to :py:class:`aiohttp.TCPConnector` to allow specifying custom socket options +-- by :user:`TimMenninger`. diff --git a/CHANGES/10529.bugfix.rst b/CHANGES/10529.bugfix.rst new file mode 100644 index 00000000000..d6714ffd043 --- /dev/null +++ b/CHANGES/10529.bugfix.rst @@ -0,0 +1,2 @@ +Fixed an issue where dns queries were delayed indefinitely when an exception occurred in a ``trace.send_dns_cache_miss`` +-- by :user:`logioniz`. diff --git a/CHANGES/10551.bugfix.rst b/CHANGES/10551.bugfix.rst new file mode 100644 index 00000000000..8f3eb24d6ae --- /dev/null +++ b/CHANGES/10551.bugfix.rst @@ -0,0 +1 @@ +The connector now raises :exc:`aiohttp.ClientConnectionError` instead of :exc:`OSError` when failing to explicitly close the socket after :py:meth:`asyncio.loop.create_connection` fails -- by :user:`bdraco`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 3004ee5cd18..e3ddd3e3d6a 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -31,6 +31,7 @@ Alexandru Mihai Alexey Firsov Alexey Nikitin Alexey Popravka +Alexey Stavrov Alexey Stepanov Almaz Salakhov Amin Etesamian diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index f7864247791..7759a997cb9 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -47,6 +47,7 @@ WSServerHandshakeError, request, ) +from .connector import AddrInfoType, SocketFactoryType from .cookiejar import CookieJar, DummyCookieJar from .formdata import FormData from .helpers import BasicAuth, ChainMapProxy, ETag @@ -112,6 +113,7 @@ __all__: Tuple[str, ...] = ( "hdrs", # client + "AddrInfoType", "BaseConnector", "ClientConnectionError", "ClientConnectionResetError", @@ -146,6 +148,7 @@ "ServerDisconnectedError", "ServerFingerprintMismatch", "ServerTimeoutError", + "SocketFactoryType", "SocketTimeoutError", "TCPConnector", "TooManyRedirects", diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 8a3f1bcbf2b..b61a4da33a3 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -20,7 +20,6 @@ DefaultDict, Deque, Dict, - Iterable, Iterator, List, Literal, @@ -34,6 +33,7 @@ ) import aiohappyeyeballs +from aiohappyeyeballs import AddrInfoType, SocketFactoryType from . import hdrs, helpers from .abc import AbstractResolver, ResolveResult @@ -96,7 +96,14 @@ # which first appeared in Python 3.12.7 and 3.13.1 -__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector") +__all__ = ( + "BaseConnector", + "TCPConnector", + "UnixConnector", + "NamedPipeConnector", + "AddrInfoType", + "SocketFactoryType", +) if TYPE_CHECKING: @@ -826,8 +833,9 @@ class TCPConnector(BaseConnector): the happy eyeballs algorithm, set to None. interleave - “First Address Family Count” as defined in RFC 8305 loop - Optional event loop. - tcp_sockopts - List of tuples of sockopts applied to underlying - socket + socket_factory - A SocketFactoryType function that, if supplied, + will be used to create sockets given an + AddrInfoType. """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) @@ -849,7 +857,7 @@ def __init__( timeout_ceil_threshold: float = 5, happy_eyeballs_delay: Optional[float] = 0.25, interleave: Optional[int] = None, - tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [], + socket_factory: Optional[SocketFactoryType] = None, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -880,7 +888,7 @@ def __init__( self._happy_eyeballs_delay = happy_eyeballs_delay self._interleave = interleave self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() - self._tcp_sockopts = tcp_sockopts + self._socket_factory = socket_factory def _close_immediately(self) -> List[Awaitable[object]]: for fut in chain.from_iterable(self._throttle_dns_futures.values()): @@ -1018,11 +1026,11 @@ async def _resolve_host_with_throttle( This method must be run in a task and shielded from cancellation to avoid cancelling the underlying lookup. """ - if traces: - for trace in traces: - await trace.send_dns_cache_miss(host) try: if traces: + for trace in traces: + await trace.send_dns_cache_miss(host) + for trace in traces: await trace.send_dns_resolvehost_start(host) @@ -1105,7 +1113,7 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: async def _wrap_create_connection( self, *args: Any, - addr_infos: List[aiohappyeyeballs.AddrInfoType], + addr_infos: List[AddrInfoType], req: ClientRequest, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, @@ -1122,9 +1130,8 @@ async def _wrap_create_connection( happy_eyeballs_delay=self._happy_eyeballs_delay, interleave=self._interleave, loop=self._loop, + socket_factory=self._socket_factory, ) - for sockopt in self._tcp_sockopts: - sock.setsockopt(*sockopt) connection = await self._loop.create_connection( *args, **kwargs, sock=sock ) @@ -1143,7 +1150,10 @@ async def _wrap_create_connection( # Will be hit if an exception is thrown before the event loop takes the socket. # In that case, proactively close the socket to guard against event loop leaks. # For example, see https://github.com/MagicStack/uvloop/issues/653. - sock.close() + try: + sock.close() + except OSError as exc: + raise client_error(req.connection_key, exc) from exc def _warn_about_tls_in_tls( self, @@ -1256,13 +1266,13 @@ async def _start_tls_connection( def _convert_hosts_to_addr_infos( self, hosts: List[ResolveResult] - ) -> List[aiohappyeyeballs.AddrInfoType]: + ) -> List[AddrInfoType]: """Converts the list of hosts to a list of addr_infos. The list of hosts is the result of a DNS lookup. The list of addr_infos is the result of a call to `socket.getaddrinfo()`. """ - addr_infos: List[aiohappyeyeballs.AddrInfoType] = [] + addr_infos: List[AddrInfoType] = [] for hinfo in hosts: host = hinfo["host"] is_ipv6 = ":" in host diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 8f34fefaf81..4b0a878d715 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -468,19 +468,26 @@ If your HTTP server uses UNIX domain sockets you can use session = aiohttp.ClientSession(connector=conn) -Setting socket options +Custom socket creation ^^^^^^^^^^^^^^^^^^^^^^ -Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed -to the underlying socket when creating a connection. For example, we may -want to change the conditions under which we consider a connection dead. -The following would change that to 9*7200 = 18 hours:: +If the default socket is insufficient for your use case, pass an optional +`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements +`SocketFactoryType`. This will be used to create all sockets for the +lifetime of the class object. For example, we may want to change the +conditions under which we consider a connection dead. The following would +make all sockets respect 9*7200 = 18 hours:: import socket - conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True), - (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200), - (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ]) + def socket_factory(addr_info): + family, type_, proto, _, _, _ = addr_info + sock = socket.socket(family=family, type=type_, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) + return sock + conn = aiohttp.TCPConnector(socket_factory=socket_factory) Named pipes in Windows diff --git a/docs/client_reference.rst b/docs/client_reference.rst index e1128934631..7dabfe1a6db 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1122,6 +1122,34 @@ is controlled by *force_close* constructor's parameter). overridden in subclasses. +.. autodata:: AddrInfoType + +.. note:: + + Refer to :py:data:`aiohappyeyeballs.AddrInfoType` for more info. + +.. warning:: + + Be sure to use ``aiohttp.AddrInfoType`` rather than + ``aiohappyeyeballs.AddrInfoType`` to avoid import breakage, as + it is likely to be removed from ``aiohappyeyeballs`` in the + future. + + +.. autodata:: SocketFactoryType + +.. note:: + + Refer to :py:data:`aiohappyeyeballs.SocketFactoryType` for more info. + +.. warning:: + + Be sure to use ``aiohttp.SocketFactoryType`` rather than + ``aiohappyeyeballs.SocketFactoryType`` to avoid import breakage, + as it is likely to be removed from ``aiohappyeyeballs`` in the + future. + + .. class:: TCPConnector(*, ssl=True, verify_ssl=True, fingerprint=None, \ use_dns_cache=True, ttl_dns_cache=10, \ family=0, ssl_context=None, local_addr=None, \ @@ -1129,7 +1157,7 @@ is controlled by *force_close* constructor's parameter). force_close=False, limit=100, limit_per_host=0, \ enable_cleanup_closed=False, timeout_ceil_threshold=5, \ happy_eyeballs_delay=0.25, interleave=None, loop=None, \ - tcp_sockopts=[]) + socket_factory=None) Connector for working with *HTTP* and *HTTPS* via *TCP* sockets. @@ -1250,9 +1278,9 @@ is controlled by *force_close* constructor's parameter). .. versionadded:: 3.10 - :param list tcp_sockopts: options applied to the socket when a connection is - created. This should be a list of 3-tuples, each a ``(level, optname, value)``. - Each tuple is deconstructed and passed verbatim to ``.setsockopt``. + :param :py:data:``SocketFactoryType`` socket_factory: This function takes an + :py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when + creating TCP connections. .. versionadded:: 3.12 diff --git a/docs/conf.py b/docs/conf.py index 2deabea1b4f..eba93188b44 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,6 +53,7 @@ # ones. extensions = [ # stdlib-party extensions: + "sphinx.ext.autodoc", "sphinx.ext.extlinks", "sphinx.ext.graphviz", "sphinx.ext.intersphinx", @@ -82,6 +83,7 @@ "aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None), "aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None), "aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None), + "aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None), } # Add any paths that contain templates here, relative to this directory. @@ -425,6 +427,7 @@ ("py:class", "cgi.FieldStorage"), # undocumented ("py:meth", "aiohttp.web.UrlDispatcher.register_resource"), # undocumented ("py:func", "aiohttp_debugtoolbar.setup"), # undocumented + ("py:class", "socket.SocketKind"), # undocumented ] # -- Options for towncrier_draft extension ----------------------------------- diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index 48ee7016d13..4dcf7a1dea3 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -1,7 +1,7 @@ # Extracted from `setup.cfg` via `make sync-direct-runtime-deps` aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin" -aiohappyeyeballs >= 2.3.0 +aiohappyeyeballs >= 2.5.0 aiosignal >= 1.1.2 async-timeout >= 4.0, < 6.0 ; python_version < "3.11" Brotli; platform_python_implementation == 'CPython' diff --git a/setup.cfg b/setup.cfg index 66b779b8db9..674d9ed7c44 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ zip_safe = False include_package_data = True install_requires = - aiohappyeyeballs >= 2.3.0 + aiohappyeyeballs >= 2.5.0 aiosignal >= 1.1.2 async-timeout >= 4.0, < 6.0 ; python_version < "3.11" frozenlist >= 1.1.1 diff --git a/tests/test_connector.py b/tests/test_connector.py index 076ed556971..aac122c7119 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -26,7 +26,6 @@ from unittest import mock import pytest -from aiohappyeyeballs import AddrInfoType from pytest_mock import MockerFixture from yarl import URL @@ -44,6 +43,7 @@ from aiohttp.connector import ( _SSL_CONTEXT_UNVERIFIED, _SSL_CONTEXT_VERIFIED, + AddrInfoType, Connection, TCPConnector, _DNSCacheTable, @@ -669,6 +669,33 @@ async def test_tcp_connector_closes_socket_on_error( await conn.close() +async def test_tcp_connector_closes_socket_on_error_results_in_another_error( + loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock +) -> None: + """Test that when error occurs while closing the socket.""" + req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + start_connection.return_value.close.side_effect = OSError( + 1, "error from closing socket" + ) + + conn = aiohttp.TCPConnector() + with ( + mock.patch.object( + conn._loop, + "create_connection", + autospec=True, + spec_set=True, + side_effect=ValueError, + ), + pytest.raises(aiohttp.ClientConnectionError, match="error from closing socket"), + ): + await conn.connect(req, [], ClientTimeout()) + + assert start_connection.return_value.close.called + + await conn.close() + + async def test_tcp_connector_server_hostname_default( loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock ) -> None: @@ -3683,6 +3710,61 @@ async def send_dns_cache_hit(self, *args: object, **kwargs: object) -> None: await connector.close() +async def test_connector_resolve_in_case_of_trace_cache_miss_exception( + loop: asyncio.AbstractEventLoop, +) -> None: + token: ResolveResult = { + "hostname": "localhost", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + + request_count = 0 + + class DummyTracer(Trace): + def __init__(self) -> None: + """Dummy""" + + async def send_dns_cache_hit(self, *args: object, **kwargs: object) -> None: + """Dummy send_dns_cache_hit""" + + async def send_dns_resolvehost_start( + self, *args: object, **kwargs: object + ) -> None: + """Dummy send_dns_resolvehost_start""" + + async def send_dns_resolvehost_end( + self, *args: object, **kwargs: object + ) -> None: + """Dummy send_dns_resolvehost_end""" + + async def send_dns_cache_miss(self, *args: object, **kwargs: object) -> None: + nonlocal request_count + request_count += 1 + if request_count <= 1: + raise Exception("first attempt") + + async def resolve_response() -> List[ResolveResult]: + await asyncio.sleep(0) + return [token] + + with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: + m_resolver().resolve.return_value = resolve_response() + + connector = TCPConnector() + traces = [DummyTracer()] + + with pytest.raises(Exception): + await connector._resolve_host("", 0, traces) + + await connector._resolve_host("", 0, traces) == [token] + + await connector.close() + + async def test_connector_does_not_remove_needed_waiters( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: @@ -3767,27 +3849,48 @@ def test_connect() -> Literal[True]: assert raw_response_list == [True, True] -async def test_tcp_connector_setsockopts( +async def test_tcp_connector_socket_factory( loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock ) -> None: - """Check that sockopts get passed to socket""" - conn = aiohttp.TCPConnector( - tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)] - ) - - with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True - ) as create_connection: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - start_connection.return_value = s - create_connection.return_value = mock.Mock(), mock.Mock() + """Check that socket factory is called""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + start_connection.return_value = s - req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + local_addr = None + socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s + happy_eyeballs_delay = 0.123 + interleave = 3 + conn = aiohttp.TCPConnector( + interleave=interleave, + local_addr=local_addr, + happy_eyeballs_delay=happy_eyeballs_delay, + socket_factory=socket_factory, + ) + with mock.patch.object( + conn._loop, + "create_connection", + autospec=True, + spec_set=True, + return_value=(mock.Mock(), mock.Mock()), + ): + host = "127.0.0.1" + port = 443 + req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2 - - await conn.close() + pass + await conn.close() + + start_connection.assert_called_with( + addr_infos=[ + (socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port)) + ], + local_addr_infos=local_addr, + happy_eyeballs_delay=happy_eyeballs_delay, + interleave=interleave, + loop=loop, + socket_factory=socket_factory, + ) def test_default_ssl_context_creation_without_ssl() -> None: