mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654)
* Work correctly if an additional fresh module imports other additional fresh module which imports a blocked module. * Raises ImportError if the specified module cannot be imported while all additional fresh modules are successfully imported. * Support blocking packages. * Always restore the import state of fresh and blocked modules and their submodules. * Fix test_decimal and test_xml_etree which depended on an undesired side effect of import_fresh_module().
This commit is contained in:
parent
b07fddd527
commit
ec4d917a6a
4 changed files with 32 additions and 52 deletions
|
@ -81,33 +81,13 @@ def import_module(name, deprecated=False, *, required_on=()):
|
||||||
raise unittest.SkipTest(str(msg))
|
raise unittest.SkipTest(str(msg))
|
||||||
|
|
||||||
|
|
||||||
def _save_and_remove_module(name, orig_modules):
|
def _save_and_remove_modules(names):
|
||||||
"""Helper function to save and remove a module from sys.modules
|
orig_modules = {}
|
||||||
|
prefixes = tuple(name + '.' for name in names)
|
||||||
Raise ImportError if the module can't be imported.
|
|
||||||
"""
|
|
||||||
# try to import the module and raise an error if it can't be imported
|
|
||||||
if name not in sys.modules:
|
|
||||||
__import__(name)
|
|
||||||
del sys.modules[name]
|
|
||||||
for modname in list(sys.modules):
|
for modname in list(sys.modules):
|
||||||
if modname == name or modname.startswith(name + '.'):
|
if modname in names or modname.startswith(prefixes):
|
||||||
orig_modules[modname] = sys.modules[modname]
|
orig_modules[modname] = sys.modules.pop(modname)
|
||||||
del sys.modules[modname]
|
return orig_modules
|
||||||
|
|
||||||
|
|
||||||
def _save_and_block_module(name, orig_modules):
|
|
||||||
"""Helper function to save and block a module in sys.modules
|
|
||||||
|
|
||||||
Return True if the module was in sys.modules, False otherwise.
|
|
||||||
"""
|
|
||||||
saved = True
|
|
||||||
try:
|
|
||||||
orig_modules[name] = sys.modules[name]
|
|
||||||
except KeyError:
|
|
||||||
saved = False
|
|
||||||
sys.modules[name] = None
|
|
||||||
return saved
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
@ -136,7 +116,8 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
|
||||||
this operation.
|
this operation.
|
||||||
|
|
||||||
*fresh* is an iterable of additional module names that are also removed
|
*fresh* is an iterable of additional module names that are also removed
|
||||||
from the sys.modules cache before doing the import.
|
from the sys.modules cache before doing the import. If one of these
|
||||||
|
modules can't be imported, None is returned.
|
||||||
|
|
||||||
*blocked* is an iterable of module names that are replaced with None
|
*blocked* is an iterable of module names that are replaced with None
|
||||||
in the module cache during the import to ensure that attempts to import
|
in the module cache during the import to ensure that attempts to import
|
||||||
|
@ -160,25 +141,25 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
|
||||||
with _ignore_deprecated_imports(deprecated):
|
with _ignore_deprecated_imports(deprecated):
|
||||||
# Keep track of modules saved for later restoration as well
|
# Keep track of modules saved for later restoration as well
|
||||||
# as those which just need a blocking entry removed
|
# as those which just need a blocking entry removed
|
||||||
orig_modules = {}
|
fresh = list(fresh)
|
||||||
names_to_remove = []
|
blocked = list(blocked)
|
||||||
_save_and_remove_module(name, orig_modules)
|
names = {name, *fresh, *blocked}
|
||||||
|
orig_modules = _save_and_remove_modules(names)
|
||||||
|
for modname in blocked:
|
||||||
|
sys.modules[modname] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for fresh_name in fresh:
|
|
||||||
_save_and_remove_module(fresh_name, orig_modules)
|
|
||||||
for blocked_name in blocked:
|
|
||||||
if not _save_and_block_module(blocked_name, orig_modules):
|
|
||||||
names_to_remove.append(blocked_name)
|
|
||||||
with frozen_modules(usefrozen):
|
with frozen_modules(usefrozen):
|
||||||
fresh_module = importlib.import_module(name)
|
# Return None when one of the "fresh" modules can not be imported.
|
||||||
except ImportError:
|
try:
|
||||||
fresh_module = None
|
for modname in fresh:
|
||||||
|
__import__(modname)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return importlib.import_module(name)
|
||||||
finally:
|
finally:
|
||||||
for orig_name, module in orig_modules.items():
|
_save_and_remove_modules(names)
|
||||||
sys.modules[orig_name] = module
|
sys.modules.update(orig_modules)
|
||||||
for name_to_remove in names_to_remove:
|
|
||||||
del sys.modules[name_to_remove]
|
|
||||||
return fresh_module
|
|
||||||
|
|
||||||
|
|
||||||
class CleanImport(object):
|
class CleanImport(object):
|
||||||
|
|
|
@ -62,7 +62,7 @@ if sys.platform == 'darwin':
|
||||||
|
|
||||||
C = import_fresh_module('decimal', fresh=['_decimal'])
|
C = import_fresh_module('decimal', fresh=['_decimal'])
|
||||||
P = import_fresh_module('decimal', blocked=['_decimal'])
|
P = import_fresh_module('decimal', blocked=['_decimal'])
|
||||||
orig_sys_decimal = sys.modules['decimal']
|
import decimal as orig_sys_decimal
|
||||||
|
|
||||||
# fractions module must import the correct decimal module.
|
# fractions module must import the correct decimal module.
|
||||||
cfractions = import_fresh_module('fractions', fresh=['fractions'])
|
cfractions = import_fresh_module('fractions', fresh=['fractions'])
|
||||||
|
|
|
@ -26,7 +26,7 @@ from itertools import product, islice
|
||||||
from test import support
|
from test import support
|
||||||
from test.support import os_helper
|
from test.support import os_helper
|
||||||
from test.support import warnings_helper
|
from test.support import warnings_helper
|
||||||
from test.support import findfile, gc_collect, swap_attr
|
from test.support import findfile, gc_collect, swap_attr, swap_item
|
||||||
from test.support.import_helper import import_fresh_module
|
from test.support.import_helper import import_fresh_module
|
||||||
from test.support.os_helper import TESTFN
|
from test.support.os_helper import TESTFN
|
||||||
|
|
||||||
|
@ -167,12 +167,11 @@ class ElementTestCase:
|
||||||
cls.modules = {pyET, ET}
|
cls.modules = {pyET, ET}
|
||||||
|
|
||||||
def pickleRoundTrip(self, obj, name, dumper, loader, proto):
|
def pickleRoundTrip(self, obj, name, dumper, loader, proto):
|
||||||
save_m = sys.modules[name]
|
|
||||||
try:
|
try:
|
||||||
sys.modules[name] = dumper
|
with swap_item(sys.modules, name, dumper):
|
||||||
temp = pickle.dumps(obj, proto)
|
temp = pickle.dumps(obj, proto)
|
||||||
sys.modules[name] = loader
|
with swap_item(sys.modules, name, loader):
|
||||||
result = pickle.loads(temp)
|
result = pickle.loads(temp)
|
||||||
except pickle.PicklingError as pe:
|
except pickle.PicklingError as pe:
|
||||||
# pyET must be second, because pyET may be (equal to) ET.
|
# pyET must be second, because pyET may be (equal to) ET.
|
||||||
human = dict([(ET, "cET"), (pyET, "pyET")])
|
human = dict([(ET, "cET"), (pyET, "pyET")])
|
||||||
|
@ -180,8 +179,6 @@ class ElementTestCase:
|
||||||
% (obj,
|
% (obj,
|
||||||
human.get(dumper, dumper),
|
human.get(dumper, dumper),
|
||||||
human.get(loader, loader))) from pe
|
human.get(loader, loader))) from pe
|
||||||
finally:
|
|
||||||
sys.modules[name] = save_m
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def assertEqualElements(self, alice, bob):
|
def assertEqualElements(self, alice, bob):
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
Fix :func:`test.support.import_helper.import_fresh_module`.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue