[3.12] gh-115514: Fix incomplete writes after close while using ssl in asyncio(GH-128037) (#129582)

gh-115514: Fix incomplete writes after close while using ssl in asyncio(GH-128037)

(cherry picked from commit 4e38eeafe2)

Co-authored-by: Vojtěch Boček <vbocek@gmail.com>
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
This commit is contained in:
Miss Islington (bot) 2025-02-02 16:47:37 +01:00 committed by GitHub
parent a7084f6075
commit e20963a12a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 213 additions and 5 deletions

View file

@ -12,6 +12,7 @@ import sys
import tempfile
import threading
import time
import unittest.mock
import weakref
import unittest
@ -1431,6 +1432,166 @@ class TestSSL(test_utils.TestCase):
with self.tcp_server(run(eof_server)) as srv:
self.loop.run_until_complete(client(srv.addr))
def test_remote_shutdown_receives_trailing_data_on_slow_socket(self):
# This test is the same as test_remote_shutdown_receives_trailing_data,
# except it simulates a socket that is not able to write data in time,
# thus triggering different code path in _SelectorSocketTransport.
# This triggers bug gh-115514, also tested using mocks in
# test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
# The slow path is triggered here by setting SO_SNDBUF, see code and comment below.
CHUNK = 1024 * 128
SIZE = 32
sslctx = self._create_server_ssl_context(
test_utils.ONLYCERT,
test_utils.ONLYKEY
)
client_sslctx = self._create_client_ssl_context()
future = None
def server(sock):
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
while True:
try:
sslobj.do_handshake()
except ssl.SSLWantReadError:
if outgoing.pending:
sock.send(outgoing.read())
incoming.write(sock.recv(16384))
else:
if outgoing.pending:
sock.send(outgoing.read())
break
while True:
try:
data = sslobj.read(4)
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))
else:
break
self.assertEqual(data, b'ping')
sslobj.write(b'pong')
sock.send(outgoing.read())
time.sleep(0.2) # wait for the peer to fill its backlog
# send close_notify but don't wait for response
with self.assertRaises(ssl.SSLWantReadError):
sslobj.unwrap()
sock.send(outgoing.read())
# should receive all data
data_len = 0
while True:
try:
chunk = len(sslobj.read(16384))
data_len += chunk
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))
except ssl.SSLZeroReturnError:
break
self.assertEqual(data_len, CHUNK * SIZE*2)
# verify that close_notify is received
sslobj.unwrap()
sock.close()
def eof_server(sock):
sock.starttls(sslctx, server_side=True)
self.assertEqual(sock.recv_all(4), b'ping')
sock.send(b'pong')
time.sleep(0.2) # wait for the peer to fill its backlog
# send EOF
sock.shutdown(socket.SHUT_WR)
# should receive all data
data = sock.recv_all(CHUNK * SIZE)
self.assertEqual(len(data), CHUNK * SIZE)
sock.close()
async def client(addr):
nonlocal future
future = self.loop.create_future()
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')
# fill write backlog in a hacky way - renegotiation won't help
for _ in range(SIZE*2):
writer.transport._test__append_write_backlog(b'x' * CHUNK)
try:
data = await reader.read()
self.assertEqual(data, b'')
except (BrokenPipeError, ConnectionResetError):
pass
# Make sure _SelectorSocketTransport enters the delayed write
# path in its `write` method by wrapping socket in a fake class
# that acts as if there is not enough space in socket buffer.
# This triggers bug gh-115514, also tested using mocks in
# test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
socket_transport = writer.transport._ssl_protocol._transport
class SocketWrapper:
def __init__(self, sock) -> None:
self.sock = sock
def __getattr__(self, name):
return getattr(self.sock, name)
def send(self, data):
# Fake that our write buffer is full, send only half
to_send = len(data)//2
return self.sock.send(data[:to_send])
def _fake_full_write_buffer(data):
if socket_transport._read_ready_cb is None and not isinstance(socket_transport._sock, SocketWrapper):
socket_transport._sock = SocketWrapper(socket_transport._sock)
return unittest.mock.DEFAULT
with unittest.mock.patch.object(
socket_transport, "write",
wraps=socket_transport.write,
side_effect=_fake_full_write_buffer
):
await future
writer.close()
await self.wait_closed(writer)
def run(meth):
def wrapper(sock):
try:
meth(sock)
except Exception as ex:
self.loop.call_soon_threadsafe(future.set_exception, ex)
else:
self.loop.call_soon_threadsafe(future.set_result, None)
return wrapper
with self.tcp_server(run(server)) as srv:
self.loop.run_until_complete(client(srv.addr))
with self.tcp_server(run(eof_server)) as srv:
self.loop.run_until_complete(client(srv.addr))
def test_connect_timeout_warning(self):
s = socket.socket(socket.AF_INET)
s.bind(('127.0.0.1', 0))