mirror of
https://github.com/python/cpython.git
synced 2025-08-25 03:04:55 +00:00
Write flow control for asyncio (includes asyncio.streams overhaul).
This commit is contained in:
parent
051a331488
commit
355491dc47
5 changed files with 288 additions and 93 deletions
|
@ -39,7 +39,8 @@ def open_connection(host=None, port=None, *,
|
|||
protocol = StreamReaderProtocol(reader)
|
||||
transport, _ = yield from loop.create_connection(
|
||||
lambda: protocol, host, port, **kwds)
|
||||
return reader, transport # (reader, writer)
|
||||
writer = StreamWriter(transport, protocol, reader, loop)
|
||||
return reader, writer
|
||||
|
||||
|
||||
class StreamReaderProtocol(protocols.Protocol):
|
||||
|
@ -52,22 +53,113 @@ class StreamReaderProtocol(protocols.Protocol):
|
|||
"""
|
||||
|
||||
def __init__(self, stream_reader):
|
||||
self.stream_reader = stream_reader
|
||||
self._stream_reader = stream_reader
|
||||
self._drain_waiter = None
|
||||
self._paused = False
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.stream_reader.set_transport(transport)
|
||||
self._stream_reader.set_transport(transport)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if exc is None:
|
||||
self.stream_reader.feed_eof()
|
||||
self._stream_reader.feed_eof()
|
||||
else:
|
||||
self.stream_reader.set_exception(exc)
|
||||
self._stream_reader.set_exception(exc)
|
||||
# Also wake up the writing side.
|
||||
if self._paused:
|
||||
waiter = self._drain_waiter
|
||||
if waiter is not None:
|
||||
self._drain_waiter = None
|
||||
if not waiter.done():
|
||||
if exc is None:
|
||||
waiter.set_result(None)
|
||||
else:
|
||||
waiter.set_exception(exc)
|
||||
|
||||
def data_received(self, data):
|
||||
self.stream_reader.feed_data(data)
|
||||
self._stream_reader.feed_data(data)
|
||||
|
||||
def eof_received(self):
|
||||
self.stream_reader.feed_eof()
|
||||
self._stream_reader.feed_eof()
|
||||
|
||||
def pause_writing(self):
|
||||
assert not self._paused
|
||||
self._paused = True
|
||||
|
||||
def resume_writing(self):
|
||||
assert self._paused
|
||||
self._paused = False
|
||||
waiter = self._drain_waiter
|
||||
if waiter is not None:
|
||||
self._drain_waiter = None
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
|
||||
class StreamWriter:
|
||||
"""Wraps a Transport.
|
||||
|
||||
This exposes write(), writelines(), [can_]write_eof(),
|
||||
get_extra_info() and close(). It adds drain() which returns an
|
||||
optional Future on which you can wait for flow control. It also
|
||||
adds a transport attribute which references the Transport
|
||||
directly.
|
||||
"""
|
||||
|
||||
def __init__(self, transport, protocol, reader, loop):
|
||||
self._transport = transport
|
||||
self._protocol = protocol
|
||||
self._reader = reader
|
||||
self._loop = loop
|
||||
|
||||
@property
|
||||
def transport(self):
|
||||
return self._transport
|
||||
|
||||
def write(self, data):
|
||||
self._transport.write(data)
|
||||
|
||||
def writelines(self, data):
|
||||
self._transport.writelines(data)
|
||||
|
||||
def write_eof(self):
|
||||
return self._transport.write_eof()
|
||||
|
||||
def can_write_eof(self):
|
||||
return self._transport.can_write_eof()
|
||||
|
||||
def close(self):
|
||||
return self._transport.close()
|
||||
|
||||
def get_extra_info(self, name, default=None):
|
||||
return self._transport.get_extra_info(name, default)
|
||||
|
||||
def drain(self):
|
||||
"""This method has an unusual return value.
|
||||
|
||||
The intended use is to write
|
||||
|
||||
w.write(data)
|
||||
yield from w.drain()
|
||||
|
||||
When there's nothing to wait for, drain() returns (), and the
|
||||
yield-from continues immediately. When the transport buffer
|
||||
is full (the protocol is paused), drain() creates and returns
|
||||
a Future and the yield-from will block until that Future is
|
||||
completed, which will happen when the buffer is (partially)
|
||||
drained and the protocol is resumed.
|
||||
"""
|
||||
if self._reader._exception is not None:
|
||||
raise self._writer._exception
|
||||
if self._transport._conn_lost: # Uses private variable.
|
||||
raise ConnectionResetError('Connection lost')
|
||||
if not self._protocol._paused:
|
||||
return ()
|
||||
waiter = self._protocol._drain_waiter
|
||||
assert waiter is None or waiter.cancelled()
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._protocol._drain_waiter = waiter
|
||||
return waiter
|
||||
|
||||
|
||||
class StreamReader:
|
||||
|
@ -75,14 +167,14 @@ class StreamReader:
|
|||
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
|
||||
# The line length limit is a security feature;
|
||||
# it also doubles as half the buffer limit.
|
||||
self.limit = limit
|
||||
self._limit = limit
|
||||
if loop is None:
|
||||
loop = events.get_event_loop()
|
||||
self.loop = loop
|
||||
self.buffer = collections.deque() # Deque of bytes objects.
|
||||
self.byte_count = 0 # Bytes in buffer.
|
||||
self.eof = False # Whether we're done.
|
||||
self.waiter = None # A future.
|
||||
self._loop = loop
|
||||
self._buffer = collections.deque() # Deque of bytes objects.
|
||||
self._byte_count = 0 # Bytes in buffer.
|
||||
self._eof = False # Whether we're done.
|
||||
self._waiter = None # A future.
|
||||
self._exception = None
|
||||
self._transport = None
|
||||
self._paused = False
|
||||
|
@ -93,9 +185,9 @@ class StreamReader:
|
|||
def set_exception(self, exc):
|
||||
self._exception = exc
|
||||
|
||||
waiter = self.waiter
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_exception(exc)
|
||||
|
||||
|
@ -104,15 +196,15 @@ class StreamReader:
|
|||
self._transport = transport
|
||||
|
||||
def _maybe_resume_transport(self):
|
||||
if self._paused and self.byte_count <= self.limit:
|
||||
if self._paused and self._byte_count <= self._limit:
|
||||
self._paused = False
|
||||
self._transport.resume_reading()
|
||||
|
||||
def feed_eof(self):
|
||||
self.eof = True
|
||||
waiter = self.waiter
|
||||
self._eof = True
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_result(True)
|
||||
|
||||
|
@ -120,18 +212,18 @@ class StreamReader:
|
|||
if not data:
|
||||
return
|
||||
|
||||
self.buffer.append(data)
|
||||
self.byte_count += len(data)
|
||||
self._buffer.append(data)
|
||||
self._byte_count += len(data)
|
||||
|
||||
waiter = self.waiter
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_result(False)
|
||||
|
||||
if (self._transport is not None and
|
||||
not self._paused and
|
||||
self.byte_count > 2*self.limit):
|
||||
self._byte_count > 2*self._limit):
|
||||
try:
|
||||
self._transport.pause_reading()
|
||||
except NotImplementedError:
|
||||
|
@ -152,8 +244,8 @@ class StreamReader:
|
|||
not_enough = True
|
||||
|
||||
while not_enough:
|
||||
while self.buffer and not_enough:
|
||||
data = self.buffer.popleft()
|
||||
while self._buffer and not_enough:
|
||||
data = self._buffer.popleft()
|
||||
ichar = data.find(b'\n')
|
||||
if ichar < 0:
|
||||
parts.append(data)
|
||||
|
@ -162,29 +254,29 @@ class StreamReader:
|
|||
ichar += 1
|
||||
head, tail = data[:ichar], data[ichar:]
|
||||
if tail:
|
||||
self.buffer.appendleft(tail)
|
||||
self._buffer.appendleft(tail)
|
||||
not_enough = False
|
||||
parts.append(head)
|
||||
parts_size += len(head)
|
||||
|
||||
if parts_size > self.limit:
|
||||
self.byte_count -= parts_size
|
||||
if parts_size > self._limit:
|
||||
self._byte_count -= parts_size
|
||||
self._maybe_resume_transport()
|
||||
raise ValueError('Line is too long')
|
||||
|
||||
if self.eof:
|
||||
if self._eof:
|
||||
break
|
||||
|
||||
if not_enough:
|
||||
assert self.waiter is None
|
||||
self.waiter = futures.Future(loop=self.loop)
|
||||
assert self._waiter is None
|
||||
self._waiter = futures.Future(loop=self._loop)
|
||||
try:
|
||||
yield from self.waiter
|
||||
yield from self._waiter
|
||||
finally:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
|
||||
line = b''.join(parts)
|
||||
self.byte_count -= parts_size
|
||||
self._byte_count -= parts_size
|
||||
self._maybe_resume_transport()
|
||||
|
||||
return line
|
||||
|
@ -198,42 +290,42 @@ class StreamReader:
|
|||
return b''
|
||||
|
||||
if n < 0:
|
||||
while not self.eof:
|
||||
assert not self.waiter
|
||||
self.waiter = futures.Future(loop=self.loop)
|
||||
while not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = futures.Future(loop=self._loop)
|
||||
try:
|
||||
yield from self.waiter
|
||||
yield from self._waiter
|
||||
finally:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
else:
|
||||
if not self.byte_count and not self.eof:
|
||||
assert not self.waiter
|
||||
self.waiter = futures.Future(loop=self.loop)
|
||||
if not self._byte_count and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = futures.Future(loop=self._loop)
|
||||
try:
|
||||
yield from self.waiter
|
||||
yield from self._waiter
|
||||
finally:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
|
||||
if n < 0 or self.byte_count <= n:
|
||||
data = b''.join(self.buffer)
|
||||
self.buffer.clear()
|
||||
self.byte_count = 0
|
||||
if n < 0 or self._byte_count <= n:
|
||||
data = b''.join(self._buffer)
|
||||
self._buffer.clear()
|
||||
self._byte_count = 0
|
||||
self._maybe_resume_transport()
|
||||
return data
|
||||
|
||||
parts = []
|
||||
parts_bytes = 0
|
||||
while self.buffer and parts_bytes < n:
|
||||
data = self.buffer.popleft()
|
||||
while self._buffer and parts_bytes < n:
|
||||
data = self._buffer.popleft()
|
||||
data_bytes = len(data)
|
||||
if n < parts_bytes + data_bytes:
|
||||
data_bytes = n - parts_bytes
|
||||
data, rest = data[:data_bytes], data[data_bytes:]
|
||||
self.buffer.appendleft(rest)
|
||||
self._buffer.appendleft(rest)
|
||||
|
||||
parts.append(data)
|
||||
parts_bytes += data_bytes
|
||||
self.byte_count -= data_bytes
|
||||
self._byte_count -= data_bytes
|
||||
self._maybe_resume_transport()
|
||||
|
||||
return b''.join(parts)
|
||||
|
@ -246,12 +338,12 @@ class StreamReader:
|
|||
if n <= 0:
|
||||
return b''
|
||||
|
||||
while self.byte_count < n and not self.eof:
|
||||
assert not self.waiter
|
||||
self.waiter = futures.Future(loop=self.loop)
|
||||
while self._byte_count < n and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = futures.Future(loop=self._loop)
|
||||
try:
|
||||
yield from self.waiter
|
||||
yield from self._waiter
|
||||
finally:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
|
||||
return (yield from self.read(n))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue