mirror of
https://github.com/python/cpython.git
synced 2025-12-15 21:44:50 +00:00
Issue #14204: The ssl module now has support for the Next Protocol Negotiation extension, if available in the underlying OpenSSL library.
Patch by Colin Marc.
This commit is contained in:
parent
a966c6fddb
commit
d5d17eb653
6 changed files with 228 additions and 8 deletions
27
Lib/ssl.py
27
Lib/ssl.py
|
|
@ -90,7 +90,7 @@ from _ssl import (
|
|||
SSL_ERROR_EOF,
|
||||
SSL_ERROR_INVALID_ERROR_CODE,
|
||||
)
|
||||
from _ssl import HAS_SNI, HAS_ECDH
|
||||
from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN
|
||||
from _ssl import (PROTOCOL_SSLv3, PROTOCOL_SSLv23,
|
||||
PROTOCOL_TLSv1)
|
||||
from _ssl import _OPENSSL_API_VERSION
|
||||
|
|
@ -209,6 +209,17 @@ class SSLContext(_SSLContext):
|
|||
server_hostname=server_hostname,
|
||||
_context=self)
|
||||
|
||||
def set_npn_protocols(self, npn_protocols):
|
||||
protos = bytearray()
|
||||
for protocol in npn_protocols:
|
||||
b = bytes(protocol, 'ascii')
|
||||
if len(b) == 0 or len(b) > 255:
|
||||
raise SSLError('NPN protocols must be 1 to 255 in length')
|
||||
protos.append(len(b))
|
||||
protos.extend(b)
|
||||
|
||||
self._set_npn_protocols(protos)
|
||||
|
||||
|
||||
class SSLSocket(socket):
|
||||
"""This class implements a subtype of socket.socket that wraps
|
||||
|
|
@ -220,7 +231,7 @@ class SSLSocket(socket):
|
|||
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
|
||||
do_handshake_on_connect=True,
|
||||
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
|
||||
suppress_ragged_eofs=True, ciphers=None,
|
||||
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
|
||||
server_hostname=None,
|
||||
_context=None):
|
||||
|
||||
|
|
@ -240,6 +251,8 @@ class SSLSocket(socket):
|
|||
self.context.load_verify_locations(ca_certs)
|
||||
if certfile:
|
||||
self.context.load_cert_chain(certfile, keyfile)
|
||||
if npn_protocols:
|
||||
self.context.set_npn_protocols(npn_protocols)
|
||||
if ciphers:
|
||||
self.context.set_ciphers(ciphers)
|
||||
self.keyfile = keyfile
|
||||
|
|
@ -340,6 +353,13 @@ class SSLSocket(socket):
|
|||
self._checkClosed()
|
||||
return self._sslobj.peer_certificate(binary_form)
|
||||
|
||||
def selected_npn_protocol(self):
|
||||
self._checkClosed()
|
||||
if not self._sslobj or not _ssl.HAS_NPN:
|
||||
return None
|
||||
else:
|
||||
return self._sslobj.selected_npn_protocol()
|
||||
|
||||
def cipher(self):
|
||||
self._checkClosed()
|
||||
if not self._sslobj:
|
||||
|
|
@ -568,7 +588,8 @@ def wrap_socket(sock, keyfile=None, certfile=None,
|
|||
server_side=False, cert_reqs=CERT_NONE,
|
||||
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
|
||||
do_handshake_on_connect=True,
|
||||
suppress_ragged_eofs=True, ciphers=None):
|
||||
suppress_ragged_eofs=True,
|
||||
ciphers=None):
|
||||
|
||||
return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
|
||||
server_side=server_side, cert_reqs=cert_reqs,
|
||||
|
|
|
|||
|
|
@ -879,6 +879,7 @@ else:
|
|||
try:
|
||||
self.sslconn = self.server.context.wrap_socket(
|
||||
self.sock, server_side=True)
|
||||
self.server.selected_protocols.append(self.sslconn.selected_npn_protocol())
|
||||
except ssl.SSLError as e:
|
||||
# XXX Various errors can have happened here, for example
|
||||
# a mismatching protocol version, an invalid certificate,
|
||||
|
|
@ -901,6 +902,8 @@ else:
|
|||
cipher = self.sslconn.cipher()
|
||||
if support.verbose and self.server.chatty:
|
||||
sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
|
||||
sys.stdout.write(" server: selected protocol is now "
|
||||
+ str(self.sslconn.selected_npn_protocol()) + "\n")
|
||||
return True
|
||||
|
||||
def read(self):
|
||||
|
|
@ -979,7 +982,7 @@ else:
|
|||
def __init__(self, certificate=None, ssl_version=None,
|
||||
certreqs=None, cacerts=None,
|
||||
chatty=True, connectionchatty=False, starttls_server=False,
|
||||
ciphers=None, context=None):
|
||||
npn_protocols=None, ciphers=None, context=None):
|
||||
if context:
|
||||
self.context = context
|
||||
else:
|
||||
|
|
@ -992,6 +995,8 @@ else:
|
|||
self.context.load_verify_locations(cacerts)
|
||||
if certificate:
|
||||
self.context.load_cert_chain(certificate)
|
||||
if npn_protocols:
|
||||
self.context.set_npn_protocols(npn_protocols)
|
||||
if ciphers:
|
||||
self.context.set_ciphers(ciphers)
|
||||
self.chatty = chatty
|
||||
|
|
@ -1001,6 +1006,7 @@ else:
|
|||
self.port = support.bind_port(self.sock)
|
||||
self.flag = None
|
||||
self.active = False
|
||||
self.selected_protocols = []
|
||||
self.conn_errors = []
|
||||
threading.Thread.__init__(self)
|
||||
self.daemon = True
|
||||
|
|
@ -1195,6 +1201,7 @@ else:
|
|||
Launch a server, connect a client to it and try various reads
|
||||
and writes.
|
||||
"""
|
||||
stats = {}
|
||||
server = ThreadedEchoServer(context=server_context,
|
||||
chatty=chatty,
|
||||
connectionchatty=False)
|
||||
|
|
@ -1220,12 +1227,14 @@ else:
|
|||
if connectionchatty:
|
||||
if support.verbose:
|
||||
sys.stdout.write(" client: closing connection.\n")
|
||||
stats = {
|
||||
stats.update({
|
||||
'compression': s.compression(),
|
||||
'cipher': s.cipher(),
|
||||
}
|
||||
'client_npn_protocol': s.selected_npn_protocol()
|
||||
})
|
||||
s.close()
|
||||
return stats
|
||||
stats['server_npn_protocols'] = server.selected_protocols
|
||||
return stats
|
||||
|
||||
def try_protocol_combo(server_protocol, client_protocol, expect_success,
|
||||
certsreqs=None, server_options=0, client_options=0):
|
||||
|
|
@ -1853,6 +1862,43 @@ else:
|
|||
if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
|
||||
self.fail("Non-DH cipher: " + cipher[0])
|
||||
|
||||
def test_selected_npn_protocol(self):
|
||||
# selected_npn_protocol() is None unless NPN is used
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
||||
context.load_cert_chain(CERTFILE)
|
||||
stats = server_params_test(context, context,
|
||||
chatty=True, connectionchatty=True)
|
||||
self.assertIs(stats['client_npn_protocol'], None)
|
||||
|
||||
@unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
|
||||
def test_npn_protocols(self):
|
||||
server_protocols = ['http/1.1', 'spdy/2']
|
||||
protocol_tests = [
|
||||
(['http/1.1', 'spdy/2'], 'http/1.1'),
|
||||
(['spdy/2', 'http/1.1'], 'http/1.1'),
|
||||
(['spdy/2', 'test'], 'spdy/2'),
|
||||
(['abc', 'def'], 'abc')
|
||||
]
|
||||
for client_protocols, expected in protocol_tests:
|
||||
server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
||||
server_context.load_cert_chain(CERTFILE)
|
||||
server_context.set_npn_protocols(server_protocols)
|
||||
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
||||
client_context.load_cert_chain(CERTFILE)
|
||||
client_context.set_npn_protocols(client_protocols)
|
||||
stats = server_params_test(client_context, server_context,
|
||||
chatty=True, connectionchatty=True)
|
||||
|
||||
msg = "failed trying %s (s) and %s (c).\n" \
|
||||
"was expecting %s, but got %%s from the %%s" \
|
||||
% (str(server_protocols), str(client_protocols),
|
||||
str(expected))
|
||||
client_result = stats['client_npn_protocol']
|
||||
self.assertEqual(client_result, expected, msg % (client_result, "client"))
|
||||
server_result = stats['server_npn_protocols'][-1] \
|
||||
if len(stats['server_npn_protocols']) else 'nothing'
|
||||
self.assertEqual(server_result, expected, msg % (server_result, "server"))
|
||||
|
||||
|
||||
def test_main(verbose=False):
|
||||
if support.verbose:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue