mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
asyncio.streams: Use bytebuffer in StreamReader; Add assertion in feed_data
This commit is contained in:
parent
58af25e930
commit
e694c9745f
2 changed files with 75 additions and 75 deletions
|
@ -4,8 +4,6 @@ __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
|
||||||
'open_connection', 'start_server', 'IncompleteReadError',
|
'open_connection', 'start_server', 'IncompleteReadError',
|
||||||
]
|
]
|
||||||
|
|
||||||
import collections
|
|
||||||
|
|
||||||
from . import events
|
from . import events
|
||||||
from . import futures
|
from . import futures
|
||||||
from . import protocols
|
from . import protocols
|
||||||
|
@ -259,9 +257,7 @@ class StreamReader:
|
||||||
if loop is None:
|
if loop is None:
|
||||||
loop = events.get_event_loop()
|
loop = events.get_event_loop()
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
# TODO: Use a bytearray for a buffer, like the transport.
|
self._buffer = bytearray()
|
||||||
self._buffer = collections.deque() # Deque of bytes objects.
|
|
||||||
self._byte_count = 0 # Bytes in buffer.
|
|
||||||
self._eof = False # Whether we're done.
|
self._eof = False # Whether we're done.
|
||||||
self._waiter = None # A future.
|
self._waiter = None # A future.
|
||||||
self._exception = None
|
self._exception = None
|
||||||
|
@ -285,7 +281,7 @@ class StreamReader:
|
||||||
self._transport = transport
|
self._transport = transport
|
||||||
|
|
||||||
def _maybe_resume_transport(self):
|
def _maybe_resume_transport(self):
|
||||||
if self._paused and self._byte_count <= self._limit:
|
if self._paused and len(self._buffer) <= self._limit:
|
||||||
self._paused = False
|
self._paused = False
|
||||||
self._transport.resume_reading()
|
self._transport.resume_reading()
|
||||||
|
|
||||||
|
@ -298,11 +294,12 @@ class StreamReader:
|
||||||
waiter.set_result(True)
|
waiter.set_result(True)
|
||||||
|
|
||||||
def feed_data(self, data):
|
def feed_data(self, data):
|
||||||
|
assert not self._eof, 'feed_data after feed_eof'
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._buffer.append(data)
|
self._buffer.extend(data)
|
||||||
self._byte_count += len(data)
|
|
||||||
|
|
||||||
waiter = self._waiter
|
waiter = self._waiter
|
||||||
if waiter is not None:
|
if waiter is not None:
|
||||||
|
@ -312,7 +309,7 @@ class StreamReader:
|
||||||
|
|
||||||
if (self._transport is not None and
|
if (self._transport is not None and
|
||||||
not self._paused and
|
not self._paused and
|
||||||
self._byte_count > 2*self._limit):
|
len(self._buffer) > 2*self._limit):
|
||||||
try:
|
try:
|
||||||
self._transport.pause_reading()
|
self._transport.pause_reading()
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
|
@ -338,28 +335,22 @@ class StreamReader:
|
||||||
if self._exception is not None:
|
if self._exception is not None:
|
||||||
raise self._exception
|
raise self._exception
|
||||||
|
|
||||||
parts = []
|
line = bytearray()
|
||||||
parts_size = 0
|
|
||||||
not_enough = True
|
not_enough = True
|
||||||
|
|
||||||
while not_enough:
|
while not_enough:
|
||||||
while self._buffer and not_enough:
|
while self._buffer and not_enough:
|
||||||
data = self._buffer.popleft()
|
ichar = self._buffer.find(b'\n')
|
||||||
ichar = data.find(b'\n')
|
|
||||||
if ichar < 0:
|
if ichar < 0:
|
||||||
parts.append(data)
|
line.extend(self._buffer)
|
||||||
parts_size += len(data)
|
self._buffer.clear()
|
||||||
else:
|
else:
|
||||||
ichar += 1
|
ichar += 1
|
||||||
head, tail = data[:ichar], data[ichar:]
|
line.extend(self._buffer[:ichar])
|
||||||
if tail:
|
del self._buffer[:ichar]
|
||||||
self._buffer.appendleft(tail)
|
|
||||||
not_enough = False
|
not_enough = False
|
||||||
parts.append(head)
|
|
||||||
parts_size += len(head)
|
|
||||||
|
|
||||||
if parts_size > self._limit:
|
if len(line) > self._limit:
|
||||||
self._byte_count -= parts_size
|
|
||||||
self._maybe_resume_transport()
|
self._maybe_resume_transport()
|
||||||
raise ValueError('Line is too long')
|
raise ValueError('Line is too long')
|
||||||
|
|
||||||
|
@ -373,11 +364,8 @@ class StreamReader:
|
||||||
finally:
|
finally:
|
||||||
self._waiter = None
|
self._waiter = None
|
||||||
|
|
||||||
line = b''.join(parts)
|
|
||||||
self._byte_count -= parts_size
|
|
||||||
self._maybe_resume_transport()
|
self._maybe_resume_transport()
|
||||||
|
return bytes(line)
|
||||||
return line
|
|
||||||
|
|
||||||
@tasks.coroutine
|
@tasks.coroutine
|
||||||
def read(self, n=-1):
|
def read(self, n=-1):
|
||||||
|
@ -395,36 +383,23 @@ class StreamReader:
|
||||||
finally:
|
finally:
|
||||||
self._waiter = None
|
self._waiter = None
|
||||||
else:
|
else:
|
||||||
if not self._byte_count and not self._eof:
|
if not self._buffer and not self._eof:
|
||||||
self._waiter = self._create_waiter('read')
|
self._waiter = self._create_waiter('read')
|
||||||
try:
|
try:
|
||||||
yield from self._waiter
|
yield from self._waiter
|
||||||
finally:
|
finally:
|
||||||
self._waiter = None
|
self._waiter = None
|
||||||
|
|
||||||
if n < 0 or self._byte_count <= n:
|
if n < 0 or len(self._buffer) <= n:
|
||||||
data = b''.join(self._buffer)
|
data = bytes(self._buffer)
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
self._byte_count = 0
|
else:
|
||||||
self._maybe_resume_transport()
|
# n > 0 and len(self._buffer) > n
|
||||||
return data
|
data = bytes(self._buffer[:n])
|
||||||
|
del self._buffer[:n]
|
||||||
|
|
||||||
parts = []
|
self._maybe_resume_transport()
|
||||||
parts_bytes = 0
|
return data
|
||||||
while self._buffer and parts_bytes < n:
|
|
||||||
data = self._buffer.popleft()
|
|
||||||
data_bytes = len(data)
|
|
||||||
if n < parts_bytes + data_bytes:
|
|
||||||
data_bytes = n - parts_bytes
|
|
||||||
data, rest = data[:data_bytes], data[data_bytes:]
|
|
||||||
self._buffer.appendleft(rest)
|
|
||||||
|
|
||||||
parts.append(data)
|
|
||||||
parts_bytes += data_bytes
|
|
||||||
self._byte_count -= data_bytes
|
|
||||||
self._maybe_resume_transport()
|
|
||||||
|
|
||||||
return b''.join(parts)
|
|
||||||
|
|
||||||
@tasks.coroutine
|
@tasks.coroutine
|
||||||
def readexactly(self, n):
|
def readexactly(self, n):
|
||||||
|
|
|
@ -79,13 +79,13 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
|
|
||||||
stream.feed_data(b'')
|
stream.feed_data(b'')
|
||||||
self.assertEqual(0, stream._byte_count)
|
self.assertEqual(b'', stream._buffer)
|
||||||
|
|
||||||
def test_feed_data_byte_count(self):
|
def test_feed_nonempty_data(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
|
|
||||||
stream.feed_data(self.DATA)
|
stream.feed_data(self.DATA)
|
||||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
self.assertEqual(self.DATA, stream._buffer)
|
||||||
|
|
||||||
def test_read_zero(self):
|
def test_read_zero(self):
|
||||||
# Read zero bytes.
|
# Read zero bytes.
|
||||||
|
@ -94,7 +94,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
|
|
||||||
data = self.loop.run_until_complete(stream.read(0))
|
data = self.loop.run_until_complete(stream.read(0))
|
||||||
self.assertEqual(b'', data)
|
self.assertEqual(b'', data)
|
||||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
self.assertEqual(self.DATA, stream._buffer)
|
||||||
|
|
||||||
def test_read(self):
|
def test_read(self):
|
||||||
# Read bytes.
|
# Read bytes.
|
||||||
|
@ -107,7 +107,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
|
|
||||||
data = self.loop.run_until_complete(read_task)
|
data = self.loop.run_until_complete(read_task)
|
||||||
self.assertEqual(self.DATA, data)
|
self.assertEqual(self.DATA, data)
|
||||||
self.assertFalse(stream._byte_count)
|
self.assertEqual(b'', stream._buffer)
|
||||||
|
|
||||||
def test_read_line_breaks(self):
|
def test_read_line_breaks(self):
|
||||||
# Read bytes without line breaks.
|
# Read bytes without line breaks.
|
||||||
|
@ -118,7 +118,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
data = self.loop.run_until_complete(stream.read(5))
|
data = self.loop.run_until_complete(stream.read(5))
|
||||||
|
|
||||||
self.assertEqual(b'line1', data)
|
self.assertEqual(b'line1', data)
|
||||||
self.assertEqual(5, stream._byte_count)
|
self.assertEqual(b'line2', stream._buffer)
|
||||||
|
|
||||||
def test_read_eof(self):
|
def test_read_eof(self):
|
||||||
# Read bytes, stop at eof.
|
# Read bytes, stop at eof.
|
||||||
|
@ -131,7 +131,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
|
|
||||||
data = self.loop.run_until_complete(read_task)
|
data = self.loop.run_until_complete(read_task)
|
||||||
self.assertEqual(b'', data)
|
self.assertEqual(b'', data)
|
||||||
self.assertFalse(stream._byte_count)
|
self.assertEqual(b'', stream._buffer)
|
||||||
|
|
||||||
def test_read_until_eof(self):
|
def test_read_until_eof(self):
|
||||||
# Read all bytes until eof.
|
# Read all bytes until eof.
|
||||||
|
@ -147,7 +147,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
data = self.loop.run_until_complete(read_task)
|
data = self.loop.run_until_complete(read_task)
|
||||||
|
|
||||||
self.assertEqual(b'chunk1\nchunk2', data)
|
self.assertEqual(b'chunk1\nchunk2', data)
|
||||||
self.assertFalse(stream._byte_count)
|
self.assertEqual(b'', stream._buffer)
|
||||||
|
|
||||||
def test_read_exception(self):
|
def test_read_exception(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
|
@ -161,7 +161,8 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
ValueError, self.loop.run_until_complete, stream.read(2))
|
ValueError, self.loop.run_until_complete, stream.read(2))
|
||||||
|
|
||||||
def test_readline(self):
|
def test_readline(self):
|
||||||
# Read one line.
|
# Read one line. 'readline' will need to wait for the data
|
||||||
|
# to come from 'cb'
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
stream.feed_data(b'chunk1 ')
|
stream.feed_data(b'chunk1 ')
|
||||||
read_task = asyncio.Task(stream.readline(), loop=self.loop)
|
read_task = asyncio.Task(stream.readline(), loop=self.loop)
|
||||||
|
@ -174,30 +175,40 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
|
|
||||||
line = self.loop.run_until_complete(read_task)
|
line = self.loop.run_until_complete(read_task)
|
||||||
self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
|
self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
|
||||||
self.assertEqual(len(b'\n chunk4')-1, stream._byte_count)
|
self.assertEqual(b' chunk4', stream._buffer)
|
||||||
|
|
||||||
def test_readline_limit_with_existing_data(self):
|
def test_readline_limit_with_existing_data(self):
|
||||||
stream = asyncio.StreamReader(3, loop=self.loop)
|
# Read one line. The data is in StreamReader's buffer
|
||||||
|
# before the event loop is run.
|
||||||
|
|
||||||
|
stream = asyncio.StreamReader(limit=3, loop=self.loop)
|
||||||
stream.feed_data(b'li')
|
stream.feed_data(b'li')
|
||||||
stream.feed_data(b'ne1\nline2\n')
|
stream.feed_data(b'ne1\nline2\n')
|
||||||
|
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValueError, self.loop.run_until_complete, stream.readline())
|
ValueError, self.loop.run_until_complete, stream.readline())
|
||||||
self.assertEqual([b'line2\n'], list(stream._buffer))
|
# The buffer should contain the remaining data after exception
|
||||||
|
self.assertEqual(b'line2\n', stream._buffer)
|
||||||
|
|
||||||
stream = asyncio.StreamReader(3, loop=self.loop)
|
stream = asyncio.StreamReader(limit=3, loop=self.loop)
|
||||||
stream.feed_data(b'li')
|
stream.feed_data(b'li')
|
||||||
stream.feed_data(b'ne1')
|
stream.feed_data(b'ne1')
|
||||||
stream.feed_data(b'li')
|
stream.feed_data(b'li')
|
||||||
|
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValueError, self.loop.run_until_complete, stream.readline())
|
ValueError, self.loop.run_until_complete, stream.readline())
|
||||||
self.assertEqual([b'li'], list(stream._buffer))
|
# No b'\n' at the end. The 'limit' is set to 3. So before
|
||||||
self.assertEqual(2, stream._byte_count)
|
# waiting for the new data in buffer, 'readline' will consume
|
||||||
|
# the entire buffer, and since the length of the consumed data
|
||||||
|
# is more than 3, it will raise a ValudError. The buffer is
|
||||||
|
# expected to be empty now.
|
||||||
|
self.assertEqual(b'', stream._buffer)
|
||||||
|
|
||||||
def test_readline_limit(self):
|
def test_readline_limit(self):
|
||||||
stream = asyncio.StreamReader(7, loop=self.loop)
|
# Read one line. StreamReaders are fed with data after
|
||||||
|
# their 'readline' methods are called.
|
||||||
|
|
||||||
|
stream = asyncio.StreamReader(limit=7, loop=self.loop)
|
||||||
def cb():
|
def cb():
|
||||||
stream.feed_data(b'chunk1')
|
stream.feed_data(b'chunk1')
|
||||||
stream.feed_data(b'chunk2')
|
stream.feed_data(b'chunk2')
|
||||||
|
@ -207,10 +218,25 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
|
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValueError, self.loop.run_until_complete, stream.readline())
|
ValueError, self.loop.run_until_complete, stream.readline())
|
||||||
self.assertEqual([b'chunk3\n'], list(stream._buffer))
|
# The buffer had just one line of data, and after raising
|
||||||
self.assertEqual(7, stream._byte_count)
|
# a ValueError it should be empty.
|
||||||
|
self.assertEqual(b'', stream._buffer)
|
||||||
|
|
||||||
def test_readline_line_byte_count(self):
|
stream = asyncio.StreamReader(limit=7, loop=self.loop)
|
||||||
|
def cb():
|
||||||
|
stream.feed_data(b'chunk1')
|
||||||
|
stream.feed_data(b'chunk2\n')
|
||||||
|
stream.feed_data(b'chunk3\n')
|
||||||
|
stream.feed_eof()
|
||||||
|
self.loop.call_soon(cb)
|
||||||
|
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError, self.loop.run_until_complete, stream.readline())
|
||||||
|
self.assertEqual(b'chunk3\n', stream._buffer)
|
||||||
|
|
||||||
|
def test_readline_nolimit_nowait(self):
|
||||||
|
# All needed data for the first 'readline' call will be
|
||||||
|
# in the buffer.
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
stream.feed_data(self.DATA[:6])
|
stream.feed_data(self.DATA[:6])
|
||||||
stream.feed_data(self.DATA[6:])
|
stream.feed_data(self.DATA[6:])
|
||||||
|
@ -218,7 +244,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
line = self.loop.run_until_complete(stream.readline())
|
line = self.loop.run_until_complete(stream.readline())
|
||||||
|
|
||||||
self.assertEqual(b'line1\n', line)
|
self.assertEqual(b'line1\n', line)
|
||||||
self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count)
|
self.assertEqual(b'line2\nline3\n', stream._buffer)
|
||||||
|
|
||||||
def test_readline_eof(self):
|
def test_readline_eof(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
|
@ -244,9 +270,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
data = self.loop.run_until_complete(stream.read(7))
|
data = self.loop.run_until_complete(stream.read(7))
|
||||||
|
|
||||||
self.assertEqual(b'line2\nl', data)
|
self.assertEqual(b'line2\nl', data)
|
||||||
self.assertEqual(
|
self.assertEqual(b'ine3\n', stream._buffer)
|
||||||
len(self.DATA) - len(b'line1\n') - len(b'line2\nl'),
|
|
||||||
stream._byte_count)
|
|
||||||
|
|
||||||
def test_readline_exception(self):
|
def test_readline_exception(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
|
@ -258,6 +282,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
stream.set_exception(ValueError())
|
stream.set_exception(ValueError())
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValueError, self.loop.run_until_complete, stream.readline())
|
ValueError, self.loop.run_until_complete, stream.readline())
|
||||||
|
self.assertEqual(b'', stream._buffer)
|
||||||
|
|
||||||
def test_readexactly_zero_or_less(self):
|
def test_readexactly_zero_or_less(self):
|
||||||
# Read exact number of bytes (zero or less).
|
# Read exact number of bytes (zero or less).
|
||||||
|
@ -266,11 +291,11 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
|
|
||||||
data = self.loop.run_until_complete(stream.readexactly(0))
|
data = self.loop.run_until_complete(stream.readexactly(0))
|
||||||
self.assertEqual(b'', data)
|
self.assertEqual(b'', data)
|
||||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
self.assertEqual(self.DATA, stream._buffer)
|
||||||
|
|
||||||
data = self.loop.run_until_complete(stream.readexactly(-1))
|
data = self.loop.run_until_complete(stream.readexactly(-1))
|
||||||
self.assertEqual(b'', data)
|
self.assertEqual(b'', data)
|
||||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
self.assertEqual(self.DATA, stream._buffer)
|
||||||
|
|
||||||
def test_readexactly(self):
|
def test_readexactly(self):
|
||||||
# Read exact number of bytes.
|
# Read exact number of bytes.
|
||||||
|
@ -287,7 +312,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
|
|
||||||
data = self.loop.run_until_complete(read_task)
|
data = self.loop.run_until_complete(read_task)
|
||||||
self.assertEqual(self.DATA + self.DATA, data)
|
self.assertEqual(self.DATA + self.DATA, data)
|
||||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
self.assertEqual(self.DATA, stream._buffer)
|
||||||
|
|
||||||
def test_readexactly_eof(self):
|
def test_readexactly_eof(self):
|
||||||
# Read exact number of bytes (eof).
|
# Read exact number of bytes (eof).
|
||||||
|
@ -306,7 +331,7 @@ class StreamReaderTests(unittest.TestCase):
|
||||||
self.assertEqual(cm.exception.expected, n)
|
self.assertEqual(cm.exception.expected, n)
|
||||||
self.assertEqual(str(cm.exception),
|
self.assertEqual(str(cm.exception),
|
||||||
'18 bytes read on a total of 36 expected bytes')
|
'18 bytes read on a total of 36 expected bytes')
|
||||||
self.assertFalse(stream._byte_count)
|
self.assertEqual(b'', stream._buffer)
|
||||||
|
|
||||||
def test_readexactly_exception(self):
|
def test_readexactly_exception(self):
|
||||||
stream = asyncio.StreamReader(loop=self.loop)
|
stream = asyncio.StreamReader(loop=self.loop)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue