bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654)

* Work correctly if an additional fresh module imports other
  additional fresh module which imports a blocked module.
* Raises ImportError if the specified module cannot be imported
  while all additional fresh modules are successfully imported.
* Support blocking packages.
* Always restore the import state of fresh and blocked modules
  and their submodules.
* Fix test_decimal and test_xml_etree which depended on an undesired
  side effect of import_fresh_module().
This commit is contained in:
Serhiy Storchaka 2021-09-30 19:20:39 +03:00 committed by GitHub
parent b07fddd527
commit ec4d917a6a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 52 deletions

View file

@ -81,33 +81,13 @@ def import_module(name, deprecated=False, *, required_on=()):
raise unittest.SkipTest(str(msg)) raise unittest.SkipTest(str(msg))
def _save_and_remove_module(name, orig_modules): def _save_and_remove_modules(names):
"""Helper function to save and remove a module from sys.modules orig_modules = {}
prefixes = tuple(name + '.' for name in names)
Raise ImportError if the module can't be imported.
"""
# try to import the module and raise an error if it can't be imported
if name not in sys.modules:
__import__(name)
del sys.modules[name]
for modname in list(sys.modules): for modname in list(sys.modules):
if modname == name or modname.startswith(name + '.'): if modname in names or modname.startswith(prefixes):
orig_modules[modname] = sys.modules[modname] orig_modules[modname] = sys.modules.pop(modname)
del sys.modules[modname] return orig_modules
def _save_and_block_module(name, orig_modules):
"""Helper function to save and block a module in sys.modules
Return True if the module was in sys.modules, False otherwise.
"""
saved = True
try:
orig_modules[name] = sys.modules[name]
except KeyError:
saved = False
sys.modules[name] = None
return saved
@contextlib.contextmanager @contextlib.contextmanager
@ -136,7 +116,8 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
this operation. this operation.
*fresh* is an iterable of additional module names that are also removed *fresh* is an iterable of additional module names that are also removed
from the sys.modules cache before doing the import. from the sys.modules cache before doing the import. If one of these
modules can't be imported, None is returned.
*blocked* is an iterable of module names that are replaced with None *blocked* is an iterable of module names that are replaced with None
in the module cache during the import to ensure that attempts to import in the module cache during the import to ensure that attempts to import
@ -160,25 +141,25 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
with _ignore_deprecated_imports(deprecated): with _ignore_deprecated_imports(deprecated):
# Keep track of modules saved for later restoration as well # Keep track of modules saved for later restoration as well
# as those which just need a blocking entry removed # as those which just need a blocking entry removed
orig_modules = {} fresh = list(fresh)
names_to_remove = [] blocked = list(blocked)
_save_and_remove_module(name, orig_modules) names = {name, *fresh, *blocked}
orig_modules = _save_and_remove_modules(names)
for modname in blocked:
sys.modules[modname] = None
try: try:
for fresh_name in fresh:
_save_and_remove_module(fresh_name, orig_modules)
for blocked_name in blocked:
if not _save_and_block_module(blocked_name, orig_modules):
names_to_remove.append(blocked_name)
with frozen_modules(usefrozen): with frozen_modules(usefrozen):
fresh_module = importlib.import_module(name) # Return None when one of the "fresh" modules can not be imported.
except ImportError: try:
fresh_module = None for modname in fresh:
__import__(modname)
except ImportError:
return None
return importlib.import_module(name)
finally: finally:
for orig_name, module in orig_modules.items(): _save_and_remove_modules(names)
sys.modules[orig_name] = module sys.modules.update(orig_modules)
for name_to_remove in names_to_remove:
del sys.modules[name_to_remove]
return fresh_module
class CleanImport(object): class CleanImport(object):

View file

@ -62,7 +62,7 @@ if sys.platform == 'darwin':
C = import_fresh_module('decimal', fresh=['_decimal']) C = import_fresh_module('decimal', fresh=['_decimal'])
P = import_fresh_module('decimal', blocked=['_decimal']) P = import_fresh_module('decimal', blocked=['_decimal'])
orig_sys_decimal = sys.modules['decimal'] import decimal as orig_sys_decimal
# fractions module must import the correct decimal module. # fractions module must import the correct decimal module.
cfractions = import_fresh_module('fractions', fresh=['fractions']) cfractions = import_fresh_module('fractions', fresh=['fractions'])

View file

@ -26,7 +26,7 @@ from itertools import product, islice
from test import support from test import support
from test.support import os_helper from test.support import os_helper
from test.support import warnings_helper from test.support import warnings_helper
from test.support import findfile, gc_collect, swap_attr from test.support import findfile, gc_collect, swap_attr, swap_item
from test.support.import_helper import import_fresh_module from test.support.import_helper import import_fresh_module
from test.support.os_helper import TESTFN from test.support.os_helper import TESTFN
@ -167,12 +167,11 @@ class ElementTestCase:
cls.modules = {pyET, ET} cls.modules = {pyET, ET}
def pickleRoundTrip(self, obj, name, dumper, loader, proto): def pickleRoundTrip(self, obj, name, dumper, loader, proto):
save_m = sys.modules[name]
try: try:
sys.modules[name] = dumper with swap_item(sys.modules, name, dumper):
temp = pickle.dumps(obj, proto) temp = pickle.dumps(obj, proto)
sys.modules[name] = loader with swap_item(sys.modules, name, loader):
result = pickle.loads(temp) result = pickle.loads(temp)
except pickle.PicklingError as pe: except pickle.PicklingError as pe:
# pyET must be second, because pyET may be (equal to) ET. # pyET must be second, because pyET may be (equal to) ET.
human = dict([(ET, "cET"), (pyET, "pyET")]) human = dict([(ET, "cET"), (pyET, "pyET")])
@ -180,8 +179,6 @@ class ElementTestCase:
% (obj, % (obj,
human.get(dumper, dumper), human.get(dumper, dumper),
human.get(loader, loader))) from pe human.get(loader, loader))) from pe
finally:
sys.modules[name] = save_m
return result return result
def assertEqualElements(self, alice, bob): def assertEqualElements(self, alice, bob):

View file

@ -0,0 +1,2 @@
Fix :func:`test.support.import_helper.import_fresh_module`.