bpo-34638: Store a weak reference to stream reader to break strong references loop (GH-9201)

Store a weak reference to stream readerfor breaking strong references

It breaks the strong reference loop between reader and protocol and allows to detect and close the socket if the stream is deleted (garbage collected)
This commit is contained in:
Andrew Svetlov 2018-09-12 11:43:04 -07:00 committed by GitHub
parent aca819fb49
commit a5d1eb8d8b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 10 deletions

View file

@ -46,6 +46,8 @@ class StreamTests(test_utils.TestCase):
self.assertIs(stream._loop, m_events.get_event_loop.return_value)
def _basetest_open_connection(self, open_connection_fut):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
reader, writer = self.loop.run_until_complete(open_connection_fut)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.readline()
@ -55,6 +57,7 @@ class StreamTests(test_utils.TestCase):
data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close()
self.assertEqual(messages, [])
def test_open_connection(self):
with test_utils.run_test_server() as httpd:
@ -70,6 +73,8 @@ class StreamTests(test_utils.TestCase):
self._basetest_open_connection(conn_fut)
def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
try:
reader, writer = self.loop.run_until_complete(open_connection_fut)
finally:
@ -80,6 +85,7 @@ class StreamTests(test_utils.TestCase):
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close()
self.assertEqual(messages, [])
@unittest.skipIf(ssl is None, 'No ssl module')
def test_open_connection_no_loop_ssl(self):
@ -104,6 +110,8 @@ class StreamTests(test_utils.TestCase):
self._basetest_open_connection_no_loop_ssl(conn_fut)
def _basetest_open_connection_error(self, open_connection_fut):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
reader, writer = self.loop.run_until_complete(open_connection_fut)
writer._protocol.connection_lost(ZeroDivisionError())
f = reader.read()
@ -111,6 +119,7 @@ class StreamTests(test_utils.TestCase):
self.loop.run_until_complete(f)
writer.close()
test_utils.run_briefly(self.loop)
self.assertEqual(messages, [])
def test_open_connection_error(self):
with test_utils.run_test_server() as httpd:
@ -621,6 +630,9 @@ class StreamTests(test_utils.TestCase):
writer.close()
return msgback
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
# test the server variant with a coroutine as client handler
server = MyServer(self.loop)
addr = server.start()
@ -637,6 +649,8 @@ class StreamTests(test_utils.TestCase):
server.stop()
self.assertEqual(msg, b"hello world!\n")
self.assertEqual(messages, [])
@support.skip_unless_bind_unix_socket
def test_start_unix_server(self):
@ -685,6 +699,9 @@ class StreamTests(test_utils.TestCase):
writer.close()
return msgback
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
# test the server variant with a coroutine as client handler
with test_utils.unix_socket_path() as path:
server = MyServer(self.loop, path)
@ -703,6 +720,8 @@ class StreamTests(test_utils.TestCase):
server.stop()
self.assertEqual(msg, b"hello world!\n")
self.assertEqual(messages, [])
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
def test_read_all_from_pipe_reader(self):
# See asyncio issue 168. This test is derived from the example
@ -893,6 +912,58 @@ os.close(fd)
wr.close()
self.loop.run_until_complete(wr.wait_closed())
def test_del_stream_before_sock_closing(self):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
with test_utils.run_test_server() as httpd:
rd, wr = self.loop.run_until_complete(
asyncio.open_connection(*httpd.address, loop=self.loop))
sock = wr.get_extra_info('socket')
self.assertNotEqual(sock.fileno(), -1)
wr.write(b'GET / HTTP/1.0\r\n\r\n')
f = rd.readline()
data = self.loop.run_until_complete(f)
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
# drop refs to reader/writer
del rd
del wr
gc.collect()
# make a chance to close the socket
test_utils.run_briefly(self.loop)
self.assertEqual(1, len(messages))
self.assertEqual(sock.fileno(), -1)
self.assertEqual(1, len(messages))
self.assertEqual('An open stream object is being garbage '
'collected; call "stream.close()" explicitly.',
messages[0]['message'])
def test_del_stream_before_connection_made(self):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
with test_utils.run_test_server() as httpd:
rd = asyncio.StreamReader(loop=self.loop)
pr = asyncio.StreamReaderProtocol(rd, loop=self.loop)
del rd
gc.collect()
tr, _ = self.loop.run_until_complete(
self.loop.create_connection(
lambda: pr, *httpd.address))
sock = tr.get_extra_info('socket')
self.assertEqual(sock.fileno(), -1)
self.assertEqual(1, len(messages))
self.assertEqual('An open stream was garbage collected prior to '
'establishing network connection; '
'call "stream.close()" explicitly.',
messages[0]['message'])
if __name__ == '__main__':
unittest.main()