gh-132775: Add _PyPickle_GetXIData() (gh-133107)

There's some extra complexity due to making sure we we get things right when handling functions and classes defined in the __main__ module.  This is also reflected in the tests, including the addition of extra functions in test.support.import_helper.
This commit is contained in:
Eric Snow 2025-04-30 17:34:05 -06:00 committed by GitHub
parent 6c522debc2
commit cb35c11d82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 1056 additions and 55 deletions

View file

@ -3,6 +3,7 @@
#include "Python.h"
#include "marshal.h" // PyMarshal_WriteObjectToString()
#include "osdefs.h" // MAXPATHLEN
#include "pycore_ceval.h" // _Py_simple_func
#include "pycore_crossinterp.h" // _PyXIData_t
#include "pycore_initconfig.h" // _PyStatus_OK()
@ -10,6 +11,155 @@
#include "pycore_typeobject.h" // _PyStaticType_InitBuiltin()
static Py_ssize_t
_Py_GetMainfile(char *buffer, size_t maxlen)
{
// We don't expect subinterpreters to have the __main__ module's
// __name__ set, but proceed just in case.
PyThreadState *tstate = _PyThreadState_GET();
PyObject *module = _Py_GetMainModule(tstate);
if (_Py_CheckMainModule(module) < 0) {
return -1;
}
Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen);
Py_DECREF(module);
return size;
}
static PyObject *
import_get_module(PyThreadState *tstate, const char *modname)
{
PyObject *module = NULL;
if (strcmp(modname, "__main__") == 0) {
module = _Py_GetMainModule(tstate);
if (_Py_CheckMainModule(module) < 0) {
assert(_PyErr_Occurred(tstate));
return NULL;
}
}
else {
module = PyImport_ImportModule(modname);
if (module == NULL) {
return NULL;
}
}
return module;
}
static PyObject *
runpy_run_path(const char *filename, const char *modname)
{
PyObject *run_path = PyImport_ImportModuleAttrString("runpy", "run_path");
if (run_path == NULL) {
return NULL;
}
PyObject *args = Py_BuildValue("(sOs)", filename, Py_None, modname);
if (args == NULL) {
Py_DECREF(run_path);
return NULL;
}
PyObject *ns = PyObject_Call(run_path, args, NULL);
Py_DECREF(run_path);
Py_DECREF(args);
return ns;
}
static PyObject *
pyerr_get_message(PyObject *exc)
{
assert(!PyErr_Occurred());
PyObject *args = PyException_GetArgs(exc);
if (args == NULL || args == Py_None || PyObject_Size(args) < 1) {
return NULL;
}
if (PyUnicode_Check(args)) {
return args;
}
PyObject *msg = PySequence_GetItem(args, 0);
Py_DECREF(args);
if (msg == NULL) {
PyErr_Clear();
return NULL;
}
if (!PyUnicode_Check(msg)) {
Py_DECREF(msg);
return NULL;
}
return msg;
}
#define MAX_MODNAME (255)
#define MAX_ATTRNAME (255)
struct attributeerror_info {
char modname[MAX_MODNAME+1];
char attrname[MAX_ATTRNAME+1];
};
static int
_parse_attributeerror(PyObject *exc, struct attributeerror_info *info)
{
assert(exc != NULL);
assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
int res = -1;
PyObject *msgobj = pyerr_get_message(exc);
if (msgobj == NULL) {
return -1;
}
const char *err = PyUnicode_AsUTF8(msgobj);
if (strncmp(err, "module '", 8) != 0) {
goto finally;
}
err += 8;
const char *matched = strchr(err, '\'');
if (matched == NULL) {
goto finally;
}
Py_ssize_t len = matched - err;
if (len > MAX_MODNAME) {
goto finally;
}
(void)strncpy(info->modname, err, len);
info->modname[len] = '\0';
err = matched;
if (strncmp(err, "' has no attribute '", 20) != 0) {
goto finally;
}
err += 20;
matched = strchr(err, '\'');
if (matched == NULL) {
goto finally;
}
len = matched - err;
if (len > MAX_ATTRNAME) {
goto finally;
}
(void)strncpy(info->attrname, err, len);
info->attrname[len] = '\0';
err = matched + 1;
if (strlen(err) > 0) {
goto finally;
}
res = 0;
finally:
Py_DECREF(msgobj);
return res;
}
#undef MAX_MODNAME
#undef MAX_ATTRNAME
/**************/
/* exceptions */
/**************/
@ -287,6 +437,308 @@ _PyObject_GetXIData(PyThreadState *tstate,
}
/* pickle C-API */
struct _pickle_context {
PyThreadState *tstate;
};
static PyObject *
_PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj)
{
PyObject *dumps = PyImport_ImportModuleAttrString("pickle", "dumps");
if (dumps == NULL) {
return NULL;
}
PyObject *bytes = PyObject_CallOneArg(dumps, obj);
Py_DECREF(dumps);
return bytes;
}
struct sync_module_result {
PyObject *module;
PyObject *loaded;
PyObject *failed;
};
struct sync_module {
const char *filename;
char _filename[MAXPATHLEN+1];
struct sync_module_result cached;
};
static void
sync_module_clear(struct sync_module *data)
{
data->filename = NULL;
Py_CLEAR(data->cached.module);
Py_CLEAR(data->cached.loaded);
Py_CLEAR(data->cached.failed);
}
struct _unpickle_context {
PyThreadState *tstate;
// We only special-case the __main__ module,
// since other modules behave consistently.
struct sync_module main;
};
static void
_unpickle_context_clear(struct _unpickle_context *ctx)
{
sync_module_clear(&ctx->main);
}
static struct sync_module_result
_unpickle_context_get_module(struct _unpickle_context *ctx,
const char *modname)
{
if (strcmp(modname, "__main__") == 0) {
return ctx->main.cached;
}
else {
return (struct sync_module_result){
.failed = PyExc_NotImplementedError,
};
}
}
static struct sync_module_result
_unpickle_context_set_module(struct _unpickle_context *ctx,
const char *modname)
{
struct sync_module_result res = {0};
struct sync_module_result *cached = NULL;
const char *filename = NULL;
if (strcmp(modname, "__main__") == 0) {
cached = &ctx->main.cached;
filename = ctx->main.filename;
}
else {
res.failed = PyExc_NotImplementedError;
goto finally;
}
res.module = import_get_module(ctx->tstate, modname);
if (res.module == NULL) {
res.failed = _PyErr_GetRaisedException(ctx->tstate);
assert(res.failed != NULL);
goto finally;
}
if (filename == NULL) {
Py_CLEAR(res.module);
res.failed = PyExc_NotImplementedError;
goto finally;
}
res.loaded = runpy_run_path(filename, modname);
if (res.loaded == NULL) {
Py_CLEAR(res.module);
res.failed = _PyErr_GetRaisedException(ctx->tstate);
assert(res.failed != NULL);
goto finally;
}
finally:
if (cached != NULL) {
assert(cached->module == NULL);
assert(cached->loaded == NULL);
assert(cached->failed == NULL);
*cached = res;
}
return res;
}
static int
_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc)
{
// The caller must check if an exception is set or not when -1 is returned.
assert(!_PyErr_Occurred(ctx->tstate));
assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
struct attributeerror_info info;
if (_parse_attributeerror(exc, &info) < 0) {
return -1;
}
// Get the module.
struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname);
if (mod.failed != NULL) {
// It must have failed previously.
return -1;
}
if (mod.module == NULL) {
mod = _unpickle_context_set_module(ctx, info.modname);
if (mod.failed != NULL) {
return -1;
}
assert(mod.module != NULL);
}
// Bail out if it is unexpectedly set already.
if (PyObject_HasAttrString(mod.module, info.attrname)) {
return -1;
}
// Try setting the attribute.
PyObject *value = NULL;
if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) {
return -1;
}
assert(value != NULL);
int res = PyObject_SetAttrString(mod.module, info.attrname, value);
Py_DECREF(value);
if (res < 0) {
return -1;
}
return 0;
}
static PyObject *
_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled)
{
PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads");
if (loads == NULL) {
return NULL;
}
PyObject *obj = PyObject_CallOneArg(loads, pickled);
if (ctx != NULL) {
while (obj == NULL) {
assert(_PyErr_Occurred(ctx->tstate));
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
// We leave other failures unhandled.
break;
}
// Try setting the attr if not set.
PyObject *exc = _PyErr_GetRaisedException(ctx->tstate);
if (_handle_unpickle_missing_attr(ctx, exc) < 0) {
// Any resulting exceptions are ignored
// in favor of the original.
_PyErr_SetRaisedException(ctx->tstate, exc);
break;
}
Py_CLEAR(exc);
// Retry with the attribute set.
obj = PyObject_CallOneArg(loads, pickled);
}
}
Py_DECREF(loads);
return obj;
}
/* pickle wrapper */
struct _pickle_xid_context {
// __main__.__file__
struct {
const char *utf8;
size_t len;
char _utf8[MAXPATHLEN+1];
} mainfile;
};
static int
_set_pickle_xid_context(PyThreadState *tstate, struct _pickle_xid_context *ctx)
{
// Set mainfile if possible.
Py_ssize_t len = _Py_GetMainfile(ctx->mainfile._utf8, MAXPATHLEN);
if (len < 0) {
// For now we ignore any exceptions.
PyErr_Clear();
}
else if (len > 0) {
ctx->mainfile.utf8 = ctx->mainfile._utf8;
ctx->mainfile.len = (size_t)len;
}
return 0;
}
struct _shared_pickle_data {
_PyBytes_data_t pickled; // Must be first if we use _PyBytes_FromXIData().
struct _pickle_xid_context ctx;
};
PyObject *
_PyPickle_LoadFromXIData(_PyXIData_t *xidata)
{
PyThreadState *tstate = _PyThreadState_GET();
struct _shared_pickle_data *shared =
(struct _shared_pickle_data *)xidata->data;
// We avoid copying the pickled data by wrapping it in a memoryview.
// The alternative is to get a bytes object using _PyBytes_FromXIData().
PyObject *pickled = PyMemoryView_FromMemory(
(char *)shared->pickled.bytes, shared->pickled.len, PyBUF_READ);
if (pickled == NULL) {
return NULL;
}
// Unpickle the object.
struct _unpickle_context ctx = {
.tstate = tstate,
.main = {
.filename = shared->ctx.mainfile.utf8,
},
};
PyObject *obj = _PyPickle_Loads(&ctx, pickled);
Py_DECREF(pickled);
_unpickle_context_clear(&ctx);
if (obj == NULL) {
PyObject *cause = _PyErr_GetRaisedException(tstate);
assert(cause != NULL);
_set_xid_lookup_failure(
tstate, NULL, "object could not be unpickled", cause);
Py_DECREF(cause);
}
return obj;
}
int
_PyPickle_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata)
{
// Pickle the object.
struct _pickle_context ctx = {
.tstate = tstate,
};
PyObject *bytes = _PyPickle_Dumps(&ctx, obj);
if (bytes == NULL) {
PyObject *cause = _PyErr_GetRaisedException(tstate);
assert(cause != NULL);
_set_xid_lookup_failure(
tstate, NULL, "object could not be pickled", cause);
Py_DECREF(cause);
return -1;
}
// If we had an "unwrapper" mechnanism, we could call
// _PyObject_GetXIData() on the bytes object directly and add
// a simple unwrapper to call pickle.loads() on the bytes.
size_t size = sizeof(struct _shared_pickle_data);
struct _shared_pickle_data *shared =
(struct _shared_pickle_data *)_PyBytes_GetXIDataWrapped(
tstate, bytes, size, _PyPickle_LoadFromXIData, xidata);
Py_DECREF(bytes);
if (shared == NULL) {
return -1;
}
// If it mattered, we could skip getting __main__.__file__
// when "__main__" doesn't show up in the pickle bytes.
if (_set_pickle_xid_context(tstate, &shared->ctx) < 0) {
_xidata_clear(xidata);
return -1;
}
return 0;
}
/* marshal wrapper */
PyObject *