mirror of
https://github.com/python/cpython.git
synced 2025-07-23 11:15:24 +00:00
GH-74116: Allow multiple drain waiters for asyncio.StreamWriter (GH-94705) (#96395)
(cherry picked from commit e5b2453e61
)
Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
This commit is contained in:
parent
126ec34558
commit
2e9f29e6a6
3 changed files with 36 additions and 19 deletions
|
@ -2,6 +2,7 @@ __all__ = (
|
||||||
'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
|
'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
|
||||||
'open_connection', 'start_server')
|
'open_connection', 'start_server')
|
||||||
|
|
||||||
|
import collections
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -129,7 +130,7 @@ class FlowControlMixin(protocols.Protocol):
|
||||||
else:
|
else:
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
self._paused = False
|
self._paused = False
|
||||||
self._drain_waiter = None
|
self._drain_waiters = collections.deque()
|
||||||
self._connection_lost = False
|
self._connection_lost = False
|
||||||
|
|
||||||
def pause_writing(self):
|
def pause_writing(self):
|
||||||
|
@ -144,38 +145,34 @@ class FlowControlMixin(protocols.Protocol):
|
||||||
if self._loop.get_debug():
|
if self._loop.get_debug():
|
||||||
logger.debug("%r resumes writing", self)
|
logger.debug("%r resumes writing", self)
|
||||||
|
|
||||||
waiter = self._drain_waiter
|
for waiter in self._drain_waiters:
|
||||||
if waiter is not None:
|
|
||||||
self._drain_waiter = None
|
|
||||||
if not waiter.done():
|
if not waiter.done():
|
||||||
waiter.set_result(None)
|
waiter.set_result(None)
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
self._connection_lost = True
|
self._connection_lost = True
|
||||||
# Wake up the writer if currently paused.
|
# Wake up the writer(s) if currently paused.
|
||||||
if not self._paused:
|
if not self._paused:
|
||||||
return
|
return
|
||||||
waiter = self._drain_waiter
|
|
||||||
if waiter is None:
|
for waiter in self._drain_waiters:
|
||||||
return
|
if not waiter.done():
|
||||||
self._drain_waiter = None
|
if exc is None:
|
||||||
if waiter.done():
|
waiter.set_result(None)
|
||||||
return
|
else:
|
||||||
if exc is None:
|
waiter.set_exception(exc)
|
||||||
waiter.set_result(None)
|
|
||||||
else:
|
|
||||||
waiter.set_exception(exc)
|
|
||||||
|
|
||||||
async def _drain_helper(self):
|
async def _drain_helper(self):
|
||||||
if self._connection_lost:
|
if self._connection_lost:
|
||||||
raise ConnectionResetError('Connection lost')
|
raise ConnectionResetError('Connection lost')
|
||||||
if not self._paused:
|
if not self._paused:
|
||||||
return
|
return
|
||||||
waiter = self._drain_waiter
|
|
||||||
assert waiter is None or waiter.cancelled()
|
|
||||||
waiter = self._loop.create_future()
|
waiter = self._loop.create_future()
|
||||||
self._drain_waiter = waiter
|
self._drain_waiters.append(waiter)
|
||||||
await waiter
|
try:
|
||||||
|
await waiter
|
||||||
|
finally:
|
||||||
|
self._drain_waiters.remove(waiter)
|
||||||
|
|
||||||
def _get_close_waiter(self, stream):
|
def _get_close_waiter(self, stream):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -864,6 +864,25 @@ os.close(fd)
|
||||||
self.assertEqual(cm.filename, __file__)
|
self.assertEqual(cm.filename, __file__)
|
||||||
self.assertIs(protocol._loop, self.loop)
|
self.assertIs(protocol._loop, self.loop)
|
||||||
|
|
||||||
|
def test_multiple_drain(self):
|
||||||
|
# See https://github.com/python/cpython/issues/74116
|
||||||
|
drained = 0
|
||||||
|
|
||||||
|
async def drainer(stream):
|
||||||
|
nonlocal drained
|
||||||
|
await stream._drain_helper()
|
||||||
|
drained += 1
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
stream = asyncio.streams.FlowControlMixin(loop)
|
||||||
|
stream.pause_writing()
|
||||||
|
loop.call_later(0.1, stream.resume_writing)
|
||||||
|
await asyncio.gather(*[drainer(stream) for _ in range(10)])
|
||||||
|
self.assertEqual(drained, 10)
|
||||||
|
|
||||||
|
self.loop.run_until_complete(main())
|
||||||
|
|
||||||
def test_drain_raises(self):
|
def test_drain_raises(self):
|
||||||
# See http://bugs.python.org/issue25441
|
# See http://bugs.python.org/issue25441
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Allow :meth:`asyncio.StreamWriter.drain` to be awaited concurrently by multiple tasks. Patch by Kumar Aditya.
|
Loading…
Add table
Add a link
Reference in a new issue