mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
gh-132775: Add _PyPickle_GetXIData() (gh-133107)
There's some extra complexity due to making sure we we get things right when handling functions and classes defined in the __main__ module. This is also reflected in the tests, including the addition of extra functions in test.support.import_helper.
This commit is contained in:
parent
6c522debc2
commit
cb35c11d82
5 changed files with 1056 additions and 55 deletions
|
@ -171,6 +171,13 @@ PyAPI_FUNC(_PyBytes_data_t *) _PyBytes_GetXIDataWrapped(
|
|||
xid_newobjfunc,
|
||||
_PyXIData_t *);
|
||||
|
||||
// _PyObject_GetXIData() for pickle
|
||||
PyAPI_DATA(PyObject *) _PyPickle_LoadFromXIData(_PyXIData_t *);
|
||||
PyAPI_FUNC(int) _PyPickle_GetXIData(
|
||||
PyThreadState *,
|
||||
PyObject *,
|
||||
_PyXIData_t *);
|
||||
|
||||
// _PyObject_GetXIData() for marshal
|
||||
PyAPI_FUNC(PyObject *) _PyMarshal_ReadObjectFromXIData(_PyXIData_t *);
|
||||
PyAPI_FUNC(int) _PyMarshal_GetXIData(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import contextlib
|
||||
import _imp
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
import os
|
||||
import shutil
|
||||
|
@ -332,3 +333,110 @@ def ensure_lazy_imports(imported_module, modules_to_block):
|
|||
)
|
||||
from .script_helper import assert_python_ok
|
||||
assert_python_ok("-S", "-c", script)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def module_restored(name):
|
||||
"""A context manager that restores a module to the original state."""
|
||||
missing = object()
|
||||
orig = sys.modules.get(name, missing)
|
||||
if orig is None:
|
||||
mod = importlib.import_module(name)
|
||||
else:
|
||||
mod = type(sys)(name)
|
||||
mod.__dict__.update(orig.__dict__)
|
||||
sys.modules[name] = mod
|
||||
try:
|
||||
yield mod
|
||||
finally:
|
||||
if orig is missing:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = orig
|
||||
|
||||
|
||||
def create_module(name, loader=None, *, ispkg=False):
|
||||
"""Return a new, empty module."""
|
||||
spec = importlib.machinery.ModuleSpec(
|
||||
name,
|
||||
loader,
|
||||
origin='<import_helper>',
|
||||
is_package=ispkg,
|
||||
)
|
||||
return importlib.util.module_from_spec(spec)
|
||||
|
||||
|
||||
def _ensure_module(name, ispkg, addparent, clearnone):
|
||||
try:
|
||||
mod = orig = sys.modules[name]
|
||||
except KeyError:
|
||||
mod = orig = None
|
||||
missing = True
|
||||
else:
|
||||
missing = False
|
||||
if mod is not None:
|
||||
# It was already imported.
|
||||
return mod, orig, missing
|
||||
# Otherwise, None means it was explicitly disabled.
|
||||
|
||||
assert name != '__main__'
|
||||
if not missing:
|
||||
assert orig is None, (name, sys.modules[name])
|
||||
if not clearnone:
|
||||
raise ModuleNotFoundError(name)
|
||||
del sys.modules[name]
|
||||
# Try normal import, then fall back to adding the module.
|
||||
try:
|
||||
mod = importlib.import_module(name)
|
||||
except ModuleNotFoundError:
|
||||
if addparent and not clearnone:
|
||||
addparent = None
|
||||
mod = _add_module(name, ispkg, addparent)
|
||||
return mod, orig, missing
|
||||
|
||||
|
||||
def _add_module(spec, ispkg, addparent):
|
||||
if isinstance(spec, str):
|
||||
name = spec
|
||||
mod = create_module(name, ispkg=ispkg)
|
||||
spec = mod.__spec__
|
||||
else:
|
||||
name = spec.name
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
sys.modules[name] = mod
|
||||
if addparent is not False and spec.parent:
|
||||
_ensure_module(spec.parent, True, addparent, bool(addparent))
|
||||
return mod
|
||||
|
||||
|
||||
def add_module(spec, *, parents=True):
|
||||
"""Return the module after creating it and adding it to sys.modules.
|
||||
|
||||
If parents is True then also create any missing parents.
|
||||
"""
|
||||
return _add_module(spec, False, parents)
|
||||
|
||||
|
||||
def add_package(spec, *, parents=True):
|
||||
"""Return the module after creating it and adding it to sys.modules.
|
||||
|
||||
If parents is True then also create any missing parents.
|
||||
"""
|
||||
return _add_module(spec, True, parents)
|
||||
|
||||
|
||||
def ensure_module_imported(name, *, clearnone=True):
|
||||
"""Return the corresponding module.
|
||||
|
||||
If it was already imported then return that. Otherwise, try
|
||||
importing it (optionally clear it first if None). If that fails
|
||||
then create a new empty module.
|
||||
|
||||
It can be helpful to combine this with ready_to_import() and/or
|
||||
isolated_modules().
|
||||
"""
|
||||
if sys.modules.get(name) is not None:
|
||||
mod = sys.modules[name]
|
||||
else:
|
||||
mod, _, _ = _force_import(name, False, True, clearnone)
|
||||
return mod
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import contextlib
|
||||
import importlib
|
||||
import importlib.util
|
||||
import itertools
|
||||
import sys
|
||||
import types
|
||||
|
@ -9,7 +12,7 @@ _testinternalcapi = import_helper.import_module('_testinternalcapi')
|
|||
_interpreters = import_helper.import_module('_interpreters')
|
||||
from _interpreters import NotShareableError
|
||||
|
||||
|
||||
from test import _code_definitions as code_defs
|
||||
from test import _crossinterp_definitions as defs
|
||||
|
||||
|
||||
|
@ -21,6 +24,88 @@ OTHER_TYPES = [o for n, o in vars(types).items()
|
|||
if (isinstance(o, type) and
|
||||
n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
|
||||
|
||||
DEFS = defs
|
||||
with open(code_defs.__file__) as infile:
|
||||
_code_defs_text = infile.read()
|
||||
with open(DEFS.__file__) as infile:
|
||||
_defs_text = infile.read()
|
||||
_defs_text = _defs_text.replace('from ', '# from ')
|
||||
DEFS_TEXT = f"""
|
||||
#######################################
|
||||
# from {code_defs.__file__}
|
||||
|
||||
{_code_defs_text}
|
||||
|
||||
#######################################
|
||||
# from {defs.__file__}
|
||||
|
||||
{_defs_text}
|
||||
"""
|
||||
del infile, _code_defs_text, _defs_text
|
||||
|
||||
|
||||
def load_defs(module=None):
|
||||
"""Return a new copy of the test._crossinterp_definitions module.
|
||||
|
||||
The module's __name__ matches the "module" arg, which is either
|
||||
a str or a module.
|
||||
|
||||
If the "module" arg is a module then the just-loaded defs are also
|
||||
copied into that module.
|
||||
|
||||
Note that the new module is not added to sys.modules.
|
||||
"""
|
||||
if module is None:
|
||||
modname = DEFS.__name__
|
||||
elif isinstance(module, str):
|
||||
modname = module
|
||||
module = None
|
||||
else:
|
||||
modname = module.__name__
|
||||
# Create the new module and populate it.
|
||||
defs = import_helper.create_module(modname)
|
||||
defs.__file__ = DEFS.__file__
|
||||
exec(DEFS_TEXT, defs.__dict__)
|
||||
# Copy the defs into the module arg, if any.
|
||||
if module is not None:
|
||||
for name, value in defs.__dict__.items():
|
||||
if name.startswith('_'):
|
||||
continue
|
||||
assert not hasattr(module, name), (name, getattr(module, name))
|
||||
setattr(module, name, value)
|
||||
return defs
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using___main__():
|
||||
"""Make sure __main__ module exists (and clean up after)."""
|
||||
modname = '__main__'
|
||||
if modname not in sys.modules:
|
||||
with import_helper.isolated_modules():
|
||||
yield import_helper.add_module(modname)
|
||||
else:
|
||||
with import_helper.module_restored(modname) as mod:
|
||||
yield mod
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_module(modname):
|
||||
"""Create the module and add to sys.modules, then remove it after."""
|
||||
assert modname not in sys.modules, (modname,)
|
||||
with import_helper.isolated_modules():
|
||||
yield import_helper.add_module(modname)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def missing_defs_module(modname, *, prep=False):
|
||||
assert modname not in sys.modules, (modname,)
|
||||
if prep:
|
||||
with import_helper.ready_to_import(modname, DEFS_TEXT):
|
||||
yield modname
|
||||
else:
|
||||
with import_helper.isolated_modules():
|
||||
yield modname
|
||||
|
||||
|
||||
class _GetXIDataTests(unittest.TestCase):
|
||||
|
||||
|
@ -32,52 +117,49 @@ class _GetXIDataTests(unittest.TestCase):
|
|||
|
||||
def get_roundtrip(self, obj, *, mode=None):
|
||||
mode = self._resolve_mode(mode)
|
||||
xid =_testinternalcapi.get_crossinterp_data(obj, mode)
|
||||
return self._get_roundtrip(obj, mode)
|
||||
|
||||
def _get_roundtrip(self, obj, mode):
|
||||
xid = _testinternalcapi.get_crossinterp_data(obj, mode)
|
||||
return _testinternalcapi.restore_crossinterp_data(xid)
|
||||
|
||||
def iter_roundtrip_values(self, values, *, mode=None):
|
||||
def assert_roundtrip_identical(self, values, *, mode=None):
|
||||
mode = self._resolve_mode(mode)
|
||||
for obj in values:
|
||||
with self.subTest(obj):
|
||||
xid = _testinternalcapi.get_crossinterp_data(obj, mode)
|
||||
got = _testinternalcapi.restore_crossinterp_data(xid)
|
||||
yield obj, got
|
||||
|
||||
def assert_roundtrip_identical(self, values, *, mode=None):
|
||||
for obj, got in self.iter_roundtrip_values(values, mode=mode):
|
||||
# XXX What about between interpreters?
|
||||
got = self._get_roundtrip(obj, mode)
|
||||
self.assertIs(got, obj)
|
||||
|
||||
def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None):
|
||||
for obj, got in self.iter_roundtrip_values(values, mode=mode):
|
||||
mode = self._resolve_mode(mode)
|
||||
for obj in values:
|
||||
with self.subTest(obj):
|
||||
got = self._get_roundtrip(obj, mode)
|
||||
self.assertEqual(got, obj)
|
||||
self.assertIs(type(got),
|
||||
type(obj) if expecttype is None else expecttype)
|
||||
|
||||
# def assert_roundtrip_equal_not_identical(self, values, *,
|
||||
# mode=None, expecttype=None):
|
||||
# mode = self._resolve_mode(mode)
|
||||
# for obj in values:
|
||||
# cls = type(obj)
|
||||
# with self.subTest(obj):
|
||||
# got = self._get_roundtrip(obj, mode)
|
||||
# self.assertIsNot(got, obj)
|
||||
# self.assertIs(type(got), type(obj))
|
||||
# self.assertEqual(got, obj)
|
||||
# self.assertIs(type(got),
|
||||
# cls if expecttype is None else expecttype)
|
||||
#
|
||||
# def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None):
|
||||
# mode = self._resolve_mode(mode)
|
||||
# for obj in values:
|
||||
# cls = type(obj)
|
||||
# with self.subTest(obj):
|
||||
# got = self._get_roundtrip(obj, mode)
|
||||
# self.assertIsNot(got, obj)
|
||||
# self.assertIs(type(got), type(obj))
|
||||
# self.assertNotEqual(got, obj)
|
||||
# self.assertIs(type(got),
|
||||
# cls if expecttype is None else expecttype)
|
||||
def assert_roundtrip_equal_not_identical(self, values, *,
|
||||
mode=None, expecttype=None):
|
||||
mode = self._resolve_mode(mode)
|
||||
for obj in values:
|
||||
with self.subTest(obj):
|
||||
got = self._get_roundtrip(obj, mode)
|
||||
self.assertIsNot(got, obj)
|
||||
self.assertIs(type(got),
|
||||
type(obj) if expecttype is None else expecttype)
|
||||
self.assertEqual(got, obj)
|
||||
|
||||
def assert_roundtrip_not_equal(self, values, *,
|
||||
mode=None, expecttype=None):
|
||||
mode = self._resolve_mode(mode)
|
||||
for obj in values:
|
||||
with self.subTest(obj):
|
||||
got = self._get_roundtrip(obj, mode)
|
||||
self.assertIsNot(got, obj)
|
||||
self.assertIs(type(got),
|
||||
type(obj) if expecttype is None else expecttype)
|
||||
self.assertNotEqual(got, obj)
|
||||
|
||||
def assert_not_shareable(self, values, exctype=None, *, mode=None):
|
||||
mode = self._resolve_mode(mode)
|
||||
|
@ -95,6 +177,363 @@ class _GetXIDataTests(unittest.TestCase):
|
|||
return mode
|
||||
|
||||
|
||||
class PickleTests(_GetXIDataTests):
|
||||
|
||||
MODE = 'pickle'
|
||||
|
||||
def test_shareable(self):
|
||||
self.assert_roundtrip_equal([
|
||||
# singletons
|
||||
None,
|
||||
True,
|
||||
False,
|
||||
# bytes
|
||||
*(i.to_bytes(2, 'little', signed=True)
|
||||
for i in range(-1, 258)),
|
||||
# str
|
||||
'hello world',
|
||||
'你好世界',
|
||||
'',
|
||||
# int
|
||||
sys.maxsize,
|
||||
-sys.maxsize - 1,
|
||||
*range(-1, 258),
|
||||
# float
|
||||
0.0,
|
||||
1.1,
|
||||
-1.0,
|
||||
0.12345678,
|
||||
-0.12345678,
|
||||
# tuple
|
||||
(),
|
||||
(1,),
|
||||
("hello", "world", ),
|
||||
(1, True, "hello"),
|
||||
((1,),),
|
||||
((1, 2), (3, 4)),
|
||||
((1, 2), (3, 4), (5, 6)),
|
||||
])
|
||||
# not shareable using xidata
|
||||
self.assert_roundtrip_equal([
|
||||
# int
|
||||
sys.maxsize + 1,
|
||||
-sys.maxsize - 2,
|
||||
2**1000,
|
||||
# tuple
|
||||
(0, 1.0, []),
|
||||
(0, 1.0, {}),
|
||||
(0, 1.0, ([],)),
|
||||
(0, 1.0, ({},)),
|
||||
])
|
||||
|
||||
def test_list(self):
|
||||
self.assert_roundtrip_equal_not_identical([
|
||||
[],
|
||||
[1, 2, 3],
|
||||
[[1], (2,), {3: 4}],
|
||||
])
|
||||
|
||||
def test_dict(self):
|
||||
self.assert_roundtrip_equal_not_identical([
|
||||
{},
|
||||
{1: 7, 2: 8, 3: 9},
|
||||
{1: [1], 2: (2,), 3: {3: 4}},
|
||||
])
|
||||
|
||||
def test_set(self):
|
||||
self.assert_roundtrip_equal_not_identical([
|
||||
set(),
|
||||
{1, 2, 3},
|
||||
{frozenset({1}), (2,)},
|
||||
])
|
||||
|
||||
# classes
|
||||
|
||||
def assert_class_defs_same(self, defs):
|
||||
# Unpickle relative to the unchanged original module.
|
||||
self.assert_roundtrip_identical(defs.TOP_CLASSES)
|
||||
|
||||
instances = []
|
||||
for cls, args in defs.TOP_CLASSES.items():
|
||||
if cls in defs.CLASSES_WITHOUT_EQUALITY:
|
||||
continue
|
||||
instances.append(cls(*args))
|
||||
self.assert_roundtrip_equal_not_identical(instances)
|
||||
|
||||
# these don't compare equal
|
||||
instances = []
|
||||
for cls, args in defs.TOP_CLASSES.items():
|
||||
if cls not in defs.CLASSES_WITHOUT_EQUALITY:
|
||||
continue
|
||||
instances.append(cls(*args))
|
||||
self.assert_roundtrip_not_equal(instances)
|
||||
|
||||
def assert_class_defs_other_pickle(self, defs, mod):
|
||||
# Pickle relative to a different module than the original.
|
||||
for cls in defs.TOP_CLASSES:
|
||||
assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__))
|
||||
self.assert_not_shareable(defs.TOP_CLASSES)
|
||||
|
||||
instances = []
|
||||
for cls, args in defs.TOP_CLASSES.items():
|
||||
instances.append(cls(*args))
|
||||
self.assert_not_shareable(instances)
|
||||
|
||||
def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False):
|
||||
# Unpickle relative to a different module than the original.
|
||||
for cls in defs.TOP_CLASSES:
|
||||
assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__))
|
||||
|
||||
instances = []
|
||||
for cls, args in defs.TOP_CLASSES.items():
|
||||
with self.subTest(cls):
|
||||
setattr(mod, cls.__name__, cls)
|
||||
xid = self.get_xidata(cls)
|
||||
inst = cls(*args)
|
||||
instxid = self.get_xidata(inst)
|
||||
instances.append(
|
||||
(cls, xid, inst, instxid))
|
||||
|
||||
for cls, xid, inst, instxid in instances:
|
||||
with self.subTest(cls):
|
||||
delattr(mod, cls.__name__)
|
||||
if fail:
|
||||
with self.assertRaises(NotShareableError):
|
||||
_testinternalcapi.restore_crossinterp_data(xid)
|
||||
continue
|
||||
got = _testinternalcapi.restore_crossinterp_data(xid)
|
||||
self.assertIsNot(got, cls)
|
||||
self.assertNotEqual(got, cls)
|
||||
|
||||
gotcls = got
|
||||
got = _testinternalcapi.restore_crossinterp_data(instxid)
|
||||
self.assertIsNot(got, inst)
|
||||
self.assertIs(type(got), gotcls)
|
||||
if cls in defs.CLASSES_WITHOUT_EQUALITY:
|
||||
self.assertNotEqual(got, inst)
|
||||
elif cls in defs.BUILTIN_SUBCLASSES:
|
||||
self.assertEqual(got, inst)
|
||||
else:
|
||||
self.assertNotEqual(got, inst)
|
||||
|
||||
def assert_class_defs_not_shareable(self, defs):
|
||||
self.assert_not_shareable(defs.TOP_CLASSES)
|
||||
|
||||
instances = []
|
||||
for cls, args in defs.TOP_CLASSES.items():
|
||||
instances.append(cls(*args))
|
||||
self.assert_not_shareable(instances)
|
||||
|
||||
def test_user_class_normal(self):
|
||||
self.assert_class_defs_same(defs)
|
||||
|
||||
def test_user_class_in___main__(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs(mod)
|
||||
self.assert_class_defs_same(defs)
|
||||
|
||||
def test_user_class_not_in___main___with_filename(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs('__main__')
|
||||
assert defs.__file__
|
||||
mod.__file__ = defs.__file__
|
||||
self.assert_class_defs_not_shareable(defs)
|
||||
|
||||
def test_user_class_not_in___main___without_filename(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs('__main__')
|
||||
defs.__file__ = None
|
||||
mod.__file__ = None
|
||||
self.assert_class_defs_not_shareable(defs)
|
||||
|
||||
def test_user_class_not_in___main___unpickle_with_filename(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs('__main__')
|
||||
assert defs.__file__
|
||||
mod.__file__ = defs.__file__
|
||||
self.assert_class_defs_other_unpickle(defs, mod)
|
||||
|
||||
def test_user_class_not_in___main___unpickle_without_filename(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs('__main__')
|
||||
defs.__file__ = None
|
||||
mod.__file__ = None
|
||||
self.assert_class_defs_other_unpickle(defs, mod, fail=True)
|
||||
|
||||
def test_user_class_in_module(self):
|
||||
with temp_module('__spam__') as mod:
|
||||
defs = load_defs(mod)
|
||||
self.assert_class_defs_same(defs)
|
||||
|
||||
def test_user_class_not_in_module_with_filename(self):
|
||||
with temp_module('__spam__') as mod:
|
||||
defs = load_defs(mod.__name__)
|
||||
assert defs.__file__
|
||||
# For now, we only address this case for __main__.
|
||||
self.assert_class_defs_not_shareable(defs)
|
||||
|
||||
def test_user_class_not_in_module_without_filename(self):
|
||||
with temp_module('__spam__') as mod:
|
||||
defs = load_defs(mod.__name__)
|
||||
defs.__file__ = None
|
||||
self.assert_class_defs_not_shareable(defs)
|
||||
|
||||
def test_user_class_module_missing_then_imported(self):
|
||||
with missing_defs_module('__spam__', prep=True) as modname:
|
||||
defs = load_defs(modname)
|
||||
# For now, we only address this case for __main__.
|
||||
self.assert_class_defs_not_shareable(defs)
|
||||
|
||||
def test_user_class_module_missing_not_available(self):
|
||||
with missing_defs_module('__spam__') as modname:
|
||||
defs = load_defs(modname)
|
||||
self.assert_class_defs_not_shareable(defs)
|
||||
|
||||
def test_nested_class(self):
|
||||
eggs = defs.EggsNested()
|
||||
with self.assertRaises(NotShareableError):
|
||||
self.get_roundtrip(eggs)
|
||||
|
||||
# functions
|
||||
|
||||
def assert_func_defs_same(self, defs):
|
||||
# Unpickle relative to the unchanged original module.
|
||||
self.assert_roundtrip_identical(defs.TOP_FUNCTIONS)
|
||||
|
||||
def assert_func_defs_other_pickle(self, defs, mod):
|
||||
# Pickle relative to a different module than the original.
|
||||
for func in defs.TOP_FUNCTIONS:
|
||||
assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
|
||||
self.assert_not_shareable(defs.TOP_FUNCTIONS)
|
||||
|
||||
def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False):
|
||||
# Unpickle relative to a different module than the original.
|
||||
for func in defs.TOP_FUNCTIONS:
|
||||
assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
|
||||
|
||||
captured = []
|
||||
for func in defs.TOP_FUNCTIONS:
|
||||
with self.subTest(func):
|
||||
setattr(mod, func.__name__, func)
|
||||
xid = self.get_xidata(func)
|
||||
captured.append(
|
||||
(func, xid))
|
||||
|
||||
for func, xid in captured:
|
||||
with self.subTest(func):
|
||||
delattr(mod, func.__name__)
|
||||
if fail:
|
||||
with self.assertRaises(NotShareableError):
|
||||
_testinternalcapi.restore_crossinterp_data(xid)
|
||||
continue
|
||||
got = _testinternalcapi.restore_crossinterp_data(xid)
|
||||
self.assertIsNot(got, func)
|
||||
self.assertNotEqual(got, func)
|
||||
|
||||
def assert_func_defs_not_shareable(self, defs):
|
||||
self.assert_not_shareable(defs.TOP_FUNCTIONS)
|
||||
|
||||
def test_user_function_normal(self):
|
||||
# self.assert_roundtrip_equal(defs.TOP_FUNCTIONS)
|
||||
self.assert_func_defs_same(defs)
|
||||
|
||||
def test_user_func_in___main__(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs(mod)
|
||||
self.assert_func_defs_same(defs)
|
||||
|
||||
def test_user_func_not_in___main___with_filename(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs('__main__')
|
||||
assert defs.__file__
|
||||
mod.__file__ = defs.__file__
|
||||
self.assert_func_defs_not_shareable(defs)
|
||||
|
||||
def test_user_func_not_in___main___without_filename(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs('__main__')
|
||||
defs.__file__ = None
|
||||
mod.__file__ = None
|
||||
self.assert_func_defs_not_shareable(defs)
|
||||
|
||||
def test_user_func_not_in___main___unpickle_with_filename(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs('__main__')
|
||||
assert defs.__file__
|
||||
mod.__file__ = defs.__file__
|
||||
self.assert_func_defs_other_unpickle(defs, mod)
|
||||
|
||||
def test_user_func_not_in___main___unpickle_without_filename(self):
|
||||
with using___main__() as mod:
|
||||
defs = load_defs('__main__')
|
||||
defs.__file__ = None
|
||||
mod.__file__ = None
|
||||
self.assert_func_defs_other_unpickle(defs, mod, fail=True)
|
||||
|
||||
def test_user_func_in_module(self):
|
||||
with temp_module('__spam__') as mod:
|
||||
defs = load_defs(mod)
|
||||
self.assert_func_defs_same(defs)
|
||||
|
||||
def test_user_func_not_in_module_with_filename(self):
|
||||
with temp_module('__spam__') as mod:
|
||||
defs = load_defs(mod.__name__)
|
||||
assert defs.__file__
|
||||
# For now, we only address this case for __main__.
|
||||
self.assert_func_defs_not_shareable(defs)
|
||||
|
||||
def test_user_func_not_in_module_without_filename(self):
|
||||
with temp_module('__spam__') as mod:
|
||||
defs = load_defs(mod.__name__)
|
||||
defs.__file__ = None
|
||||
self.assert_func_defs_not_shareable(defs)
|
||||
|
||||
def test_user_func_module_missing_then_imported(self):
|
||||
with missing_defs_module('__spam__', prep=True) as modname:
|
||||
defs = load_defs(modname)
|
||||
# For now, we only address this case for __main__.
|
||||
self.assert_func_defs_not_shareable(defs)
|
||||
|
||||
def test_user_func_module_missing_not_available(self):
|
||||
with missing_defs_module('__spam__') as modname:
|
||||
defs = load_defs(modname)
|
||||
self.assert_func_defs_not_shareable(defs)
|
||||
|
||||
def test_nested_function(self):
|
||||
self.assert_not_shareable(defs.NESTED_FUNCTIONS)
|
||||
|
||||
# exceptions
|
||||
|
||||
def test_user_exception_normal(self):
|
||||
self.assert_roundtrip_not_equal([
|
||||
defs.MimimalError('error!'),
|
||||
])
|
||||
self.assert_roundtrip_equal_not_identical([
|
||||
defs.RichError('error!', 42),
|
||||
])
|
||||
|
||||
def test_builtin_exception(self):
|
||||
msg = 'error!'
|
||||
try:
|
||||
raise Exception
|
||||
except Exception as exc:
|
||||
caught = exc
|
||||
special = {
|
||||
BaseExceptionGroup: (msg, [caught]),
|
||||
ExceptionGroup: (msg, [caught]),
|
||||
# UnicodeError: (None, msg, None, None, None),
|
||||
UnicodeEncodeError: ('utf-8', '', 1, 3, msg),
|
||||
UnicodeDecodeError: ('utf-8', b'', 1, 3, msg),
|
||||
UnicodeTranslateError: ('', 1, 3, msg),
|
||||
}
|
||||
exceptions = []
|
||||
for cls in EXCEPTION_TYPES:
|
||||
args = special.get(cls) or (msg,)
|
||||
exceptions.append(cls(*args))
|
||||
|
||||
self.assert_roundtrip_not_equal(exceptions)
|
||||
|
||||
|
||||
class MarshalTests(_GetXIDataTests):
|
||||
|
||||
MODE = 'marshal'
|
||||
|
@ -444,22 +883,12 @@ class ShareableTypeTests(_GetXIDataTests):
|
|||
])
|
||||
|
||||
def test_class(self):
|
||||
self.assert_not_shareable([
|
||||
defs.Spam,
|
||||
defs.SpamOkay,
|
||||
defs.SpamFull,
|
||||
defs.SubSpamFull,
|
||||
defs.SubTuple,
|
||||
defs.EggsNested,
|
||||
])
|
||||
self.assert_not_shareable([
|
||||
defs.Spam(),
|
||||
defs.SpamOkay(),
|
||||
defs.SpamFull(1, 2, 3),
|
||||
defs.SubSpamFull(1, 2, 3),
|
||||
defs.SubTuple([1, 2, 3]),
|
||||
defs.EggsNested(),
|
||||
])
|
||||
self.assert_not_shareable(defs.CLASSES)
|
||||
|
||||
instances = []
|
||||
for cls, args in defs.CLASSES.items():
|
||||
instances.append(cls(*args))
|
||||
self.assert_not_shareable(instances)
|
||||
|
||||
def test_builtin_type(self):
|
||||
self.assert_not_shareable([
|
||||
|
|
|
@ -1939,6 +1939,11 @@ get_crossinterp_data(PyObject *self, PyObject *args, PyObject *kwargs)
|
|||
goto error;
|
||||
}
|
||||
}
|
||||
else if (strcmp(mode, "pickle") == 0) {
|
||||
if (_PyPickle_GetXIData(tstate, obj, xidata) != 0) {
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
else if (strcmp(mode, "marshal") == 0) {
|
||||
if (_PyMarshal_GetXIData(tstate, obj, xidata) != 0) {
|
||||
goto error;
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "Python.h"
|
||||
#include "marshal.h" // PyMarshal_WriteObjectToString()
|
||||
#include "osdefs.h" // MAXPATHLEN
|
||||
#include "pycore_ceval.h" // _Py_simple_func
|
||||
#include "pycore_crossinterp.h" // _PyXIData_t
|
||||
#include "pycore_initconfig.h" // _PyStatus_OK()
|
||||
|
@ -10,6 +11,155 @@
|
|||
#include "pycore_typeobject.h" // _PyStaticType_InitBuiltin()
|
||||
|
||||
|
||||
static Py_ssize_t
|
||||
_Py_GetMainfile(char *buffer, size_t maxlen)
|
||||
{
|
||||
// We don't expect subinterpreters to have the __main__ module's
|
||||
// __name__ set, but proceed just in case.
|
||||
PyThreadState *tstate = _PyThreadState_GET();
|
||||
PyObject *module = _Py_GetMainModule(tstate);
|
||||
if (_Py_CheckMainModule(module) < 0) {
|
||||
return -1;
|
||||
}
|
||||
Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen);
|
||||
Py_DECREF(module);
|
||||
return size;
|
||||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
import_get_module(PyThreadState *tstate, const char *modname)
|
||||
{
|
||||
PyObject *module = NULL;
|
||||
if (strcmp(modname, "__main__") == 0) {
|
||||
module = _Py_GetMainModule(tstate);
|
||||
if (_Py_CheckMainModule(module) < 0) {
|
||||
assert(_PyErr_Occurred(tstate));
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
module = PyImport_ImportModule(modname);
|
||||
if (module == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
runpy_run_path(const char *filename, const char *modname)
|
||||
{
|
||||
PyObject *run_path = PyImport_ImportModuleAttrString("runpy", "run_path");
|
||||
if (run_path == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject *args = Py_BuildValue("(sOs)", filename, Py_None, modname);
|
||||
if (args == NULL) {
|
||||
Py_DECREF(run_path);
|
||||
return NULL;
|
||||
}
|
||||
PyObject *ns = PyObject_Call(run_path, args, NULL);
|
||||
Py_DECREF(run_path);
|
||||
Py_DECREF(args);
|
||||
return ns;
|
||||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
pyerr_get_message(PyObject *exc)
|
||||
{
|
||||
assert(!PyErr_Occurred());
|
||||
PyObject *args = PyException_GetArgs(exc);
|
||||
if (args == NULL || args == Py_None || PyObject_Size(args) < 1) {
|
||||
return NULL;
|
||||
}
|
||||
if (PyUnicode_Check(args)) {
|
||||
return args;
|
||||
}
|
||||
PyObject *msg = PySequence_GetItem(args, 0);
|
||||
Py_DECREF(args);
|
||||
if (msg == NULL) {
|
||||
PyErr_Clear();
|
||||
return NULL;
|
||||
}
|
||||
if (!PyUnicode_Check(msg)) {
|
||||
Py_DECREF(msg);
|
||||
return NULL;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
#define MAX_MODNAME (255)
|
||||
#define MAX_ATTRNAME (255)
|
||||
|
||||
struct attributeerror_info {
|
||||
char modname[MAX_MODNAME+1];
|
||||
char attrname[MAX_ATTRNAME+1];
|
||||
};
|
||||
|
||||
static int
|
||||
_parse_attributeerror(PyObject *exc, struct attributeerror_info *info)
|
||||
{
|
||||
assert(exc != NULL);
|
||||
assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
|
||||
int res = -1;
|
||||
|
||||
PyObject *msgobj = pyerr_get_message(exc);
|
||||
if (msgobj == NULL) {
|
||||
return -1;
|
||||
}
|
||||
const char *err = PyUnicode_AsUTF8(msgobj);
|
||||
|
||||
if (strncmp(err, "module '", 8) != 0) {
|
||||
goto finally;
|
||||
}
|
||||
err += 8;
|
||||
|
||||
const char *matched = strchr(err, '\'');
|
||||
if (matched == NULL) {
|
||||
goto finally;
|
||||
}
|
||||
Py_ssize_t len = matched - err;
|
||||
if (len > MAX_MODNAME) {
|
||||
goto finally;
|
||||
}
|
||||
(void)strncpy(info->modname, err, len);
|
||||
info->modname[len] = '\0';
|
||||
err = matched;
|
||||
|
||||
if (strncmp(err, "' has no attribute '", 20) != 0) {
|
||||
goto finally;
|
||||
}
|
||||
err += 20;
|
||||
|
||||
matched = strchr(err, '\'');
|
||||
if (matched == NULL) {
|
||||
goto finally;
|
||||
}
|
||||
len = matched - err;
|
||||
if (len > MAX_ATTRNAME) {
|
||||
goto finally;
|
||||
}
|
||||
(void)strncpy(info->attrname, err, len);
|
||||
info->attrname[len] = '\0';
|
||||
err = matched + 1;
|
||||
|
||||
if (strlen(err) > 0) {
|
||||
goto finally;
|
||||
}
|
||||
res = 0;
|
||||
|
||||
finally:
|
||||
Py_DECREF(msgobj);
|
||||
return res;
|
||||
}
|
||||
|
||||
#undef MAX_MODNAME
|
||||
#undef MAX_ATTRNAME
|
||||
|
||||
|
||||
/**************/
|
||||
/* exceptions */
|
||||
/**************/
|
||||
|
@ -287,6 +437,308 @@ _PyObject_GetXIData(PyThreadState *tstate,
|
|||
}
|
||||
|
||||
|
||||
/* pickle C-API */
|
||||
|
||||
struct _pickle_context {
|
||||
PyThreadState *tstate;
|
||||
};
|
||||
|
||||
static PyObject *
|
||||
_PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj)
|
||||
{
|
||||
PyObject *dumps = PyImport_ImportModuleAttrString("pickle", "dumps");
|
||||
if (dumps == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject *bytes = PyObject_CallOneArg(dumps, obj);
|
||||
Py_DECREF(dumps);
|
||||
return bytes;
|
||||
}
|
||||
|
||||
|
||||
struct sync_module_result {
|
||||
PyObject *module;
|
||||
PyObject *loaded;
|
||||
PyObject *failed;
|
||||
};
|
||||
|
||||
struct sync_module {
|
||||
const char *filename;
|
||||
char _filename[MAXPATHLEN+1];
|
||||
struct sync_module_result cached;
|
||||
};
|
||||
|
||||
static void
|
||||
sync_module_clear(struct sync_module *data)
|
||||
{
|
||||
data->filename = NULL;
|
||||
Py_CLEAR(data->cached.module);
|
||||
Py_CLEAR(data->cached.loaded);
|
||||
Py_CLEAR(data->cached.failed);
|
||||
}
|
||||
|
||||
|
||||
struct _unpickle_context {
|
||||
PyThreadState *tstate;
|
||||
// We only special-case the __main__ module,
|
||||
// since other modules behave consistently.
|
||||
struct sync_module main;
|
||||
};
|
||||
|
||||
static void
|
||||
_unpickle_context_clear(struct _unpickle_context *ctx)
|
||||
{
|
||||
sync_module_clear(&ctx->main);
|
||||
}
|
||||
|
||||
static struct sync_module_result
|
||||
_unpickle_context_get_module(struct _unpickle_context *ctx,
|
||||
const char *modname)
|
||||
{
|
||||
if (strcmp(modname, "__main__") == 0) {
|
||||
return ctx->main.cached;
|
||||
}
|
||||
else {
|
||||
return (struct sync_module_result){
|
||||
.failed = PyExc_NotImplementedError,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
static struct sync_module_result
|
||||
_unpickle_context_set_module(struct _unpickle_context *ctx,
|
||||
const char *modname)
|
||||
{
|
||||
struct sync_module_result res = {0};
|
||||
struct sync_module_result *cached = NULL;
|
||||
const char *filename = NULL;
|
||||
if (strcmp(modname, "__main__") == 0) {
|
||||
cached = &ctx->main.cached;
|
||||
filename = ctx->main.filename;
|
||||
}
|
||||
else {
|
||||
res.failed = PyExc_NotImplementedError;
|
||||
goto finally;
|
||||
}
|
||||
|
||||
res.module = import_get_module(ctx->tstate, modname);
|
||||
if (res.module == NULL) {
|
||||
res.failed = _PyErr_GetRaisedException(ctx->tstate);
|
||||
assert(res.failed != NULL);
|
||||
goto finally;
|
||||
}
|
||||
|
||||
if (filename == NULL) {
|
||||
Py_CLEAR(res.module);
|
||||
res.failed = PyExc_NotImplementedError;
|
||||
goto finally;
|
||||
}
|
||||
res.loaded = runpy_run_path(filename, modname);
|
||||
if (res.loaded == NULL) {
|
||||
Py_CLEAR(res.module);
|
||||
res.failed = _PyErr_GetRaisedException(ctx->tstate);
|
||||
assert(res.failed != NULL);
|
||||
goto finally;
|
||||
}
|
||||
|
||||
finally:
|
||||
if (cached != NULL) {
|
||||
assert(cached->module == NULL);
|
||||
assert(cached->loaded == NULL);
|
||||
assert(cached->failed == NULL);
|
||||
*cached = res;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
static int
|
||||
_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc)
|
||||
{
|
||||
// The caller must check if an exception is set or not when -1 is returned.
|
||||
assert(!_PyErr_Occurred(ctx->tstate));
|
||||
assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
|
||||
struct attributeerror_info info;
|
||||
if (_parse_attributeerror(exc, &info) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get the module.
|
||||
struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname);
|
||||
if (mod.failed != NULL) {
|
||||
// It must have failed previously.
|
||||
return -1;
|
||||
}
|
||||
if (mod.module == NULL) {
|
||||
mod = _unpickle_context_set_module(ctx, info.modname);
|
||||
if (mod.failed != NULL) {
|
||||
return -1;
|
||||
}
|
||||
assert(mod.module != NULL);
|
||||
}
|
||||
|
||||
// Bail out if it is unexpectedly set already.
|
||||
if (PyObject_HasAttrString(mod.module, info.attrname)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Try setting the attribute.
|
||||
PyObject *value = NULL;
|
||||
if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) {
|
||||
return -1;
|
||||
}
|
||||
assert(value != NULL);
|
||||
int res = PyObject_SetAttrString(mod.module, info.attrname, value);
|
||||
Py_DECREF(value);
|
||||
if (res < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled)
|
||||
{
|
||||
PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads");
|
||||
if (loads == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject *obj = PyObject_CallOneArg(loads, pickled);
|
||||
if (ctx != NULL) {
|
||||
while (obj == NULL) {
|
||||
assert(_PyErr_Occurred(ctx->tstate));
|
||||
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
|
||||
// We leave other failures unhandled.
|
||||
break;
|
||||
}
|
||||
// Try setting the attr if not set.
|
||||
PyObject *exc = _PyErr_GetRaisedException(ctx->tstate);
|
||||
if (_handle_unpickle_missing_attr(ctx, exc) < 0) {
|
||||
// Any resulting exceptions are ignored
|
||||
// in favor of the original.
|
||||
_PyErr_SetRaisedException(ctx->tstate, exc);
|
||||
break;
|
||||
}
|
||||
Py_CLEAR(exc);
|
||||
// Retry with the attribute set.
|
||||
obj = PyObject_CallOneArg(loads, pickled);
|
||||
}
|
||||
}
|
||||
Py_DECREF(loads);
|
||||
return obj;
|
||||
}
|
||||
|
||||
|
||||
/* pickle wrapper */
|
||||
|
||||
struct _pickle_xid_context {
|
||||
// __main__.__file__
|
||||
struct {
|
||||
const char *utf8;
|
||||
size_t len;
|
||||
char _utf8[MAXPATHLEN+1];
|
||||
} mainfile;
|
||||
};
|
||||
|
||||
static int
|
||||
_set_pickle_xid_context(PyThreadState *tstate, struct _pickle_xid_context *ctx)
|
||||
{
|
||||
// Set mainfile if possible.
|
||||
Py_ssize_t len = _Py_GetMainfile(ctx->mainfile._utf8, MAXPATHLEN);
|
||||
if (len < 0) {
|
||||
// For now we ignore any exceptions.
|
||||
PyErr_Clear();
|
||||
}
|
||||
else if (len > 0) {
|
||||
ctx->mainfile.utf8 = ctx->mainfile._utf8;
|
||||
ctx->mainfile.len = (size_t)len;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
struct _shared_pickle_data {
|
||||
_PyBytes_data_t pickled; // Must be first if we use _PyBytes_FromXIData().
|
||||
struct _pickle_xid_context ctx;
|
||||
};
|
||||
|
||||
PyObject *
|
||||
_PyPickle_LoadFromXIData(_PyXIData_t *xidata)
|
||||
{
|
||||
PyThreadState *tstate = _PyThreadState_GET();
|
||||
struct _shared_pickle_data *shared =
|
||||
(struct _shared_pickle_data *)xidata->data;
|
||||
// We avoid copying the pickled data by wrapping it in a memoryview.
|
||||
// The alternative is to get a bytes object using _PyBytes_FromXIData().
|
||||
PyObject *pickled = PyMemoryView_FromMemory(
|
||||
(char *)shared->pickled.bytes, shared->pickled.len, PyBUF_READ);
|
||||
if (pickled == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Unpickle the object.
|
||||
struct _unpickle_context ctx = {
|
||||
.tstate = tstate,
|
||||
.main = {
|
||||
.filename = shared->ctx.mainfile.utf8,
|
||||
},
|
||||
};
|
||||
PyObject *obj = _PyPickle_Loads(&ctx, pickled);
|
||||
Py_DECREF(pickled);
|
||||
_unpickle_context_clear(&ctx);
|
||||
if (obj == NULL) {
|
||||
PyObject *cause = _PyErr_GetRaisedException(tstate);
|
||||
assert(cause != NULL);
|
||||
_set_xid_lookup_failure(
|
||||
tstate, NULL, "object could not be unpickled", cause);
|
||||
Py_DECREF(cause);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
_PyPickle_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata)
|
||||
{
|
||||
// Pickle the object.
|
||||
struct _pickle_context ctx = {
|
||||
.tstate = tstate,
|
||||
};
|
||||
PyObject *bytes = _PyPickle_Dumps(&ctx, obj);
|
||||
if (bytes == NULL) {
|
||||
PyObject *cause = _PyErr_GetRaisedException(tstate);
|
||||
assert(cause != NULL);
|
||||
_set_xid_lookup_failure(
|
||||
tstate, NULL, "object could not be pickled", cause);
|
||||
Py_DECREF(cause);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// If we had an "unwrapper" mechnanism, we could call
|
||||
// _PyObject_GetXIData() on the bytes object directly and add
|
||||
// a simple unwrapper to call pickle.loads() on the bytes.
|
||||
size_t size = sizeof(struct _shared_pickle_data);
|
||||
struct _shared_pickle_data *shared =
|
||||
(struct _shared_pickle_data *)_PyBytes_GetXIDataWrapped(
|
||||
tstate, bytes, size, _PyPickle_LoadFromXIData, xidata);
|
||||
Py_DECREF(bytes);
|
||||
if (shared == NULL) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// If it mattered, we could skip getting __main__.__file__
|
||||
// when "__main__" doesn't show up in the pickle bytes.
|
||||
if (_set_pickle_xid_context(tstate, &shared->ctx) < 0) {
|
||||
_xidata_clear(xidata);
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
/* marshal wrapper */
|
||||
|
||||
PyObject *
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue