bpo-45046: Support context managers in unittest (GH-28045)

Add methods enterContext() and enterClassContext() in TestCase.
Add method enterAsyncContext() in IsolatedAsyncioTestCase.
Add function enterModuleContext().
(cherry picked from commit 086c6b1b0f)

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
This commit is contained in:
Miss Islington (bot) 2022-05-08 08:12:19 -07:00 committed by GitHub
parent a85bdd7e02
commit c63c8ac238
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 307 additions and 92 deletions

View file

@ -1495,6 +1495,16 @@ Test cases
.. versionadded:: 3.1 .. versionadded:: 3.1
.. method:: enterContext(cm)
Enter the supplied :term:`context manager`. If successful, also
add its :meth:`~object.__exit__` method as a cleanup function by
:meth:`addCleanup` and return the result of the
:meth:`~object.__enter__` method.
.. versionadded:: 3.11
.. method:: doCleanups() .. method:: doCleanups()
This method is called unconditionally after :meth:`tearDown`, or This method is called unconditionally after :meth:`tearDown`, or
@ -1510,6 +1520,7 @@ Test cases
.. versionadded:: 3.1 .. versionadded:: 3.1
.. classmethod:: addClassCleanup(function, /, *args, **kwargs) .. classmethod:: addClassCleanup(function, /, *args, **kwargs)
Add a function to be called after :meth:`tearDownClass` to cleanup Add a function to be called after :meth:`tearDownClass` to cleanup
@ -1524,6 +1535,16 @@ Test cases
.. versionadded:: 3.8 .. versionadded:: 3.8
.. classmethod:: enterClassContext(cm)
Enter the supplied :term:`context manager`. If successful, also
add its :meth:`~object.__exit__` method as a cleanup function by
:meth:`addClassCleanup` and return the result of the
:meth:`~object.__enter__` method.
.. versionadded:: 3.11
.. classmethod:: doClassCleanups() .. classmethod:: doClassCleanups()
This method is called unconditionally after :meth:`tearDownClass`, or This method is called unconditionally after :meth:`tearDownClass`, or
@ -1571,6 +1592,16 @@ Test cases
This method accepts a coroutine that can be used as a cleanup function. This method accepts a coroutine that can be used as a cleanup function.
.. coroutinemethod:: enterAsyncContext(cm)
Enter the supplied :term:`asynchronous context manager`. If successful,
also add its :meth:`~object.__aexit__` method as a cleanup function by
:meth:`addAsyncCleanup` and return the result of the
:meth:`~object.__aenter__` method.
.. versionadded:: 3.11
.. method:: run(result=None) .. method:: run(result=None)
Sets up a new event loop to run the test, collecting the result into Sets up a new event loop to run the test, collecting the result into
@ -2465,6 +2496,16 @@ To add cleanup code that must be run even in the case of an exception, use
.. versionadded:: 3.8 .. versionadded:: 3.8
.. classmethod:: enterModuleContext(cm)
Enter the supplied :term:`context manager`. If successful, also
add its :meth:`~object.__exit__` method as a cleanup function by
:func:`addModuleCleanup` and return the result of the
:meth:`~object.__enter__` method.
.. versionadded:: 3.11
.. function:: doModuleCleanups() .. function:: doModuleCleanups()
This function is called unconditionally after :func:`tearDownModule`, or This function is called unconditionally after :func:`tearDownModule`, or
@ -2480,6 +2521,7 @@ To add cleanup code that must be run even in the case of an exception, use
.. versionadded:: 3.8 .. versionadded:: 3.8
Signal Handling Signal Handling
--------------- ---------------

View file

@ -758,6 +758,18 @@ unicodedata
* The Unicode database has been updated to version 14.0.0. (:issue:`45190`). * The Unicode database has been updated to version 14.0.0. (:issue:`45190`).
unittest
--------
* Added methods :meth:`~unittest.TestCase.enterContext` and
:meth:`~unittest.TestCase.enterClassContext` of class
:class:`~unittest.TestCase`, method
:meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of
class :class:`~unittest.IsolatedAsyncioTestCase` and function
:func:`unittest.enterModuleContext`.
(Contributed by Serhiy Storchaka in :issue:`45046`.)
venv venv
---- ----

View file

@ -41,9 +41,7 @@ class BuildExtTestCase(TempdirManager,
# bpo-30132: On Windows, a .pdb file may be created in the current # bpo-30132: On Windows, a .pdb file may be created in the current
# working directory. Create a temporary working directory to cleanup # working directory. Create a temporary working directory to cleanup
# everything at the end of the test. # everything at the end of the test.
change_cwd = os_helper.change_cwd(self.tmp_dir) self.enterContext(os_helper.change_cwd(self.tmp_dir))
change_cwd.__enter__()
self.addCleanup(change_cwd.__exit__, None, None, None)
def tearDown(self): def tearDown(self):
import site import site

View file

@ -19,8 +19,7 @@ class Test_OSXSupport(unittest.TestCase):
self.maxDiff = None self.maxDiff = None
self.prog_name = 'bogus_program_xxxx' self.prog_name = 'bogus_program_xxxx'
self.temp_path_dir = os.path.abspath(os.getcwd()) self.temp_path_dir = os.path.abspath(os.getcwd())
self.env = os_helper.EnvironmentVarGuard() self.env = self.enterContext(os_helper.EnvironmentVarGuard())
self.addCleanup(self.env.__exit__)
for cv in ('CFLAGS', 'LDFLAGS', 'CPPFLAGS', for cv in ('CFLAGS', 'LDFLAGS', 'CPPFLAGS',
'BASECFLAGS', 'BLDSHARED', 'LDSHARED', 'CC', 'BASECFLAGS', 'BLDSHARED', 'LDSHARED', 'CC',
'CXX', 'PY_CFLAGS', 'PY_LDFLAGS', 'PY_CPPFLAGS', 'CXX', 'PY_CFLAGS', 'PY_LDFLAGS', 'PY_CPPFLAGS',

View file

@ -41,9 +41,8 @@ class TestCase(unittest.TestCase):
# The tests assume that line wrapping occurs at 80 columns, but this # The tests assume that line wrapping occurs at 80 columns, but this
# behaviour can be overridden by setting the COLUMNS environment # behaviour can be overridden by setting the COLUMNS environment
# variable. To ensure that this width is used, set COLUMNS to 80. # variable. To ensure that this width is used, set COLUMNS to 80.
env = os_helper.EnvironmentVarGuard() env = self.enterContext(os_helper.EnvironmentVarGuard())
env['COLUMNS'] = '80' env['COLUMNS'] = '80'
self.addCleanup(env.__exit__)
class TempDirMixin(object): class TempDirMixin(object):
@ -3428,9 +3427,8 @@ class TestShortColumns(HelpTestCase):
but we don't want any exceptions thrown in such cases. Only ugly representation. but we don't want any exceptions thrown in such cases. Only ugly representation.
''' '''
def setUp(self): def setUp(self):
env = os_helper.EnvironmentVarGuard() env = self.enterContext(os_helper.EnvironmentVarGuard())
env.set("COLUMNS", '15') env.set("COLUMNS", '15')
self.addCleanup(env.__exit__)
parser_signature = TestHelpBiggerOptionals.parser_signature parser_signature = TestHelpBiggerOptionals.parser_signature
argument_signatures = TestHelpBiggerOptionals.argument_signatures argument_signatures = TestHelpBiggerOptionals.argument_signatures

View file

@ -11,14 +11,10 @@ sentinel = object()
class GetoptTests(unittest.TestCase): class GetoptTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.env = EnvironmentVarGuard() self.env = self.enterContext(EnvironmentVarGuard())
if "POSIXLY_CORRECT" in self.env: if "POSIXLY_CORRECT" in self.env:
del self.env["POSIXLY_CORRECT"] del self.env["POSIXLY_CORRECT"]
def tearDown(self):
self.env.__exit__()
del self.env
def assertError(self, *args, **kwargs): def assertError(self, *args, **kwargs):
self.assertRaises(getopt.GetoptError, *args, **kwargs) self.assertRaises(getopt.GetoptError, *args, **kwargs)

View file

@ -117,6 +117,7 @@ MMOFILE = os.path.join(LOCALEDIR, 'metadata.mo')
class GettextBaseTest(unittest.TestCase): class GettextBaseTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.addCleanup(os_helper.rmtree, os.path.split(LOCALEDIR)[0])
if not os.path.isdir(LOCALEDIR): if not os.path.isdir(LOCALEDIR):
os.makedirs(LOCALEDIR) os.makedirs(LOCALEDIR)
with open(MOFILE, 'wb') as fp: with open(MOFILE, 'wb') as fp:
@ -129,14 +130,10 @@ class GettextBaseTest(unittest.TestCase):
fp.write(base64.decodebytes(UMO_DATA)) fp.write(base64.decodebytes(UMO_DATA))
with open(MMOFILE, 'wb') as fp: with open(MMOFILE, 'wb') as fp:
fp.write(base64.decodebytes(MMO_DATA)) fp.write(base64.decodebytes(MMO_DATA))
self.env = os_helper.EnvironmentVarGuard() self.env = self.enterContext(os_helper.EnvironmentVarGuard())
self.env['LANGUAGE'] = 'xx' self.env['LANGUAGE'] = 'xx'
gettext._translations.clear() gettext._translations.clear()
def tearDown(self):
self.env.__exit__()
del self.env
os_helper.rmtree(os.path.split(LOCALEDIR)[0])
GNU_MO_DATA_ISSUE_17898 = b'''\ GNU_MO_DATA_ISSUE_17898 = b'''\
3hIElQAAAAABAAAAHAAAACQAAAAAAAAAAAAAAAAAAAAsAAAAggAAAC0AAAAAUGx1cmFsLUZvcm1z 3hIElQAAAAABAAAAHAAAACQAAAAAAAAAAAAAAAAAAAAsAAAAggAAAC0AAAAAUGx1cmFsLUZvcm1z

View file

@ -9,14 +9,9 @@ import warnings
class GlobalTests(unittest.TestCase): class GlobalTests(unittest.TestCase):
def setUp(self): def setUp(self):
self._warnings_manager = check_warnings() self.enterContext(check_warnings())
self._warnings_manager.__enter__()
warnings.filterwarnings("error", module="<test string>") warnings.filterwarnings("error", module="<test string>")
def tearDown(self):
self._warnings_manager.__exit__(None, None, None)
def test1(self): def test1(self):
prog_text_1 = """\ prog_text_1 = """\
def wrong1(): def wrong1():
@ -54,9 +49,7 @@ x = 2
def setUpModule(): def setUpModule():
cm = warnings.catch_warnings() unittest.enterModuleContext(warnings.catch_warnings())
cm.__enter__()
unittest.addModuleCleanup(cm.__exit__, None, None, None)
warnings.filterwarnings("error", module="<test string>") warnings.filterwarnings("error", module="<test string>")

View file

@ -157,21 +157,12 @@ class FinderTests(abc.FinderTests):
def test_no_read_directory(self): def test_no_read_directory(self):
# Issue #16730 # Issue #16730
tempdir = tempfile.TemporaryDirectory() tempdir = tempfile.TemporaryDirectory()
self.enterContext(tempdir)
# Since we muck with the permissions, we want to set them back to
# their original values to make sure the directory can be properly
# cleaned up.
original_mode = os.stat(tempdir.name).st_mode original_mode = os.stat(tempdir.name).st_mode
def cleanup(tempdir): self.addCleanup(os.chmod, tempdir.name, original_mode)
"""Cleanup function for the temporary directory.
Since we muck with the permissions, we want to set them back to
their original values to make sure the directory can be properly
cleaned up.
"""
os.chmod(tempdir.name, original_mode)
# If this is not explicitly called then the __del__ method is used,
# but since already mucking around might as well explicitly clean
# up.
tempdir.__exit__(None, None, None)
self.addCleanup(cleanup, tempdir)
os.chmod(tempdir.name, stat.S_IWUSR | stat.S_IXUSR) os.chmod(tempdir.name, stat.S_IWUSR | stat.S_IXUSR)
finder = self.get_finder(tempdir.name) finder = self.get_finder(tempdir.name)
found = self._find(finder, 'doesnotexist') found = self._find(finder, 'doesnotexist')

View file

@ -65,12 +65,7 @@ class NamespacePackageTest(unittest.TestCase):
self.resolved_paths = [ self.resolved_paths = [
os.path.join(self.root, path) for path in self.paths os.path.join(self.root, path) for path in self.paths
] ]
self.ctx = namespace_tree_context(path=self.resolved_paths) self.enterContext(namespace_tree_context(path=self.resolved_paths))
self.ctx.__enter__()
def tearDown(self):
# TODO: will we ever want to pass exc_info to __exit__?
self.ctx.__exit__(None, None, None)
class SingleNamespacePackage(NamespacePackageTest): class SingleNamespacePackage(NamespacePackageTest):

View file

@ -5650,9 +5650,7 @@ class MiscTestCase(unittest.TestCase):
# why the test does this, but in any case we save the current locale # why the test does this, but in any case we save the current locale
# first and restore it at the end. # first and restore it at the end.
def setUpModule(): def setUpModule():
cm = support.run_with_locale('LC_ALL', '') unittest.enterModuleContext(support.run_with_locale('LC_ALL', ''))
cm.__enter__()
unittest.addModuleCleanup(cm.__exit__, None, None, None)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1593,8 +1593,7 @@ class LocalServerTests(unittest.TestCase):
self.background.start() self.background.start()
self.addCleanup(self.background.join) self.addCleanup(self.background.join)
self.nntp = NNTP(socket_helper.HOST, port, usenetrc=False).__enter__() self.nntp = self.enterContext(NNTP(socket_helper.HOST, port, usenetrc=False))
self.addCleanup(self.nntp.__exit__, None, None, None)
def run_server(self, sock): def run_server(self, sock):
# Could be generalized to handle more commands in separate methods # Could be generalized to handle more commands in separate methods

View file

@ -96,9 +96,7 @@ class TestCParser(unittest.TestCase):
self.skipTest("The %r command is not found" % cmd) self.skipTest("The %r command is not found" % cmd)
self.old_cwd = os.getcwd() self.old_cwd = os.getcwd()
self.tmp_path = tempfile.mkdtemp(dir=self.tmp_base) self.tmp_path = tempfile.mkdtemp(dir=self.tmp_base)
change_cwd = os_helper.change_cwd(self.tmp_path) self.enterContext(os_helper.change_cwd(self.tmp_path))
change_cwd.__enter__()
self.addCleanup(change_cwd.__exit__, None, None, None)
def tearDown(self): def tearDown(self):
os.chdir(self.old_cwd) os.chdir(self.old_cwd)

View file

@ -128,8 +128,7 @@ class PollTests(unittest.TestCase):
cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done' cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done'
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
bufsize=0) bufsize=0)
proc.__enter__() self.enterContext(proc)
self.addCleanup(proc.__exit__, None, None, None)
p = proc.stdout p = proc.stdout
pollster = select.poll() pollster = select.poll()
pollster.register( p, select.POLLIN ) pollster.register( p, select.POLLIN )

View file

@ -53,19 +53,13 @@ class PosixTester(unittest.TestCase):
def setUp(self): def setUp(self):
# create empty file # create empty file
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
with open(os_helper.TESTFN, "wb"): with open(os_helper.TESTFN, "wb"):
pass pass
self.teardown_files = [ os_helper.TESTFN ] self.enterContext(warnings_helper.check_warnings())
self._warnings_manager = warnings_helper.check_warnings()
self._warnings_manager.__enter__()
warnings.filterwarnings('ignore', '.* potential security risk .*', warnings.filterwarnings('ignore', '.* potential security risk .*',
RuntimeWarning) RuntimeWarning)
def tearDown(self):
for teardown_file in self.teardown_files:
os_helper.unlink(teardown_file)
self._warnings_manager.__exit__(None, None, None)
def testNoArgFunctions(self): def testNoArgFunctions(self):
# test posix functions which take no arguments and have # test posix functions which take no arguments and have
# no side-effects which we need to cleanup (e.g., fork, wait, abort) # no side-effects which we need to cleanup (e.g., fork, wait, abort)
@ -973,8 +967,8 @@ class PosixTester(unittest.TestCase):
self.assertTrue(hasattr(testfn_st, 'st_flags')) self.assertTrue(hasattr(testfn_st, 'st_flags'))
self.addCleanup(os_helper.unlink, _DUMMY_SYMLINK)
os.symlink(os_helper.TESTFN, _DUMMY_SYMLINK) os.symlink(os_helper.TESTFN, _DUMMY_SYMLINK)
self.teardown_files.append(_DUMMY_SYMLINK)
dummy_symlink_st = os.lstat(_DUMMY_SYMLINK) dummy_symlink_st = os.lstat(_DUMMY_SYMLINK)
def chflags_nofollow(path, flags): def chflags_nofollow(path, flags):

View file

@ -1022,8 +1022,7 @@ class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
def setUp(self): def setUp(self):
self._warning_filters = warnings_helper.check_warnings() self.enterContext(warnings_helper.check_warnings())
self._warning_filters.__enter__()
warnings.simplefilter('ignore', BytesWarning) warnings.simplefilter('ignore', BytesWarning)
self.case = "string and bytes set" self.case = "string and bytes set"
self.values = ["a", "b", b"a", b"b"] self.values = ["a", "b", b"a", b"b"]
@ -1031,9 +1030,6 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
self.dup = set(self.values) self.dup = set(self.values)
self.length = 4 self.length = 4
def tearDown(self):
self._warning_filters.__exit__(None, None, None)
def test_repr(self): def test_repr(self):
self.check_repr_against_values() self.check_repr_against_values()

View file

@ -338,9 +338,7 @@ class ThreadableTest:
self.server_ready.set() self.server_ready.set()
def _setUp(self): def _setUp(self):
self.wait_threads = threading_helper.wait_threads_exit() self.enterContext(threading_helper.wait_threads_exit())
self.wait_threads.__enter__()
self.addCleanup(self.wait_threads.__exit__, None, None, None)
self.server_ready = threading.Event() self.server_ready = threading.Event()
self.client_ready = threading.Event() self.client_ready = threading.Event()

View file

@ -1999,9 +1999,8 @@ class SimpleBackgroundTests(unittest.TestCase):
self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.server_context.load_cert_chain(SIGNED_CERTFILE) self.server_context.load_cert_chain(SIGNED_CERTFILE)
server = ThreadedEchoServer(context=self.server_context) server = ThreadedEchoServer(context=self.server_context)
self.enterContext(server)
self.server_addr = (HOST, server.port) self.server_addr = (HOST, server.port)
server.__enter__()
self.addCleanup(server.__exit__, None, None, None)
def test_connect(self): def test_connect(self):
with test_wrap_socket(socket.socket(socket.AF_INET), with test_wrap_socket(socket.socket(socket.AF_INET),
@ -3713,8 +3712,7 @@ class ThreadedTests(unittest.TestCase):
def test_recv_zero(self): def test_recv_zero(self):
server = ThreadedEchoServer(CERTFILE) server = ThreadedEchoServer(CERTFILE)
server.__enter__() self.enterContext(server)
self.addCleanup(server.__exit__, None, None)
s = socket.create_connection((HOST, server.port)) s = socket.create_connection((HOST, server.port))
self.addCleanup(s.close) self.addCleanup(s.close)
s = test_wrap_socket(s, suppress_ragged_eofs=False) s = test_wrap_socket(s, suppress_ragged_eofs=False)

View file

@ -90,14 +90,10 @@ class BaseTestCase(unittest.TestCase):
b_check = re.compile(br"^[a-z0-9_-]{8}$") b_check = re.compile(br"^[a-z0-9_-]{8}$")
def setUp(self): def setUp(self):
self._warnings_manager = warnings_helper.check_warnings() self.enterContext(warnings_helper.check_warnings())
self._warnings_manager.__enter__()
warnings.filterwarnings("ignore", category=RuntimeWarning, warnings.filterwarnings("ignore", category=RuntimeWarning,
message="mktemp", module=__name__) message="mktemp", module=__name__)
def tearDown(self):
self._warnings_manager.__exit__(None, None, None)
def nameCheck(self, name, dir, pre, suf): def nameCheck(self, name, dir, pre, suf):
(ndir, nbase) = os.path.split(name) (ndir, nbase) = os.path.split(name)
npre = nbase[:len(pre)] npre = nbase[:len(pre)]

View file

@ -232,17 +232,12 @@ class ProxyTests(unittest.TestCase):
def setUp(self): def setUp(self):
# Records changes to env vars # Records changes to env vars
self.env = os_helper.EnvironmentVarGuard() self.env = self.enterContext(os_helper.EnvironmentVarGuard())
# Delete all proxy related env vars # Delete all proxy related env vars
for k in list(os.environ): for k in list(os.environ):
if 'proxy' in k.lower(): if 'proxy' in k.lower():
self.env.unset(k) self.env.unset(k)
def tearDown(self):
# Restore all proxy related env vars
self.env.__exit__()
del self.env
def test_getproxies_environment_keep_no_proxies(self): def test_getproxies_environment_keep_no_proxies(self):
self.env.set('NO_PROXY', 'localhost') self.env.set('NO_PROXY', 'localhost')
proxies = urllib.request.getproxies_environment() proxies = urllib.request.getproxies_environment()

View file

@ -49,7 +49,7 @@ __all__ = ['TestResult', 'TestCase', 'IsolatedAsyncioTestCase', 'TestSuite',
'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless', 'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
'expectedFailure', 'TextTestResult', 'installHandler', 'expectedFailure', 'TextTestResult', 'installHandler',
'registerResult', 'removeResult', 'removeHandler', 'registerResult', 'removeResult', 'removeHandler',
'addModuleCleanup', 'doModuleCleanups'] 'addModuleCleanup', 'doModuleCleanups', 'enterModuleContext']
# Expose obsolete functions for backwards compatibility # Expose obsolete functions for backwards compatibility
# bpo-5846: Deprecated in Python 3.11, scheduled for removal in Python 3.13. # bpo-5846: Deprecated in Python 3.11, scheduled for removal in Python 3.13.
@ -59,7 +59,8 @@ __unittest = True
from .result import TestResult from .result import TestResult
from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip, from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip,
skipIf, skipUnless, expectedFailure, doModuleCleanups) skipIf, skipUnless, expectedFailure, doModuleCleanups,
enterModuleContext)
from .suite import BaseTestSuite, TestSuite from .suite import BaseTestSuite, TestSuite
from .loader import TestLoader, defaultTestLoader from .loader import TestLoader, defaultTestLoader
from .main import TestProgram, main from .main import TestProgram, main

View file

@ -58,6 +58,26 @@ class IsolatedAsyncioTestCase(TestCase):
# 3. Regular "def func()" that returns awaitable object # 3. Regular "def func()" that returns awaitable object
self.addCleanup(*(func, *args), **kwargs) self.addCleanup(*(func, *args), **kwargs)
async def enterAsyncContext(self, cm):
"""Enters the supplied asynchronous context manager.
If successful, also adds its __aexit__ method as a cleanup
function and returns the result of the __aenter__ method.
"""
# We look up the special methods on the type to match the with
# statement.
cls = type(cm)
try:
enter = cls.__aenter__
exit = cls.__aexit__
except AttributeError:
raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
f"not support the asynchronous context manager protocol"
) from None
result = await enter(cm)
self.addAsyncCleanup(exit, cm, None, None, None)
return result
def _callSetUp(self): def _callSetUp(self):
self._asyncioTestContext.run(self.setUp) self._asyncioTestContext.run(self.setUp)
self._callAsync(self.asyncSetUp) self._callAsync(self.asyncSetUp)

View file

@ -102,12 +102,31 @@ def _id(obj):
return obj return obj
def _enter_context(cm, addcleanup):
# We look up the special methods on the type to match the with
# statement.
cls = type(cm)
try:
enter = cls.__enter__
exit = cls.__exit__
except AttributeError:
raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
f"not support the context manager protocol") from None
result = enter(cm)
addcleanup(exit, cm, None, None, None)
return result
_module_cleanups = [] _module_cleanups = []
def addModuleCleanup(function, /, *args, **kwargs): def addModuleCleanup(function, /, *args, **kwargs):
"""Same as addCleanup, except the cleanup items are called even if """Same as addCleanup, except the cleanup items are called even if
setUpModule fails (unlike tearDownModule).""" setUpModule fails (unlike tearDownModule)."""
_module_cleanups.append((function, args, kwargs)) _module_cleanups.append((function, args, kwargs))
def enterModuleContext(cm):
"""Same as enterContext, but module-wide."""
return _enter_context(cm, addModuleCleanup)
def doModuleCleanups(): def doModuleCleanups():
"""Execute all module cleanup functions. Normally called for you after """Execute all module cleanup functions. Normally called for you after
@ -426,12 +445,25 @@ class TestCase(object):
Cleanup items are called even if setUp fails (unlike tearDown).""" Cleanup items are called even if setUp fails (unlike tearDown)."""
self._cleanups.append((function, args, kwargs)) self._cleanups.append((function, args, kwargs))
def enterContext(self, cm):
"""Enters the supplied context manager.
If successful, also adds its __exit__ method as a cleanup
function and returns the result of the __enter__ method.
"""
return _enter_context(cm, self.addCleanup)
@classmethod @classmethod
def addClassCleanup(cls, function, /, *args, **kwargs): def addClassCleanup(cls, function, /, *args, **kwargs):
"""Same as addCleanup, except the cleanup items are called even if """Same as addCleanup, except the cleanup items are called even if
setUpClass fails (unlike tearDownClass).""" setUpClass fails (unlike tearDownClass)."""
cls._class_cleanups.append((function, args, kwargs)) cls._class_cleanups.append((function, args, kwargs))
@classmethod
def enterClassContext(cls, cm):
"""Same as enterContext, but class-wide."""
return _enter_context(cm, cls.addClassCleanup)
def setUp(self): def setUp(self):
"Hook method for setting up the test fixture before exercising it." "Hook method for setting up the test fixture before exercising it."
pass pass

View file

@ -14,6 +14,29 @@ def tearDownModule():
asyncio.set_event_loop_policy(None) asyncio.set_event_loop_policy(None)
class TestCM:
def __init__(self, ordering, enter_result=None):
self.ordering = ordering
self.enter_result = enter_result
async def __aenter__(self):
self.ordering.append('enter')
return self.enter_result
async def __aexit__(self, *exc_info):
self.ordering.append('exit')
class LacksEnterAndExit:
pass
class LacksEnter:
async def __aexit__(self, *exc_info):
pass
class LacksExit:
async def __aenter__(self):
pass
VAR = contextvars.ContextVar('VAR', default=()) VAR = contextvars.ContextVar('VAR', default=())
@ -337,6 +360,36 @@ class TestAsyncCase(unittest.TestCase):
output = test.run() output = test.run()
self.assertTrue(cancelled) self.assertTrue(cancelled)
def test_enterAsyncContext(self):
events = []
class Test(unittest.IsolatedAsyncioTestCase):
async def test_func(slf):
slf.addAsyncCleanup(events.append, 'cleanup1')
cm = TestCM(events, 42)
self.assertEqual(await slf.enterAsyncContext(cm), 42)
slf.addAsyncCleanup(events.append, 'cleanup2')
events.append('test')
test = Test('test_func')
output = test.run()
self.assertTrue(output.wasSuccessful(), output)
self.assertEqual(events, ['enter', 'test', 'cleanup2', 'exit', 'cleanup1'])
def test_enterAsyncContext_arg_errors(self):
class Test(unittest.IsolatedAsyncioTestCase):
async def test_func(slf):
with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
await slf.enterAsyncContext(LacksEnterAndExit())
with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
await slf.enterAsyncContext(LacksEnter())
with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
await slf.enterAsyncContext(LacksExit())
test = Test('test_func')
output = test.run()
self.assertTrue(output.wasSuccessful())
def test_debug_cleanup_same_loop(self): def test_debug_cleanup_same_loop(self):
class Test(unittest.IsolatedAsyncioTestCase): class Test(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):

View file

@ -46,6 +46,29 @@ def cleanup(ordering, blowUp=False):
raise Exception('CleanUpExc') raise Exception('CleanUpExc')
class TestCM:
def __init__(self, ordering, enter_result=None):
self.ordering = ordering
self.enter_result = enter_result
def __enter__(self):
self.ordering.append('enter')
return self.enter_result
def __exit__(self, *exc_info):
self.ordering.append('exit')
class LacksEnterAndExit:
pass
class LacksEnter:
def __exit__(self, *exc_info):
pass
class LacksExit:
def __enter__(self):
pass
class TestCleanUp(unittest.TestCase): class TestCleanUp(unittest.TestCase):
def testCleanUp(self): def testCleanUp(self):
class TestableTest(unittest.TestCase): class TestableTest(unittest.TestCase):
@ -173,6 +196,39 @@ class TestCleanUp(unittest.TestCase):
self.assertEqual(ordering, ['setUp', 'test', 'tearDown', 'cleanup1', 'cleanup2']) self.assertEqual(ordering, ['setUp', 'test', 'tearDown', 'cleanup1', 'cleanup2'])
def test_enterContext(self):
class TestableTest(unittest.TestCase):
def testNothing(self):
pass
test = TestableTest('testNothing')
cleanups = []
test.addCleanup(cleanups.append, 'cleanup1')
cm = TestCM(cleanups, 42)
self.assertEqual(test.enterContext(cm), 42)
test.addCleanup(cleanups.append, 'cleanup2')
self.assertTrue(test.doCleanups())
self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
def test_enterContext_arg_errors(self):
class TestableTest(unittest.TestCase):
def testNothing(self):
pass
test = TestableTest('testNothing')
with self.assertRaisesRegex(TypeError, 'the context manager'):
test.enterContext(LacksEnterAndExit())
with self.assertRaisesRegex(TypeError, 'the context manager'):
test.enterContext(LacksEnter())
with self.assertRaisesRegex(TypeError, 'the context manager'):
test.enterContext(LacksExit())
self.assertEqual(test._cleanups, [])
class TestClassCleanup(unittest.TestCase): class TestClassCleanup(unittest.TestCase):
def test_addClassCleanUp(self): def test_addClassCleanUp(self):
class TestableTest(unittest.TestCase): class TestableTest(unittest.TestCase):
@ -451,6 +507,35 @@ class TestClassCleanup(unittest.TestCase):
self.assertEqual(ordering, self.assertEqual(ordering,
['setUpClass', 'test', 'tearDownClass', 'cleanup_good']) ['setUpClass', 'test', 'tearDownClass', 'cleanup_good'])
def test_enterClassContext(self):
class TestableTest(unittest.TestCase):
def testNothing(self):
pass
cleanups = []
TestableTest.addClassCleanup(cleanups.append, 'cleanup1')
cm = TestCM(cleanups, 42)
self.assertEqual(TestableTest.enterClassContext(cm), 42)
TestableTest.addClassCleanup(cleanups.append, 'cleanup2')
TestableTest.doClassCleanups()
self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
def test_enterClassContext_arg_errors(self):
class TestableTest(unittest.TestCase):
def testNothing(self):
pass
with self.assertRaisesRegex(TypeError, 'the context manager'):
TestableTest.enterClassContext(LacksEnterAndExit())
with self.assertRaisesRegex(TypeError, 'the context manager'):
TestableTest.enterClassContext(LacksEnter())
with self.assertRaisesRegex(TypeError, 'the context manager'):
TestableTest.enterClassContext(LacksExit())
self.assertEqual(TestableTest._class_cleanups, [])
class TestModuleCleanUp(unittest.TestCase): class TestModuleCleanUp(unittest.TestCase):
def test_add_and_do_ModuleCleanup(self): def test_add_and_do_ModuleCleanup(self):
@ -1000,6 +1085,31 @@ class TestModuleCleanUp(unittest.TestCase):
'cleanup2', 'setUp2', 'test2', 'tearDown2', 'cleanup2', 'setUp2', 'test2', 'tearDown2',
'cleanup3', 'tearDownModule', 'cleanup1']) 'cleanup3', 'tearDownModule', 'cleanup1'])
def test_enterModuleContext(self):
cleanups = []
unittest.addModuleCleanup(cleanups.append, 'cleanup1')
cm = TestCM(cleanups, 42)
self.assertEqual(unittest.enterModuleContext(cm), 42)
unittest.addModuleCleanup(cleanups.append, 'cleanup2')
unittest.case.doModuleCleanups()
self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
def test_enterModuleContext_arg_errors(self):
class TestableTest(unittest.TestCase):
def testNothing(self):
pass
with self.assertRaisesRegex(TypeError, 'the context manager'):
unittest.enterModuleContext(LacksEnterAndExit())
with self.assertRaisesRegex(TypeError, 'the context manager'):
unittest.enterModuleContext(LacksEnter())
with self.assertRaisesRegex(TypeError, 'the context manager'):
unittest.enterModuleContext(LacksExit())
self.assertEqual(unittest.case._module_cleanups, [])
class Test_TextTestRunner(unittest.TestCase): class Test_TextTestRunner(unittest.TestCase):
"""Tests for TextTestRunner.""" """Tests for TextTestRunner."""

View file

@ -0,0 +1,7 @@
Add support of context managers in :mod:`unittest`: methods
:meth:`~unittest.TestCase.enterContext` and
:meth:`~unittest.TestCase.enterClassContext` of class
:class:`~unittest.TestCase`, method
:meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of class
:class:`~unittest.IsolatedAsyncioTestCase` and function
:func:`unittest.enterModuleContext`.