mirror of
https://github.com/python/cpython.git
synced 2025-07-07 19:35:27 +00:00
bpo-24334: Cleanup SSLSocket (#5252)
* The SSLSocket is no longer implemented on top of SSLObject to avoid an extra level of indirection. * Owner and session are now handled in the internal constructor. * _ssl._SSLSocket now uses the same method names as SSLSocket and SSLObject. * Channel binding type check is now handled in C code. Channel binding is always available. The patch also changes the signature of SSLObject.__init__(). In my opinion it's fine. A SSLObject is not a user-constructable object. SSLContext.wrap_bio() is the only valid factory.
This commit is contained in:
parent
b18f8bc1a7
commit
141c5e8c24
5 changed files with 183 additions and 117 deletions
116
Lib/ssl.py
116
Lib/ssl.py
|
@ -166,10 +166,7 @@ import warnings
|
|||
|
||||
socket_error = OSError # keep that public name in module namespace
|
||||
|
||||
if _ssl.HAS_TLS_UNIQUE:
|
||||
CHANNEL_BINDING_TYPES = ['tls-unique']
|
||||
else:
|
||||
CHANNEL_BINDING_TYPES = []
|
||||
CHANNEL_BINDING_TYPES = ['tls-unique']
|
||||
|
||||
HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT')
|
||||
|
||||
|
@ -407,11 +404,11 @@ class SSLContext(_SSLContext):
|
|||
server_hostname=None, session=None):
|
||||
# Need to encode server_hostname here because _wrap_bio() can only
|
||||
# handle ASCII str.
|
||||
sslobj = self._wrap_bio(
|
||||
return self.sslobject_class(
|
||||
incoming, outgoing, server_side=server_side,
|
||||
server_hostname=self._encode_hostname(server_hostname)
|
||||
server_hostname=self._encode_hostname(server_hostname),
|
||||
session=session, _context=self,
|
||||
)
|
||||
return self.sslobject_class(sslobj, session=session)
|
||||
|
||||
def set_npn_protocols(self, npn_protocols):
|
||||
protos = bytearray()
|
||||
|
@ -616,12 +613,13 @@ class SSLObject:
|
|||
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
|
||||
"""
|
||||
|
||||
def __init__(self, sslobj, owner=None, session=None):
|
||||
self._sslobj = sslobj
|
||||
# Note: _sslobj takes a weak reference to owner
|
||||
self._sslobj.owner = owner or self
|
||||
if session is not None:
|
||||
self._sslobj.session = session
|
||||
def __init__(self, incoming, outgoing, server_side=False,
|
||||
server_hostname=None, session=None, _context=None):
|
||||
self._sslobj = _context._wrap_bio(
|
||||
incoming, outgoing, server_side=server_side,
|
||||
server_hostname=server_hostname,
|
||||
owner=self, session=session
|
||||
)
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
|
@ -684,7 +682,7 @@ class SSLObject:
|
|||
Return None if no certificate was provided, {} if a certificate was
|
||||
provided, but not validated.
|
||||
"""
|
||||
return self._sslobj.peer_certificate(binary_form)
|
||||
return self._sslobj.getpeercert(binary_form)
|
||||
|
||||
def selected_npn_protocol(self):
|
||||
"""Return the currently selected NPN protocol as a string, or ``None``
|
||||
|
@ -732,13 +730,7 @@ class SSLObject:
|
|||
"""Get channel binding data for current connection. Raise ValueError
|
||||
if the requested `cb_type` is not supported. Return bytes of the data
|
||||
or None if the data is not available (e.g. before the handshake)."""
|
||||
if cb_type not in CHANNEL_BINDING_TYPES:
|
||||
raise ValueError("Unsupported channel binding type")
|
||||
if cb_type != "tls-unique":
|
||||
raise NotImplementedError(
|
||||
"{0} channel binding type not implemented"
|
||||
.format(cb_type))
|
||||
return self._sslobj.tls_unique_cb()
|
||||
return self._sslobj.get_channel_binding(cb_type)
|
||||
|
||||
def version(self):
|
||||
"""Return a string identifying the protocol version used by the
|
||||
|
@ -832,10 +824,10 @@ class SSLSocket(socket):
|
|||
if connected:
|
||||
# create the SSL object
|
||||
try:
|
||||
sslobj = self._context._wrap_socket(self, server_side,
|
||||
self.server_hostname)
|
||||
self._sslobj = SSLObject(sslobj, owner=self,
|
||||
session=self._session)
|
||||
self._sslobj = self._context._wrap_socket(
|
||||
self, server_side, self.server_hostname,
|
||||
owner=self, session=self._session,
|
||||
)
|
||||
if do_handshake_on_connect:
|
||||
timeout = self.gettimeout()
|
||||
if timeout == 0.0:
|
||||
|
@ -895,10 +887,13 @@ class SSLSocket(socket):
|
|||
Return zero-length string on EOF."""
|
||||
|
||||
self._checkClosed()
|
||||
if not self._sslobj:
|
||||
if self._sslobj is None:
|
||||
raise ValueError("Read on closed or unwrapped SSL socket.")
|
||||
try:
|
||||
return self._sslobj.read(len, buffer)
|
||||
if buffer is not None:
|
||||
return self._sslobj.read(len, buffer)
|
||||
else:
|
||||
return self._sslobj.read(len)
|
||||
except SSLError as x:
|
||||
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
||||
if buffer is not None:
|
||||
|
@ -913,7 +908,7 @@ class SSLSocket(socket):
|
|||
number of bytes of DATA actually transmitted."""
|
||||
|
||||
self._checkClosed()
|
||||
if not self._sslobj:
|
||||
if self._sslobj is None:
|
||||
raise ValueError("Write on closed or unwrapped SSL socket.")
|
||||
return self._sslobj.write(data)
|
||||
|
||||
|
@ -929,41 +924,42 @@ class SSLSocket(socket):
|
|||
|
||||
def selected_npn_protocol(self):
|
||||
self._checkClosed()
|
||||
if not self._sslobj or not _ssl.HAS_NPN:
|
||||
if self._sslobj is None or not _ssl.HAS_NPN:
|
||||
return None
|
||||
else:
|
||||
return self._sslobj.selected_npn_protocol()
|
||||
|
||||
def selected_alpn_protocol(self):
|
||||
self._checkClosed()
|
||||
if not self._sslobj or not _ssl.HAS_ALPN:
|
||||
if self._sslobj is None or not _ssl.HAS_ALPN:
|
||||
return None
|
||||
else:
|
||||
return self._sslobj.selected_alpn_protocol()
|
||||
|
||||
def cipher(self):
|
||||
self._checkClosed()
|
||||
if not self._sslobj:
|
||||
if self._sslobj is None:
|
||||
return None
|
||||
else:
|
||||
return self._sslobj.cipher()
|
||||
|
||||
def shared_ciphers(self):
|
||||
self._checkClosed()
|
||||
if not self._sslobj:
|
||||
if self._sslobj is None:
|
||||
return None
|
||||
return self._sslobj.shared_ciphers()
|
||||
else:
|
||||
return self._sslobj.shared_ciphers()
|
||||
|
||||
def compression(self):
|
||||
self._checkClosed()
|
||||
if not self._sslobj:
|
||||
if self._sslobj is None:
|
||||
return None
|
||||
else:
|
||||
return self._sslobj.compression()
|
||||
|
||||
def send(self, data, flags=0):
|
||||
self._checkClosed()
|
||||
if self._sslobj:
|
||||
if self._sslobj is not None:
|
||||
if flags != 0:
|
||||
raise ValueError(
|
||||
"non-zero flags not allowed in calls to send() on %s" %
|
||||
|
@ -974,7 +970,7 @@ class SSLSocket(socket):
|
|||
|
||||
def sendto(self, data, flags_or_addr, addr=None):
|
||||
self._checkClosed()
|
||||
if self._sslobj:
|
||||
if self._sslobj is not None:
|
||||
raise ValueError("sendto not allowed on instances of %s" %
|
||||
self.__class__)
|
||||
elif addr is None:
|
||||
|
@ -990,7 +986,7 @@ class SSLSocket(socket):
|
|||
|
||||
def sendall(self, data, flags=0):
|
||||
self._checkClosed()
|
||||
if self._sslobj:
|
||||
if self._sslobj is not None:
|
||||
if flags != 0:
|
||||
raise ValueError(
|
||||
"non-zero flags not allowed in calls to sendall() on %s" %
|
||||
|
@ -1008,15 +1004,15 @@ class SSLSocket(socket):
|
|||
"""Send a file, possibly by using os.sendfile() if this is a
|
||||
clear-text socket. Return the total number of bytes sent.
|
||||
"""
|
||||
if self._sslobj is None:
|
||||
if self._sslobj is not None:
|
||||
return self._sendfile_use_send(file, offset, count)
|
||||
else:
|
||||
# os.sendfile() works with plain sockets only
|
||||
return super().sendfile(file, offset, count)
|
||||
else:
|
||||
return self._sendfile_use_send(file, offset, count)
|
||||
|
||||
def recv(self, buflen=1024, flags=0):
|
||||
self._checkClosed()
|
||||
if self._sslobj:
|
||||
if self._sslobj is not None:
|
||||
if flags != 0:
|
||||
raise ValueError(
|
||||
"non-zero flags not allowed in calls to recv() on %s" %
|
||||
|
@ -1031,7 +1027,7 @@ class SSLSocket(socket):
|
|||
nbytes = len(buffer)
|
||||
elif nbytes is None:
|
||||
nbytes = 1024
|
||||
if self._sslobj:
|
||||
if self._sslobj is not None:
|
||||
if flags != 0:
|
||||
raise ValueError(
|
||||
"non-zero flags not allowed in calls to recv_into() on %s" %
|
||||
|
@ -1042,7 +1038,7 @@ class SSLSocket(socket):
|
|||
|
||||
def recvfrom(self, buflen=1024, flags=0):
|
||||
self._checkClosed()
|
||||
if self._sslobj:
|
||||
if self._sslobj is not None:
|
||||
raise ValueError("recvfrom not allowed on instances of %s" %
|
||||
self.__class__)
|
||||
else:
|
||||
|
@ -1050,7 +1046,7 @@ class SSLSocket(socket):
|
|||
|
||||
def recvfrom_into(self, buffer, nbytes=None, flags=0):
|
||||
self._checkClosed()
|
||||
if self._sslobj:
|
||||
if self._sslobj is not None:
|
||||
raise ValueError("recvfrom_into not allowed on instances of %s" %
|
||||
self.__class__)
|
||||
else:
|
||||
|
@ -1066,7 +1062,7 @@ class SSLSocket(socket):
|
|||
|
||||
def pending(self):
|
||||
self._checkClosed()
|
||||
if self._sslobj:
|
||||
if self._sslobj is not None:
|
||||
return self._sslobj.pending()
|
||||
else:
|
||||
return 0
|
||||
|
@ -1078,7 +1074,7 @@ class SSLSocket(socket):
|
|||
|
||||
def unwrap(self):
|
||||
if self._sslobj:
|
||||
s = self._sslobj.unwrap()
|
||||
s = self._sslobj.shutdown()
|
||||
self._sslobj = None
|
||||
return s
|
||||
else:
|
||||
|
@ -1096,6 +1092,11 @@ class SSLSocket(socket):
|
|||
if timeout == 0.0 and block:
|
||||
self.settimeout(None)
|
||||
self._sslobj.do_handshake()
|
||||
if self.context.check_hostname:
|
||||
if not self.server_hostname:
|
||||
raise ValueError("check_hostname needs server_hostname "
|
||||
"argument")
|
||||
match_hostname(self.getpeercert(), self.server_hostname)
|
||||
finally:
|
||||
self.settimeout(timeout)
|
||||
|
||||
|
@ -1104,11 +1105,12 @@ class SSLSocket(socket):
|
|||
raise ValueError("can't connect in server-side mode")
|
||||
# Here we assume that the socket is client-side, and not
|
||||
# connected at the time of the call. We connect it, then wrap it.
|
||||
if self._connected:
|
||||
if self._connected or self._sslobj is not None:
|
||||
raise ValueError("attempt to connect already-connected SSLSocket!")
|
||||
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
|
||||
self._sslobj = SSLObject(sslobj, owner=self,
|
||||
session=self._session)
|
||||
self._sslobj = self.context._wrap_socket(
|
||||
self, False, self.server_hostname,
|
||||
owner=self, session=self._session
|
||||
)
|
||||
try:
|
||||
if connect_ex:
|
||||
rc = super().connect_ex(addr)
|
||||
|
@ -1151,18 +1153,24 @@ class SSLSocket(socket):
|
|||
if the requested `cb_type` is not supported. Return bytes of the data
|
||||
or None if the data is not available (e.g. before the handshake).
|
||||
"""
|
||||
if self._sslobj is None:
|
||||
if self._sslobj is not None:
|
||||
return self._sslobj.get_channel_binding(cb_type)
|
||||
else:
|
||||
if cb_type not in CHANNEL_BINDING_TYPES:
|
||||
raise ValueError(
|
||||
"{0} channel binding type not implemented".format(cb_type)
|
||||
)
|
||||
return None
|
||||
return self._sslobj.get_channel_binding(cb_type)
|
||||
|
||||
def version(self):
|
||||
"""
|
||||
Return a string identifying the protocol version used by the
|
||||
current SSL channel, or None if there is no established channel.
|
||||
"""
|
||||
if self._sslobj is None:
|
||||
if self._sslobj is not None:
|
||||
return self._sslobj.version()
|
||||
else:
|
||||
return None
|
||||
return self._sslobj.version()
|
||||
|
||||
|
||||
# Python does not support forward declaration of types.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue