gh-132775: Add _PyMarshal_GetXIData() (gh-133108)

Note that the bulk of this change is tests.
This commit is contained in:
Eric Snow 2025-04-28 17:23:46 -06:00 committed by GitHub
parent 68a737691b
commit bdd23c0bb9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 337 additions and 10 deletions

View file

@ -171,6 +171,13 @@ PyAPI_FUNC(_PyBytes_data_t *) _PyBytes_GetXIDataWrapped(
xid_newobjfunc, xid_newobjfunc,
_PyXIData_t *); _PyXIData_t *);
// _PyObject_GetXIData() for marshal
PyAPI_FUNC(PyObject *) _PyMarshal_ReadObjectFromXIData(_PyXIData_t *);
PyAPI_FUNC(int) _PyMarshal_GetXIData(
PyThreadState *,
PyObject *,
_PyXIData_t *);
/* using cross-interpreter data */ /* using cross-interpreter data */

View file

@ -100,7 +100,7 @@ ham_C_nested, *_ = eggs_closure_N(2)
ham_C_closure, *_ = eggs_closure_C(2) ham_C_closure, *_ = eggs_closure_C(2)
FUNCTIONS = [ TOP_FUNCTIONS = [
# shallow # shallow
spam_minimal, spam_minimal,
spam_full, spam_full,
@ -112,6 +112,8 @@ FUNCTIONS = [
spam_NC, spam_NC,
spam_CN, spam_CN,
spam_CC, spam_CC,
]
NESTED_FUNCTIONS = [
# inner func # inner func
eggs_nested, eggs_nested,
eggs_closure, eggs_closure,
@ -125,6 +127,10 @@ FUNCTIONS = [
ham_C_nested, ham_C_nested,
ham_C_closure, ham_C_closure,
] ]
FUNCTIONS = [
*TOP_FUNCTIONS,
*NESTED_FUNCTIONS,
]
####################################### #######################################
@ -157,8 +163,10 @@ FUNCTION_LIKE = [
gen_spam_1, gen_spam_1,
gen_spam_2, gen_spam_2,
async_spam, async_spam,
coro_spam, # actually FunctionType?
asyncgen_spam, asyncgen_spam,
]
FUNCTION_LIKE_APPLIED = [
coro_spam, # actually FunctionType?
asynccoro_spam, # actually FunctionType? asynccoro_spam, # actually FunctionType?
] ]
@ -202,6 +210,13 @@ class SpamFull:
# __str__ # __str__
# ... # ...
def __eq__(self, other):
if not isinstance(other, SpamFull):
return NotImplemented
return (self.a == other.a and
self.b == other.b and
self.c == other.c)
@property @property
def prop(self): def prop(self):
return True return True
@ -222,9 +237,47 @@ def class_eggs_inner():
EggsNested = class_eggs_inner() EggsNested = class_eggs_inner()
TOP_CLASSES = {
Spam: (),
SpamOkay: (),
SpamFull: (1, 2, 3),
SubSpamFull: (1, 2, 3),
SubTuple: ([1, 2, 3],),
}
CLASSES_WITHOUT_EQUALITY = [
Spam,
SpamOkay,
]
BUILTIN_SUBCLASSES = [
SubTuple,
]
NESTED_CLASSES = {
EggsNested: (),
}
CLASSES = {
**TOP_CLASSES,
**NESTED_CLASSES,
}
####################################### #######################################
# exceptions # exceptions
class MimimalError(Exception): class MimimalError(Exception):
pass pass
class RichError(Exception):
def __init__(self, msg, value=None):
super().__init__(msg, value)
self.msg = msg
self.value = value
def __eq__(self, other):
if not isinstance(other, RichError):
return NotImplemented
if self.msg != other.msg:
return False
if self.value != other.value:
return False
return True

View file

@ -17,6 +17,9 @@ BUILTIN_TYPES = [o for _, o in __builtins__.items()
if isinstance(o, type)] if isinstance(o, type)]
EXCEPTION_TYPES = [cls for cls in BUILTIN_TYPES EXCEPTION_TYPES = [cls for cls in BUILTIN_TYPES
if issubclass(cls, BaseException)] if issubclass(cls, BaseException)]
OTHER_TYPES = [o for n, o in vars(types).items()
if (isinstance(o, type) and
n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
class _GetXIDataTests(unittest.TestCase): class _GetXIDataTests(unittest.TestCase):
@ -40,16 +43,42 @@ class _GetXIDataTests(unittest.TestCase):
got = _testinternalcapi.restore_crossinterp_data(xid) got = _testinternalcapi.restore_crossinterp_data(xid)
yield obj, got yield obj, got
def assert_roundtrip_equal(self, values, *, mode=None):
for obj, got in self.iter_roundtrip_values(values, mode=mode):
self.assertEqual(got, obj)
self.assertIs(type(got), type(obj))
def assert_roundtrip_identical(self, values, *, mode=None): def assert_roundtrip_identical(self, values, *, mode=None):
for obj, got in self.iter_roundtrip_values(values, mode=mode): for obj, got in self.iter_roundtrip_values(values, mode=mode):
# XXX What about between interpreters? # XXX What about between interpreters?
self.assertIs(got, obj) self.assertIs(got, obj)
def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None):
for obj, got in self.iter_roundtrip_values(values, mode=mode):
self.assertEqual(got, obj)
self.assertIs(type(got),
type(obj) if expecttype is None else expecttype)
# def assert_roundtrip_equal_not_identical(self, values, *,
# mode=None, expecttype=None):
# mode = self._resolve_mode(mode)
# for obj in values:
# cls = type(obj)
# with self.subTest(obj):
# got = self._get_roundtrip(obj, mode)
# self.assertIsNot(got, obj)
# self.assertIs(type(got), type(obj))
# self.assertEqual(got, obj)
# self.assertIs(type(got),
# cls if expecttype is None else expecttype)
#
# def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None):
# mode = self._resolve_mode(mode)
# for obj in values:
# cls = type(obj)
# with self.subTest(obj):
# got = self._get_roundtrip(obj, mode)
# self.assertIsNot(got, obj)
# self.assertIs(type(got), type(obj))
# self.assertNotEqual(got, obj)
# self.assertIs(type(got),
# cls if expecttype is None else expecttype)
def assert_not_shareable(self, values, exctype=None, *, mode=None): def assert_not_shareable(self, values, exctype=None, *, mode=None):
mode = self._resolve_mode(mode) mode = self._resolve_mode(mode)
for obj in values: for obj in values:
@ -66,6 +95,197 @@ class _GetXIDataTests(unittest.TestCase):
return mode return mode
class MarshalTests(_GetXIDataTests):
MODE = 'marshal'
def test_simple_builtin_singletons(self):
self.assert_roundtrip_identical([
True,
False,
None,
Ellipsis,
])
self.assert_not_shareable([
NotImplemented,
])
def test_simple_builtin_objects(self):
self.assert_roundtrip_equal([
# int
*range(-1, 258),
sys.maxsize + 1,
sys.maxsize,
-sys.maxsize - 1,
-sys.maxsize - 2,
2**1000,
# complex
1+2j,
# float
0.0,
1.1,
-1.0,
0.12345678,
-0.12345678,
# bytes
*(i.to_bytes(2, 'little', signed=True)
for i in range(-1, 258)),
b'hello world',
# str
'hello world',
'你好世界',
'',
])
self.assert_not_shareable([
object(),
types.SimpleNamespace(),
])
def test_bytearray(self):
# bytearray is special because it unmarshals to bytes, not bytearray.
self.assert_roundtrip_equal([
bytearray(),
bytearray(b'hello world'),
], expecttype=bytes)
def test_compound_immutable_builtin_objects(self):
self.assert_roundtrip_equal([
# tuple
(),
(1,),
("hello", "world"),
(1, True, "hello"),
# frozenset
frozenset([1, 2, 3]),
])
# nested
self.assert_roundtrip_equal([
# tuple
((1,),),
((1, 2), (3, 4)),
((1, 2), (3, 4), (5, 6)),
# frozenset
frozenset([frozenset([1]), frozenset([2]), frozenset([3])]),
])
def test_compound_mutable_builtin_objects(self):
self.assert_roundtrip_equal([
# list
[],
[1, 2, 3],
# dict
{},
{1: 7, 2: 8, 3: 9},
# set
set(),
{1, 2, 3},
])
# nested
self.assert_roundtrip_equal([
[[1], [2], [3]],
{1: {'a': True}, 2: {'b': False}},
{(1, 2, 3,)},
])
def test_compound_builtin_objects_with_bad_items(self):
bogus = object()
self.assert_not_shareable([
(bogus,),
frozenset([bogus]),
[bogus],
{bogus: True},
{True: bogus},
{bogus},
])
def test_builtin_code(self):
self.assert_roundtrip_equal([
*(f.__code__ for f in defs.FUNCTIONS),
*(f.__code__ for f in defs.FUNCTION_LIKE),
])
def test_builtin_type(self):
shareable = [
StopIteration,
]
types = [
*BUILTIN_TYPES,
*OTHER_TYPES,
]
self.assert_not_shareable(cls for cls in types
if cls not in shareable)
self.assert_roundtrip_identical(cls for cls in types
if cls in shareable)
def test_builtin_function(self):
functions = [
len,
sys.is_finalizing,
sys.exit,
_testinternalcapi.get_crossinterp_data,
]
for func in functions:
assert type(func) is types.BuiltinFunctionType, func
self.assert_not_shareable(functions)
def test_builtin_exception(self):
msg = 'error!'
try:
raise Exception
except Exception as exc:
caught = exc
special = {
BaseExceptionGroup: (msg, [caught]),
ExceptionGroup: (msg, [caught]),
# UnicodeError: (None, msg, None, None, None),
UnicodeEncodeError: ('utf-8', '', 1, 3, msg),
UnicodeDecodeError: ('utf-8', b'', 1, 3, msg),
UnicodeTranslateError: ('', 1, 3, msg),
}
exceptions = []
for cls in EXCEPTION_TYPES:
args = special.get(cls) or (msg,)
exceptions.append(cls(*args))
self.assert_not_shareable(exceptions)
# Note that StopIteration (the type) can be marshalled,
# but its instances cannot.
def test_module(self):
assert type(sys) is types.ModuleType, type(sys)
assert type(defs) is types.ModuleType, type(defs)
assert type(unittest) is types.ModuleType, type(defs)
assert 'emptymod' not in sys.modules
with import_helper.ready_to_import('emptymod', ''):
import emptymod
self.assert_not_shareable([
sys,
defs,
unittest,
emptymod,
])
def test_user_class(self):
self.assert_not_shareable(defs.TOP_CLASSES)
instances = []
for cls, args in defs.TOP_CLASSES.items():
instances.append(cls(*args))
self.assert_not_shareable(instances)
def test_user_function(self):
self.assert_not_shareable(defs.TOP_FUNCTIONS)
def test_user_exception(self):
self.assert_not_shareable([
defs.MimimalError('error!'),
defs.RichError('error!', 42),
])
class ShareableTypeTests(_GetXIDataTests): class ShareableTypeTests(_GetXIDataTests):
MODE = 'xidata' MODE = 'xidata'
@ -184,6 +404,7 @@ class ShareableTypeTests(_GetXIDataTests):
def test_function_like(self): def test_function_like(self):
self.assert_not_shareable(defs.FUNCTION_LIKE) self.assert_not_shareable(defs.FUNCTION_LIKE)
self.assert_not_shareable(defs.FUNCTION_LIKE_APPLIED)
def test_builtin_wrapper(self): def test_builtin_wrapper(self):
_wrappers = { _wrappers = {
@ -243,9 +464,7 @@ class ShareableTypeTests(_GetXIDataTests):
def test_builtin_type(self): def test_builtin_type(self):
self.assert_not_shareable([ self.assert_not_shareable([
*BUILTIN_TYPES, *BUILTIN_TYPES,
*(o for n, o in vars(types).items() *OTHER_TYPES,
if (isinstance(o, type) and
n not in ('DynamicClassAttribute', '_GeneratorWrapper'))),
]) ])
def test_exception(self): def test_exception(self):

View file

@ -1730,6 +1730,11 @@ get_crossinterp_data(PyObject *self, PyObject *args, PyObject *kwargs)
goto error; goto error;
} }
} }
else if (strcmp(mode, "marshal") == 0) {
if (_PyMarshal_GetXIData(tstate, obj, xidata) != 0) {
goto error;
}
}
else { else {
PyErr_Format(PyExc_ValueError, "unsupported mode %R", modeobj); PyErr_Format(PyExc_ValueError, "unsupported mode %R", modeobj);
goto error; goto error;

View file

@ -2,6 +2,7 @@
/* API for managing interactions between isolated interpreters */ /* API for managing interactions between isolated interpreters */
#include "Python.h" #include "Python.h"
#include "marshal.h" // PyMarshal_WriteObjectToString()
#include "pycore_ceval.h" // _Py_simple_func #include "pycore_ceval.h" // _Py_simple_func
#include "pycore_crossinterp.h" // _PyXIData_t #include "pycore_crossinterp.h" // _PyXIData_t
#include "pycore_initconfig.h" // _PyStatus_OK() #include "pycore_initconfig.h" // _PyStatus_OK()
@ -286,6 +287,48 @@ _PyObject_GetXIData(PyThreadState *tstate,
} }
/* marshal wrapper */
PyObject *
_PyMarshal_ReadObjectFromXIData(_PyXIData_t *xidata)
{
PyThreadState *tstate = _PyThreadState_GET();
_PyBytes_data_t *shared = (_PyBytes_data_t *)xidata->data;
PyObject *obj = PyMarshal_ReadObjectFromString(shared->bytes, shared->len);
if (obj == NULL) {
PyObject *cause = _PyErr_GetRaisedException(tstate);
assert(cause != NULL);
_set_xid_lookup_failure(
tstate, NULL, "object could not be unmarshalled", cause);
Py_DECREF(cause);
return NULL;
}
return obj;
}
int
_PyMarshal_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata)
{
PyObject *bytes = PyMarshal_WriteObjectToString(obj, Py_MARSHAL_VERSION);
if (bytes == NULL) {
PyObject *cause = _PyErr_GetRaisedException(tstate);
assert(cause != NULL);
_set_xid_lookup_failure(
tstate, NULL, "object could not be marshalled", cause);
Py_DECREF(cause);
return -1;
}
size_t size = sizeof(_PyBytes_data_t);
_PyBytes_data_t *shared = _PyBytes_GetXIDataWrapped(
tstate, bytes, size, _PyMarshal_ReadObjectFromXIData, xidata);
Py_DECREF(bytes);
if (shared == NULL) {
return -1;
}
return 0;
}
/* using cross-interpreter data */ /* using cross-interpreter data */
PyObject * PyObject *