mirror of
https://github.com/python/cpython.git
synced 2025-12-23 09:19:18 +00:00
[3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657)
* Work correctly if an additional fresh module imports other
additional fresh module which imports a blocked module.
* Raises ImportError if the specified module cannot be imported
while all additional fresh modules are successfully imported.
* Support blocking packages.
* Always restore the import state of fresh and blocked modules
and their submodules.
* Fix test_decimal and test_xml_etree which depended on an undesired
side effect of import_fresh_module().
(cherry picked from commit ec4d917a6a)
This commit is contained in:
parent
80285ecc8d
commit
7873884d47
4 changed files with 32 additions and 52 deletions
|
|
@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()):
|
|||
raise unittest.SkipTest(str(msg))
|
||||
|
||||
|
||||
def _save_and_remove_module(name, orig_modules):
|
||||
"""Helper function to save and remove a module from sys.modules
|
||||
|
||||
Raise ImportError if the module can't be imported.
|
||||
"""
|
||||
# try to import the module and raise an error if it can't be imported
|
||||
if name not in sys.modules:
|
||||
__import__(name)
|
||||
del sys.modules[name]
|
||||
def _save_and_remove_modules(names):
|
||||
orig_modules = {}
|
||||
prefixes = tuple(name + '.' for name in names)
|
||||
for modname in list(sys.modules):
|
||||
if modname == name or modname.startswith(name + '.'):
|
||||
orig_modules[modname] = sys.modules[modname]
|
||||
del sys.modules[modname]
|
||||
|
||||
|
||||
def _save_and_block_module(name, orig_modules):
|
||||
"""Helper function to save and block a module in sys.modules
|
||||
|
||||
Return True if the module was in sys.modules, False otherwise.
|
||||
"""
|
||||
saved = True
|
||||
try:
|
||||
orig_modules[name] = sys.modules[name]
|
||||
except KeyError:
|
||||
saved = False
|
||||
sys.modules[name] = None
|
||||
return saved
|
||||
if modname in names or modname.startswith(prefixes):
|
||||
orig_modules[modname] = sys.modules.pop(modname)
|
||||
return orig_modules
|
||||
|
||||
|
||||
def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
|
||||
|
|
@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
|
|||
this operation.
|
||||
|
||||
*fresh* is an iterable of additional module names that are also removed
|
||||
from the sys.modules cache before doing the import.
|
||||
from the sys.modules cache before doing the import. If one of these
|
||||
modules can't be imported, None is returned.
|
||||
|
||||
*blocked* is an iterable of module names that are replaced with None
|
||||
in the module cache during the import to ensure that attempts to import
|
||||
|
|
@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
|
|||
with _ignore_deprecated_imports(deprecated):
|
||||
# Keep track of modules saved for later restoration as well
|
||||
# as those which just need a blocking entry removed
|
||||
orig_modules = {}
|
||||
names_to_remove = []
|
||||
_save_and_remove_module(name, orig_modules)
|
||||
fresh = list(fresh)
|
||||
blocked = list(blocked)
|
||||
names = {name, *fresh, *blocked}
|
||||
orig_modules = _save_and_remove_modules(names)
|
||||
for modname in blocked:
|
||||
sys.modules[modname] = None
|
||||
|
||||
try:
|
||||
for fresh_name in fresh:
|
||||
_save_and_remove_module(fresh_name, orig_modules)
|
||||
for blocked_name in blocked:
|
||||
if not _save_and_block_module(blocked_name, orig_modules):
|
||||
names_to_remove.append(blocked_name)
|
||||
fresh_module = importlib.import_module(name)
|
||||
except ImportError:
|
||||
fresh_module = None
|
||||
# Return None when one of the "fresh" modules can not be imported.
|
||||
try:
|
||||
for modname in fresh:
|
||||
__import__(modname)
|
||||
except ImportError:
|
||||
return None
|
||||
return importlib.import_module(name)
|
||||
finally:
|
||||
for orig_name, module in orig_modules.items():
|
||||
sys.modules[orig_name] = module
|
||||
for name_to_remove in names_to_remove:
|
||||
del sys.modules[name_to_remove]
|
||||
return fresh_module
|
||||
_save_and_remove_modules(names)
|
||||
sys.modules.update(orig_modules)
|
||||
|
||||
|
||||
class CleanImport(object):
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ if sys.platform == 'darwin':
|
|||
|
||||
C = import_fresh_module('decimal', fresh=['_decimal'])
|
||||
P = import_fresh_module('decimal', blocked=['_decimal'])
|
||||
orig_sys_decimal = sys.modules['decimal']
|
||||
import decimal as orig_sys_decimal
|
||||
|
||||
# fractions module must import the correct decimal module.
|
||||
cfractions = import_fresh_module('fractions', fresh=['fractions'])
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from itertools import product, islice
|
|||
from test import support
|
||||
from test.support import os_helper
|
||||
from test.support import warnings_helper
|
||||
from test.support import findfile, gc_collect, swap_attr
|
||||
from test.support import findfile, gc_collect, swap_attr, swap_item
|
||||
from test.support.import_helper import import_fresh_module
|
||||
from test.support.os_helper import TESTFN
|
||||
|
||||
|
|
@ -167,12 +167,11 @@ class ElementTestCase:
|
|||
cls.modules = {pyET, ET}
|
||||
|
||||
def pickleRoundTrip(self, obj, name, dumper, loader, proto):
|
||||
save_m = sys.modules[name]
|
||||
try:
|
||||
sys.modules[name] = dumper
|
||||
temp = pickle.dumps(obj, proto)
|
||||
sys.modules[name] = loader
|
||||
result = pickle.loads(temp)
|
||||
with swap_item(sys.modules, name, dumper):
|
||||
temp = pickle.dumps(obj, proto)
|
||||
with swap_item(sys.modules, name, loader):
|
||||
result = pickle.loads(temp)
|
||||
except pickle.PicklingError as pe:
|
||||
# pyET must be second, because pyET may be (equal to) ET.
|
||||
human = dict([(ET, "cET"), (pyET, "pyET")])
|
||||
|
|
@ -180,8 +179,6 @@ class ElementTestCase:
|
|||
% (obj,
|
||||
human.get(dumper, dumper),
|
||||
human.get(loader, loader))) from pe
|
||||
finally:
|
||||
sys.modules[name] = save_m
|
||||
return result
|
||||
|
||||
def assertEqualElements(self, alice, bob):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
Fix :func:`test.support.import_helper.import_fresh_module`.
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue