gh-79156: Add start_tls() method to streams API (#91453)

The existing event loop `start_tls()` method is not sufficient for
connections using the streams API. The existing StreamReader works
because the new transport passes received data to the original protocol.
The StreamWriter must then write data to the new transport, and the
StreamReaderProtocol must be updated to close the new transport
correctly.

The new StreamWriter `start_tls()` updates itself and the reader
protocol to the new SSL transport.

Co-authored-by: Ian Good <icgood@gmail.com>
This commit is contained in:
Oleg Iarygin 2022-04-15 15:23:14 +03:00 committed by GitHub
parent bd26ef5e9e
commit 6217864fe5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 0 deletions

View file

@ -295,6 +295,24 @@ StreamWriter
be resumed. When there is nothing to wait for, the :meth:`drain`
returns immediately.
.. coroutinemethod:: start_tls(sslcontext, \*, server_hostname=None, \
ssl_handshake_timeout=None)
Upgrade an existing stream-based connection to TLS.
Parameters:
* *sslcontext*: a configured instance of :class:`~ssl.SSLContext`.
* *server_hostname*: sets or overrides the host name that the target
server's certificate will be matched against.
* *ssl_handshake_timeout* is the time in seconds to wait for the TLS
handshake to complete before aborting the connection. ``60.0`` seconds
if ``None`` (default).
.. versionadded:: 3.8
.. method:: is_closing()
Return ``True`` if the stream is closed or in the process of

View file

@ -246,6 +246,10 @@ asyncio
:meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`.
(Contributed by Alex Grönholm in :issue:`46805`.)
* Add :meth:`~asyncio.streams.StreamWriter.start_tls` method for upgrading
existing stream-based connections to TLS. (Contributed by Ian Good in
:issue:`34975`.)
fractions
---------

View file

@ -217,6 +217,13 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
return None
return self._stream_reader_wr()
def _replace_writer(self, writer):
loop = self._loop
transport = writer.transport
self._stream_writer = writer
self._transport = transport
self._over_ssl = transport.get_extra_info('sslcontext') is not None
def connection_made(self, transport):
if self._reject_connection:
context = {
@ -371,6 +378,20 @@ class StreamWriter:
await sleep(0)
await self._protocol._drain_helper()
async def start_tls(self, sslcontext, *,
server_hostname=None,
ssl_handshake_timeout=None):
"""Upgrade an existing stream-based connection to TLS."""
server_side = self._protocol._client_connected_cb is not None
protocol = self._protocol
await self.drain()
new_transport = await self._loop.start_tls( # type: ignore
self._transport, protocol, sslcontext,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
self._transport = new_transport
protocol._replace_writer(self)
class StreamReader:

View file

@ -706,6 +706,69 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(messages, [])
@unittest.skipIf(ssl is None, 'No ssl module')
def test_start_tls(self):
class MyServer:
def __init__(self, loop):
self.server = None
self.loop = loop
async def handle_client(self, client_reader, client_writer):
data1 = await client_reader.readline()
client_writer.write(data1)
await client_writer.drain()
assert client_writer.get_extra_info('sslcontext') is None
await client_writer.start_tls(
test_utils.simple_server_sslcontext())
assert client_writer.get_extra_info('sslcontext') is not None
data2 = await client_reader.readline()
client_writer.write(data2)
await client_writer.drain()
client_writer.close()
await client_writer.wait_closed()
def start(self):
sock = socket.create_server(('127.0.0.1', 0))
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client,
sock=sock))
return sock.getsockname()
def stop(self):
if self.server is not None:
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.server = None
async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
writer.write(b"hello world 1!\n")
await writer.drain()
msgback1 = await reader.readline()
assert writer.get_extra_info('sslcontext') is None
await writer.start_tls(test_utils.simple_client_sslcontext())
assert writer.get_extra_info('sslcontext') is not None
writer.write(b"hello world 2!\n")
await writer.drain()
msgback2 = await reader.readline()
writer.close()
await writer.wait_closed()
return msgback1, msgback2
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
server = MyServer(self.loop)
addr = server.start()
msg1, msg2 = self.loop.run_until_complete(client(addr))
server.stop()
self.assertEqual(messages, [])
self.assertEqual(msg1, b"hello world 1!\n")
self.assertEqual(msg2, b"hello world 2!\n")
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
def test_read_all_from_pipe_reader(self):
# See asyncio issue 168. This test is derived from the example

View file

@ -0,0 +1,3 @@
Adds a ``start_tls()`` method to :class:`~asyncio.streams.StreamWriter`,
which upgrades the connection with TLS using the given
:class:`~ssl.SSLContext`.