Issue #21965: Add support for in-memory SSL to the ssl module.

Patch by Geert Jansen.
This commit is contained in:
Antoine Pitrou 2014-10-05 20:41:53 +02:00
parent 414e15a88d
commit b1fdf47ff5
5 changed files with 926 additions and 102 deletions

View file

@ -97,7 +97,7 @@ from enum import Enum as _Enum, IntEnum as _IntEnum
import _ssl # if we can't import it, let the error propagate
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
from _ssl import _SSLContext
from _ssl import _SSLContext, MemoryBIO
from _ssl import (
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
SSLSyscallError, SSLEOFError,
@ -352,6 +352,12 @@ class SSLContext(_SSLContext):
server_hostname=server_hostname,
_context=self)
def wrap_bio(self, incoming, outgoing, server_side=False,
server_hostname=None):
sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
server_hostname=server_hostname)
return SSLObject(sslobj)
def set_npn_protocols(self, npn_protocols):
protos = bytearray()
for protocol in npn_protocols:
@ -469,6 +475,129 @@ def _create_stdlib_context(protocol=PROTOCOL_SSLv23, *, cert_reqs=None,
return context
class SSLObject:
"""This class implements an interface on top of a low-level SSL object as
implemented by OpenSSL. This object captures the state of an SSL connection
but does not provide any network IO itself. IO needs to be performed
through separate "BIO" objects which are OpenSSL's IO abstraction layer.
This class does not have a public constructor. Instances are returned by
``SSLContext.wrap_bio``. This class is typically used by framework authors
that want to implement asynchronous IO for SSL through memory buffers.
When compared to ``SSLSocket``, this object lacks the following features:
* Any form of network IO incluging methods such as ``recv`` and ``send``.
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
"""
def __init__(self, sslobj, owner=None):
self._sslobj = sslobj
# Note: _sslobj takes a weak reference to owner
self._sslobj.owner = owner or self
@property
def context(self):
"""The SSLContext that is currently in use."""
return self._sslobj.context
@context.setter
def context(self, ctx):
self._sslobj.context = ctx
@property
def server_side(self):
"""Whether this is a server-side socket."""
return self._sslobj.server_side
@property
def server_hostname(self):
"""The currently set server hostname (for SNI), or ``None`` if no
server hostame is set."""
return self._sslobj.server_hostname
def read(self, len=0, buffer=None):
"""Read up to 'len' bytes from the SSL object and return them.
If 'buffer' is provided, read into this buffer and return the number of
bytes read.
"""
if buffer is not None:
v = self._sslobj.read(len, buffer)
else:
v = self._sslobj.read(len or 1024)
return v
def write(self, data):
"""Write 'data' to the SSL object and return the number of bytes
written.
The 'data' argument must support the buffer interface.
"""
return self._sslobj.write(data)
def getpeercert(self, binary_form=False):
"""Returns a formatted version of the data in the certificate provided
by the other end of the SSL channel.
Return None if no certificate was provided, {} if a certificate was
provided, but not validated.
"""
return self._sslobj.peer_certificate(binary_form)
def selected_npn_protocol(self):
"""Return the currently selected NPN protocol as a string, or ``None``
if a next protocol was not negotiated or if NPN is not supported by one
of the peers."""
if _ssl.HAS_NPN:
return self._sslobj.selected_npn_protocol()
def cipher(self):
"""Return the currently selected cipher as a 3-tuple ``(name,
ssl_version, secret_bits)``."""
return self._sslobj.cipher()
def compression(self):
"""Return the current compression algorithm in use, or ``None`` if
compression was not negotiated or not supported by one of the peers."""
return self._sslobj.compression()
def pending(self):
"""Return the number of bytes that can be read immediately."""
return self._sslobj.pending()
def do_handshake(self, block=False):
"""Start the SSL/TLS handshake."""
self._sslobj.do_handshake()
if self.context.check_hostname:
if not self.server_hostname:
raise ValueError("check_hostname needs server_hostname "
"argument")
match_hostname(self.getpeercert(), self.server_hostname)
def unwrap(self):
"""Start the SSL shutdown handshake."""
return self._sslobj.shutdown()
def get_channel_binding(self, cb_type="tls-unique"):
"""Get channel binding data for current connection. Raise ValueError
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake)."""
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError("Unsupported channel binding type")
if cb_type != "tls-unique":
raise NotImplementedError(
"{0} channel binding type not implemented"
.format(cb_type))
return self._sslobj.tls_unique_cb()
def version(self):
"""Return a string identifying the protocol version used by the
current SSL channel. """
return self._sslobj.version()
class SSLSocket(socket):
"""This class implements a subtype of socket.socket that wraps
the underlying OS socket in an SSL context when necessary, and
@ -556,8 +685,9 @@ class SSLSocket(socket):
if connected:
# create the SSL object
try:
self._sslobj = self._context._wrap_socket(self, server_side,
server_hostname)
sslobj = self._context._wrap_socket(self, server_side,
server_hostname)
self._sslobj = SSLObject(sslobj, owner=self)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
@ -602,11 +732,7 @@ class SSLSocket(socket):
if not self._sslobj:
raise ValueError("Read on closed or unwrapped SSL socket.")
try:
if buffer is not None:
v = self._sslobj.read(len, buffer)
else:
v = self._sslobj.read(len or 1024)
return v
return self._sslobj.read(len, buffer)
except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is not None:
@ -633,7 +759,7 @@ class SSLSocket(socket):
self._checkClosed()
self._check_connected()
return self._sslobj.peer_certificate(binary_form)
return self._sslobj.getpeercert(binary_form)
def selected_npn_protocol(self):
self._checkClosed()
@ -773,7 +899,7 @@ class SSLSocket(socket):
def unwrap(self):
if self._sslobj:
s = self._sslobj.shutdown()
s = self._sslobj.unwrap()
self._sslobj = None
return s
else:
@ -794,12 +920,6 @@ class SSLSocket(socket):
finally:
self.settimeout(timeout)
if self.context.check_hostname:
if not self.server_hostname:
raise ValueError("check_hostname needs server_hostname "
"argument")
match_hostname(self.getpeercert(), self.server_hostname)
def _real_connect(self, addr, connect_ex):
if self.server_side:
raise ValueError("can't connect in server-side mode")
@ -807,7 +927,8 @@ class SSLSocket(socket):
# connected at the time of the call. We connect it, then wrap it.
if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!")
self._sslobj = self.context._wrap_socket(self, False, self.server_hostname)
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
self._sslobj = SSLObject(sslobj, owner=self)
try:
if connect_ex:
rc = socket.connect_ex(self, addr)
@ -850,15 +971,9 @@ class SSLSocket(socket):
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake).
"""
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError("Unsupported channel binding type")
if cb_type != "tls-unique":
raise NotImplementedError(
"{0} channel binding type not implemented"
.format(cb_type))
if self._sslobj is None:
return None
return self._sslobj.tls_unique_cb()
return self._sslobj.get_channel_binding(cb_type)
def version(self):
"""