#3566: Clean up handling of remote server disconnects.

This changeset does two things: introduces a new RemoteDisconnected exception
(that subclasses ConnectionResetError and BadStatusLine) so that a remote
server disconnection can be detected by client code (and provides a better
error message for debugging purposes), and ensures that the client socket is
closed if a ConnectionError happens, so that the automatic re-connection code
can work if the application handles the error and continues on.

Tests are added that confirm that a connection is re-used or not re-used
as appropriate to the various combinations of protocol version and headers.

Patch by Martin Panter, reviewed by Demian Brecht.  (Tweaked only slightly by
me.)
This commit is contained in:
R David Murray 2015-04-05 19:26:29 -04:00
parent 142bf565b4
commit cae7bdb424
4 changed files with 131 additions and 10 deletions

View file

@ -107,6 +107,23 @@ class NoEOFBytesIO(io.BytesIO):
raise AssertionError('caller tried to read past EOF')
return data
class FakeSocketHTTPConnection(client.HTTPConnection):
"""HTTPConnection subclass using FakeSocket; counts connect() calls"""
def __init__(self, *args):
self.connections = 0
super().__init__('example.com')
self.fake_socket_args = args
self._create_connection = self.create_connection
def connect(self):
"""Count the number of times connect() is invoked"""
self.connections += 1
return super().connect()
def create_connection(self, *pos, **kw):
return FakeSocket(*self.fake_socket_args)
class HeaderTests(TestCase):
def test_auto_headers(self):
# Some headers are added automatically, but should not be added by
@ -777,7 +794,7 @@ class BasicTest(TestCase):
response = self # Avoid garbage collector closing the socket
client.HTTPResponse.__init__(self, *pos, **kw)
conn.response_class = Response
conn.sock = FakeSocket('') # Emulate server dropping connection
conn.sock = FakeSocket('Invalid status line')
conn.request('GET', '/')
self.assertRaises(client.BadStatusLine, conn.getresponse)
self.assertTrue(response.closed)
@ -1174,6 +1191,78 @@ class TimeoutTest(TestCase):
httpConn.close()
class PersistenceTest(TestCase):
def test_reuse_reconnect(self):
# Should reuse or reconnect depending on header from server
tests = (
('1.0', '', False),
('1.0', 'Connection: keep-alive\r\n', True),
('1.1', '', True),
('1.1', 'Connection: close\r\n', False),
('1.0', 'Connection: keep-ALIVE\r\n', True),
('1.1', 'Connection: cloSE\r\n', False),
)
for version, header, reuse in tests:
with self.subTest(version=version, header=header):
msg = (
'HTTP/{} 200 OK\r\n'
'{}'
'Content-Length: 12\r\n'
'\r\n'
'Dummy body\r\n'
).format(version, header)
conn = FakeSocketHTTPConnection(msg)
self.assertIsNone(conn.sock)
conn.request('GET', '/open-connection')
with conn.getresponse() as response:
self.assertEqual(conn.sock is None, not reuse)
response.read()
self.assertEqual(conn.sock is None, not reuse)
self.assertEqual(conn.connections, 1)
conn.request('GET', '/subsequent-request')
self.assertEqual(conn.connections, 1 if reuse else 2)
def test_disconnected(self):
def make_reset_reader(text):
"""Return BufferedReader that raises ECONNRESET at EOF"""
stream = io.BytesIO(text)
def readinto(buffer):
size = io.BytesIO.readinto(stream, buffer)
if size == 0:
raise ConnectionResetError()
return size
stream.readinto = readinto
return io.BufferedReader(stream)
tests = (
(io.BytesIO, client.RemoteDisconnected),
(make_reset_reader, ConnectionResetError),
)
for stream_factory, exception in tests:
with self.subTest(exception=exception):
conn = FakeSocketHTTPConnection(b'', stream_factory)
conn.request('GET', '/eof-response')
self.assertRaises(exception, conn.getresponse)
self.assertIsNone(conn.sock)
# HTTPConnection.connect() should be automatically invoked
conn.request('GET', '/reconnect')
self.assertEqual(conn.connections, 2)
def test_100_close(self):
conn = FakeSocketHTTPConnection(
b'HTTP/1.1 100 Continue\r\n'
b'\r\n'
# Missing final response
)
conn.request('GET', '/', headers={'Expect': '100-continue'})
self.assertRaises(client.RemoteDisconnected, conn.getresponse)
self.assertIsNone(conn.sock)
conn.request('GET', '/reconnect')
self.assertEqual(conn.connections, 2)
class HTTPSTest(TestCase):
def setUp(self):
@ -1513,6 +1602,7 @@ class TunnelTests(TestCase):
@support.reap_threads
def test_main(verbose=None):
support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
PersistenceTest,
HTTPSTest, RequestBodyTest, SourceAddressTest,
HTTPResponseTest, ExtendedReadTest,
ExtendedReadTestChunked, TunnelTests)