mirror of
https://github.com/python/cpython.git
synced 2025-10-10 00:43:41 +00:00
asyncio: Refactor drain logic in streams.py to be reusable.
This commit is contained in:
parent
aaabc4fdca
commit
4d62d0b353
1 changed files with 61 additions and 36 deletions
|
@ -94,8 +94,63 @@ def start_server(client_connected_cb, host=None, port=None, *,
|
||||||
return (yield from loop.create_server(factory, host, port, **kwds))
|
return (yield from loop.create_server(factory, host, port, **kwds))
|
||||||
|
|
||||||
|
|
||||||
class StreamReaderProtocol(protocols.Protocol):
|
class FlowControlMixin(protocols.Protocol):
|
||||||
"""Trivial helper class to adapt between Protocol and StreamReader.
|
"""Reusable flow control logic for StreamWriter.drain().
|
||||||
|
|
||||||
|
This implements the protocol methods pause_writing(),
|
||||||
|
resume_reading() and connection_lost(). If the subclass overrides
|
||||||
|
these it must call the super methods.
|
||||||
|
|
||||||
|
StreamWriter.drain() must check for error conditions and then call
|
||||||
|
_make_drain_waiter(), which will return either () or a Future
|
||||||
|
depending on the paused state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loop=None):
|
||||||
|
self._loop = loop # May be None; we may never need it.
|
||||||
|
self._paused = False
|
||||||
|
self._drain_waiter = None
|
||||||
|
|
||||||
|
def pause_writing(self):
|
||||||
|
assert not self._paused
|
||||||
|
self._paused = True
|
||||||
|
|
||||||
|
def resume_writing(self):
|
||||||
|
assert self._paused
|
||||||
|
self._paused = False
|
||||||
|
waiter = self._drain_waiter
|
||||||
|
if waiter is not None:
|
||||||
|
self._drain_waiter = None
|
||||||
|
if not waiter.done():
|
||||||
|
waiter.set_result(None)
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
# Wake up the writer if currently paused.
|
||||||
|
if not self._paused:
|
||||||
|
return
|
||||||
|
waiter = self._drain_waiter
|
||||||
|
if waiter is None:
|
||||||
|
return
|
||||||
|
self._drain_waiter = None
|
||||||
|
if waiter.done():
|
||||||
|
return
|
||||||
|
if exc is None:
|
||||||
|
waiter.set_result(None)
|
||||||
|
else:
|
||||||
|
waiter.set_exception(exc)
|
||||||
|
|
||||||
|
def _make_drain_waiter(self):
|
||||||
|
if not self._paused:
|
||||||
|
return ()
|
||||||
|
waiter = self._drain_waiter
|
||||||
|
assert waiter is None or waiter.cancelled()
|
||||||
|
waiter = futures.Future(loop=self._loop)
|
||||||
|
self._drain_waiter = waiter
|
||||||
|
return waiter
|
||||||
|
|
||||||
|
|
||||||
|
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
|
||||||
|
"""Helper class to adapt between Protocol and StreamReader.
|
||||||
|
|
||||||
(This is a helper class instead of making StreamReader itself a
|
(This is a helper class instead of making StreamReader itself a
|
||||||
Protocol subclass, because the StreamReader has other potential
|
Protocol subclass, because the StreamReader has other potential
|
||||||
|
@ -104,12 +159,10 @@ class StreamReaderProtocol(protocols.Protocol):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
|
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
|
||||||
|
super().__init__(loop=loop)
|
||||||
self._stream_reader = stream_reader
|
self._stream_reader = stream_reader
|
||||||
self._stream_writer = None
|
self._stream_writer = None
|
||||||
self._drain_waiter = None
|
|
||||||
self._paused = False
|
|
||||||
self._client_connected_cb = client_connected_cb
|
self._client_connected_cb = client_connected_cb
|
||||||
self._loop = loop # May be None; we may never need it.
|
|
||||||
|
|
||||||
def connection_made(self, transport):
|
def connection_made(self, transport):
|
||||||
self._stream_reader.set_transport(transport)
|
self._stream_reader.set_transport(transport)
|
||||||
|
@ -127,16 +180,7 @@ class StreamReaderProtocol(protocols.Protocol):
|
||||||
self._stream_reader.feed_eof()
|
self._stream_reader.feed_eof()
|
||||||
else:
|
else:
|
||||||
self._stream_reader.set_exception(exc)
|
self._stream_reader.set_exception(exc)
|
||||||
# Also wake up the writing side.
|
super().connection_lost(exc)
|
||||||
if self._paused:
|
|
||||||
waiter = self._drain_waiter
|
|
||||||
if waiter is not None:
|
|
||||||
self._drain_waiter = None
|
|
||||||
if not waiter.done():
|
|
||||||
if exc is None:
|
|
||||||
waiter.set_result(None)
|
|
||||||
else:
|
|
||||||
waiter.set_exception(exc)
|
|
||||||
|
|
||||||
def data_received(self, data):
|
def data_received(self, data):
|
||||||
self._stream_reader.feed_data(data)
|
self._stream_reader.feed_data(data)
|
||||||
|
@ -144,19 +188,6 @@ class StreamReaderProtocol(protocols.Protocol):
|
||||||
def eof_received(self):
|
def eof_received(self):
|
||||||
self._stream_reader.feed_eof()
|
self._stream_reader.feed_eof()
|
||||||
|
|
||||||
def pause_writing(self):
|
|
||||||
assert not self._paused
|
|
||||||
self._paused = True
|
|
||||||
|
|
||||||
def resume_writing(self):
|
|
||||||
assert self._paused
|
|
||||||
self._paused = False
|
|
||||||
waiter = self._drain_waiter
|
|
||||||
if waiter is not None:
|
|
||||||
self._drain_waiter = None
|
|
||||||
if not waiter.done():
|
|
||||||
waiter.set_result(None)
|
|
||||||
|
|
||||||
|
|
||||||
class StreamWriter:
|
class StreamWriter:
|
||||||
"""Wraps a Transport.
|
"""Wraps a Transport.
|
||||||
|
@ -211,17 +242,11 @@ class StreamWriter:
|
||||||
completed, which will happen when the buffer is (partially)
|
completed, which will happen when the buffer is (partially)
|
||||||
drained and the protocol is resumed.
|
drained and the protocol is resumed.
|
||||||
"""
|
"""
|
||||||
if self._reader._exception is not None:
|
if self._reader is not None and self._reader._exception is not None:
|
||||||
raise self._reader._exception
|
raise self._reader._exception
|
||||||
if self._transport._conn_lost: # Uses private variable.
|
if self._transport._conn_lost: # Uses private variable.
|
||||||
raise ConnectionResetError('Connection lost')
|
raise ConnectionResetError('Connection lost')
|
||||||
if not self._protocol._paused:
|
return self._protocol._make_drain_waiter()
|
||||||
return ()
|
|
||||||
waiter = self._protocol._drain_waiter
|
|
||||||
assert waiter is None or waiter.cancelled()
|
|
||||||
waiter = futures.Future(loop=self._loop)
|
|
||||||
self._protocol._drain_waiter = waiter
|
|
||||||
return waiter
|
|
||||||
|
|
||||||
|
|
||||||
class StreamReader:
|
class StreamReader:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue