GH-74116: Allow multiple drain waiters for asyncio.StreamWriter (GH-94705) (#96395)

(cherry picked from commit e5b2453e61)

Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>

Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
This commit is contained in:
Miss Islington (bot) 2022-08-30 04:00:21 -07:00 committed by GitHub
parent 126ec34558
commit 2e9f29e6a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 19 deletions

View file

@ -2,6 +2,7 @@ __all__ = (
'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server')
import collections
import socket
import sys
import warnings
@ -129,7 +130,7 @@ class FlowControlMixin(protocols.Protocol):
else:
self._loop = loop
self._paused = False
self._drain_waiter = None
self._drain_waiters = collections.deque()
self._connection_lost = False
def pause_writing(self):
@ -144,38 +145,34 @@ class FlowControlMixin(protocols.Protocol):
if self._loop.get_debug():
logger.debug("%r resumes writing", self)
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
for waiter in self._drain_waiters:
if not waiter.done():
waiter.set_result(None)
def connection_lost(self, exc):
self._connection_lost = True
# Wake up the writer if currently paused.
# Wake up the writer(s) if currently paused.
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
for waiter in self._drain_waiters:
if not waiter.done():
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
async def _drain_helper(self):
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused:
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = self._loop.create_future()
self._drain_waiter = waiter
await waiter
self._drain_waiters.append(waiter)
try:
await waiter
finally:
self._drain_waiters.remove(waiter)
def _get_close_waiter(self, stream):
raise NotImplementedError