mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
gh-115514: Fix incomplete writes after close while using ssl in asyncio(#128037)
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
This commit is contained in:
parent
853a6b7de2
commit
4e38eeafe2
5 changed files with 213 additions and 5 deletions
|
@ -1185,10 +1185,13 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _call_connection_lost(self, exc):
|
def _call_connection_lost(self, exc):
|
||||||
super()._call_connection_lost(exc)
|
try:
|
||||||
if self._empty_waiter is not None:
|
super()._call_connection_lost(exc)
|
||||||
self._empty_waiter.set_exception(
|
finally:
|
||||||
ConnectionError("Connection is closed by peer"))
|
self._write_ready = None
|
||||||
|
if self._empty_waiter is not None:
|
||||||
|
self._empty_waiter.set_exception(
|
||||||
|
ConnectionError("Connection is closed by peer"))
|
||||||
|
|
||||||
def _make_empty_waiter(self):
|
def _make_empty_waiter(self):
|
||||||
if self._empty_waiter is not None:
|
if self._empty_waiter is not None:
|
||||||
|
@ -1203,7 +1206,6 @@ class _SelectorSocketTransport(_SelectorTransport):
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._read_ready_cb = None
|
self._read_ready_cb = None
|
||||||
self._write_ready = None
|
|
||||||
super().close()
|
super().close()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1051,6 +1051,48 @@ class SelectorSocketTransportTests(test_utils.TestCase):
|
||||||
transport.close()
|
transport.close()
|
||||||
remove_writer.assert_called_with(self.sock_fd)
|
remove_writer.assert_called_with(self.sock_fd)
|
||||||
|
|
||||||
|
def test_write_buffer_after_close(self):
|
||||||
|
# gh-115514: If the transport is closed while:
|
||||||
|
# * Transport write buffer is not empty
|
||||||
|
# * Transport is paused
|
||||||
|
# * Protocol has data in its buffer, like SSLProtocol in self._outgoing
|
||||||
|
# The data is still written out.
|
||||||
|
|
||||||
|
# Also tested with real SSL transport in
|
||||||
|
# test.test_asyncio.test_ssl.TestSSL.test_remote_shutdown_receives_trailing_data
|
||||||
|
|
||||||
|
data = memoryview(b'data')
|
||||||
|
self.sock.send.return_value = 2
|
||||||
|
self.sock.send.fileno.return_value = 7
|
||||||
|
|
||||||
|
def _resume_writing():
|
||||||
|
transport.write(b"data")
|
||||||
|
self.protocol.resume_writing.side_effect = None
|
||||||
|
|
||||||
|
self.protocol.resume_writing.side_effect = _resume_writing
|
||||||
|
|
||||||
|
transport = self.socket_transport()
|
||||||
|
transport._high_water = 1
|
||||||
|
|
||||||
|
transport.write(data)
|
||||||
|
|
||||||
|
self.assertTrue(transport._protocol_paused)
|
||||||
|
self.assertTrue(self.sock.send.called)
|
||||||
|
self.loop.assert_writer(7, transport._write_ready)
|
||||||
|
|
||||||
|
transport.close()
|
||||||
|
|
||||||
|
# not called, we still have data in write buffer
|
||||||
|
self.assertFalse(self.protocol.connection_lost.called)
|
||||||
|
|
||||||
|
self.loop.writers[7]._run()
|
||||||
|
# during this ^ run, the _resume_writing mock above was called and added more data
|
||||||
|
|
||||||
|
self.assertEqual(transport.get_write_buffer_size(), 2)
|
||||||
|
self.loop.writers[7]._run()
|
||||||
|
|
||||||
|
self.assertEqual(transport.get_write_buffer_size(), 0)
|
||||||
|
self.assertTrue(self.protocol.connection_lost.called)
|
||||||
|
|
||||||
class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
|
class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import unittest.mock
|
||||||
import weakref
|
import weakref
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
@ -1431,6 +1432,166 @@ class TestSSL(test_utils.TestCase):
|
||||||
with self.tcp_server(run(eof_server)) as srv:
|
with self.tcp_server(run(eof_server)) as srv:
|
||||||
self.loop.run_until_complete(client(srv.addr))
|
self.loop.run_until_complete(client(srv.addr))
|
||||||
|
|
||||||
|
def test_remote_shutdown_receives_trailing_data_on_slow_socket(self):
|
||||||
|
# This test is the same as test_remote_shutdown_receives_trailing_data,
|
||||||
|
# except it simulates a socket that is not able to write data in time,
|
||||||
|
# thus triggering different code path in _SelectorSocketTransport.
|
||||||
|
# This triggers bug gh-115514, also tested using mocks in
|
||||||
|
# test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
|
||||||
|
# The slow path is triggered here by setting SO_SNDBUF, see code and comment below.
|
||||||
|
|
||||||
|
CHUNK = 1024 * 128
|
||||||
|
SIZE = 32
|
||||||
|
|
||||||
|
sslctx = self._create_server_ssl_context(
|
||||||
|
test_utils.ONLYCERT,
|
||||||
|
test_utils.ONLYKEY
|
||||||
|
)
|
||||||
|
client_sslctx = self._create_client_ssl_context()
|
||||||
|
future = None
|
||||||
|
|
||||||
|
def server(sock):
|
||||||
|
incoming = ssl.MemoryBIO()
|
||||||
|
outgoing = ssl.MemoryBIO()
|
||||||
|
sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
sslobj.do_handshake()
|
||||||
|
except ssl.SSLWantReadError:
|
||||||
|
if outgoing.pending:
|
||||||
|
sock.send(outgoing.read())
|
||||||
|
incoming.write(sock.recv(16384))
|
||||||
|
else:
|
||||||
|
if outgoing.pending:
|
||||||
|
sock.send(outgoing.read())
|
||||||
|
break
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = sslobj.read(4)
|
||||||
|
except ssl.SSLWantReadError:
|
||||||
|
incoming.write(sock.recv(16384))
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertEqual(data, b'ping')
|
||||||
|
sslobj.write(b'pong')
|
||||||
|
sock.send(outgoing.read())
|
||||||
|
|
||||||
|
time.sleep(0.2) # wait for the peer to fill its backlog
|
||||||
|
|
||||||
|
# send close_notify but don't wait for response
|
||||||
|
with self.assertRaises(ssl.SSLWantReadError):
|
||||||
|
sslobj.unwrap()
|
||||||
|
sock.send(outgoing.read())
|
||||||
|
|
||||||
|
# should receive all data
|
||||||
|
data_len = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = len(sslobj.read(16384))
|
||||||
|
data_len += chunk
|
||||||
|
except ssl.SSLWantReadError:
|
||||||
|
incoming.write(sock.recv(16384))
|
||||||
|
except ssl.SSLZeroReturnError:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertEqual(data_len, CHUNK * SIZE*2)
|
||||||
|
|
||||||
|
# verify that close_notify is received
|
||||||
|
sslobj.unwrap()
|
||||||
|
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
def eof_server(sock):
|
||||||
|
sock.starttls(sslctx, server_side=True)
|
||||||
|
self.assertEqual(sock.recv_all(4), b'ping')
|
||||||
|
sock.send(b'pong')
|
||||||
|
|
||||||
|
time.sleep(0.2) # wait for the peer to fill its backlog
|
||||||
|
|
||||||
|
# send EOF
|
||||||
|
sock.shutdown(socket.SHUT_WR)
|
||||||
|
|
||||||
|
# should receive all data
|
||||||
|
data = sock.recv_all(CHUNK * SIZE)
|
||||||
|
self.assertEqual(len(data), CHUNK * SIZE)
|
||||||
|
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
async def client(addr):
|
||||||
|
nonlocal future
|
||||||
|
future = self.loop.create_future()
|
||||||
|
|
||||||
|
reader, writer = await asyncio.open_connection(
|
||||||
|
*addr,
|
||||||
|
ssl=client_sslctx,
|
||||||
|
server_hostname='')
|
||||||
|
writer.write(b'ping')
|
||||||
|
data = await reader.readexactly(4)
|
||||||
|
self.assertEqual(data, b'pong')
|
||||||
|
|
||||||
|
# fill write backlog in a hacky way - renegotiation won't help
|
||||||
|
for _ in range(SIZE*2):
|
||||||
|
writer.transport._test__append_write_backlog(b'x' * CHUNK)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await reader.read()
|
||||||
|
self.assertEqual(data, b'')
|
||||||
|
except (BrokenPipeError, ConnectionResetError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Make sure _SelectorSocketTransport enters the delayed write
|
||||||
|
# path in its `write` method by wrapping socket in a fake class
|
||||||
|
# that acts as if there is not enough space in socket buffer.
|
||||||
|
# This triggers bug gh-115514, also tested using mocks in
|
||||||
|
# test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
|
||||||
|
socket_transport = writer.transport._ssl_protocol._transport
|
||||||
|
|
||||||
|
class SocketWrapper:
|
||||||
|
def __init__(self, sock) -> None:
|
||||||
|
self.sock = sock
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.sock, name)
|
||||||
|
|
||||||
|
def send(self, data):
|
||||||
|
# Fake that our write buffer is full, send only half
|
||||||
|
to_send = len(data)//2
|
||||||
|
return self.sock.send(data[:to_send])
|
||||||
|
|
||||||
|
def _fake_full_write_buffer(data):
|
||||||
|
if socket_transport._read_ready_cb is None and not isinstance(socket_transport._sock, SocketWrapper):
|
||||||
|
socket_transport._sock = SocketWrapper(socket_transport._sock)
|
||||||
|
return unittest.mock.DEFAULT
|
||||||
|
|
||||||
|
with unittest.mock.patch.object(
|
||||||
|
socket_transport, "write",
|
||||||
|
wraps=socket_transport.write,
|
||||||
|
side_effect=_fake_full_write_buffer
|
||||||
|
):
|
||||||
|
await future
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
await self.wait_closed(writer)
|
||||||
|
|
||||||
|
def run(meth):
|
||||||
|
def wrapper(sock):
|
||||||
|
try:
|
||||||
|
meth(sock)
|
||||||
|
except Exception as ex:
|
||||||
|
self.loop.call_soon_threadsafe(future.set_exception, ex)
|
||||||
|
else:
|
||||||
|
self.loop.call_soon_threadsafe(future.set_result, None)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
with self.tcp_server(run(server)) as srv:
|
||||||
|
self.loop.run_until_complete(client(srv.addr))
|
||||||
|
|
||||||
|
with self.tcp_server(run(eof_server)) as srv:
|
||||||
|
self.loop.run_until_complete(client(srv.addr))
|
||||||
|
|
||||||
def test_connect_timeout_warning(self):
|
def test_connect_timeout_warning(self):
|
||||||
s = socket.socket(socket.AF_INET)
|
s = socket.socket(socket.AF_INET)
|
||||||
s.bind(('127.0.0.1', 0))
|
s.bind(('127.0.0.1', 0))
|
||||||
|
|
|
@ -189,6 +189,7 @@ Stéphane Blondon
|
||||||
Eric Blossom
|
Eric Blossom
|
||||||
Sergey Bobrov
|
Sergey Bobrov
|
||||||
Finn Bock
|
Finn Bock
|
||||||
|
Vojtěch Boček
|
||||||
Paul Boddie
|
Paul Boddie
|
||||||
Matthew Boedicker
|
Matthew Boedicker
|
||||||
Robin Boerdijk
|
Robin Boerdijk
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
Fix exceptions and incomplete writes after :class:`!asyncio._SelectorTransport`
|
||||||
|
is closed before writes are completed.
|
Loading…
Add table
Add a link
Reference in a new issue