mirror of
https://github.com/python/cpython.git
synced 2025-08-04 17:08:35 +00:00
bpo-33654: Support BufferedProtocol in set_protocol() and start_tls() (GH-7130)
In this commit: * Support BufferedProtocol in set_protocol() and start_tls() * Fix proactor to cancel readers reliably * Update tests to be compatible with OpenSSL 1.1.1 * Clarify BufferedProtocol docs * Bump TLS tests timeouts to 60 seconds; eliminate possible race from start_serving * Rewrite test_start_tls_server_1
This commit is contained in:
parent
e549c4be5f
commit
dbf102271f
13 changed files with 382 additions and 69 deletions
|
@ -1,8 +1,7 @@
|
|||
"""Tests for asyncio/sslproto.py."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import socket
|
||||
import unittest
|
||||
from unittest import mock
|
||||
try:
|
||||
|
@ -185,17 +184,67 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
|
||||
class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
||||
|
||||
PAYLOAD_SIZE = 1024 * 100
|
||||
TIMEOUT = 60
|
||||
|
||||
def new_loop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_start_tls_client_1(self):
|
||||
HELLO_MSG = b'1' * 1024 * 1024
|
||||
def test_buf_feed_data(self):
|
||||
|
||||
class Proto(asyncio.BufferedProtocol):
|
||||
|
||||
def __init__(self, bufsize, usemv):
|
||||
self.buf = bytearray(bufsize)
|
||||
self.mv = memoryview(self.buf)
|
||||
self.data = b''
|
||||
self.usemv = usemv
|
||||
|
||||
def get_buffer(self, sizehint):
|
||||
if self.usemv:
|
||||
return self.mv
|
||||
else:
|
||||
return self.buf
|
||||
|
||||
def buffer_updated(self, nsize):
|
||||
if self.usemv:
|
||||
self.data += self.mv[:nsize]
|
||||
else:
|
||||
self.data += self.buf[:nsize]
|
||||
|
||||
for usemv in [False, True]:
|
||||
proto = Proto(1, usemv)
|
||||
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
|
||||
self.assertEqual(proto.data, b'12345')
|
||||
|
||||
proto = Proto(2, usemv)
|
||||
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
|
||||
self.assertEqual(proto.data, b'12345')
|
||||
|
||||
proto = Proto(2, usemv)
|
||||
sslproto._feed_data_to_bufferred_proto(proto, b'1234')
|
||||
self.assertEqual(proto.data, b'1234')
|
||||
|
||||
proto = Proto(4, usemv)
|
||||
sslproto._feed_data_to_bufferred_proto(proto, b'1234')
|
||||
self.assertEqual(proto.data, b'1234')
|
||||
|
||||
proto = Proto(100, usemv)
|
||||
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
|
||||
self.assertEqual(proto.data, b'12345')
|
||||
|
||||
proto = Proto(0, usemv)
|
||||
with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
|
||||
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
|
||||
|
||||
def test_start_tls_client_reg_proto_1(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
|
||||
server_context = test_utils.simple_server_sslcontext()
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
|
||||
def serve(sock):
|
||||
sock.settimeout(5)
|
||||
sock.settimeout(self.TIMEOUT)
|
||||
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
@ -205,6 +254,8 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
|||
sock.sendall(b'O')
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
|
||||
class ClientProto(asyncio.Protocol):
|
||||
|
@ -246,17 +297,80 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
|||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
|
||||
|
||||
def test_start_tls_server_1(self):
|
||||
HELLO_MSG = b'1' * 1024 * 1024
|
||||
def test_start_tls_client_buf_proto_1(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
|
||||
server_context = test_utils.simple_server_sslcontext()
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
|
||||
def serve(sock):
|
||||
sock.settimeout(self.TIMEOUT)
|
||||
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.start_tls(server_context, server_side=True)
|
||||
|
||||
sock.sendall(b'O')
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
|
||||
class ClientProto(asyncio.BufferedProtocol):
|
||||
def __init__(self, on_data, on_eof):
|
||||
self.on_data = on_data
|
||||
self.on_eof = on_eof
|
||||
self.con_made_cnt = 0
|
||||
self.buf = bytearray(1)
|
||||
|
||||
def connection_made(proto, tr):
|
||||
proto.con_made_cnt += 1
|
||||
# Ensure connection_made gets called only once.
|
||||
self.assertEqual(proto.con_made_cnt, 1)
|
||||
|
||||
def get_buffer(self, sizehint):
|
||||
return self.buf
|
||||
|
||||
def buffer_updated(self, nsize):
|
||||
assert nsize == 1
|
||||
self.on_data.set_result(bytes(self.buf[:nsize]))
|
||||
|
||||
def eof_received(self):
|
||||
self.on_eof.set_result(True)
|
||||
|
||||
async def client(addr):
|
||||
await asyncio.sleep(0.5, loop=self.loop)
|
||||
|
||||
on_data = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
|
||||
tr, proto = await self.loop.create_connection(
|
||||
lambda: ClientProto(on_data, on_eof), *addr)
|
||||
|
||||
tr.write(HELLO_MSG)
|
||||
new_tr = await self.loop.start_tls(tr, proto, client_context)
|
||||
|
||||
self.assertEqual(await on_data, b'O')
|
||||
new_tr.write(HELLO_MSG)
|
||||
await on_eof
|
||||
|
||||
new_tr.close()
|
||||
|
||||
with self.tcp_server(serve) as srv:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(srv.addr),
|
||||
loop=self.loop, timeout=self.TIMEOUT))
|
||||
|
||||
def test_start_tls_server_1(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
|
||||
server_context = test_utils.simple_server_sslcontext()
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
# TODO: fix TLSv1.3 support
|
||||
client_context.options |= ssl.OP_NO_TLSv1_3
|
||||
|
||||
def client(sock, addr):
|
||||
time.sleep(0.5)
|
||||
sock.settimeout(5)
|
||||
sock.settimeout(self.TIMEOUT)
|
||||
|
||||
sock.connect(addr)
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
|
@ -264,12 +378,15 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
|||
|
||||
sock.start_tls(client_context)
|
||||
sock.sendall(HELLO_MSG)
|
||||
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
|
||||
class ServerProto(asyncio.Protocol):
|
||||
def __init__(self, on_con, on_eof):
|
||||
def __init__(self, on_con, on_eof, on_con_lost):
|
||||
self.on_con = on_con
|
||||
self.on_eof = on_eof
|
||||
self.on_con_lost = on_con_lost
|
||||
self.data = b''
|
||||
|
||||
def connection_made(self, tr):
|
||||
|
@ -281,7 +398,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
|||
def eof_received(self):
|
||||
self.on_eof.set_result(1)
|
||||
|
||||
async def main():
|
||||
def connection_lost(self, exc):
|
||||
if exc is None:
|
||||
self.on_con_lost.set_result(None)
|
||||
else:
|
||||
self.on_con_lost.set_exception(exc)
|
||||
|
||||
async def main(proto, on_con, on_eof, on_con_lost):
|
||||
tr = await on_con
|
||||
tr.write(HELLO_MSG)
|
||||
|
||||
|
@ -292,24 +415,29 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
|||
server_side=True)
|
||||
|
||||
await on_eof
|
||||
await on_con_lost
|
||||
self.assertEqual(proto.data, HELLO_MSG)
|
||||
new_tr.close()
|
||||
|
||||
async def run_main():
|
||||
on_con = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
on_con_lost = self.loop.create_future()
|
||||
proto = ServerProto(on_con, on_eof, on_con_lost)
|
||||
|
||||
server = await self.loop.create_server(
|
||||
lambda: proto, '127.0.0.1', 0)
|
||||
addr = server.sockets[0].getsockname()
|
||||
|
||||
with self.tcp_client(lambda sock: client(sock, addr)):
|
||||
await asyncio.wait_for(
|
||||
main(proto, on_con, on_eof, on_con_lost),
|
||||
loop=self.loop, timeout=self.TIMEOUT)
|
||||
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
on_con = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
proto = ServerProto(on_con, on_eof)
|
||||
|
||||
server = self.loop.run_until_complete(
|
||||
self.loop.create_server(
|
||||
lambda: proto, '127.0.0.1', 0))
|
||||
addr = server.sockets[0].getsockname()
|
||||
|
||||
with self.tcp_client(lambda sock: client(sock, addr)):
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(main(), loop=self.loop, timeout=10))
|
||||
self.loop.run_until_complete(run_main())
|
||||
|
||||
def test_start_tls_wrong_args(self):
|
||||
async def main():
|
||||
|
@ -332,7 +460,6 @@ class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
|
||||
@unittest.skipIf(os.environ.get('APPVEYOR'), 'XXX: issue 32458')
|
||||
class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue