mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
asyncio: WriteTransport.set_write_buffer_size to call _maybe_pause_protocol
This commit is contained in:
parent
a6919aa4ed
commit
1589920977
2 changed files with 29 additions and 2 deletions
|
@ -241,7 +241,7 @@ class _FlowControlMixin(Transport):
|
||||||
def __init__(self, extra=None):
|
def __init__(self, extra=None):
|
||||||
super().__init__(extra)
|
super().__init__(extra)
|
||||||
self._protocol_paused = False
|
self._protocol_paused = False
|
||||||
self.set_write_buffer_limits()
|
self._set_write_buffer_limits()
|
||||||
|
|
||||||
def _maybe_pause_protocol(self):
|
def _maybe_pause_protocol(self):
|
||||||
size = self.get_write_buffer_size()
|
size = self.get_write_buffer_size()
|
||||||
|
@ -273,7 +273,7 @@ class _FlowControlMixin(Transport):
|
||||||
'protocol': self._protocol,
|
'protocol': self._protocol,
|
||||||
})
|
})
|
||||||
|
|
||||||
def set_write_buffer_limits(self, high=None, low=None):
|
def _set_write_buffer_limits(self, high=None, low=None):
|
||||||
if high is None:
|
if high is None:
|
||||||
if low is None:
|
if low is None:
|
||||||
high = 64*1024
|
high = 64*1024
|
||||||
|
@ -287,5 +287,9 @@ class _FlowControlMixin(Transport):
|
||||||
self._high_water = high
|
self._high_water = high
|
||||||
self._low_water = low
|
self._low_water = low
|
||||||
|
|
||||||
|
def set_write_buffer_limits(self, high=None, low=None):
|
||||||
|
self._set_write_buffer_limits(high=high, low=low)
|
||||||
|
self._maybe_pause_protocol()
|
||||||
|
|
||||||
def get_write_buffer_size(self):
|
def get_write_buffer_size(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -4,6 +4,7 @@ import unittest
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from asyncio import transports
|
||||||
|
|
||||||
|
|
||||||
class TransportTests(unittest.TestCase):
|
class TransportTests(unittest.TestCase):
|
||||||
|
@ -60,6 +61,28 @@ class TransportTests(unittest.TestCase):
|
||||||
self.assertRaises(NotImplementedError, transport.terminate)
|
self.assertRaises(NotImplementedError, transport.terminate)
|
||||||
self.assertRaises(NotImplementedError, transport.kill)
|
self.assertRaises(NotImplementedError, transport.kill)
|
||||||
|
|
||||||
|
def test_flowcontrol_mixin_set_write_limits(self):
|
||||||
|
|
||||||
|
class MyTransport(transports._FlowControlMixin,
|
||||||
|
transports.Transport):
|
||||||
|
|
||||||
|
def get_write_buffer_size(self):
|
||||||
|
return 512
|
||||||
|
|
||||||
|
transport = MyTransport()
|
||||||
|
transport._protocol = unittest.mock.Mock()
|
||||||
|
|
||||||
|
self.assertFalse(transport._protocol_paused)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, 'high.*must be >= low'):
|
||||||
|
transport.set_write_buffer_limits(high=0, low=1)
|
||||||
|
|
||||||
|
transport.set_write_buffer_limits(high=1024, low=128)
|
||||||
|
self.assertFalse(transport._protocol_paused)
|
||||||
|
|
||||||
|
transport.set_write_buffer_limits(high=256, low=128)
|
||||||
|
self.assertTrue(transport._protocol_paused)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue