mirror of
https://github.com/python/cpython.git
synced 2025-12-15 21:44:50 +00:00
Issue #19509: Finish implementation of check_hostname
The new asyncio package now supports the new feature and comes with additional tests for SSL.
This commit is contained in:
parent
8ff6f3e895
commit
6d8c1abb00
8 changed files with 319 additions and 57 deletions
|
|
@ -17,7 +17,7 @@ import time
|
|||
import errno
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from test.support import find_unused_port, IPV6_ENABLED
|
||||
from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR
|
||||
|
||||
|
||||
from asyncio import futures
|
||||
|
|
@ -30,10 +30,27 @@ from asyncio import test_utils
|
|||
from asyncio import locks
|
||||
|
||||
|
||||
def data_file(filename):
|
||||
if hasattr(support, 'TEST_HOME_DIR'):
|
||||
fullname = os.path.join(support.TEST_HOME_DIR, filename)
|
||||
if os.path.isfile(fullname):
|
||||
return fullname
|
||||
fullname = os.path.join(os.path.dirname(__file__), filename)
|
||||
if os.path.isfile(fullname):
|
||||
return fullname
|
||||
raise FileNotFoundError(filename)
|
||||
|
||||
ONLYCERT = data_file('ssl_cert.pem')
|
||||
ONLYKEY = data_file('ssl_key.pem')
|
||||
SIGNED_CERTFILE = data_file('keycert3.pem')
|
||||
SIGNING_CA = data_file('pycacert.pem')
|
||||
|
||||
|
||||
class MyProto(protocols.Protocol):
|
||||
done = None
|
||||
|
||||
def __init__(self, loop=None):
|
||||
self.transport = None
|
||||
self.state = 'INITIAL'
|
||||
self.nbytes = 0
|
||||
if loop is not None:
|
||||
|
|
@ -523,7 +540,7 @@ class EventLoopTestsMixin:
|
|||
|
||||
def test_create_connection_local_addr(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
port = find_unused_port()
|
||||
port = support.find_unused_port()
|
||||
f = self.loop.create_connection(
|
||||
lambda: MyProto(loop=self.loop),
|
||||
*httpd.address, local_addr=(httpd.address[0], port))
|
||||
|
|
@ -587,6 +604,20 @@ class EventLoopTestsMixin:
|
|||
# close server
|
||||
server.close()
|
||||
|
||||
def _make_ssl_server(self, factory, certfile, keyfile=None):
|
||||
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext.options |= ssl.OP_NO_SSLv2
|
||||
sslcontext.load_cert_chain(certfile, keyfile)
|
||||
|
||||
f = self.loop.create_server(
|
||||
factory, '127.0.0.1', 0, ssl=sslcontext)
|
||||
|
||||
server = self.loop.run_until_complete(f)
|
||||
sock = server.sockets[0]
|
||||
host, port = sock.getsockname()
|
||||
self.assertEqual(host, '127.0.0.1')
|
||||
return server, host, port
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_create_server_ssl(self):
|
||||
proto = None
|
||||
|
|
@ -602,19 +633,7 @@ class EventLoopTestsMixin:
|
|||
proto = MyProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
here = os.path.dirname(__file__)
|
||||
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext.load_cert_chain(
|
||||
certfile=os.path.join(here, 'sample.crt'),
|
||||
keyfile=os.path.join(here, 'sample.key'))
|
||||
|
||||
f = self.loop.create_server(
|
||||
factory, '127.0.0.1', 0, ssl=sslcontext)
|
||||
|
||||
server = self.loop.run_until_complete(f)
|
||||
sock = server.sockets[0]
|
||||
host, port = sock.getsockname()
|
||||
self.assertEqual(host, '127.0.0.1')
|
||||
server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY)
|
||||
|
||||
f_c = self.loop.create_connection(ClientMyProto, host, port,
|
||||
ssl=test_utils.dummy_ssl_context())
|
||||
|
|
@ -646,6 +665,93 @@ class EventLoopTestsMixin:
|
|||
# stop serving
|
||||
server.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_create_server_ssl_verify_failed(self):
|
||||
proto = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
|
||||
|
||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
||||
if hasattr(sslcontext_client, 'check_hostname'):
|
||||
sslcontext_client.check_hostname = True
|
||||
|
||||
# no CA loaded
|
||||
f_c = self.loop.create_connection(MyProto, host, port,
|
||||
ssl=sslcontext_client)
|
||||
with self.assertRaisesRegex(ssl.SSLError,
|
||||
'certificate verify failed '):
|
||||
self.loop.run_until_complete(f_c)
|
||||
|
||||
# close connection
|
||||
self.assertIsNone(proto.transport)
|
||||
server.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_create_server_ssl_match_failed(self):
|
||||
proto = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
|
||||
|
||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
||||
sslcontext_client.load_verify_locations(
|
||||
cafile=SIGNING_CA)
|
||||
if hasattr(sslcontext_client, 'check_hostname'):
|
||||
sslcontext_client.check_hostname = True
|
||||
|
||||
# incorrect server_hostname
|
||||
f_c = self.loop.create_connection(MyProto, host, port,
|
||||
ssl=sslcontext_client)
|
||||
with self.assertRaisesRegex(ssl.CertificateError,
|
||||
"hostname '127.0.0.1' doesn't match 'localhost'"):
|
||||
self.loop.run_until_complete(f_c)
|
||||
|
||||
# close connection
|
||||
proto.transport.close()
|
||||
server.close()
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def test_create_server_ssl_verified(self):
|
||||
proto = None
|
||||
|
||||
def factory():
|
||||
nonlocal proto
|
||||
proto = MyProto(loop=self.loop)
|
||||
return proto
|
||||
|
||||
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
|
||||
|
||||
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext_client.options |= ssl.OP_NO_SSLv2
|
||||
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
|
||||
sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
|
||||
if hasattr(sslcontext_client, 'check_hostname'):
|
||||
sslcontext_client.check_hostname = True
|
||||
|
||||
# Connection succeeds with correct CA and server hostname.
|
||||
f_c = self.loop.create_connection(MyProto, host, port,
|
||||
ssl=sslcontext_client,
|
||||
server_hostname='localhost')
|
||||
client, pr = self.loop.run_until_complete(f_c)
|
||||
|
||||
# close connection
|
||||
proto.transport.close()
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
def test_create_server_sock(self):
|
||||
proto = futures.Future(loop=self.loop)
|
||||
|
||||
|
|
@ -688,7 +794,7 @@ class EventLoopTestsMixin:
|
|||
|
||||
server.close()
|
||||
|
||||
@unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled')
|
||||
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled')
|
||||
def test_create_server_dual_stack(self):
|
||||
f_proto = futures.Future(loop=self.loop)
|
||||
|
||||
|
|
@ -700,7 +806,7 @@ class EventLoopTestsMixin:
|
|||
try_count = 0
|
||||
while True:
|
||||
try:
|
||||
port = find_unused_port()
|
||||
port = support.find_unused_port()
|
||||
f = self.loop.create_server(TestMyProto, host=None, port=port)
|
||||
server = self.loop.run_until_complete(f)
|
||||
except OSError as ex:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue