bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654)

* Work correctly if an additional fresh module imports other
  additional fresh module which imports a blocked module.
* Raises ImportError if the specified module cannot be imported
  while all additional fresh modules are successfully imported.
* Support blocking packages.
* Always restore the import state of fresh and blocked modules
  and their submodules.
* Fix test_decimal and test_xml_etree which depended on an undesired
  side effect of import_fresh_module().
This commit is contained in:
Serhiy Storchaka 2021-09-30 19:20:39 +03:00 committed by GitHub
parent b07fddd527
commit ec4d917a6a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 52 deletions

View file

@ -81,33 +81,13 @@ def import_module(name, deprecated=False, *, required_on=()):
raise unittest.SkipTest(str(msg))
def _save_and_remove_module(name, orig_modules):
"""Helper function to save and remove a module from sys.modules
Raise ImportError if the module can't be imported.
"""
# try to import the module and raise an error if it can't be imported
if name not in sys.modules:
__import__(name)
del sys.modules[name]
def _save_and_remove_modules(names):
orig_modules = {}
prefixes = tuple(name + '.' for name in names)
for modname in list(sys.modules):
if modname == name or modname.startswith(name + '.'):
orig_modules[modname] = sys.modules[modname]
del sys.modules[modname]
def _save_and_block_module(name, orig_modules):
"""Helper function to save and block a module in sys.modules
Return True if the module was in sys.modules, False otherwise.
"""
saved = True
try:
orig_modules[name] = sys.modules[name]
except KeyError:
saved = False
sys.modules[name] = None
return saved
if modname in names or modname.startswith(prefixes):
orig_modules[modname] = sys.modules.pop(modname)
return orig_modules
@contextlib.contextmanager
@ -136,7 +116,8 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
this operation.
*fresh* is an iterable of additional module names that are also removed
from the sys.modules cache before doing the import.
from the sys.modules cache before doing the import. If one of these
modules can't be imported, None is returned.
*blocked* is an iterable of module names that are replaced with None
in the module cache during the import to ensure that attempts to import
@ -160,25 +141,25 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
with _ignore_deprecated_imports(deprecated):
# Keep track of modules saved for later restoration as well
# as those which just need a blocking entry removed
orig_modules = {}
names_to_remove = []
_save_and_remove_module(name, orig_modules)
fresh = list(fresh)
blocked = list(blocked)
names = {name, *fresh, *blocked}
orig_modules = _save_and_remove_modules(names)
for modname in blocked:
sys.modules[modname] = None
try:
for fresh_name in fresh:
_save_and_remove_module(fresh_name, orig_modules)
for blocked_name in blocked:
if not _save_and_block_module(blocked_name, orig_modules):
names_to_remove.append(blocked_name)
with frozen_modules(usefrozen):
fresh_module = importlib.import_module(name)
# Return None when one of the "fresh" modules can not be imported.
try:
for modname in fresh:
__import__(modname)
except ImportError:
fresh_module = None
return None
return importlib.import_module(name)
finally:
for orig_name, module in orig_modules.items():
sys.modules[orig_name] = module
for name_to_remove in names_to_remove:
del sys.modules[name_to_remove]
return fresh_module
_save_and_remove_modules(names)
sys.modules.update(orig_modules)
class CleanImport(object):

View file

@ -62,7 +62,7 @@ if sys.platform == 'darwin':
C = import_fresh_module('decimal', fresh=['_decimal'])
P = import_fresh_module('decimal', blocked=['_decimal'])
orig_sys_decimal = sys.modules['decimal']
import decimal as orig_sys_decimal
# fractions module must import the correct decimal module.
cfractions = import_fresh_module('fractions', fresh=['fractions'])

View file

@ -26,7 +26,7 @@ from itertools import product, islice
from test import support
from test.support import os_helper
from test.support import warnings_helper
from test.support import findfile, gc_collect, swap_attr
from test.support import findfile, gc_collect, swap_attr, swap_item
from test.support.import_helper import import_fresh_module
from test.support.os_helper import TESTFN
@ -167,11 +167,10 @@ class ElementTestCase:
cls.modules = {pyET, ET}
def pickleRoundTrip(self, obj, name, dumper, loader, proto):
save_m = sys.modules[name]
try:
sys.modules[name] = dumper
with swap_item(sys.modules, name, dumper):
temp = pickle.dumps(obj, proto)
sys.modules[name] = loader
with swap_item(sys.modules, name, loader):
result = pickle.loads(temp)
except pickle.PicklingError as pe:
# pyET must be second, because pyET may be (equal to) ET.
@ -180,8 +179,6 @@ class ElementTestCase:
% (obj,
human.get(dumper, dumper),
human.get(loader, loader))) from pe
finally:
sys.modules[name] = save_m
return result
def assertEqualElements(self, alice, bob):

View file

@ -0,0 +1,2 @@
Fix :func:`test.support.import_helper.import_fresh_module`.