[bpo-28414] Make all hostnames in SSL module IDN A-labels (GH-5128)

Previously, the ssl module stored international domain names (IDNs)
as U-labels. This is problematic for a number of reasons -- for
example, it made it impossible for users to use a different version
of IDNA than the one built into Python.

After this change, we always convert to A-labels as soon as possible,
and use them for all internal processing. In particular, server_hostname
attribute is now an A-label, and on the server side there's a new
sni_callback that receives the SNI servername as an A-label rather than
a U-label.
This commit is contained in:
Christian Heimes 2018-02-24 02:35:08 +01:00 committed by Nathaniel J. Smith
parent 82ab13d756
commit 11a1493bc4
7 changed files with 163 additions and 111 deletions

View file

@ -355,13 +355,20 @@ class SSLContext(_SSLContext):
self = _SSLContext.__new__(cls, protocol)
return self
def __init__(self, protocol=PROTOCOL_TLS):
self.protocol = protocol
def _encode_hostname(self, hostname):
if hostname is None:
return None
elif isinstance(hostname, str):
return hostname.encode('idna').decode('ascii')
else:
return hostname.decode('ascii')
def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None, session=None):
# SSLSocket class handles server_hostname encoding before it calls
# ctx._wrap_socket()
return self.sslsocket_class(
sock=sock,
server_side=server_side,
@ -374,8 +381,12 @@ class SSLContext(_SSLContext):
def wrap_bio(self, incoming, outgoing, server_side=False,
server_hostname=None, session=None):
sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
server_hostname=server_hostname)
# Need to encode server_hostname here because _wrap_bio() can only
# handle ASCII str.
sslobj = self._wrap_bio(
incoming, outgoing, server_side=server_side,
server_hostname=self._encode_hostname(server_hostname)
)
return self.sslobject_class(sslobj, session=session)
def set_npn_protocols(self, npn_protocols):
@ -389,6 +400,19 @@ class SSLContext(_SSLContext):
self._set_npn_protocols(protos)
def set_servername_callback(self, server_name_callback):
if server_name_callback is None:
self.sni_callback = None
else:
if not callable(server_name_callback):
raise TypeError("not a callable object")
def shim_cb(sslobj, servername, sslctx):
servername = self._encode_hostname(servername)
return server_name_callback(sslobj, servername, sslctx)
self.sni_callback = shim_cb
def set_alpn_protocols(self, alpn_protocols):
protos = bytearray()
for protocol in alpn_protocols:
@ -447,6 +471,10 @@ class SSLContext(_SSLContext):
def hostname_checks_common_name(self):
return True
@property
def protocol(self):
return _SSLMethod(super().protocol)
@property
def verify_flags(self):
return VerifyFlags(super().verify_flags)
@ -749,7 +777,7 @@ class SSLSocket(socket):
raise ValueError("check_hostname requires server_hostname")
self._session = _session
self.server_side = server_side
self.server_hostname = server_hostname
self.server_hostname = self._context._encode_hostname(server_hostname)
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
if sock is not None:
@ -781,7 +809,7 @@ class SSLSocket(socket):
# create the SSL object
try:
sslobj = self._context._wrap_socket(self, server_side,
server_hostname)
self.server_hostname)
self._sslobj = SSLObject(sslobj, owner=self,
session=self._session)
if do_handshake_on_connect: