gh-132775: Fix Interpreter.call() __main__ Visibility (gh-135595)

As noted in the new tests, there are a few situations we must carefully accommodate
for functions that get pickled during interp.call().  We do so by running the script
from the main interpreter's __main__ module in a hidden module in the other
interpreter.  That hidden module is used as the function __globals__.
This commit is contained in:
Eric Snow 2025-06-17 13:16:59 -06:00 committed by GitHub
parent fba5dded6d
commit 269e19e0a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 419 additions and 246 deletions

View file

@ -7,9 +7,12 @@
#include "pycore_ceval.h" // _Py_simple_func
#include "pycore_crossinterp.h" // _PyXIData_t
#include "pycore_function.h" // _PyFunction_VerifyStateless()
#include "pycore_global_strings.h" // _Py_ID()
#include "pycore_import.h" // _PyImport_SetModule()
#include "pycore_initconfig.h" // _PyStatus_OK()
#include "pycore_namespace.h" // _PyNamespace_New()
#include "pycore_pythonrun.h" // _Py_SourceAsString()
#include "pycore_runtime.h" // _PyRuntime
#include "pycore_setobject.h" // _PySet_NextEntry()
#include "pycore_typeobject.h" // _PyStaticType_InitBuiltin()
@ -22,6 +25,7 @@ _Py_GetMainfile(char *buffer, size_t maxlen)
PyThreadState *tstate = _PyThreadState_GET();
PyObject *module = _Py_GetMainModule(tstate);
if (_Py_CheckMainModule(module) < 0) {
Py_XDECREF(module);
return -1;
}
Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen);
@ -30,27 +34,6 @@ _Py_GetMainfile(char *buffer, size_t maxlen)
}
static PyObject *
import_get_module(PyThreadState *tstate, const char *modname)
{
PyObject *module = NULL;
if (strcmp(modname, "__main__") == 0) {
module = _Py_GetMainModule(tstate);
if (_Py_CheckMainModule(module) < 0) {
assert(_PyErr_Occurred(tstate));
return NULL;
}
}
else {
module = PyImport_ImportModule(modname);
if (module == NULL) {
return NULL;
}
}
return module;
}
static PyObject *
runpy_run_path(const char *filename, const char *modname)
{
@ -81,97 +64,181 @@ set_exc_with_cause(PyObject *exctype, const char *msg)
}
static PyObject *
pyerr_get_message(PyObject *exc)
{
assert(!PyErr_Occurred());
PyObject *args = PyException_GetArgs(exc);
if (args == NULL || args == Py_None || PyObject_Size(args) < 1) {
return NULL;
}
if (PyUnicode_Check(args)) {
return args;
}
PyObject *msg = PySequence_GetItem(args, 0);
Py_DECREF(args);
if (msg == NULL) {
PyErr_Clear();
return NULL;
}
if (!PyUnicode_Check(msg)) {
Py_DECREF(msg);
return NULL;
}
return msg;
}
/****************************/
/* module duplication utils */
/****************************/
#define MAX_MODNAME (255)
#define MAX_ATTRNAME (255)
struct attributeerror_info {
char modname[MAX_MODNAME+1];
char attrname[MAX_ATTRNAME+1];
struct sync_module_result {
PyObject *module;
PyObject *loaded;
PyObject *failed;
};
static int
_parse_attributeerror(PyObject *exc, struct attributeerror_info *info)
struct sync_module {
const char *filename;
char _filename[MAXPATHLEN+1];
struct sync_module_result cached;
};
static void
sync_module_clear(struct sync_module *data)
{
assert(exc != NULL);
assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
int res = -1;
PyObject *msgobj = pyerr_get_message(exc);
if (msgobj == NULL) {
return -1;
}
const char *err = PyUnicode_AsUTF8(msgobj);
if (strncmp(err, "module '", 8) != 0) {
goto finally;
}
err += 8;
const char *matched = strchr(err, '\'');
if (matched == NULL) {
goto finally;
}
Py_ssize_t len = matched - err;
if (len > MAX_MODNAME) {
goto finally;
}
(void)strncpy(info->modname, err, len);
info->modname[len] = '\0';
err = matched;
if (strncmp(err, "' has no attribute '", 20) != 0) {
goto finally;
}
err += 20;
matched = strchr(err, '\'');
if (matched == NULL) {
goto finally;
}
len = matched - err;
if (len > MAX_ATTRNAME) {
goto finally;
}
(void)strncpy(info->attrname, err, len);
info->attrname[len] = '\0';
err = matched + 1;
if (strlen(err) > 0) {
goto finally;
}
res = 0;
finally:
Py_DECREF(msgobj);
return res;
data->filename = NULL;
Py_CLEAR(data->cached.module);
Py_CLEAR(data->cached.loaded);
Py_CLEAR(data->cached.failed);
}
#undef MAX_MODNAME
#undef MAX_ATTRNAME
static void
sync_module_capture_exc(PyThreadState *tstate, struct sync_module *data)
{
assert(_PyErr_Occurred(tstate));
PyObject *context = data->cached.failed;
PyObject *exc = _PyErr_GetRaisedException(tstate);
_PyErr_SetRaisedException(tstate, Py_NewRef(exc));
if (context != NULL) {
PyException_SetContext(exc, context);
}
data->cached.failed = exc;
}
static int
ensure_isolated_main(PyThreadState *tstate, struct sync_module *main)
{
// Load the module from the original file (or from a cache).
// First try the local cache.
if (main->cached.failed != NULL) {
// We'll deal with this in apply_isolated_main().
assert(main->cached.module == NULL);
assert(main->cached.loaded == NULL);
return 0;
}
else if (main->cached.loaded != NULL) {
assert(main->cached.module != NULL);
return 0;
}
assert(main->cached.module == NULL);
if (main->filename == NULL) {
_PyErr_SetString(tstate, PyExc_NotImplementedError, "");
return -1;
}
// It wasn't in the local cache so we'll need to populate it.
PyObject *mod = _Py_GetMainModule(tstate);
if (_Py_CheckMainModule(mod) < 0) {
// This is probably unrecoverable, so don't bother caching the error.
assert(_PyErr_Occurred(tstate));
Py_XDECREF(mod);
return -1;
}
PyObject *loaded = NULL;
// Try the per-interpreter cache for the loaded module.
// XXX Store it in sys.modules?
PyObject *interpns = PyInterpreterState_GetDict(tstate->interp);
assert(interpns != NULL);
PyObject *key = PyUnicode_FromString("CACHED_MODULE_NS___main__");
if (key == NULL) {
// It's probably unrecoverable, so don't bother caching the error.
Py_DECREF(mod);
return -1;
}
else if (PyDict_GetItemRef(interpns, key, &loaded) < 0) {
// It's probably unrecoverable, so don't bother caching the error.
Py_DECREF(mod);
Py_DECREF(key);
return -1;
}
else if (loaded == NULL) {
// It wasn't already loaded from file.
loaded = PyModule_NewObject(&_Py_ID(__main__));
if (loaded == NULL) {
goto error;
}
PyObject *ns = _PyModule_GetDict(loaded);
// We don't want to trigger "if __name__ == '__main__':",
// so we use a bogus module name.
PyObject *loaded_ns =
runpy_run_path(main->filename, "<fake __main__>");
if (loaded_ns == NULL) {
goto error;
}
int res = PyDict_Update(ns, loaded_ns);
Py_DECREF(loaded_ns);
if (res < 0) {
goto error;
}
// Set the per-interpreter cache entry.
if (PyDict_SetItem(interpns, key, loaded) < 0) {
goto error;
}
}
Py_DECREF(key);
main->cached = (struct sync_module_result){
.module = mod,
.loaded = loaded,
};
return 0;
error:
sync_module_capture_exc(tstate, main);
Py_XDECREF(loaded);
Py_DECREF(mod);
Py_XDECREF(key);
return -1;
}
#ifndef NDEBUG
static int
main_mod_matches(PyObject *expected)
{
PyObject *mod = PyImport_GetModule(&_Py_ID(__main__));
Py_XDECREF(mod);
return mod == expected;
}
#endif
static int
apply_isolated_main(PyThreadState *tstate, struct sync_module *main)
{
assert((main->cached.loaded == NULL) == (main->cached.loaded == NULL));
if (main->cached.failed != NULL) {
// It must have failed previously.
assert(main->cached.loaded == NULL);
_PyErr_SetRaisedException(tstate, main->cached.failed);
return -1;
}
assert(main->cached.loaded != NULL);
assert(main_mod_matches(main->cached.module));
if (_PyImport_SetModule(&_Py_ID(__main__), main->cached.loaded) < 0) {
sync_module_capture_exc(tstate, main);
return -1;
}
return 0;
}
static void
restore_main(PyThreadState *tstate, struct sync_module *main)
{
assert(main->cached.failed == NULL);
assert(main->cached.module != NULL);
assert(main->cached.loaded != NULL);
PyObject *exc = _PyErr_GetRaisedException(tstate);
assert(main_mod_matches(main->cached.loaded));
int res = _PyImport_SetModule(&_Py_ID(__main__), main->cached.module);
assert(res == 0);
if (res < 0) {
PyErr_FormatUnraisable("Exception ignored while restoring __main__");
}
_PyErr_SetRaisedException(tstate, exc);
}
/**************/
@ -518,28 +585,6 @@ _PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj)
}
struct sync_module_result {
PyObject *module;
PyObject *loaded;
PyObject *failed;
};
struct sync_module {
const char *filename;
char _filename[MAXPATHLEN+1];
struct sync_module_result cached;
};
static void
sync_module_clear(struct sync_module *data)
{
data->filename = NULL;
Py_CLEAR(data->cached.module);
Py_CLEAR(data->cached.loaded);
Py_CLEAR(data->cached.failed);
}
struct _unpickle_context {
PyThreadState *tstate;
// We only special-case the __main__ module,
@ -553,142 +598,86 @@ _unpickle_context_clear(struct _unpickle_context *ctx)
sync_module_clear(&ctx->main);
}
static struct sync_module_result
_unpickle_context_get_module(struct _unpickle_context *ctx,
const char *modname)
{
if (strcmp(modname, "__main__") == 0) {
return ctx->main.cached;
}
else {
return (struct sync_module_result){
.failed = PyExc_NotImplementedError,
};
}
}
static struct sync_module_result
_unpickle_context_set_module(struct _unpickle_context *ctx,
const char *modname)
{
struct sync_module_result res = {0};
struct sync_module_result *cached = NULL;
const char *filename = NULL;
const char *run_modname = modname;
if (strcmp(modname, "__main__") == 0) {
cached = &ctx->main.cached;
filename = ctx->main.filename;
// We don't want to trigger "if __name__ == '__main__':".
run_modname = "<fake __main__>";
}
else {
res.failed = PyExc_NotImplementedError;
goto finally;
}
res.module = import_get_module(ctx->tstate, modname);
if (res.module == NULL) {
res.failed = _PyErr_GetRaisedException(ctx->tstate);
assert(res.failed != NULL);
goto finally;
}
if (filename == NULL) {
Py_CLEAR(res.module);
res.failed = PyExc_NotImplementedError;
goto finally;
}
res.loaded = runpy_run_path(filename, run_modname);
if (res.loaded == NULL) {
Py_CLEAR(res.module);
res.failed = _PyErr_GetRaisedException(ctx->tstate);
assert(res.failed != NULL);
goto finally;
}
finally:
if (cached != NULL) {
assert(cached->module == NULL);
assert(cached->loaded == NULL);
assert(cached->failed == NULL);
*cached = res;
}
return res;
}
static int
_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc)
check_missing___main___attr(PyObject *exc)
{
// The caller must check if an exception is set or not when -1 is returned.
assert(!_PyErr_Occurred(ctx->tstate));
assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
struct attributeerror_info info;
if (_parse_attributeerror(exc, &info) < 0) {
return -1;
assert(!PyErr_Occurred());
if (!PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)) {
return 0;
}
// Get the module.
struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname);
if (mod.failed != NULL) {
// It must have failed previously.
return -1;
// Get the error message.
PyObject *args = PyException_GetArgs(exc);
if (args == NULL || args == Py_None || PyObject_Size(args) < 1) {
assert(!PyErr_Occurred());
return 0;
}
if (mod.module == NULL) {
mod = _unpickle_context_set_module(ctx, info.modname);
if (mod.failed != NULL) {
return -1;
PyObject *msgobj = args;
if (!PyUnicode_Check(msgobj)) {
msgobj = PySequence_GetItem(args, 0);
Py_DECREF(args);
if (msgobj == NULL) {
PyErr_Clear();
return 0;
}
assert(mod.module != NULL);
}
const char *err = PyUnicode_AsUTF8(msgobj);
// Bail out if it is unexpectedly set already.
if (PyObject_HasAttrString(mod.module, info.attrname)) {
return -1;
}
// Try setting the attribute.
PyObject *value = NULL;
if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) {
return -1;
}
assert(value != NULL);
int res = PyObject_SetAttrString(mod.module, info.attrname, value);
Py_DECREF(value);
if (res < 0) {
return -1;
}
return 0;
// Check if it's a missing __main__ attr.
int cmp = strncmp(err, "module '__main__' has no attribute '", 36);
Py_DECREF(msgobj);
return cmp == 0;
}
static PyObject *
_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled)
{
PyThreadState *tstate = ctx->tstate;
PyObject *exc = NULL;
PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads");
if (loads == NULL) {
return NULL;
}
// Make an initial attempt to unpickle.
PyObject *obj = PyObject_CallOneArg(loads, pickled);
if (ctx != NULL) {
while (obj == NULL) {
assert(_PyErr_Occurred(ctx->tstate));
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
// We leave other failures unhandled.
break;
}
// Try setting the attr if not set.
PyObject *exc = _PyErr_GetRaisedException(ctx->tstate);
if (_handle_unpickle_missing_attr(ctx, exc) < 0) {
// Any resulting exceptions are ignored
// in favor of the original.
_PyErr_SetRaisedException(ctx->tstate, exc);
break;
}
Py_CLEAR(exc);
// Retry with the attribute set.
obj = PyObject_CallOneArg(loads, pickled);
}
if (obj != NULL) {
goto finally;
}
assert(_PyErr_Occurred(tstate));
if (ctx == NULL) {
goto finally;
}
exc = _PyErr_GetRaisedException(tstate);
if (!check_missing___main___attr(exc)) {
goto finally;
}
// Temporarily swap in a fake __main__ loaded from the original
// file and cached. Note that functions will use the cached ns
// for __globals__, // not the actual module.
if (ensure_isolated_main(tstate, &ctx->main) < 0) {
goto finally;
}
if (apply_isolated_main(tstate, &ctx->main) < 0) {
goto finally;
}
// Try to unpickle once more.
obj = PyObject_CallOneArg(loads, pickled);
restore_main(tstate, &ctx->main);
if (obj == NULL) {
goto finally;
}
Py_CLEAR(exc);
finally:
if (exc != NULL) {
sync_module_capture_exc(tstate, &ctx->main);
// We restore the original exception.
// It might make sense to chain it (__context__).
_PyErr_SetRaisedException(tstate, exc);
}
Py_DECREF(loads);
return obj;
@ -2889,6 +2878,7 @@ _ensure_main_ns(_PyXI_session *session, _PyXI_failure *failure)
// Cache __main__.__dict__.
PyObject *main_mod = _Py_GetMainModule(tstate);
if (_Py_CheckMainModule(main_mod) < 0) {
Py_XDECREF(main_mod);
if (failure != NULL) {
*failure = (_PyXI_failure){
.code = _PyXI_ERR_MAIN_NS_FAILURE,

View file

@ -1142,6 +1142,7 @@ _Py_CheckMainModule(PyObject *module)
PyObject *msg = PyUnicode_FromString("invalid __main__ module");
if (msg != NULL) {
(void)PyErr_SetImportError(msg, &_Py_ID(__main__), NULL);
Py_DECREF(msg);
}
return -1;
}