Issue #26470: Port ssl and hashlib module to OpenSSL 1.1.0.

This commit is contained in:
Christian Heimes 2016-09-05 23:19:05 +02:00
parent b3b7a5a16b
commit 598894ff48
6 changed files with 394 additions and 165 deletions

View file

@ -23,6 +23,9 @@ ssl = support.import_module("ssl")
PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
HOST = support.HOST
IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
IS_OPENSSL_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
def data_file(*name):
return os.path.join(os.path.dirname(__file__), *name)
@ -143,8 +146,8 @@ class BasicSocketTests(unittest.TestCase):
def test_str_for_enums(self):
# Make sure that the PROTOCOL_* constants have enum-like string
# reprs.
proto = ssl.PROTOCOL_SSLv23
self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_SSLv23')
proto = ssl.PROTOCOL_TLS
self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_TLS')
ctx = ssl.SSLContext(proto)
self.assertIs(ctx.protocol, proto)
@ -312,8 +315,8 @@ class BasicSocketTests(unittest.TestCase):
self.assertGreaterEqual(status, 0)
self.assertLessEqual(status, 15)
# Version string as returned by {Open,Libre}SSL, the format might change
if "LibreSSL" in s:
self.assertTrue(s.startswith("LibreSSL {:d}.{:d}".format(major, minor)),
if IS_LIBRESSL:
self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
(s, t, hex(n)))
else:
self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
@ -790,7 +793,8 @@ class ContextTests(unittest.TestCase):
def test_constructor(self):
for protocol in PROTOCOLS:
ssl.SSLContext(protocol)
self.assertRaises(TypeError, ssl.SSLContext)
ctx = ssl.SSLContext()
self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
self.assertRaises(ValueError, ssl.SSLContext, -1)
self.assertRaises(ValueError, ssl.SSLContext, 42)
@ -811,15 +815,15 @@ class ContextTests(unittest.TestCase):
def test_options(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
# OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
self.assertEqual(ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3,
ctx.options)
default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
if not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0):
default |= ssl.OP_NO_COMPRESSION
self.assertEqual(default, ctx.options)
ctx.options |= ssl.OP_NO_TLSv1
self.assertEqual(ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1,
ctx.options)
self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
if can_clear_options():
ctx.options = (ctx.options & ~ssl.OP_NO_SSLv2) | ssl.OP_NO_TLSv1
self.assertEqual(ssl.OP_ALL | ssl.OP_NO_TLSv1 | ssl.OP_NO_SSLv3,
ctx.options)
ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
self.assertEqual(default, ctx.options)
ctx.options = 0
# Ubuntu has OP_NO_SSLv3 forced on by default
self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
@ -1155,6 +1159,7 @@ class ContextTests(unittest.TestCase):
self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
@unittest.skipIf(sys.platform == "win32", "not-Windows specific")
@unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
def test_load_default_certs_env(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
with support.EnvironmentVarGuard() as env:
@ -1750,13 +1755,13 @@ class NetworkedBIOTests(unittest.TestCase):
sslobj = ctx.wrap_bio(incoming, outgoing, False, REMOTE_HOST)
self.assertIs(sslobj._sslobj.owner, sslobj)
self.assertIsNone(sslobj.cipher())
self.assertIsNone(sslobj.shared_ciphers())
self.assertIsNotNone(sslobj.shared_ciphers())
self.assertRaises(ValueError, sslobj.getpeercert)
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
self.assertTrue(sslobj.cipher())
self.assertIsNone(sslobj.shared_ciphers())
self.assertIsNotNone(sslobj.shared_ciphers())
self.assertTrue(sslobj.getpeercert())
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
self.assertTrue(sslobj.get_channel_binding('tls-unique'))
@ -2993,7 +2998,7 @@ else:
with context.wrap_socket(socket.socket()) as s:
self.assertIs(s.version(), None)
s.connect((HOST, server.port))
self.assertEqual(s.version(), "TLSv1")
self.assertEqual(s.version(), 'TLSv1')
self.assertIs(s.version(), None)
@unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
@ -3135,24 +3140,36 @@ else:
(['http/3.0', 'http/4.0'], None)
]
for client_protocols, expected in protocol_tests:
server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
server_context.load_cert_chain(CERTFILE)
server_context.set_alpn_protocols(server_protocols)
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
client_context.load_cert_chain(CERTFILE)
client_context.set_alpn_protocols(client_protocols)
stats = server_params_test(client_context, server_context,
chatty=True, connectionchatty=True)
msg = "failed trying %s (s) and %s (c).\n" \
"was expecting %s, but got %%s from the %%s" \
% (str(server_protocols), str(client_protocols),
str(expected))
client_result = stats['client_alpn_protocol']
self.assertEqual(client_result, expected, msg % (client_result, "client"))
server_result = stats['server_alpn_protocols'][-1] \
if len(stats['server_alpn_protocols']) else 'nothing'
self.assertEqual(server_result, expected, msg % (server_result, "server"))
try:
stats = server_params_test(client_context,
server_context,
chatty=True,
connectionchatty=True)
except ssl.SSLError as e:
stats = e
if expected is None and IS_OPENSSL_1_1:
# OpenSSL 1.1.0 raises handshake error
self.assertIsInstance(stats, ssl.SSLError)
else:
msg = "failed trying %s (s) and %s (c).\n" \
"was expecting %s, but got %%s from the %%s" \
% (str(server_protocols), str(client_protocols),
str(expected))
client_result = stats['client_alpn_protocol']
self.assertEqual(client_result, expected,
msg % (client_result, "client"))
server_result = stats['server_alpn_protocols'][-1] \
if len(stats['server_alpn_protocols']) else 'nothing'
self.assertEqual(server_result, expected,
msg % (server_result, "server"))
def test_selected_npn_protocol(self):
# selected_npn_protocol() is None unless NPN is used
@ -3300,13 +3317,23 @@ else:
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
client_context.verify_mode = ssl.CERT_REQUIRED
client_context.load_verify_locations(SIGNING_CA)
client_context.set_ciphers("RC4")
server_context.set_ciphers("AES:RC4")
if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
client_context.set_ciphers("AES128:AES256")
server_context.set_ciphers("AES256")
alg1 = "AES256"
alg2 = "AES-256"
else:
client_context.set_ciphers("AES:3DES")
server_context.set_ciphers("3DES")
alg1 = "3DES"
alg2 = "DES-CBC3"
stats = server_params_test(client_context, server_context)
ciphers = stats['server_shared_ciphers'][0]
self.assertGreater(len(ciphers), 0)
for name, tls_version, bits in ciphers:
self.assertIn("RC4", name.split("-"))
if not alg1 in name.split("-") and alg2 not in name:
self.fail(name)
def test_read_write_after_close_raises_valuerror(self):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)