diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 80584a86..897b054a 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -1263,6 +1263,116 @@ class _TestSSL(tb.SSLTestCase): PAYLOAD_SIZE = 1024 * 100 TIMEOUT = 60 + def test_start_tls_buffer_transfer(self): + TIMEOUT = 10 + + if ( + self.implementation == 'asyncio' + or sys.version_info[:2] < (3, 11) + ): + # StreamWriter.start_tls() introduced in Python 3.11 + raise unittest.SkipTest( + 'StreamWriter.start_tls() not supported' + ) + self.loop.set_exception_handler(lambda loop, ctx: None) + + client_read_buffered = asyncio.Event() + server_sent_ok = asyncio.Event() + + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + BUFFERED_MSG = b'buffered data before TLS' + + server_context = self._create_server_ssl_context( + self.ONLYCERT, self.ONLYKEY) + client_context = self._create_client_ssl_context() + + async def handle_client(reader, writer): + try: + # Send data before TLS upgrade + writer.write(BUFFERED_MSG) + await writer.drain() + + await asyncio.wait_for( + client_read_buffered.wait(), + timeout=TIMEOUT + ) + + # Read pre-TLS data + data = await asyncio.wait_for( + reader.readexactly(len(HELLO_MSG)), + timeout=TIMEOUT, + ) + self.assertEqual(len(data), len(HELLO_MSG)) + + # Upgrade to TLS (server side) + # We need the wait_for because the broken version hangs here + await asyncio.wait_for( + writer.start_tls(server_context), + timeout=TIMEOUT,) + self.assertIsNotNone(writer.get_extra_info('sslcontext')) + + # Send/receive over TLS + writer.write(b'OK') + await writer.drain() + server_sent_ok.set() + + data = await asyncio.wait_for( + reader.readexactly(len(HELLO_MSG)), + timeout=TIMEOUT, + ) + self.assertEqual(len(data), len(HELLO_MSG)) + finally: + if not writer.is_closing(): + writer.close() + await self.wait_closed(writer) + + async def client(addr): + # Use open_connection for StreamReader/StreamWriter + reader, writer = await asyncio.open_connection(*addr) + + try: + # Read buffered data before TLS + buffered = await reader.readexactly(len(BUFFERED_MSG)) + self.assertEqual(buffered, BUFFERED_MSG, + "Wrong pre-TLS buffered data from server") + client_read_buffered.set() + + # Write before TLS upgrade + writer.write(HELLO_MSG) + await writer.drain() + + # Upgrade to TLS + await writer.start_tls(client_context) + self.assertIsNotNone(writer.get_extra_info('sslcontext')) + + # Verify communication over TLS + await server_sent_ok.wait() + tls_data = await reader.readexactly(2) + self.assertEqual(tls_data, b'OK', + "Wrong data from server after TLS upgrade") + + # Continue over TLS + writer.write(HELLO_MSG) + await writer.drain() + finally: + if not writer.is_closing(): + writer.close() + await self.wait_closed(writer) + + async def run_test(): + srv = await asyncio.start_server( + handle_client, '127.0.0.1', 0, family=socket.AF_INET) + + try: + addr = srv.sockets[0].getsockname() + + await asyncio.wait_for(client(addr), timeout=self.TIMEOUT) + finally: + srv.close() + await srv.wait_closed() + + self.loop.run_until_complete(run_test()) + def test_create_server_ssl_1(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 2ed1f272..9dc2bda0 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -1616,6 +1616,14 @@ cdef class Loop: ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) + # Transfer buffered data from the old protocol to the new one. + stream_reader = getattr(protocol, '_stream_reader', None) + if stream_reader is not None: + stream_buff = getattr(stream_reader, '_buffer', None) + if stream_buff is not None: + ssl_protocol._incoming.write(stream_buff) + stream_buff.clear() + # Pause early so that "ssl_protocol.data_received()" doesn't # have a chance to get called before "ssl_protocol.connection_made()". transport.pause_reading()