mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654)
* Work correctly if an additional fresh module imports other additional fresh module which imports a blocked module. * Raises ImportError if the specified module cannot be imported while all additional fresh modules are successfully imported. * Support blocking packages. * Always restore the import state of fresh and blocked modules and their submodules. * Fix test_decimal and test_xml_etree which depended on an undesired side effect of import_fresh_module().
This commit is contained in:
parent
b07fddd527
commit
ec4d917a6a
4 changed files with 32 additions and 52 deletions
|
@ -81,33 +81,13 @@ def import_module(name, deprecated=False, *, required_on=()):
|
|||
raise unittest.SkipTest(str(msg))
|
||||
|
||||
|
||||
def _save_and_remove_module(name, orig_modules):
|
||||
"""Helper function to save and remove a module from sys.modules
|
||||
|
||||
Raise ImportError if the module can't be imported.
|
||||
"""
|
||||
# try to import the module and raise an error if it can't be imported
|
||||
if name not in sys.modules:
|
||||
__import__(name)
|
||||
del sys.modules[name]
|
||||
def _save_and_remove_modules(names):
|
||||
orig_modules = {}
|
||||
prefixes = tuple(name + '.' for name in names)
|
||||
for modname in list(sys.modules):
|
||||
if modname == name or modname.startswith(name + '.'):
|
||||
orig_modules[modname] = sys.modules[modname]
|
||||
del sys.modules[modname]
|
||||
|
||||
|
||||
def _save_and_block_module(name, orig_modules):
|
||||
"""Helper function to save and block a module in sys.modules
|
||||
|
||||
Return True if the module was in sys.modules, False otherwise.
|
||||
"""
|
||||
saved = True
|
||||
try:
|
||||
orig_modules[name] = sys.modules[name]
|
||||
except KeyError:
|
||||
saved = False
|
||||
sys.modules[name] = None
|
||||
return saved
|
||||
if modname in names or modname.startswith(prefixes):
|
||||
orig_modules[modname] = sys.modules.pop(modname)
|
||||
return orig_modules
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -136,7 +116,8 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
|
|||
this operation.
|
||||
|
||||
*fresh* is an iterable of additional module names that are also removed
|
||||
from the sys.modules cache before doing the import.
|
||||
from the sys.modules cache before doing the import. If one of these
|
||||
modules can't be imported, None is returned.
|
||||
|
||||
*blocked* is an iterable of module names that are replaced with None
|
||||
in the module cache during the import to ensure that attempts to import
|
||||
|
@ -160,25 +141,25 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
|
|||
with _ignore_deprecated_imports(deprecated):
|
||||
# Keep track of modules saved for later restoration as well
|
||||
# as those which just need a blocking entry removed
|
||||
orig_modules = {}
|
||||
names_to_remove = []
|
||||
_save_and_remove_module(name, orig_modules)
|
||||
fresh = list(fresh)
|
||||
blocked = list(blocked)
|
||||
names = {name, *fresh, *blocked}
|
||||
orig_modules = _save_and_remove_modules(names)
|
||||
for modname in blocked:
|
||||
sys.modules[modname] = None
|
||||
|
||||
try:
|
||||
for fresh_name in fresh:
|
||||
_save_and_remove_module(fresh_name, orig_modules)
|
||||
for blocked_name in blocked:
|
||||
if not _save_and_block_module(blocked_name, orig_modules):
|
||||
names_to_remove.append(blocked_name)
|
||||
with frozen_modules(usefrozen):
|
||||
fresh_module = importlib.import_module(name)
|
||||
# Return None when one of the "fresh" modules can not be imported.
|
||||
try:
|
||||
for modname in fresh:
|
||||
__import__(modname)
|
||||
except ImportError:
|
||||
fresh_module = None
|
||||
return None
|
||||
return importlib.import_module(name)
|
||||
finally:
|
||||
for orig_name, module in orig_modules.items():
|
||||
sys.modules[orig_name] = module
|
||||
for name_to_remove in names_to_remove:
|
||||
del sys.modules[name_to_remove]
|
||||
return fresh_module
|
||||
_save_and_remove_modules(names)
|
||||
sys.modules.update(orig_modules)
|
||||
|
||||
|
||||
class CleanImport(object):
|
||||
|
|
|
@ -62,7 +62,7 @@ if sys.platform == 'darwin':
|
|||
|
||||
C = import_fresh_module('decimal', fresh=['_decimal'])
|
||||
P = import_fresh_module('decimal', blocked=['_decimal'])
|
||||
orig_sys_decimal = sys.modules['decimal']
|
||||
import decimal as orig_sys_decimal
|
||||
|
||||
# fractions module must import the correct decimal module.
|
||||
cfractions = import_fresh_module('fractions', fresh=['fractions'])
|
||||
|
|
|
@ -26,7 +26,7 @@ from itertools import product, islice
|
|||
from test import support
|
||||
from test.support import os_helper
|
||||
from test.support import warnings_helper
|
||||
from test.support import findfile, gc_collect, swap_attr
|
||||
from test.support import findfile, gc_collect, swap_attr, swap_item
|
||||
from test.support.import_helper import import_fresh_module
|
||||
from test.support.os_helper import TESTFN
|
||||
|
||||
|
@ -167,11 +167,10 @@ class ElementTestCase:
|
|||
cls.modules = {pyET, ET}
|
||||
|
||||
def pickleRoundTrip(self, obj, name, dumper, loader, proto):
|
||||
save_m = sys.modules[name]
|
||||
try:
|
||||
sys.modules[name] = dumper
|
||||
with swap_item(sys.modules, name, dumper):
|
||||
temp = pickle.dumps(obj, proto)
|
||||
sys.modules[name] = loader
|
||||
with swap_item(sys.modules, name, loader):
|
||||
result = pickle.loads(temp)
|
||||
except pickle.PicklingError as pe:
|
||||
# pyET must be second, because pyET may be (equal to) ET.
|
||||
|
@ -180,8 +179,6 @@ class ElementTestCase:
|
|||
% (obj,
|
||||
human.get(dumper, dumper),
|
||||
human.get(loader, loader))) from pe
|
||||
finally:
|
||||
sys.modules[name] = save_m
|
||||
return result
|
||||
|
||||
def assertEqualElements(self, alice, bob):
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Fix :func:`test.support.import_helper.import_fresh_module`.
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue