mirror of
https://github.com/python/cpython.git
synced 2025-10-24 07:26:11 +00:00

When wrapped, `_SSLProtocolTransport._force_close(exc)` is called just like in the unwrapped scenario `_SelectorTransport._force_close(exc)` or `_ProactorBasePipeTransport._force_close(exc)` would be called, except here the exception needs to be passed through the `SSLProtocol._abort()` method, which didn't accept an exception object. This commit ensures that this path works, in the same way that the uvloop implementation of SSLProto passes on the exception (on which the current implementation of SSLProto is based).
921 lines
31 KiB
Python
921 lines
31 KiB
Python
import collections
|
|
import enum
|
|
import warnings
|
|
try:
|
|
import ssl
|
|
except ImportError: # pragma: no cover
|
|
ssl = None
|
|
|
|
from . import constants
|
|
from . import exceptions
|
|
from . import protocols
|
|
from . import transports
|
|
from .log import logger
|
|
|
|
if ssl is not None:
|
|
SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
|
|
|
|
|
|
class SSLProtocolState(enum.Enum):
|
|
UNWRAPPED = "UNWRAPPED"
|
|
DO_HANDSHAKE = "DO_HANDSHAKE"
|
|
WRAPPED = "WRAPPED"
|
|
FLUSHING = "FLUSHING"
|
|
SHUTDOWN = "SHUTDOWN"
|
|
|
|
|
|
class AppProtocolState(enum.Enum):
|
|
# This tracks the state of app protocol (https://git.io/fj59P):
|
|
#
|
|
# INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
|
|
#
|
|
# * cm: connection_made()
|
|
# * dr: data_received()
|
|
# * er: eof_received()
|
|
# * cl: connection_lost()
|
|
|
|
STATE_INIT = "STATE_INIT"
|
|
STATE_CON_MADE = "STATE_CON_MADE"
|
|
STATE_EOF = "STATE_EOF"
|
|
STATE_CON_LOST = "STATE_CON_LOST"
|
|
|
|
|
|
def _create_transport_context(server_side, server_hostname):
|
|
if server_side:
|
|
raise ValueError('Server side SSL needs a valid SSLContext')
|
|
|
|
# Client side may pass ssl=True to use a default
|
|
# context; in that case the sslcontext passed is None.
|
|
# The default is secure for client connections.
|
|
# Python 3.4+: use up-to-date strong settings.
|
|
sslcontext = ssl.create_default_context()
|
|
if not server_hostname:
|
|
sslcontext.check_hostname = False
|
|
return sslcontext
|
|
|
|
|
|
def add_flowcontrol_defaults(high, low, kb):
|
|
if high is None:
|
|
if low is None:
|
|
hi = kb * 1024
|
|
else:
|
|
lo = low
|
|
hi = 4 * lo
|
|
else:
|
|
hi = high
|
|
if low is None:
|
|
lo = hi // 4
|
|
else:
|
|
lo = low
|
|
|
|
if not hi >= lo >= 0:
|
|
raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
|
|
(hi, lo))
|
|
|
|
return hi, lo
|
|
|
|
|
|
class _SSLProtocolTransport(transports._FlowControlMixin,
|
|
transports.Transport):
|
|
|
|
_start_tls_compatible = True
|
|
_sendfile_compatible = constants._SendfileMode.FALLBACK
|
|
|
|
def __init__(self, loop, ssl_protocol):
|
|
self._loop = loop
|
|
self._ssl_protocol = ssl_protocol
|
|
self._closed = False
|
|
|
|
def get_extra_info(self, name, default=None):
|
|
"""Get optional transport information."""
|
|
return self._ssl_protocol._get_extra_info(name, default)
|
|
|
|
def set_protocol(self, protocol):
|
|
self._ssl_protocol._set_app_protocol(protocol)
|
|
|
|
def get_protocol(self):
|
|
return self._ssl_protocol._app_protocol
|
|
|
|
def is_closing(self):
|
|
return self._closed
|
|
|
|
def close(self):
|
|
"""Close the transport.
|
|
|
|
Buffered data will be flushed asynchronously. No more data
|
|
will be received. After all buffered data is flushed, the
|
|
protocol's connection_lost() method will (eventually) called
|
|
with None as its argument.
|
|
"""
|
|
if not self._closed:
|
|
self._closed = True
|
|
self._ssl_protocol._start_shutdown()
|
|
else:
|
|
self._ssl_protocol = None
|
|
|
|
def __del__(self, _warnings=warnings):
|
|
if not self._closed:
|
|
self._closed = True
|
|
_warnings.warn(
|
|
"unclosed transport <asyncio._SSLProtocolTransport "
|
|
"object>", ResourceWarning)
|
|
|
|
def is_reading(self):
|
|
return not self._ssl_protocol._app_reading_paused
|
|
|
|
def pause_reading(self):
|
|
"""Pause the receiving end.
|
|
|
|
No data will be passed to the protocol's data_received()
|
|
method until resume_reading() is called.
|
|
"""
|
|
self._ssl_protocol._pause_reading()
|
|
|
|
def resume_reading(self):
|
|
"""Resume the receiving end.
|
|
|
|
Data received will once again be passed to the protocol's
|
|
data_received() method.
|
|
"""
|
|
self._ssl_protocol._resume_reading()
|
|
|
|
def set_write_buffer_limits(self, high=None, low=None):
|
|
"""Set the high- and low-water limits for write flow control.
|
|
|
|
These two values control when to call the protocol's
|
|
pause_writing() and resume_writing() methods. If specified,
|
|
the low-water limit must be less than or equal to the
|
|
high-water limit. Neither value can be negative.
|
|
|
|
The defaults are implementation-specific. If only the
|
|
high-water limit is given, the low-water limit defaults to an
|
|
implementation-specific value less than or equal to the
|
|
high-water limit. Setting high to zero forces low to zero as
|
|
well, and causes pause_writing() to be called whenever the
|
|
buffer becomes non-empty. Setting low to zero causes
|
|
resume_writing() to be called only once the buffer is empty.
|
|
Use of zero for either limit is generally sub-optimal as it
|
|
reduces opportunities for doing I/O and computation
|
|
concurrently.
|
|
"""
|
|
self._ssl_protocol._set_write_buffer_limits(high, low)
|
|
self._ssl_protocol._control_app_writing()
|
|
|
|
def get_write_buffer_limits(self):
|
|
return (self._ssl_protocol._outgoing_low_water,
|
|
self._ssl_protocol._outgoing_high_water)
|
|
|
|
def get_write_buffer_size(self):
|
|
"""Return the current size of the write buffers."""
|
|
return self._ssl_protocol._get_write_buffer_size()
|
|
|
|
def set_read_buffer_limits(self, high=None, low=None):
|
|
"""Set the high- and low-water limits for read flow control.
|
|
|
|
These two values control when to call the upstream transport's
|
|
pause_reading() and resume_reading() methods. If specified,
|
|
the low-water limit must be less than or equal to the
|
|
high-water limit. Neither value can be negative.
|
|
|
|
The defaults are implementation-specific. If only the
|
|
high-water limit is given, the low-water limit defaults to an
|
|
implementation-specific value less than or equal to the
|
|
high-water limit. Setting high to zero forces low to zero as
|
|
well, and causes pause_reading() to be called whenever the
|
|
buffer becomes non-empty. Setting low to zero causes
|
|
resume_reading() to be called only once the buffer is empty.
|
|
Use of zero for either limit is generally sub-optimal as it
|
|
reduces opportunities for doing I/O and computation
|
|
concurrently.
|
|
"""
|
|
self._ssl_protocol._set_read_buffer_limits(high, low)
|
|
self._ssl_protocol._control_ssl_reading()
|
|
|
|
def get_read_buffer_limits(self):
|
|
return (self._ssl_protocol._incoming_low_water,
|
|
self._ssl_protocol._incoming_high_water)
|
|
|
|
def get_read_buffer_size(self):
|
|
"""Return the current size of the read buffer."""
|
|
return self._ssl_protocol._get_read_buffer_size()
|
|
|
|
@property
|
|
def _protocol_paused(self):
|
|
# Required for sendfile fallback pause_writing/resume_writing logic
|
|
return self._ssl_protocol._app_writing_paused
|
|
|
|
def write(self, data):
|
|
"""Write some data bytes to the transport.
|
|
|
|
This does not block; it buffers the data and arranges for it
|
|
to be sent out asynchronously.
|
|
"""
|
|
if not isinstance(data, (bytes, bytearray, memoryview)):
|
|
raise TypeError(f"data: expecting a bytes-like instance, "
|
|
f"got {type(data).__name__}")
|
|
if not data:
|
|
return
|
|
self._ssl_protocol._write_appdata((data,))
|
|
|
|
def writelines(self, list_of_data):
|
|
"""Write a list (or any iterable) of data bytes to the transport.
|
|
|
|
The default implementation concatenates the arguments and
|
|
calls write() on the result.
|
|
"""
|
|
self._ssl_protocol._write_appdata(list_of_data)
|
|
|
|
def write_eof(self):
|
|
"""Close the write end after flushing buffered data.
|
|
|
|
This raises :exc:`NotImplementedError` right now.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def can_write_eof(self):
|
|
"""Return True if this transport supports write_eof(), False if not."""
|
|
return False
|
|
|
|
def abort(self):
|
|
"""Close the transport immediately.
|
|
|
|
Buffered data will be lost. No more data will be received.
|
|
The protocol's connection_lost() method will (eventually) be
|
|
called with None as its argument.
|
|
"""
|
|
self._force_close(None)
|
|
|
|
def _force_close(self, exc):
|
|
self._closed = True
|
|
if self._ssl_protocol is not None:
|
|
self._ssl_protocol._abort(exc)
|
|
|
|
def _test__append_write_backlog(self, data):
|
|
# for test only
|
|
self._ssl_protocol._write_backlog.append(data)
|
|
self._ssl_protocol._write_buffer_size += len(data)
|
|
|
|
|
|
class SSLProtocol(protocols.BufferedProtocol):
|
|
max_size = 256 * 1024 # Buffer size passed to read()
|
|
|
|
_handshake_start_time = None
|
|
_handshake_timeout_handle = None
|
|
_shutdown_timeout_handle = None
|
|
|
|
def __init__(self, loop, app_protocol, sslcontext, waiter,
|
|
server_side=False, server_hostname=None,
|
|
call_connection_made=True,
|
|
ssl_handshake_timeout=None,
|
|
ssl_shutdown_timeout=None):
|
|
if ssl is None:
|
|
raise RuntimeError("stdlib ssl module not available")
|
|
|
|
self._ssl_buffer = bytearray(self.max_size)
|
|
self._ssl_buffer_view = memoryview(self._ssl_buffer)
|
|
|
|
if ssl_handshake_timeout is None:
|
|
ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
|
|
elif ssl_handshake_timeout <= 0:
|
|
raise ValueError(
|
|
f"ssl_handshake_timeout should be a positive number, "
|
|
f"got {ssl_handshake_timeout}")
|
|
if ssl_shutdown_timeout is None:
|
|
ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
|
|
elif ssl_shutdown_timeout <= 0:
|
|
raise ValueError(
|
|
f"ssl_shutdown_timeout should be a positive number, "
|
|
f"got {ssl_shutdown_timeout}")
|
|
|
|
if not sslcontext:
|
|
sslcontext = _create_transport_context(
|
|
server_side, server_hostname)
|
|
|
|
self._server_side = server_side
|
|
if server_hostname and not server_side:
|
|
self._server_hostname = server_hostname
|
|
else:
|
|
self._server_hostname = None
|
|
self._sslcontext = sslcontext
|
|
# SSL-specific extra info. More info are set when the handshake
|
|
# completes.
|
|
self._extra = dict(sslcontext=sslcontext)
|
|
|
|
# App data write buffering
|
|
self._write_backlog = collections.deque()
|
|
self._write_buffer_size = 0
|
|
|
|
self._waiter = waiter
|
|
self._loop = loop
|
|
self._set_app_protocol(app_protocol)
|
|
self._app_transport = None
|
|
self._app_transport_created = False
|
|
# transport, ex: SelectorSocketTransport
|
|
self._transport = None
|
|
self._ssl_handshake_timeout = ssl_handshake_timeout
|
|
self._ssl_shutdown_timeout = ssl_shutdown_timeout
|
|
# SSL and state machine
|
|
self._incoming = ssl.MemoryBIO()
|
|
self._outgoing = ssl.MemoryBIO()
|
|
self._state = SSLProtocolState.UNWRAPPED
|
|
self._conn_lost = 0 # Set when connection_lost called
|
|
if call_connection_made:
|
|
self._app_state = AppProtocolState.STATE_INIT
|
|
else:
|
|
self._app_state = AppProtocolState.STATE_CON_MADE
|
|
self._sslobj = self._sslcontext.wrap_bio(
|
|
self._incoming, self._outgoing,
|
|
server_side=self._server_side,
|
|
server_hostname=self._server_hostname)
|
|
|
|
# Flow Control
|
|
|
|
self._ssl_writing_paused = False
|
|
|
|
self._app_reading_paused = False
|
|
|
|
self._ssl_reading_paused = False
|
|
self._incoming_high_water = 0
|
|
self._incoming_low_water = 0
|
|
self._set_read_buffer_limits()
|
|
self._eof_received = False
|
|
|
|
self._app_writing_paused = False
|
|
self._outgoing_high_water = 0
|
|
self._outgoing_low_water = 0
|
|
self._set_write_buffer_limits()
|
|
self._get_app_transport()
|
|
|
|
def _set_app_protocol(self, app_protocol):
|
|
self._app_protocol = app_protocol
|
|
# Make fast hasattr check first
|
|
if (hasattr(app_protocol, 'get_buffer') and
|
|
isinstance(app_protocol, protocols.BufferedProtocol)):
|
|
self._app_protocol_get_buffer = app_protocol.get_buffer
|
|
self._app_protocol_buffer_updated = app_protocol.buffer_updated
|
|
self._app_protocol_is_buffer = True
|
|
else:
|
|
self._app_protocol_is_buffer = False
|
|
|
|
def _wakeup_waiter(self, exc=None):
|
|
if self._waiter is None:
|
|
return
|
|
if not self._waiter.cancelled():
|
|
if exc is not None:
|
|
self._waiter.set_exception(exc)
|
|
else:
|
|
self._waiter.set_result(None)
|
|
self._waiter = None
|
|
|
|
def _get_app_transport(self):
|
|
if self._app_transport is None:
|
|
if self._app_transport_created:
|
|
raise RuntimeError('Creating _SSLProtocolTransport twice')
|
|
self._app_transport = _SSLProtocolTransport(self._loop, self)
|
|
self._app_transport_created = True
|
|
return self._app_transport
|
|
|
|
def connection_made(self, transport):
|
|
"""Called when the low-level connection is made.
|
|
|
|
Start the SSL handshake.
|
|
"""
|
|
self._transport = transport
|
|
self._start_handshake()
|
|
|
|
def connection_lost(self, exc):
|
|
"""Called when the low-level connection is lost or closed.
|
|
|
|
The argument is an exception object or None (the latter
|
|
meaning a regular EOF is received or the connection was
|
|
aborted or closed).
|
|
"""
|
|
self._write_backlog.clear()
|
|
self._outgoing.read()
|
|
self._conn_lost += 1
|
|
|
|
# Just mark the app transport as closed so that its __dealloc__
|
|
# doesn't complain.
|
|
if self._app_transport is not None:
|
|
self._app_transport._closed = True
|
|
|
|
if self._state != SSLProtocolState.DO_HANDSHAKE:
|
|
if (
|
|
self._app_state == AppProtocolState.STATE_CON_MADE or
|
|
self._app_state == AppProtocolState.STATE_EOF
|
|
):
|
|
self._app_state = AppProtocolState.STATE_CON_LOST
|
|
self._loop.call_soon(self._app_protocol.connection_lost, exc)
|
|
self._set_state(SSLProtocolState.UNWRAPPED)
|
|
self._transport = None
|
|
self._app_transport = None
|
|
self._app_protocol = None
|
|
self._wakeup_waiter(exc)
|
|
|
|
if self._shutdown_timeout_handle:
|
|
self._shutdown_timeout_handle.cancel()
|
|
self._shutdown_timeout_handle = None
|
|
if self._handshake_timeout_handle:
|
|
self._handshake_timeout_handle.cancel()
|
|
self._handshake_timeout_handle = None
|
|
|
|
def get_buffer(self, n):
|
|
want = n
|
|
if want <= 0 or want > self.max_size:
|
|
want = self.max_size
|
|
if len(self._ssl_buffer) < want:
|
|
self._ssl_buffer = bytearray(want)
|
|
self._ssl_buffer_view = memoryview(self._ssl_buffer)
|
|
return self._ssl_buffer_view
|
|
|
|
def buffer_updated(self, nbytes):
|
|
self._incoming.write(self._ssl_buffer_view[:nbytes])
|
|
|
|
if self._state == SSLProtocolState.DO_HANDSHAKE:
|
|
self._do_handshake()
|
|
|
|
elif self._state == SSLProtocolState.WRAPPED:
|
|
self._do_read()
|
|
|
|
elif self._state == SSLProtocolState.FLUSHING:
|
|
self._do_flush()
|
|
|
|
elif self._state == SSLProtocolState.SHUTDOWN:
|
|
self._do_shutdown()
|
|
|
|
def eof_received(self):
|
|
"""Called when the other end of the low-level stream
|
|
is half-closed.
|
|
|
|
If this returns a false value (including None), the transport
|
|
will close itself. If it returns a true value, closing the
|
|
transport is up to the protocol.
|
|
"""
|
|
self._eof_received = True
|
|
try:
|
|
if self._loop.get_debug():
|
|
logger.debug("%r received EOF", self)
|
|
|
|
if self._state == SSLProtocolState.DO_HANDSHAKE:
|
|
self._on_handshake_complete(ConnectionResetError)
|
|
|
|
elif self._state == SSLProtocolState.WRAPPED:
|
|
self._set_state(SSLProtocolState.FLUSHING)
|
|
if self._app_reading_paused:
|
|
return True
|
|
else:
|
|
self._do_flush()
|
|
|
|
elif self._state == SSLProtocolState.FLUSHING:
|
|
self._do_write()
|
|
self._set_state(SSLProtocolState.SHUTDOWN)
|
|
self._do_shutdown()
|
|
|
|
elif self._state == SSLProtocolState.SHUTDOWN:
|
|
self._do_shutdown()
|
|
|
|
except Exception:
|
|
self._transport.close()
|
|
raise
|
|
|
|
def _get_extra_info(self, name, default=None):
|
|
if name in self._extra:
|
|
return self._extra[name]
|
|
elif self._transport is not None:
|
|
return self._transport.get_extra_info(name, default)
|
|
else:
|
|
return default
|
|
|
|
def _set_state(self, new_state):
|
|
allowed = False
|
|
|
|
if new_state == SSLProtocolState.UNWRAPPED:
|
|
allowed = True
|
|
|
|
elif (
|
|
self._state == SSLProtocolState.UNWRAPPED and
|
|
new_state == SSLProtocolState.DO_HANDSHAKE
|
|
):
|
|
allowed = True
|
|
|
|
elif (
|
|
self._state == SSLProtocolState.DO_HANDSHAKE and
|
|
new_state == SSLProtocolState.WRAPPED
|
|
):
|
|
allowed = True
|
|
|
|
elif (
|
|
self._state == SSLProtocolState.WRAPPED and
|
|
new_state == SSLProtocolState.FLUSHING
|
|
):
|
|
allowed = True
|
|
|
|
elif (
|
|
self._state == SSLProtocolState.FLUSHING and
|
|
new_state == SSLProtocolState.SHUTDOWN
|
|
):
|
|
allowed = True
|
|
|
|
if allowed:
|
|
self._state = new_state
|
|
|
|
else:
|
|
raise RuntimeError(
|
|
'cannot switch state from {} to {}'.format(
|
|
self._state, new_state))
|
|
|
|
# Handshake flow
|
|
|
|
def _start_handshake(self):
|
|
if self._loop.get_debug():
|
|
logger.debug("%r starts SSL handshake", self)
|
|
self._handshake_start_time = self._loop.time()
|
|
else:
|
|
self._handshake_start_time = None
|
|
|
|
self._set_state(SSLProtocolState.DO_HANDSHAKE)
|
|
|
|
# start handshake timeout count down
|
|
self._handshake_timeout_handle = \
|
|
self._loop.call_later(self._ssl_handshake_timeout,
|
|
self._check_handshake_timeout)
|
|
|
|
self._do_handshake()
|
|
|
|
def _check_handshake_timeout(self):
|
|
if self._state == SSLProtocolState.DO_HANDSHAKE:
|
|
msg = (
|
|
f"SSL handshake is taking longer than "
|
|
f"{self._ssl_handshake_timeout} seconds: "
|
|
f"aborting the connection"
|
|
)
|
|
self._fatal_error(ConnectionAbortedError(msg))
|
|
|
|
def _do_handshake(self):
|
|
try:
|
|
self._sslobj.do_handshake()
|
|
except SSLAgainErrors:
|
|
self._process_outgoing()
|
|
except ssl.SSLError as exc:
|
|
self._on_handshake_complete(exc)
|
|
else:
|
|
self._on_handshake_complete(None)
|
|
|
|
def _on_handshake_complete(self, handshake_exc):
|
|
if self._handshake_timeout_handle is not None:
|
|
self._handshake_timeout_handle.cancel()
|
|
self._handshake_timeout_handle = None
|
|
|
|
sslobj = self._sslobj
|
|
try:
|
|
if handshake_exc is None:
|
|
self._set_state(SSLProtocolState.WRAPPED)
|
|
else:
|
|
raise handshake_exc
|
|
|
|
peercert = sslobj.getpeercert()
|
|
except Exception as exc:
|
|
self._set_state(SSLProtocolState.UNWRAPPED)
|
|
if isinstance(exc, ssl.CertificateError):
|
|
msg = 'SSL handshake failed on verifying the certificate'
|
|
else:
|
|
msg = 'SSL handshake failed'
|
|
self._fatal_error(exc, msg)
|
|
self._wakeup_waiter(exc)
|
|
return
|
|
|
|
if self._loop.get_debug():
|
|
dt = self._loop.time() - self._handshake_start_time
|
|
logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
|
|
|
|
# Add extra info that becomes available after handshake.
|
|
self._extra.update(peercert=peercert,
|
|
cipher=sslobj.cipher(),
|
|
compression=sslobj.compression(),
|
|
ssl_object=sslobj)
|
|
if self._app_state == AppProtocolState.STATE_INIT:
|
|
self._app_state = AppProtocolState.STATE_CON_MADE
|
|
self._app_protocol.connection_made(self._get_app_transport())
|
|
self._wakeup_waiter()
|
|
self._do_read()
|
|
|
|
# Shutdown flow
|
|
|
|
def _start_shutdown(self):
|
|
if (
|
|
self._state in (
|
|
SSLProtocolState.FLUSHING,
|
|
SSLProtocolState.SHUTDOWN,
|
|
SSLProtocolState.UNWRAPPED
|
|
)
|
|
):
|
|
return
|
|
if self._app_transport is not None:
|
|
self._app_transport._closed = True
|
|
if self._state == SSLProtocolState.DO_HANDSHAKE:
|
|
self._abort(None)
|
|
else:
|
|
self._set_state(SSLProtocolState.FLUSHING)
|
|
self._shutdown_timeout_handle = self._loop.call_later(
|
|
self._ssl_shutdown_timeout,
|
|
self._check_shutdown_timeout
|
|
)
|
|
self._do_flush()
|
|
|
|
def _check_shutdown_timeout(self):
|
|
if (
|
|
self._state in (
|
|
SSLProtocolState.FLUSHING,
|
|
SSLProtocolState.SHUTDOWN
|
|
)
|
|
):
|
|
self._transport._force_close(
|
|
exceptions.TimeoutError('SSL shutdown timed out'))
|
|
|
|
def _do_flush(self):
|
|
self._do_read()
|
|
self._set_state(SSLProtocolState.SHUTDOWN)
|
|
self._do_shutdown()
|
|
|
|
def _do_shutdown(self):
|
|
try:
|
|
if not self._eof_received:
|
|
self._sslobj.unwrap()
|
|
except SSLAgainErrors:
|
|
self._process_outgoing()
|
|
except ssl.SSLError as exc:
|
|
self._on_shutdown_complete(exc)
|
|
else:
|
|
self._process_outgoing()
|
|
self._call_eof_received()
|
|
self._on_shutdown_complete(None)
|
|
|
|
def _on_shutdown_complete(self, shutdown_exc):
|
|
if self._shutdown_timeout_handle is not None:
|
|
self._shutdown_timeout_handle.cancel()
|
|
self._shutdown_timeout_handle = None
|
|
|
|
if shutdown_exc:
|
|
self._fatal_error(shutdown_exc)
|
|
else:
|
|
self._loop.call_soon(self._transport.close)
|
|
|
|
def _abort(self, exc):
|
|
self._set_state(SSLProtocolState.UNWRAPPED)
|
|
if self._transport is not None:
|
|
self._transport._force_close(exc)
|
|
|
|
# Outgoing flow
|
|
|
|
def _write_appdata(self, list_of_data):
|
|
if (
|
|
self._state in (
|
|
SSLProtocolState.FLUSHING,
|
|
SSLProtocolState.SHUTDOWN,
|
|
SSLProtocolState.UNWRAPPED
|
|
)
|
|
):
|
|
if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
|
|
logger.warning('SSL connection is closed')
|
|
self._conn_lost += 1
|
|
return
|
|
|
|
for data in list_of_data:
|
|
self._write_backlog.append(data)
|
|
self._write_buffer_size += len(data)
|
|
|
|
try:
|
|
if self._state == SSLProtocolState.WRAPPED:
|
|
self._do_write()
|
|
|
|
except Exception as ex:
|
|
self._fatal_error(ex, 'Fatal error on SSL protocol')
|
|
|
|
def _do_write(self):
|
|
try:
|
|
while self._write_backlog:
|
|
data = self._write_backlog[0]
|
|
count = self._sslobj.write(data)
|
|
data_len = len(data)
|
|
if count < data_len:
|
|
self._write_backlog[0] = data[count:]
|
|
self._write_buffer_size -= count
|
|
else:
|
|
del self._write_backlog[0]
|
|
self._write_buffer_size -= data_len
|
|
except SSLAgainErrors:
|
|
pass
|
|
self._process_outgoing()
|
|
|
|
def _process_outgoing(self):
|
|
if not self._ssl_writing_paused:
|
|
data = self._outgoing.read()
|
|
if len(data):
|
|
self._transport.write(data)
|
|
self._control_app_writing()
|
|
|
|
# Incoming flow
|
|
|
|
def _do_read(self):
|
|
if (
|
|
self._state not in (
|
|
SSLProtocolState.WRAPPED,
|
|
SSLProtocolState.FLUSHING,
|
|
)
|
|
):
|
|
return
|
|
try:
|
|
if not self._app_reading_paused:
|
|
if self._app_protocol_is_buffer:
|
|
self._do_read__buffered()
|
|
else:
|
|
self._do_read__copied()
|
|
if self._write_backlog:
|
|
self._do_write()
|
|
else:
|
|
self._process_outgoing()
|
|
self._control_ssl_reading()
|
|
except Exception as ex:
|
|
self._fatal_error(ex, 'Fatal error on SSL protocol')
|
|
|
|
def _do_read__buffered(self):
|
|
offset = 0
|
|
count = 1
|
|
|
|
buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
|
|
wants = len(buf)
|
|
|
|
try:
|
|
count = self._sslobj.read(wants, buf)
|
|
|
|
if count > 0:
|
|
offset = count
|
|
while offset < wants:
|
|
count = self._sslobj.read(wants - offset, buf[offset:])
|
|
if count > 0:
|
|
offset += count
|
|
else:
|
|
break
|
|
else:
|
|
self._loop.call_soon(self._do_read)
|
|
except SSLAgainErrors:
|
|
pass
|
|
if offset > 0:
|
|
self._app_protocol_buffer_updated(offset)
|
|
if not count:
|
|
# close_notify
|
|
self._call_eof_received()
|
|
self._start_shutdown()
|
|
|
|
def _do_read__copied(self):
|
|
chunk = b'1'
|
|
zero = True
|
|
one = False
|
|
|
|
try:
|
|
while True:
|
|
chunk = self._sslobj.read(self.max_size)
|
|
if not chunk:
|
|
break
|
|
if zero:
|
|
zero = False
|
|
one = True
|
|
first = chunk
|
|
elif one:
|
|
one = False
|
|
data = [first, chunk]
|
|
else:
|
|
data.append(chunk)
|
|
except SSLAgainErrors:
|
|
pass
|
|
if one:
|
|
self._app_protocol.data_received(first)
|
|
elif not zero:
|
|
self._app_protocol.data_received(b''.join(data))
|
|
if not chunk:
|
|
# close_notify
|
|
self._call_eof_received()
|
|
self._start_shutdown()
|
|
|
|
def _call_eof_received(self):
|
|
try:
|
|
if self._app_state == AppProtocolState.STATE_CON_MADE:
|
|
self._app_state = AppProtocolState.STATE_EOF
|
|
keep_open = self._app_protocol.eof_received()
|
|
if keep_open:
|
|
logger.warning('returning true from eof_received() '
|
|
'has no effect when using ssl')
|
|
except (KeyboardInterrupt, SystemExit):
|
|
raise
|
|
except BaseException as ex:
|
|
self._fatal_error(ex, 'Error calling eof_received()')
|
|
|
|
# Flow control for writes from APP socket
|
|
|
|
def _control_app_writing(self):
|
|
size = self._get_write_buffer_size()
|
|
if size >= self._outgoing_high_water and not self._app_writing_paused:
|
|
self._app_writing_paused = True
|
|
try:
|
|
self._app_protocol.pause_writing()
|
|
except (KeyboardInterrupt, SystemExit):
|
|
raise
|
|
except BaseException as exc:
|
|
self._loop.call_exception_handler({
|
|
'message': 'protocol.pause_writing() failed',
|
|
'exception': exc,
|
|
'transport': self._app_transport,
|
|
'protocol': self,
|
|
})
|
|
elif size <= self._outgoing_low_water and self._app_writing_paused:
|
|
self._app_writing_paused = False
|
|
try:
|
|
self._app_protocol.resume_writing()
|
|
except (KeyboardInterrupt, SystemExit):
|
|
raise
|
|
except BaseException as exc:
|
|
self._loop.call_exception_handler({
|
|
'message': 'protocol.resume_writing() failed',
|
|
'exception': exc,
|
|
'transport': self._app_transport,
|
|
'protocol': self,
|
|
})
|
|
|
|
def _get_write_buffer_size(self):
|
|
return self._outgoing.pending + self._write_buffer_size
|
|
|
|
def _set_write_buffer_limits(self, high=None, low=None):
|
|
high, low = add_flowcontrol_defaults(
|
|
high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
|
|
self._outgoing_high_water = high
|
|
self._outgoing_low_water = low
|
|
|
|
# Flow control for reads to APP socket
|
|
|
|
def _pause_reading(self):
|
|
self._app_reading_paused = True
|
|
|
|
def _resume_reading(self):
|
|
if self._app_reading_paused:
|
|
self._app_reading_paused = False
|
|
|
|
def resume():
|
|
if self._state == SSLProtocolState.WRAPPED:
|
|
self._do_read()
|
|
elif self._state == SSLProtocolState.FLUSHING:
|
|
self._do_flush()
|
|
elif self._state == SSLProtocolState.SHUTDOWN:
|
|
self._do_shutdown()
|
|
self._loop.call_soon(resume)
|
|
|
|
# Flow control for reads from SSL socket
|
|
|
|
def _control_ssl_reading(self):
|
|
size = self._get_read_buffer_size()
|
|
if size >= self._incoming_high_water and not self._ssl_reading_paused:
|
|
self._ssl_reading_paused = True
|
|
self._transport.pause_reading()
|
|
elif size <= self._incoming_low_water and self._ssl_reading_paused:
|
|
self._ssl_reading_paused = False
|
|
self._transport.resume_reading()
|
|
|
|
def _set_read_buffer_limits(self, high=None, low=None):
|
|
high, low = add_flowcontrol_defaults(
|
|
high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
|
|
self._incoming_high_water = high
|
|
self._incoming_low_water = low
|
|
|
|
def _get_read_buffer_size(self):
|
|
return self._incoming.pending
|
|
|
|
# Flow control for writes to SSL socket
|
|
|
|
def pause_writing(self):
|
|
"""Called when the low-level transport's buffer goes over
|
|
the high-water mark.
|
|
"""
|
|
assert not self._ssl_writing_paused
|
|
self._ssl_writing_paused = True
|
|
|
|
def resume_writing(self):
|
|
"""Called when the low-level transport's buffer drains below
|
|
the low-water mark.
|
|
"""
|
|
assert self._ssl_writing_paused
|
|
self._ssl_writing_paused = False
|
|
self._process_outgoing()
|
|
|
|
def _fatal_error(self, exc, message='Fatal error on transport'):
|
|
if self._transport:
|
|
self._transport._force_close(exc)
|
|
|
|
if isinstance(exc, OSError):
|
|
if self._loop.get_debug():
|
|
logger.debug("%r: %s", self, message, exc_info=True)
|
|
elif not isinstance(exc, exceptions.CancelledError):
|
|
self._loop.call_exception_handler({
|
|
'message': message,
|
|
'exception': exc,
|
|
'transport': self._transport,
|
|
'protocol': self,
|
|
})
|