GH-74116: Allow multiple drain waiters for asyncio.StreamWriter (GH-94705) (#96395)

(cherry picked from commit e5b2453e61)

Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>

Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
This commit is contained in:
Miss Islington (bot) 2022-08-30 04:00:21 -07:00 committed by GitHub
parent 126ec34558
commit 2e9f29e6a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 19 deletions

View file

@ -2,6 +2,7 @@ __all__ = (
'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server') 'open_connection', 'start_server')
import collections
import socket import socket
import sys import sys
import warnings import warnings
@ -129,7 +130,7 @@ class FlowControlMixin(protocols.Protocol):
else: else:
self._loop = loop self._loop = loop
self._paused = False self._paused = False
self._drain_waiter = None self._drain_waiters = collections.deque()
self._connection_lost = False self._connection_lost = False
def pause_writing(self): def pause_writing(self):
@ -144,38 +145,34 @@ class FlowControlMixin(protocols.Protocol):
if self._loop.get_debug(): if self._loop.get_debug():
logger.debug("%r resumes writing", self) logger.debug("%r resumes writing", self)
waiter = self._drain_waiter for waiter in self._drain_waiters:
if waiter is not None:
self._drain_waiter = None
if not waiter.done(): if not waiter.done():
waiter.set_result(None) waiter.set_result(None)
def connection_lost(self, exc): def connection_lost(self, exc):
self._connection_lost = True self._connection_lost = True
# Wake up the writer if currently paused. # Wake up the writer(s) if currently paused.
if not self._paused: if not self._paused:
return return
waiter = self._drain_waiter
if waiter is None: for waiter in self._drain_waiters:
return if not waiter.done():
self._drain_waiter = None if exc is None:
if waiter.done(): waiter.set_result(None)
return else:
if exc is None: waiter.set_exception(exc)
waiter.set_result(None)
else:
waiter.set_exception(exc)
async def _drain_helper(self): async def _drain_helper(self):
if self._connection_lost: if self._connection_lost:
raise ConnectionResetError('Connection lost') raise ConnectionResetError('Connection lost')
if not self._paused: if not self._paused:
return return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = self._loop.create_future() waiter = self._loop.create_future()
self._drain_waiter = waiter self._drain_waiters.append(waiter)
await waiter try:
await waiter
finally:
self._drain_waiters.remove(waiter)
def _get_close_waiter(self, stream): def _get_close_waiter(self, stream):
raise NotImplementedError raise NotImplementedError

View file

@ -864,6 +864,25 @@ os.close(fd)
self.assertEqual(cm.filename, __file__) self.assertEqual(cm.filename, __file__)
self.assertIs(protocol._loop, self.loop) self.assertIs(protocol._loop, self.loop)
def test_multiple_drain(self):
# See https://github.com/python/cpython/issues/74116
drained = 0
async def drainer(stream):
nonlocal drained
await stream._drain_helper()
drained += 1
async def main():
loop = asyncio.get_running_loop()
stream = asyncio.streams.FlowControlMixin(loop)
stream.pause_writing()
loop.call_later(0.1, stream.resume_writing)
await asyncio.gather(*[drainer(stream) for _ in range(10)])
self.assertEqual(drained, 10)
self.loop.run_until_complete(main())
def test_drain_raises(self): def test_drain_raises(self):
# See http://bugs.python.org/issue25441 # See http://bugs.python.org/issue25441

View file

@ -0,0 +1 @@
Allow :meth:`asyncio.StreamWriter.drain` to be awaited concurrently by multiple tasks. Patch by Kumar Aditya.