Fixes Issue #14635: telnetlib will use poll() rather than select() when possible

to avoid failing due to the select() file descriptor limit.
This commit is contained in:
Gregory P. Smith 2012-07-15 23:42:26 -07:00
parent 4774946c3b
commit dad5711677
4 changed files with 223 additions and 7 deletions

View file

@ -75,8 +75,8 @@ class GeneralTests(TestCase):
class SocketStub(object):
''' a socket proxy that re-defines sendall() '''
def __init__(self, reads=[]):
self.reads = reads
def __init__(self, reads=()):
self.reads = list(reads) # Intentionally make a copy.
self.writes = []
self.block = False
def sendall(self, data):
@ -102,7 +102,7 @@ class TelnetAlike(telnetlib.Telnet):
self._messages += out.getvalue()
return
def new_select(*s_args):
def mock_select(*s_args):
block = False
for l in s_args:
for fob in l:
@ -113,6 +113,30 @@ def new_select(*s_args):
else:
return s_args
class MockPoller(object):
test_case = None # Set during TestCase setUp.
def __init__(self):
self._file_objs = []
def register(self, fd, eventmask):
self.test_case.assertTrue(hasattr(fd, 'fileno'), fd)
self.test_case.assertEqual(eventmask, select.POLLIN|select.POLLPRI)
self._file_objs.append(fd)
def poll(self, timeout=None):
block = False
for fob in self._file_objs:
if isinstance(fob, TelnetAlike):
block = fob.sock.block
if block:
return []
else:
return zip(self._file_objs, [select.POLLIN]*len(self._file_objs))
def unregister(self, fd):
self._file_objs.remove(fd)
@contextlib.contextmanager
def test_socket(reads):
def new_conn(*ignored):
@ -125,7 +149,7 @@ def test_socket(reads):
socket.create_connection = old_conn
return
def test_telnet(reads=[], cls=TelnetAlike):
def test_telnet(reads=(), cls=TelnetAlike, use_poll=None):
''' return a telnetlib.Telnet object that uses a SocketStub with
reads queued up to be read '''
for x in reads:
@ -133,15 +157,28 @@ def test_telnet(reads=[], cls=TelnetAlike):
with test_socket(reads):
telnet = cls('dummy', 0)
telnet._messages = '' # debuglevel output
if use_poll is not None:
if use_poll and not telnet._has_poll:
raise unittest.SkipTest('select.poll() required.')
telnet._has_poll = use_poll
return telnet
class ReadTests(TestCase):
class ExpectAndReadTestCase(TestCase):
def setUp(self):
self.old_select = select.select
select.select = new_select
self.old_poll = select.poll
select.select = mock_select
select.poll = MockPoller
MockPoller.test_case = self
def tearDown(self):
MockPoller.test_case = None
select.poll = self.old_poll
select.select = self.old_select
class ReadTests(ExpectAndReadTestCase):
def test_read_until(self):
"""
read_until(expected, timeout=None)
@ -158,6 +195,21 @@ class ReadTests(TestCase):
data = telnet.read_until(b'match')
self.assertEqual(data, expect)
def test_read_until_with_poll(self):
"""Use select.poll() to implement telnet.read_until()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=True)
select.select = lambda *_: self.fail('unexpected select() call.')
data = telnet.read_until(b'match')
self.assertEqual(data, b''.join(want[:-1]))
def test_read_until_with_select(self):
"""Use select.select() to implement telnet.read_until()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=False)
select.poll = lambda *_: self.fail('unexpected poll() call.')
data = telnet.read_until(b'match')
self.assertEqual(data, b''.join(want[:-1]))
def test_read_all(self):
"""
@ -349,8 +401,38 @@ class OptionTests(TestCase):
self.assertRegex(telnet._messages, r'0.*test')
class ExpectTests(ExpectAndReadTestCase):
def test_expect(self):
"""
expect(expected, [timeout])
Read until the expected string has been seen, or a timeout is
hit (default is no timeout); may block.
"""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want)
(_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1]))
def test_expect_with_poll(self):
"""Use select.poll() to implement telnet.expect()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=True)
select.select = lambda *_: self.fail('unexpected select() call.')
(_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1]))
def test_expect_with_select(self):
"""Use select.select() to implement telnet.expect()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=False)
select.poll = lambda *_: self.fail('unexpected poll() call.')
(_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1]))
def test_main(verbose=None):
support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests)
support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests,
ExpectTests)
if __name__ == '__main__':
test_main()