asyncio: Refactor tests: add a base TestCase class

This commit is contained in:
Victor Stinner 2014-06-18 01:36:32 +02:00
parent d6f02fc649
commit c73701de72
13 changed files with 145 additions and 219 deletions

View file

@ -29,14 +29,11 @@ MOCK_ANY = mock.ANY
@unittest.skipUnless(signal, 'Signals are not supported')
class SelectorEventLoopSignalTests(unittest.TestCase):
class SelectorEventLoopSignalTests(test_utils.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.set_event_loop(self.loop)
def test_check_signal(self):
self.assertRaises(
@ -208,14 +205,11 @@ class SelectorEventLoopSignalTests(unittest.TestCase):
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
'UNIX Sockets are not supported')
class SelectorEventLoopUnixSocketTests(unittest.TestCase):
class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.set_event_loop(self.loop)
def test_create_unix_server_existing_path_sock(self):
with test_utils.unix_socket_path() as path:
@ -304,10 +298,10 @@ class SelectorEventLoopUnixSocketTests(unittest.TestCase):
self.loop.run_until_complete(coro)
class UnixReadPipeTransportTests(unittest.TestCase):
class UnixReadPipeTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5
@ -451,7 +445,7 @@ class UnixReadPipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop),
self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop)))
def test__call_connection_lost_with_err(self):
@ -468,14 +462,14 @@ class UnixReadPipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop),
self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop)))
class UnixWritePipeTransportTests(unittest.TestCase):
class UnixWritePipeTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5
@ -737,7 +731,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop),
self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop)))
def test__call_connection_lost_with_err(self):
@ -753,7 +747,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop),
self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop)))
def test_close(self):
@ -834,7 +828,7 @@ class ChildWatcherTestsMixin:
ignore_warnings = mock.patch.object(log.logger, "warning")
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.running = False
self.zombies = {}
@ -1392,7 +1386,7 @@ class ChildWatcherTestsMixin:
# attach a new loop
old_loop = self.loop
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
patch = mock.patch.object
with patch(old_loop, "remove_signal_handler") as m_old_remove, \
@ -1447,7 +1441,7 @@ class ChildWatcherTestsMixin:
self.assertFalse(callback3.called)
# attach a new loop
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
with mock.patch.object(
self.loop, "add_signal_handler") as m_add_signal_handler:
@ -1505,12 +1499,12 @@ class ChildWatcherTestsMixin:
self.assertFalse(self.watcher._zombies)
class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
def create_watcher(self):
return asyncio.SafeChildWatcher()
class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
def create_watcher(self):
return asyncio.FastChildWatcher()