diff --git a/doc/changelog.rst b/doc/changelog.rst index ebbc72047b..585ea045c8 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,14 @@ Changelog ========= +Changes in Version 4.18.0 +------------------------- + +- Improved TLS connection performance by reusing TLS sessions across connections + to the same server, avoiding a full handshake on each new connection. + Session resumption is supported on all Python versions for synchronous clients + and on Python 3.11+ for async clients. + Changes in Version 4.17.0 (2026/04/20) -------------------------------------- diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 475f4bfa99..e8bea0980a 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -754,6 +754,9 @@ def __init__( self._pending = 0 self._max_connecting = self.opts.max_connecting self._client_id = client_id + self._ssl_session_cache: Optional[list[Any]] = ( + [None] if self.opts._ssl_context is not None else None + ) # Log before publishing event to prevent potential listener preemption in tests if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( @@ -1040,7 +1043,9 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A ) try: - networking_interface = await _configured_protocol_interface(self.address, self.opts) + networking_interface = await _configured_protocol_interface( + self.address, self.opts, self._ssl_session_cache + ) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: async with self.lock: diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index a6f434885b..6a98141067 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -46,6 +46,14 @@ from pymongo.pyopenssl_context import _sslConn from pymongo.typings import _Address + +def _get_ssl_session(ssl_sock: Any) -> Optional[Any]: + """Return the TLS session from an SSL socket, handling both PyOpenSSL and stdlib ssl.""" + if hasattr(ssl_sock, "get_session"): + return ssl_sock.get_session() + return getattr(ssl_sock, "session", None) + + try: from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl @@ -298,7 +306,9 @@ async def _async_configured_socket( async def _configured_protocol_interface( - address: _Address, options: PoolOptions + address: _Address, + options: PoolOptions, + ssl_session_cache: Optional[list[Any]] = None, ) -> AsyncNetworkingInterface: """Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface. @@ -318,6 +328,22 @@ async def _configured_protocol_interface( ) host = address[0] + # asyncio does not support TLS session resumption natively (cpython#79152, + # closed without a fix). On Python 3.11+ SSLProtocol.__init__ calls + # wrap_bio() synchronously before the first event-loop yield, so setting + # sslobject_class is race-free. Session injection is skipped on older + # Python versions. (The async path always uses stdlib ssl, never PyOpenSSL.) + if ssl_session_cache is not None and sys.version_info >= (3, 11): + session = ssl_session_cache[0] + if session is not None: + _session = session + + class _SessionSSLObject(ssl.SSLObject): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.session = _session + + ssl_context.sslobject_class = _SessionSSLObject # type: ignore[attr-defined] try: # We have to pass hostname / ip address to wrap_socket # to use SSLContext.check_hostname. @@ -337,6 +363,7 @@ async def _configured_protocol_interface( # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( ssl_context.verify_mode and not ssl_context.check_hostname @@ -348,6 +375,13 @@ async def _configured_protocol_interface( transport.abort() raise + if ssl_session_cache is not None: + ssl_obj = transport.get_extra_info("ssl_object") + if ssl_obj is not None: + new_session = ssl_obj.session + if new_session is not None: + ssl_session_cache[0] = new_session + return AsyncNetworkingInterface((transport, protocol)) @@ -470,7 +504,11 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket. return ssl_sock -def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface: +def _configured_socket_interface( + address: _Address, + options: PoolOptions, + ssl_session_cache: Optional[list[Any]] = None, +) -> NetworkingInterface: """Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket. Can raise socket.error, ConnectionFailure, or _CertificateError. @@ -485,13 +523,14 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net return NetworkingInterface(sock) host = address[0] + session = ssl_session_cache[0] if ssl_session_cache is not None else None try: # We have to pass hostname / ip address to wrap_socket # to use SSLContext.check_hostname. if _has_sni(True): - ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host, session=session) else: - ssl_sock = ssl_context.wrap_socket(sock) + ssl_sock = ssl_context.wrap_socket(sock, session=session) except _CertificateError: sock.close() # Raise _CertificateError directly like we do after match_hostname @@ -515,5 +554,10 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net ssl_sock.close() raise + if ssl_session_cache is not None: + new_session = _get_ssl_session(ssl_sock) + if new_session is not None: + ssl_session_cache[0] = new_session + ssl_sock.settimeout(options.socket_timeout) return NetworkingInterface(ssl_sock) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 938eca42bd..467faaf20c 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -752,6 +752,9 @@ def __init__( self._pending = 0 self._max_connecting = self.opts.max_connecting self._client_id = client_id + self._ssl_session_cache: Optional[list[Any]] = ( + [None] if self.opts._ssl_context is not None else None + ) # Log before publishing event to prevent potential listener preemption in tests if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( @@ -1036,7 +1039,9 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect ) try: - networking_interface = _configured_socket_interface(self.address, self.opts) + networking_interface = _configured_socket_interface( + self.address, self.opts, self._ssl_session_cache + ) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: with self.lock: diff --git a/test/asynchronous/test_ssl.py b/test/asynchronous/test_ssl.py index 7fe57e8503..5a50f54729 100644 --- a/test/asynchronous/test_ssl.py +++ b/test/asynchronous/test_ssl.py @@ -128,6 +128,298 @@ def test_config_ssl(self): def test_use_pyopenssl_when_available(self): self.assertTrue(HAVE_PYSSL) + def test_ssl_session_cache(self): + cache: list = [None] + self.assertIsNone(cache[0]) + cache[0] = "session" + self.assertEqual(cache[0], "session") + cache[0] = "new_session" + self.assertEqual(cache[0], "new_session") + + @unittest.skipUnless(_IS_SYNC, "Tests sync wrap_socket path only") + def test_tls_session_reused_on_second_connection(self): + """Cached TLS session is passed to wrap_socket on subsequent connections.""" + import unittest.mock as mock + + from pymongo.pool_shared import _configured_socket_interface + + fake_session = object() + cache: list = [fake_session] + + fake_ssl_sock = mock.MagicMock() + fake_ssl_sock.getpeercert.return_value = {} + + mock_ssl_context = mock.MagicMock() + mock_ssl_context.wrap_socket.return_value = fake_ssl_sock + mock_ssl_context.verify_mode = False + mock_ssl_context.check_hostname = False + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = mock_ssl_context + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + with mock.patch("pymongo.pool_shared._create_connection") as mock_create: + mock_create.return_value = mock.MagicMock() + _configured_socket_interface(("localhost", 27017), mock_opts, cache) + + mock_ssl_context.wrap_socket.assert_called_once() + _, kwargs = mock_ssl_context.wrap_socket.call_args + self.assertIs(kwargs.get("session"), fake_session) + + def test_get_ssl_session_pyopenssl_style(self): + """_get_ssl_session uses get_session() when available (PyOpenSSL path).""" + import unittest.mock as mock + + from pymongo.pool_shared import _get_ssl_session + + fake_session = object() + conn = mock.MagicMock() + conn.get_session.return_value = fake_session + self.assertIs(_get_ssl_session(conn), fake_session) + conn.get_session.assert_called_once() + + def test_get_ssl_session_stdlib_style(self): + """_get_ssl_session falls back to .session attribute (stdlib ssl path).""" + from pymongo.pool_shared import _get_ssl_session + + fake_session = object() + + class FakeSSLSock: + session = fake_session + + self.assertIs(_get_ssl_session(FakeSSLSock()), fake_session) + + @unittest.skipUnless( + not _IS_SYNC and sys.version_info >= (3, 11), + "Tests async sslobject_class injection (Python 3.11+ only)", + ) + def test_async_tls_session_injected_via_sslobject_class(self): + """On Python 3.11+, a cached session is injected by setting sslobject_class.""" + import ssl + + fake_session = object() + cache: list = [fake_session] + + real_ctx = ssl.create_default_context() + self.assertIs(real_ctx.sslobject_class, ssl.SSLObject) + + # Simulate what _configured_protocol_interface does + session = cache[0] + assert session is not None + _session = session + + class _SessionSSLObject(ssl.SSLObject): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.session = _session + + real_ctx.sslobject_class = _SessionSSLObject + + self.assertIs(real_ctx.sslobject_class, _SessionSSLObject) + self.assertTrue(issubclass(real_ctx.sslobject_class, ssl.SSLObject)) + + @unittest.skipUnless(not _IS_SYNC, "Tests async _configured_protocol_interface only") + def test_async_configured_protocol_saves_session_to_cache(self): + """After a successful TLS connection the session is stored in the cache.""" + import asyncio + import ssl + import unittest.mock as mock + + from pymongo.pool_shared import _configured_protocol_interface + + fake_session = object() + cache: list = [None] + + mock_ssl_obj = mock.MagicMock() + mock_ssl_obj.session = fake_session + + mock_transport = mock.MagicMock() + mock_transport.get_extra_info.side_effect = lambda key: ( + mock_ssl_obj if key == "ssl_object" else None + ) + + real_ctx = ssl.create_default_context() + real_ctx.check_hostname = False + real_ctx.verify_mode = ssl.CERT_NONE + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = real_ctx + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + mock_loop = mock.MagicMock() + mock_loop.create_connection = mock.AsyncMock( + return_value=(mock_transport, mock.MagicMock()) + ) + + async def run(): + with ( + mock.patch( + "pymongo.pool_shared._async_create_connection", + new=mock.AsyncMock(return_value=mock.MagicMock()), + ), + mock.patch( + "pymongo.pool_shared.asyncio.get_running_loop", + return_value=mock_loop, + ), + ): + await _configured_protocol_interface(("localhost", 27017), mock_opts, cache) + + asyncio.run(run()) + self.assertIs(cache[0], fake_session) + + @unittest.skipUnless( + not _IS_SYNC and sys.version_info >= (3, 11), + "Tests async session injection on Python 3.11+", + ) + def test_async_configured_protocol_injects_session_via_sslobject_class(self): + """When the cache has a session, sslobject_class is set and its __init__ body runs.""" + import asyncio + import ssl + import unittest.mock as mock + + from pymongo.pool_shared import _configured_protocol_interface + + initial_session = object() + cache: list = [initial_session] + + mock_transport = mock.MagicMock() + mock_transport.get_extra_info.return_value = None # no ssl_object → save block skips + + real_ctx = ssl.create_default_context() + real_ctx.check_hostname = False + real_ctx.verify_mode = ssl.CERT_NONE + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = real_ctx + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + mock_loop = mock.MagicMock() + mock_loop.create_connection = mock.AsyncMock( + return_value=(mock_transport, mock.MagicMock()) + ) + + async def run(): + with ( + mock.patch( + "pymongo.pool_shared._async_create_connection", + new=mock.AsyncMock(return_value=mock.MagicMock()), + ), + mock.patch( + "pymongo.pool_shared.asyncio.get_running_loop", + return_value=mock_loop, + ), + ): + await _configured_protocol_interface(("localhost", 27017), mock_opts, cache) + + asyncio.run(run()) + + session_cls = real_ctx.sslobject_class # type: ignore[attr-defined] + self.assertIsNot(session_cls, ssl.SSLObject) + self.assertTrue(issubclass(session_cls, ssl.SSLObject)) + + # Exercise the __init__ body (super().__init__ + self.session = _session) by + # calling wrap_bio, patching the session setter to accept non-SSLSession objects. + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + no_op_session = property(lambda s: None, lambda s, v: None) + with mock.patch.object(ssl.SSLObject, "session", no_op_session): + ssl_obj = real_ctx.wrap_bio(incoming, outgoing, server_side=False) + self.assertIsInstance(ssl_obj, ssl.SSLObject) + + @unittest.skipUnless(not _IS_SYNC, "Tests async _configured_protocol_interface only") + def test_async_configured_protocol_no_cache(self): + """When ssl_session_cache is None, no injection or save occurs.""" + import asyncio + import ssl + import unittest.mock as mock + + from pymongo.pool_shared import _configured_protocol_interface + + real_ctx = ssl.create_default_context() + real_ctx.check_hostname = False + real_ctx.verify_mode = ssl.CERT_NONE + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = real_ctx + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + mock_transport = mock.MagicMock() + mock_transport.get_extra_info.return_value = None + + mock_loop = mock.MagicMock() + mock_loop.create_connection = mock.AsyncMock( + return_value=(mock_transport, mock.MagicMock()) + ) + + async def run(): + with ( + mock.patch( + "pymongo.pool_shared._async_create_connection", + new=mock.AsyncMock(return_value=mock.MagicMock()), + ), + mock.patch( + "pymongo.pool_shared.asyncio.get_running_loop", + return_value=mock_loop, + ), + ): + await _configured_protocol_interface(("localhost", 27017), mock_opts, None) + + asyncio.run(run()) + self.assertIs(real_ctx.sslobject_class, ssl.SSLObject) # type: ignore[attr-defined] + + @unittest.skipUnless(not _IS_SYNC, "Tests async _configured_protocol_interface only") + def test_async_configured_protocol_new_session_is_none(self): + """When ssl_object.session is None after connect, the cache is not updated.""" + import asyncio + import ssl + import unittest.mock as mock + + from pymongo.pool_shared import _configured_protocol_interface + + cache: list = [None] + + mock_ssl_obj = mock.MagicMock() + mock_ssl_obj.session = None + + mock_transport = mock.MagicMock() + mock_transport.get_extra_info.side_effect = lambda key: ( + mock_ssl_obj if key == "ssl_object" else None + ) + + real_ctx = ssl.create_default_context() + real_ctx.check_hostname = False + real_ctx.verify_mode = ssl.CERT_NONE + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = real_ctx + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + mock_loop = mock.MagicMock() + mock_loop.create_connection = mock.AsyncMock( + return_value=(mock_transport, mock.MagicMock()) + ) + + async def run(): + with ( + mock.patch( + "pymongo.pool_shared._async_create_connection", + new=mock.AsyncMock(return_value=mock.MagicMock()), + ), + mock.patch( + "pymongo.pool_shared.asyncio.get_running_loop", + return_value=mock_loop, + ), + ): + await _configured_protocol_interface(("localhost", 27017), mock_opts, cache) + + asyncio.run(run()) + self.assertIsNone(cache[0]) + class TestSSL(AsyncIntegrationTest): saved_port: int @@ -673,6 +965,22 @@ async def test_pyopenssl_ignored_in_async(self): await client.admin.command("ping") # command doesn't matter, just needs it to connect await client.close() + @async_client_context.require_tls + async def test_pool_has_ssl_session_cache(self): + pool = list(self.client._topology._servers.values())[0].pool + self.assertIsInstance(pool._ssl_session_cache, list) + + @async_client_context.require_tls + @unittest.skipUnless( + _IS_SYNC and _HAVE_PYOPENSSL, + "Sync stdlib ssl may return None for session on TLS 1.3; test limited to PyOpenSSL", + ) + async def test_tls_session_cached_after_connect(self): + await self.client.admin.command("ping") + pool = list(self.client._topology._servers.values())[0].pool + self.assertIsNotNone(pool._ssl_session_cache) + self.assertIsNotNone(pool._ssl_session_cache[0]) + if __name__ == "__main__": unittest.main() diff --git a/test/test_ssl.py b/test/test_ssl.py index 77bb086ecb..f5070a7e39 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -128,6 +128,290 @@ def test_config_ssl(self): def test_use_pyopenssl_when_available(self): self.assertTrue(HAVE_PYSSL) + def test_ssl_session_cache(self): + cache: list = [None] + self.assertIsNone(cache[0]) + cache[0] = "session" + self.assertEqual(cache[0], "session") + cache[0] = "new_session" + self.assertEqual(cache[0], "new_session") + + @unittest.skipUnless(_IS_SYNC, "Tests sync wrap_socket path only") + def test_tls_session_reused_on_second_connection(self): + """Cached TLS session is passed to wrap_socket on subsequent connections.""" + import unittest.mock as mock + + from pymongo.pool_shared import _configured_socket_interface + + fake_session = object() + cache: list = [fake_session] + + fake_ssl_sock = mock.MagicMock() + fake_ssl_sock.getpeercert.return_value = {} + + mock_ssl_context = mock.MagicMock() + mock_ssl_context.wrap_socket.return_value = fake_ssl_sock + mock_ssl_context.verify_mode = False + mock_ssl_context.check_hostname = False + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = mock_ssl_context + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + with mock.patch("pymongo.pool_shared._create_connection") as mock_create: + mock_create.return_value = mock.MagicMock() + _configured_socket_interface(("localhost", 27017), mock_opts, cache) + + mock_ssl_context.wrap_socket.assert_called_once() + _, kwargs = mock_ssl_context.wrap_socket.call_args + self.assertIs(kwargs.get("session"), fake_session) + + def test_get_ssl_session_pyopenssl_style(self): + """_get_ssl_session uses get_session() when available (PyOpenSSL path).""" + import unittest.mock as mock + + from pymongo.pool_shared import _get_ssl_session + + fake_session = object() + conn = mock.MagicMock() + conn.get_session.return_value = fake_session + self.assertIs(_get_ssl_session(conn), fake_session) + conn.get_session.assert_called_once() + + def test_get_ssl_session_stdlib_style(self): + """_get_ssl_session falls back to .session attribute (stdlib ssl path).""" + from pymongo.pool_shared import _get_ssl_session + + fake_session = object() + + class FakeSSLSock: + session = fake_session + + self.assertIs(_get_ssl_session(FakeSSLSock()), fake_session) + + @unittest.skipUnless( + not _IS_SYNC and sys.version_info >= (3, 11), + "Tests async sslobject_class injection (Python 3.11+ only)", + ) + def test_async_tls_session_injected_via_sslobject_class(self): + """On Python 3.11+, a cached session is injected by setting sslobject_class.""" + import ssl + + fake_session = object() + cache: list = [fake_session] + + real_ctx = ssl.create_default_context() + self.assertIs(real_ctx.sslobject_class, ssl.SSLObject) + + # Simulate what _configured_socket_interface does + session = cache[0] + assert session is not None + _session = session + + class _SessionSSLObject(ssl.SSLObject): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.session = _session + + real_ctx.sslobject_class = _SessionSSLObject + + self.assertIs(real_ctx.sslobject_class, _SessionSSLObject) + self.assertTrue(issubclass(real_ctx.sslobject_class, ssl.SSLObject)) + + @unittest.skipUnless(not _IS_SYNC, "Tests async _configured_socket_interface only") + def test_async_configured_protocol_saves_session_to_cache(self): + """After a successful TLS connection the session is stored in the cache.""" + import asyncio + import ssl + import unittest.mock as mock + + from pymongo.pool_shared import _configured_socket_interface + + fake_session = object() + cache: list = [None] + + mock_ssl_obj = mock.MagicMock() + mock_ssl_obj.session = fake_session + + mock_transport = mock.MagicMock() + mock_transport.get_extra_info.side_effect = lambda key: ( + mock_ssl_obj if key == "ssl_object" else None + ) + + real_ctx = ssl.create_default_context() + real_ctx.check_hostname = False + real_ctx.verify_mode = ssl.CERT_NONE + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = real_ctx + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + mock_loop = mock.MagicMock() + mock_loop.create_connection = mock.Mock(return_value=(mock_transport, mock.MagicMock())) + + def run(): + with ( + mock.patch( + "pymongo.pool_shared._create_connection", + new=mock.SyncMock(return_value=mock.MagicMock()), + ), + mock.patch( + "pymongo.pool_shared.asyncio.get_running_loop", + return_value=mock_loop, + ), + ): + _configured_socket_interface(("localhost", 27017), mock_opts, cache) + + asyncio.run(run()) + self.assertIs(cache[0], fake_session) + + @unittest.skipUnless( + not _IS_SYNC and sys.version_info >= (3, 11), + "Tests async session injection on Python 3.11+", + ) + def test_async_configured_protocol_injects_session_via_sslobject_class(self): + """When the cache has a session, sslobject_class is set and its __init__ body runs.""" + import asyncio + import ssl + import unittest.mock as mock + + from pymongo.pool_shared import _configured_socket_interface + + initial_session = object() + cache: list = [initial_session] + + mock_transport = mock.MagicMock() + mock_transport.get_extra_info.return_value = None # no ssl_object → save block skips + + real_ctx = ssl.create_default_context() + real_ctx.check_hostname = False + real_ctx.verify_mode = ssl.CERT_NONE + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = real_ctx + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + mock_loop = mock.MagicMock() + mock_loop.create_connection = mock.Mock(return_value=(mock_transport, mock.MagicMock())) + + def run(): + with ( + mock.patch( + "pymongo.pool_shared._create_connection", + new=mock.SyncMock(return_value=mock.MagicMock()), + ), + mock.patch( + "pymongo.pool_shared.asyncio.get_running_loop", + return_value=mock_loop, + ), + ): + _configured_socket_interface(("localhost", 27017), mock_opts, cache) + + asyncio.run(run()) + + session_cls = real_ctx.sslobject_class # type: ignore[attr-defined] + self.assertIsNot(session_cls, ssl.SSLObject) + self.assertTrue(issubclass(session_cls, ssl.SSLObject)) + + # Exercise the __init__ body (super().__init__ + self.session = _session) by + # calling wrap_bio, patching the session setter to accept non-SSLSession objects. + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + no_op_session = property(lambda s: None, lambda s, v: None) + with mock.patch.object(ssl.SSLObject, "session", no_op_session): + ssl_obj = real_ctx.wrap_bio(incoming, outgoing, server_side=False) + self.assertIsInstance(ssl_obj, ssl.SSLObject) + + @unittest.skipUnless(not _IS_SYNC, "Tests async _configured_socket_interface only") + def test_async_configured_protocol_no_cache(self): + """When ssl_session_cache is None, no injection or save occurs.""" + import asyncio + import ssl + import unittest.mock as mock + + from pymongo.pool_shared import _configured_socket_interface + + real_ctx = ssl.create_default_context() + real_ctx.check_hostname = False + real_ctx.verify_mode = ssl.CERT_NONE + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = real_ctx + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + mock_transport = mock.MagicMock() + mock_transport.get_extra_info.return_value = None + + mock_loop = mock.MagicMock() + mock_loop.create_connection = mock.Mock(return_value=(mock_transport, mock.MagicMock())) + + def run(): + with ( + mock.patch( + "pymongo.pool_shared._create_connection", + new=mock.SyncMock(return_value=mock.MagicMock()), + ), + mock.patch( + "pymongo.pool_shared.asyncio.get_running_loop", + return_value=mock_loop, + ), + ): + _configured_socket_interface(("localhost", 27017), mock_opts, None) + + asyncio.run(run()) + self.assertIs(real_ctx.sslobject_class, ssl.SSLObject) # type: ignore[attr-defined] + + @unittest.skipUnless(not _IS_SYNC, "Tests async _configured_socket_interface only") + def test_async_configured_protocol_new_session_is_none(self): + """When ssl_object.session is None after connect, the cache is not updated.""" + import asyncio + import ssl + import unittest.mock as mock + + from pymongo.pool_shared import _configured_socket_interface + + cache: list = [None] + + mock_ssl_obj = mock.MagicMock() + mock_ssl_obj.session = None + + mock_transport = mock.MagicMock() + mock_transport.get_extra_info.side_effect = lambda key: ( + mock_ssl_obj if key == "ssl_object" else None + ) + + real_ctx = ssl.create_default_context() + real_ctx.check_hostname = False + real_ctx.verify_mode = ssl.CERT_NONE + + mock_opts = mock.MagicMock() + mock_opts._ssl_context = real_ctx + mock_opts.socket_timeout = None + mock_opts.tls_allow_invalid_hostnames = True + + mock_loop = mock.MagicMock() + mock_loop.create_connection = mock.Mock(return_value=(mock_transport, mock.MagicMock())) + + def run(): + with ( + mock.patch( + "pymongo.pool_shared._create_connection", + new=mock.SyncMock(return_value=mock.MagicMock()), + ), + mock.patch( + "pymongo.pool_shared.asyncio.get_running_loop", + return_value=mock_loop, + ), + ): + _configured_socket_interface(("localhost", 27017), mock_opts, cache) + + asyncio.run(run()) + self.assertIsNone(cache[0]) + class TestSSL(IntegrationTest): saved_port: int @@ -671,6 +955,22 @@ def test_pyopenssl_ignored_in_async(self): client.admin.command("ping") # command doesn't matter, just needs it to connect client.close() + @client_context.require_tls + def test_pool_has_ssl_session_cache(self): + pool = list(self.client._topology._servers.values())[0].pool + self.assertIsInstance(pool._ssl_session_cache, list) + + @client_context.require_tls + @unittest.skipUnless( + _IS_SYNC and _HAVE_PYOPENSSL, + "Sync stdlib ssl may return None for session on TLS 1.3; test limited to PyOpenSSL", + ) + def test_tls_session_cached_after_connect(self): + self.client.admin.command("ping") + pool = list(self.client._topology._servers.values())[0].pool + self.assertIsNotNone(pool._ssl_session_cache) + self.assertIsNotNone(pool._ssl_session_cache[0]) + if __name__ == "__main__": unittest.main()