Issue #19509: Finish implementation of check_hostname

The new asyncio package now supports the new feature and comes with additional tests for SSL.
This commit is contained in:
Christian Heimes 2013-12-06 00:23:13 +01:00
parent 8ff6f3e895
commit 6d8c1abb00
8 changed files with 319 additions and 57 deletions

View file

@ -17,7 +17,7 @@ import time
import errno
import unittest
import unittest.mock
from test.support import find_unused_port, IPV6_ENABLED
from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR
from asyncio import futures
@ -30,10 +30,27 @@ from asyncio import test_utils
from asyncio import locks
def data_file(filename):
if hasattr(support, 'TEST_HOME_DIR'):
fullname = os.path.join(support.TEST_HOME_DIR, filename)
if os.path.isfile(fullname):
return fullname
fullname = os.path.join(os.path.dirname(__file__), filename)
if os.path.isfile(fullname):
return fullname
raise FileNotFoundError(filename)
ONLYCERT = data_file('ssl_cert.pem')
ONLYKEY = data_file('ssl_key.pem')
SIGNED_CERTFILE = data_file('keycert3.pem')
SIGNING_CA = data_file('pycacert.pem')
class MyProto(protocols.Protocol):
done = None
def __init__(self, loop=None):
self.transport = None
self.state = 'INITIAL'
self.nbytes = 0
if loop is not None:
@ -523,7 +540,7 @@ class EventLoopTestsMixin:
def test_create_connection_local_addr(self):
with test_utils.run_test_server() as httpd:
port = find_unused_port()
port = support.find_unused_port()
f = self.loop.create_connection(
lambda: MyProto(loop=self.loop),
*httpd.address, local_addr=(httpd.address[0], port))
@ -587,6 +604,20 @@ class EventLoopTestsMixin:
# close server
server.close()
def _make_ssl_server(self, factory, certfile, keyfile=None):
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.load_cert_chain(certfile, keyfile)
f = self.loop.create_server(
factory, '127.0.0.1', 0, ssl=sslcontext)
server = self.loop.run_until_complete(f)
sock = server.sockets[0]
host, port = sock.getsockname()
self.assertEqual(host, '127.0.0.1')
return server, host, port
@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl(self):
proto = None
@ -602,19 +633,7 @@ class EventLoopTestsMixin:
proto = MyProto(loop=self.loop)
return proto
here = os.path.dirname(__file__)
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.load_cert_chain(
certfile=os.path.join(here, 'sample.crt'),
keyfile=os.path.join(here, 'sample.key'))
f = self.loop.create_server(
factory, '127.0.0.1', 0, ssl=sslcontext)
server = self.loop.run_until_complete(f)
sock = server.sockets[0]
host, port = sock.getsockname()
self.assertEqual(host, '127.0.0.1')
server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY)
f_c = self.loop.create_connection(ClientMyProto, host, port,
ssl=test_utils.dummy_ssl_context())
@ -646,6 +665,93 @@ class EventLoopTestsMixin:
# stop serving
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl_verify_failed(self):
proto = None
def factory():
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
if hasattr(sslcontext_client, 'check_hostname'):
sslcontext_client.check_hostname = True
# no CA loaded
f_c = self.loop.create_connection(MyProto, host, port,
ssl=sslcontext_client)
with self.assertRaisesRegex(ssl.SSLError,
'certificate verify failed '):
self.loop.run_until_complete(f_c)
# close connection
self.assertIsNone(proto.transport)
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl_match_failed(self):
proto = None
def factory():
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
sslcontext_client.load_verify_locations(
cafile=SIGNING_CA)
if hasattr(sslcontext_client, 'check_hostname'):
sslcontext_client.check_hostname = True
# incorrect server_hostname
f_c = self.loop.create_connection(MyProto, host, port,
ssl=sslcontext_client)
with self.assertRaisesRegex(ssl.CertificateError,
"hostname '127.0.0.1' doesn't match 'localhost'"):
self.loop.run_until_complete(f_c)
# close connection
proto.transport.close()
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl_verified(self):
proto = None
def factory():
nonlocal proto
proto = MyProto(loop=self.loop)
return proto
server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
sslcontext_client.verify_mode = ssl.CERT_REQUIRED
sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
if hasattr(sslcontext_client, 'check_hostname'):
sslcontext_client.check_hostname = True
# Connection succeeds with correct CA and server hostname.
f_c = self.loop.create_connection(MyProto, host, port,
ssl=sslcontext_client,
server_hostname='localhost')
client, pr = self.loop.run_until_complete(f_c)
# close connection
proto.transport.close()
client.close()
server.close()
def test_create_server_sock(self):
proto = futures.Future(loop=self.loop)
@ -688,7 +794,7 @@ class EventLoopTestsMixin:
server.close()
@unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled')
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled')
def test_create_server_dual_stack(self):
f_proto = futures.Future(loop=self.loop)
@ -700,7 +806,7 @@ class EventLoopTestsMixin:
try_count = 0
while True:
try:
port = find_unused_port()
port = support.find_unused_port()
f = self.loop.create_server(TestMyProto, host=None, port=port)
server = self.loop.run_until_complete(f)
except OSError as ex: