mirror of
https://github.com/python/cpython.git
synced 2025-07-07 11:25:30 +00:00
gh-132775: Fix Interpreter.call() __main__ Visibility (gh-135595)
As noted in the new tests, there are a few situations we must carefully accommodate for functions that get pickled during interp.call(). We do so by running the script from the main interpreter's __main__ module in a hidden module in the other interpreter. That hidden module is used as the function __globals__.
This commit is contained in:
parent
fba5dded6d
commit
269e19e0a7
4 changed files with 419 additions and 246 deletions
|
@ -1356,6 +1356,187 @@ class TestInterpreterCall(TestBase):
|
|||
with self.assertRaises(interpreters.NotShareableError):
|
||||
interp.call(defs.spam_returns_arg, arg)
|
||||
|
||||
def test_func_in___main___hidden(self):
|
||||
# When a top-level function that uses global variables is called
|
||||
# through Interpreter.call(), it will be pickled, sent over,
|
||||
# and unpickled. That requires that it be found in the other
|
||||
# interpreter's __main__ module. However, the original script
|
||||
# that defined the function is only run in the main interpreter,
|
||||
# so pickle.loads() would normally fail.
|
||||
#
|
||||
# We work around this by running the script in the other
|
||||
# interpreter. However, this is a one-off solution for the sake
|
||||
# of unpickling, so we avoid modifying that interpreter's
|
||||
# __main__ module by running the script in a hidden module.
|
||||
#
|
||||
# In this test we verify that the function runs with the hidden
|
||||
# module as its __globals__ when called in the other interpreter,
|
||||
# and that the interpreter's __main__ module is unaffected.
|
||||
text = dedent("""
|
||||
eggs = True
|
||||
|
||||
def spam(*, explicit=False):
|
||||
if explicit:
|
||||
import __main__
|
||||
ns = __main__.__dict__
|
||||
else:
|
||||
# For now we have to have a LOAD_GLOBAL in the
|
||||
# function in order for globals() to actually return
|
||||
# spam.__globals__. Maybe it doesn't go through pickle?
|
||||
# XXX We will fix this later.
|
||||
spam
|
||||
ns = globals()
|
||||
|
||||
func = ns.get('spam')
|
||||
return [
|
||||
id(ns),
|
||||
ns.get('__name__'),
|
||||
ns.get('__file__'),
|
||||
id(func),
|
||||
None if func is None else repr(func),
|
||||
ns.get('eggs'),
|
||||
ns.get('ham'),
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
from concurrent import interpreters
|
||||
interp = interpreters.create()
|
||||
|
||||
ham = True
|
||||
print([
|
||||
[
|
||||
spam(explicit=True),
|
||||
spam(),
|
||||
],
|
||||
[
|
||||
interp.call(spam, explicit=True),
|
||||
interp.call(spam),
|
||||
],
|
||||
])
|
||||
""")
|
||||
with os_helper.temp_dir() as tempdir:
|
||||
filename = script_helper.make_script(tempdir, 'my-script', text)
|
||||
res = script_helper.assert_python_ok(filename)
|
||||
stdout = res.out.decode('utf-8').strip()
|
||||
local, remote = eval(stdout)
|
||||
|
||||
# In the main interpreter.
|
||||
main, unpickled = local
|
||||
nsid, _, _, funcid, func, _, _ = main
|
||||
self.assertEqual(main, [
|
||||
nsid,
|
||||
'__main__',
|
||||
filename,
|
||||
funcid,
|
||||
func,
|
||||
True,
|
||||
True,
|
||||
])
|
||||
self.assertIsNot(func, None)
|
||||
self.assertRegex(func, '^<function spam at 0x.*>$')
|
||||
self.assertEqual(unpickled, main)
|
||||
|
||||
# In the subinterpreter.
|
||||
main, unpickled = remote
|
||||
nsid1, _, _, funcid1, _, _, _ = main
|
||||
self.assertEqual(main, [
|
||||
nsid1,
|
||||
'__main__',
|
||||
None,
|
||||
funcid1,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
])
|
||||
nsid2, _, _, funcid2, func, _, _ = unpickled
|
||||
self.assertEqual(unpickled, [
|
||||
nsid2,
|
||||
'<fake __main__>',
|
||||
filename,
|
||||
funcid2,
|
||||
func,
|
||||
True,
|
||||
None,
|
||||
])
|
||||
self.assertIsNot(func, None)
|
||||
self.assertRegex(func, '^<function spam at 0x.*>$')
|
||||
self.assertNotEqual(nsid2, nsid1)
|
||||
self.assertNotEqual(funcid2, funcid1)
|
||||
|
||||
def test_func_in___main___uses_globals(self):
|
||||
# See the note in test_func_in___main___hidden about pickle
|
||||
# and the __main__ module.
|
||||
#
|
||||
# Additionally, the solution to that problem must provide
|
||||
# for global variables on which a pickled function might rely.
|
||||
#
|
||||
# To check that, we run a script that has two global functions
|
||||
# and a global variable in the __main__ module. One of the
|
||||
# functions sets the global variable and the other returns
|
||||
# the value.
|
||||
#
|
||||
# The script calls those functions multiple times in another
|
||||
# interpreter, to verify the following:
|
||||
#
|
||||
# * the global variable is properly initialized
|
||||
# * the global variable retains state between calls
|
||||
# * the setter modifies that persistent variable
|
||||
# * the getter uses the variable
|
||||
# * the calls in the other interpreter do not modify
|
||||
# the main interpreter
|
||||
# * those calls don't modify the interpreter's __main__ module
|
||||
# * the functions and variable do not actually show up in the
|
||||
# other interpreter's __main__ module
|
||||
text = dedent("""
|
||||
count = 0
|
||||
|
||||
def inc(x=1):
|
||||
global count
|
||||
count += x
|
||||
|
||||
def get_count():
|
||||
return count
|
||||
|
||||
if __name__ == "__main__":
|
||||
counts = []
|
||||
results = [count, counts]
|
||||
|
||||
from concurrent import interpreters
|
||||
interp = interpreters.create()
|
||||
|
||||
val = interp.call(get_count)
|
||||
counts.append(val)
|
||||
|
||||
interp.call(inc)
|
||||
val = interp.call(get_count)
|
||||
counts.append(val)
|
||||
|
||||
interp.call(inc, 3)
|
||||
val = interp.call(get_count)
|
||||
counts.append(val)
|
||||
|
||||
results.append(count)
|
||||
|
||||
modified = {name: interp.call(eval, f'{name!r} in vars()')
|
||||
for name in ('count', 'inc', 'get_count')}
|
||||
results.append(modified)
|
||||
|
||||
print(results)
|
||||
""")
|
||||
with os_helper.temp_dir() as tempdir:
|
||||
filename = script_helper.make_script(tempdir, 'my-script', text)
|
||||
res = script_helper.assert_python_ok(filename)
|
||||
stdout = res.out.decode('utf-8').strip()
|
||||
before, counts, after, modified = eval(stdout)
|
||||
self.assertEqual(modified, {
|
||||
'count': False,
|
||||
'inc': False,
|
||||
'get_count': False,
|
||||
})
|
||||
self.assertEqual(before, 0)
|
||||
self.assertEqual(after, 0)
|
||||
self.assertEqual(counts, [0, 1, 4])
|
||||
|
||||
def test_raises(self):
|
||||
interp = interpreters.create()
|
||||
with self.assertRaises(ExecutionFailed):
|
||||
|
|
|
@ -601,6 +601,7 @@ _make_call(struct interp_call *call,
|
|||
unwrap_not_shareable(tstate, failure);
|
||||
return -1;
|
||||
}
|
||||
assert(!_PyErr_Occurred(tstate));
|
||||
|
||||
// Make the call.
|
||||
PyObject *resobj = PyObject_Call(func, args, kwargs);
|
||||
|
|
|
@ -7,9 +7,12 @@
|
|||
#include "pycore_ceval.h" // _Py_simple_func
|
||||
#include "pycore_crossinterp.h" // _PyXIData_t
|
||||
#include "pycore_function.h" // _PyFunction_VerifyStateless()
|
||||
#include "pycore_global_strings.h" // _Py_ID()
|
||||
#include "pycore_import.h" // _PyImport_SetModule()
|
||||
#include "pycore_initconfig.h" // _PyStatus_OK()
|
||||
#include "pycore_namespace.h" // _PyNamespace_New()
|
||||
#include "pycore_pythonrun.h" // _Py_SourceAsString()
|
||||
#include "pycore_runtime.h" // _PyRuntime
|
||||
#include "pycore_setobject.h" // _PySet_NextEntry()
|
||||
#include "pycore_typeobject.h" // _PyStaticType_InitBuiltin()
|
||||
|
||||
|
@ -22,6 +25,7 @@ _Py_GetMainfile(char *buffer, size_t maxlen)
|
|||
PyThreadState *tstate = _PyThreadState_GET();
|
||||
PyObject *module = _Py_GetMainModule(tstate);
|
||||
if (_Py_CheckMainModule(module) < 0) {
|
||||
Py_XDECREF(module);
|
||||
return -1;
|
||||
}
|
||||
Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen);
|
||||
|
@ -30,27 +34,6 @@ _Py_GetMainfile(char *buffer, size_t maxlen)
|
|||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
import_get_module(PyThreadState *tstate, const char *modname)
|
||||
{
|
||||
PyObject *module = NULL;
|
||||
if (strcmp(modname, "__main__") == 0) {
|
||||
module = _Py_GetMainModule(tstate);
|
||||
if (_Py_CheckMainModule(module) < 0) {
|
||||
assert(_PyErr_Occurred(tstate));
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
module = PyImport_ImportModule(modname);
|
||||
if (module == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
runpy_run_path(const char *filename, const char *modname)
|
||||
{
|
||||
|
@ -81,97 +64,181 @@ set_exc_with_cause(PyObject *exctype, const char *msg)
|
|||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
pyerr_get_message(PyObject *exc)
|
||||
{
|
||||
assert(!PyErr_Occurred());
|
||||
PyObject *args = PyException_GetArgs(exc);
|
||||
if (args == NULL || args == Py_None || PyObject_Size(args) < 1) {
|
||||
return NULL;
|
||||
}
|
||||
if (PyUnicode_Check(args)) {
|
||||
return args;
|
||||
}
|
||||
PyObject *msg = PySequence_GetItem(args, 0);
|
||||
Py_DECREF(args);
|
||||
if (msg == NULL) {
|
||||
PyErr_Clear();
|
||||
return NULL;
|
||||
}
|
||||
if (!PyUnicode_Check(msg)) {
|
||||
Py_DECREF(msg);
|
||||
return NULL;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
/****************************/
|
||||
/* module duplication utils */
|
||||
/****************************/
|
||||
|
||||
#define MAX_MODNAME (255)
|
||||
#define MAX_ATTRNAME (255)
|
||||
|
||||
struct attributeerror_info {
|
||||
char modname[MAX_MODNAME+1];
|
||||
char attrname[MAX_ATTRNAME+1];
|
||||
struct sync_module_result {
|
||||
PyObject *module;
|
||||
PyObject *loaded;
|
||||
PyObject *failed;
|
||||
};
|
||||
|
||||
static int
|
||||
_parse_attributeerror(PyObject *exc, struct attributeerror_info *info)
|
||||
struct sync_module {
|
||||
const char *filename;
|
||||
char _filename[MAXPATHLEN+1];
|
||||
struct sync_module_result cached;
|
||||
};
|
||||
|
||||
static void
|
||||
sync_module_clear(struct sync_module *data)
|
||||
{
|
||||
assert(exc != NULL);
|
||||
assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
|
||||
int res = -1;
|
||||
|
||||
PyObject *msgobj = pyerr_get_message(exc);
|
||||
if (msgobj == NULL) {
|
||||
return -1;
|
||||
}
|
||||
const char *err = PyUnicode_AsUTF8(msgobj);
|
||||
|
||||
if (strncmp(err, "module '", 8) != 0) {
|
||||
goto finally;
|
||||
}
|
||||
err += 8;
|
||||
|
||||
const char *matched = strchr(err, '\'');
|
||||
if (matched == NULL) {
|
||||
goto finally;
|
||||
}
|
||||
Py_ssize_t len = matched - err;
|
||||
if (len > MAX_MODNAME) {
|
||||
goto finally;
|
||||
}
|
||||
(void)strncpy(info->modname, err, len);
|
||||
info->modname[len] = '\0';
|
||||
err = matched;
|
||||
|
||||
if (strncmp(err, "' has no attribute '", 20) != 0) {
|
||||
goto finally;
|
||||
}
|
||||
err += 20;
|
||||
|
||||
matched = strchr(err, '\'');
|
||||
if (matched == NULL) {
|
||||
goto finally;
|
||||
}
|
||||
len = matched - err;
|
||||
if (len > MAX_ATTRNAME) {
|
||||
goto finally;
|
||||
}
|
||||
(void)strncpy(info->attrname, err, len);
|
||||
info->attrname[len] = '\0';
|
||||
err = matched + 1;
|
||||
|
||||
if (strlen(err) > 0) {
|
||||
goto finally;
|
||||
}
|
||||
res = 0;
|
||||
|
||||
finally:
|
||||
Py_DECREF(msgobj);
|
||||
return res;
|
||||
data->filename = NULL;
|
||||
Py_CLEAR(data->cached.module);
|
||||
Py_CLEAR(data->cached.loaded);
|
||||
Py_CLEAR(data->cached.failed);
|
||||
}
|
||||
|
||||
#undef MAX_MODNAME
|
||||
#undef MAX_ATTRNAME
|
||||
static void
|
||||
sync_module_capture_exc(PyThreadState *tstate, struct sync_module *data)
|
||||
{
|
||||
assert(_PyErr_Occurred(tstate));
|
||||
PyObject *context = data->cached.failed;
|
||||
PyObject *exc = _PyErr_GetRaisedException(tstate);
|
||||
_PyErr_SetRaisedException(tstate, Py_NewRef(exc));
|
||||
if (context != NULL) {
|
||||
PyException_SetContext(exc, context);
|
||||
}
|
||||
data->cached.failed = exc;
|
||||
}
|
||||
|
||||
|
||||
static int
|
||||
ensure_isolated_main(PyThreadState *tstate, struct sync_module *main)
|
||||
{
|
||||
// Load the module from the original file (or from a cache).
|
||||
|
||||
// First try the local cache.
|
||||
if (main->cached.failed != NULL) {
|
||||
// We'll deal with this in apply_isolated_main().
|
||||
assert(main->cached.module == NULL);
|
||||
assert(main->cached.loaded == NULL);
|
||||
return 0;
|
||||
}
|
||||
else if (main->cached.loaded != NULL) {
|
||||
assert(main->cached.module != NULL);
|
||||
return 0;
|
||||
}
|
||||
assert(main->cached.module == NULL);
|
||||
|
||||
if (main->filename == NULL) {
|
||||
_PyErr_SetString(tstate, PyExc_NotImplementedError, "");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// It wasn't in the local cache so we'll need to populate it.
|
||||
PyObject *mod = _Py_GetMainModule(tstate);
|
||||
if (_Py_CheckMainModule(mod) < 0) {
|
||||
// This is probably unrecoverable, so don't bother caching the error.
|
||||
assert(_PyErr_Occurred(tstate));
|
||||
Py_XDECREF(mod);
|
||||
return -1;
|
||||
}
|
||||
PyObject *loaded = NULL;
|
||||
|
||||
// Try the per-interpreter cache for the loaded module.
|
||||
// XXX Store it in sys.modules?
|
||||
PyObject *interpns = PyInterpreterState_GetDict(tstate->interp);
|
||||
assert(interpns != NULL);
|
||||
PyObject *key = PyUnicode_FromString("CACHED_MODULE_NS___main__");
|
||||
if (key == NULL) {
|
||||
// It's probably unrecoverable, so don't bother caching the error.
|
||||
Py_DECREF(mod);
|
||||
return -1;
|
||||
}
|
||||
else if (PyDict_GetItemRef(interpns, key, &loaded) < 0) {
|
||||
// It's probably unrecoverable, so don't bother caching the error.
|
||||
Py_DECREF(mod);
|
||||
Py_DECREF(key);
|
||||
return -1;
|
||||
}
|
||||
else if (loaded == NULL) {
|
||||
// It wasn't already loaded from file.
|
||||
loaded = PyModule_NewObject(&_Py_ID(__main__));
|
||||
if (loaded == NULL) {
|
||||
goto error;
|
||||
}
|
||||
PyObject *ns = _PyModule_GetDict(loaded);
|
||||
|
||||
// We don't want to trigger "if __name__ == '__main__':",
|
||||
// so we use a bogus module name.
|
||||
PyObject *loaded_ns =
|
||||
runpy_run_path(main->filename, "<fake __main__>");
|
||||
if (loaded_ns == NULL) {
|
||||
goto error;
|
||||
}
|
||||
int res = PyDict_Update(ns, loaded_ns);
|
||||
Py_DECREF(loaded_ns);
|
||||
if (res < 0) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
// Set the per-interpreter cache entry.
|
||||
if (PyDict_SetItem(interpns, key, loaded) < 0) {
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
|
||||
Py_DECREF(key);
|
||||
main->cached = (struct sync_module_result){
|
||||
.module = mod,
|
||||
.loaded = loaded,
|
||||
};
|
||||
return 0;
|
||||
|
||||
error:
|
||||
sync_module_capture_exc(tstate, main);
|
||||
Py_XDECREF(loaded);
|
||||
Py_DECREF(mod);
|
||||
Py_XDECREF(key);
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
static int
|
||||
main_mod_matches(PyObject *expected)
|
||||
{
|
||||
PyObject *mod = PyImport_GetModule(&_Py_ID(__main__));
|
||||
Py_XDECREF(mod);
|
||||
return mod == expected;
|
||||
}
|
||||
#endif
|
||||
|
||||
static int
|
||||
apply_isolated_main(PyThreadState *tstate, struct sync_module *main)
|
||||
{
|
||||
assert((main->cached.loaded == NULL) == (main->cached.loaded == NULL));
|
||||
if (main->cached.failed != NULL) {
|
||||
// It must have failed previously.
|
||||
assert(main->cached.loaded == NULL);
|
||||
_PyErr_SetRaisedException(tstate, main->cached.failed);
|
||||
return -1;
|
||||
}
|
||||
assert(main->cached.loaded != NULL);
|
||||
|
||||
assert(main_mod_matches(main->cached.module));
|
||||
if (_PyImport_SetModule(&_Py_ID(__main__), main->cached.loaded) < 0) {
|
||||
sync_module_capture_exc(tstate, main);
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void
|
||||
restore_main(PyThreadState *tstate, struct sync_module *main)
|
||||
{
|
||||
assert(main->cached.failed == NULL);
|
||||
assert(main->cached.module != NULL);
|
||||
assert(main->cached.loaded != NULL);
|
||||
PyObject *exc = _PyErr_GetRaisedException(tstate);
|
||||
assert(main_mod_matches(main->cached.loaded));
|
||||
int res = _PyImport_SetModule(&_Py_ID(__main__), main->cached.module);
|
||||
assert(res == 0);
|
||||
if (res < 0) {
|
||||
PyErr_FormatUnraisable("Exception ignored while restoring __main__");
|
||||
}
|
||||
_PyErr_SetRaisedException(tstate, exc);
|
||||
}
|
||||
|
||||
|
||||
/**************/
|
||||
|
@ -518,28 +585,6 @@ _PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj)
|
|||
}
|
||||
|
||||
|
||||
struct sync_module_result {
|
||||
PyObject *module;
|
||||
PyObject *loaded;
|
||||
PyObject *failed;
|
||||
};
|
||||
|
||||
struct sync_module {
|
||||
const char *filename;
|
||||
char _filename[MAXPATHLEN+1];
|
||||
struct sync_module_result cached;
|
||||
};
|
||||
|
||||
static void
|
||||
sync_module_clear(struct sync_module *data)
|
||||
{
|
||||
data->filename = NULL;
|
||||
Py_CLEAR(data->cached.module);
|
||||
Py_CLEAR(data->cached.loaded);
|
||||
Py_CLEAR(data->cached.failed);
|
||||
}
|
||||
|
||||
|
||||
struct _unpickle_context {
|
||||
PyThreadState *tstate;
|
||||
// We only special-case the __main__ module,
|
||||
|
@ -553,142 +598,86 @@ _unpickle_context_clear(struct _unpickle_context *ctx)
|
|||
sync_module_clear(&ctx->main);
|
||||
}
|
||||
|
||||
static struct sync_module_result
|
||||
_unpickle_context_get_module(struct _unpickle_context *ctx,
|
||||
const char *modname)
|
||||
{
|
||||
if (strcmp(modname, "__main__") == 0) {
|
||||
return ctx->main.cached;
|
||||
}
|
||||
else {
|
||||
return (struct sync_module_result){
|
||||
.failed = PyExc_NotImplementedError,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
static struct sync_module_result
|
||||
_unpickle_context_set_module(struct _unpickle_context *ctx,
|
||||
const char *modname)
|
||||
{
|
||||
struct sync_module_result res = {0};
|
||||
struct sync_module_result *cached = NULL;
|
||||
const char *filename = NULL;
|
||||
const char *run_modname = modname;
|
||||
if (strcmp(modname, "__main__") == 0) {
|
||||
cached = &ctx->main.cached;
|
||||
filename = ctx->main.filename;
|
||||
// We don't want to trigger "if __name__ == '__main__':".
|
||||
run_modname = "<fake __main__>";
|
||||
}
|
||||
else {
|
||||
res.failed = PyExc_NotImplementedError;
|
||||
goto finally;
|
||||
}
|
||||
|
||||
res.module = import_get_module(ctx->tstate, modname);
|
||||
if (res.module == NULL) {
|
||||
res.failed = _PyErr_GetRaisedException(ctx->tstate);
|
||||
assert(res.failed != NULL);
|
||||
goto finally;
|
||||
}
|
||||
|
||||
if (filename == NULL) {
|
||||
Py_CLEAR(res.module);
|
||||
res.failed = PyExc_NotImplementedError;
|
||||
goto finally;
|
||||
}
|
||||
res.loaded = runpy_run_path(filename, run_modname);
|
||||
if (res.loaded == NULL) {
|
||||
Py_CLEAR(res.module);
|
||||
res.failed = _PyErr_GetRaisedException(ctx->tstate);
|
||||
assert(res.failed != NULL);
|
||||
goto finally;
|
||||
}
|
||||
|
||||
finally:
|
||||
if (cached != NULL) {
|
||||
assert(cached->module == NULL);
|
||||
assert(cached->loaded == NULL);
|
||||
assert(cached->failed == NULL);
|
||||
*cached = res;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
static int
|
||||
_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc)
|
||||
check_missing___main___attr(PyObject *exc)
|
||||
{
|
||||
// The caller must check if an exception is set or not when -1 is returned.
|
||||
assert(!_PyErr_Occurred(ctx->tstate));
|
||||
assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError));
|
||||
struct attributeerror_info info;
|
||||
if (_parse_attributeerror(exc, &info) < 0) {
|
||||
return -1;
|
||||
assert(!PyErr_Occurred());
|
||||
if (!PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Get the module.
|
||||
struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname);
|
||||
if (mod.failed != NULL) {
|
||||
// It must have failed previously.
|
||||
return -1;
|
||||
// Get the error message.
|
||||
PyObject *args = PyException_GetArgs(exc);
|
||||
if (args == NULL || args == Py_None || PyObject_Size(args) < 1) {
|
||||
assert(!PyErr_Occurred());
|
||||
return 0;
|
||||
}
|
||||
if (mod.module == NULL) {
|
||||
mod = _unpickle_context_set_module(ctx, info.modname);
|
||||
if (mod.failed != NULL) {
|
||||
return -1;
|
||||
PyObject *msgobj = args;
|
||||
if (!PyUnicode_Check(msgobj)) {
|
||||
msgobj = PySequence_GetItem(args, 0);
|
||||
Py_DECREF(args);
|
||||
if (msgobj == NULL) {
|
||||
PyErr_Clear();
|
||||
return 0;
|
||||
}
|
||||
assert(mod.module != NULL);
|
||||
}
|
||||
const char *err = PyUnicode_AsUTF8(msgobj);
|
||||
|
||||
// Bail out if it is unexpectedly set already.
|
||||
if (PyObject_HasAttrString(mod.module, info.attrname)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Try setting the attribute.
|
||||
PyObject *value = NULL;
|
||||
if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) {
|
||||
return -1;
|
||||
}
|
||||
assert(value != NULL);
|
||||
int res = PyObject_SetAttrString(mod.module, info.attrname, value);
|
||||
Py_DECREF(value);
|
||||
if (res < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
// Check if it's a missing __main__ attr.
|
||||
int cmp = strncmp(err, "module '__main__' has no attribute '", 36);
|
||||
Py_DECREF(msgobj);
|
||||
return cmp == 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled)
|
||||
{
|
||||
PyThreadState *tstate = ctx->tstate;
|
||||
|
||||
PyObject *exc = NULL;
|
||||
PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads");
|
||||
if (loads == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Make an initial attempt to unpickle.
|
||||
PyObject *obj = PyObject_CallOneArg(loads, pickled);
|
||||
if (ctx != NULL) {
|
||||
while (obj == NULL) {
|
||||
assert(_PyErr_Occurred(ctx->tstate));
|
||||
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
|
||||
// We leave other failures unhandled.
|
||||
break;
|
||||
}
|
||||
// Try setting the attr if not set.
|
||||
PyObject *exc = _PyErr_GetRaisedException(ctx->tstate);
|
||||
if (_handle_unpickle_missing_attr(ctx, exc) < 0) {
|
||||
// Any resulting exceptions are ignored
|
||||
// in favor of the original.
|
||||
_PyErr_SetRaisedException(ctx->tstate, exc);
|
||||
break;
|
||||
}
|
||||
Py_CLEAR(exc);
|
||||
// Retry with the attribute set.
|
||||
obj = PyObject_CallOneArg(loads, pickled);
|
||||
}
|
||||
if (obj != NULL) {
|
||||
goto finally;
|
||||
}
|
||||
assert(_PyErr_Occurred(tstate));
|
||||
if (ctx == NULL) {
|
||||
goto finally;
|
||||
}
|
||||
exc = _PyErr_GetRaisedException(tstate);
|
||||
if (!check_missing___main___attr(exc)) {
|
||||
goto finally;
|
||||
}
|
||||
|
||||
// Temporarily swap in a fake __main__ loaded from the original
|
||||
// file and cached. Note that functions will use the cached ns
|
||||
// for __globals__, // not the actual module.
|
||||
if (ensure_isolated_main(tstate, &ctx->main) < 0) {
|
||||
goto finally;
|
||||
}
|
||||
if (apply_isolated_main(tstate, &ctx->main) < 0) {
|
||||
goto finally;
|
||||
}
|
||||
|
||||
// Try to unpickle once more.
|
||||
obj = PyObject_CallOneArg(loads, pickled);
|
||||
restore_main(tstate, &ctx->main);
|
||||
if (obj == NULL) {
|
||||
goto finally;
|
||||
}
|
||||
Py_CLEAR(exc);
|
||||
|
||||
finally:
|
||||
if (exc != NULL) {
|
||||
sync_module_capture_exc(tstate, &ctx->main);
|
||||
// We restore the original exception.
|
||||
// It might make sense to chain it (__context__).
|
||||
_PyErr_SetRaisedException(tstate, exc);
|
||||
}
|
||||
Py_DECREF(loads);
|
||||
return obj;
|
||||
|
@ -2889,6 +2878,7 @@ _ensure_main_ns(_PyXI_session *session, _PyXI_failure *failure)
|
|||
// Cache __main__.__dict__.
|
||||
PyObject *main_mod = _Py_GetMainModule(tstate);
|
||||
if (_Py_CheckMainModule(main_mod) < 0) {
|
||||
Py_XDECREF(main_mod);
|
||||
if (failure != NULL) {
|
||||
*failure = (_PyXI_failure){
|
||||
.code = _PyXI_ERR_MAIN_NS_FAILURE,
|
||||
|
|
|
@ -1142,6 +1142,7 @@ _Py_CheckMainModule(PyObject *module)
|
|||
PyObject *msg = PyUnicode_FromString("invalid __main__ module");
|
||||
if (msg != NULL) {
|
||||
(void)PyErr_SetImportError(msg, &_Py_ID(__main__), NULL);
|
||||
Py_DECREF(msg);
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue