gh-99741: Implement Multi-Phase Init for the _xxsubinterpreters Module (gh-99742)

_xxsubinterpreters is an internal module used for testing.

https://github.com/python/cpython/issues/99741
This commit is contained in:
Eric Snow 2022-12-05 13:40:20 -07:00 committed by GitHub
parent 51ee0a29e9
commit 530cc9dbb6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 384 additions and 190 deletions

View file

@ -96,18 +96,20 @@ add_new_exception(PyObject *mod, const char *name, PyObject *base)
add_new_exception(MOD, MODULE_NAME "." Py_STRINGIFY(NAME), BASE)
static PyTypeObject *
add_new_type(PyObject *mod, PyTypeObject *cls, crossinterpdatafunc shared)
add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared)
{
if (PyType_Ready(cls) != 0) {
PyTypeObject *cls = (PyTypeObject *)PyType_FromMetaclass(
NULL, mod, spec, NULL);
if (cls == NULL) {
return NULL;
}
if (PyModule_AddType(mod, cls) != 0) {
// XXX When this becomes a heap type, we need to decref here.
if (PyModule_AddType(mod, cls) < 0) {
Py_DECREF(cls);
return NULL;
}
if (shared != NULL) {
if (_PyCrossInterpreterData_RegisterClass(cls, shared)) {
// XXX When this becomes a heap type, we need to decref here.
Py_DECREF(cls);
return NULL;
}
}
@ -135,12 +137,7 @@ _release_xid_data(_PyCrossInterpreterData *data, int ignoreexc)
* shareable types are all very basic, with no GC.
* That said, it becomes much messier once interpreters
* no longer share a GIL, so this needs to be fixed before then. */
// We do what _release_xidata() does in pystate.c.
if (data->free != NULL) {
data->free(data->data);
data->data = NULL;
}
Py_CLEAR(data->obj);
_PyCrossInterpreterData_Clear(NULL, data);
if (ignoreexc) {
// XXX Emit a warning?
PyErr_Clear();
@ -153,6 +150,69 @@ _release_xid_data(_PyCrossInterpreterData *data, int ignoreexc)
}
/* module state *************************************************************/
typedef struct {
PyTypeObject *ChannelIDType;
/* interpreter exceptions */
PyObject *RunFailedError;
/* channel exceptions */
PyObject *ChannelError;
PyObject *ChannelNotFoundError;
PyObject *ChannelClosedError;
PyObject *ChannelEmptyError;
PyObject *ChannelNotEmptyError;
} module_state;
static inline module_state *
get_module_state(PyObject *mod)
{
assert(mod != NULL);
module_state *state = PyModule_GetState(mod);
assert(state != NULL);
return state;
}
static int
traverse_module_state(module_state *state, visitproc visit, void *arg)
{
/* heap types */
Py_VISIT(state->ChannelIDType);
/* interpreter exceptions */
Py_VISIT(state->RunFailedError);
/* channel exceptions */
Py_VISIT(state->ChannelError);
Py_VISIT(state->ChannelNotFoundError);
Py_VISIT(state->ChannelClosedError);
Py_VISIT(state->ChannelEmptyError);
Py_VISIT(state->ChannelNotEmptyError);
return 0;
}
static int
clear_module_state(module_state *state)
{
/* heap types */
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
Py_CLEAR(state->ChannelIDType);
/* interpreter exceptions */
Py_CLEAR(state->RunFailedError);
/* channel exceptions */
Py_CLEAR(state->ChannelError);
Py_CLEAR(state->ChannelNotFoundError);
Py_CLEAR(state->ChannelClosedError);
Py_CLEAR(state->ChannelEmptyError);
Py_CLEAR(state->ChannelNotEmptyError);
return 0;
}
/* data-sharing-specific code ***********************************************/
struct _sharednsitem {
@ -420,82 +480,80 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass)
#define ERR_CHANNELS_MUTEX_INIT -8
#define ERR_NO_NEXT_CHANNEL_ID -9
static PyObject *ChannelError;
static PyObject *ChannelNotFoundError;
static PyObject *ChannelClosedError;
static PyObject *ChannelEmptyError;
static PyObject *ChannelNotEmptyError;
static int
channel_exceptions_init(PyObject *mod)
{
// XXX Move the exceptions into per-module memory?
module_state *state = get_module_state(mod);
if (state == NULL) {
return -1;
}
#define ADD(NAME, BASE) \
do { \
if (NAME == NULL) { \
NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \
if (NAME == NULL) { \
return -1; \
} \
assert(state->NAME == NULL); \
state->NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \
if (state->NAME == NULL) { \
return -1; \
} \
} while (0)
// A channel-related operation failed.
ADD(ChannelError, PyExc_RuntimeError);
// An operation tried to use a channel that doesn't exist.
ADD(ChannelNotFoundError, ChannelError);
ADD(ChannelNotFoundError, state->ChannelError);
// An operation tried to use a closed channel.
ADD(ChannelClosedError, ChannelError);
ADD(ChannelClosedError, state->ChannelError);
// An operation tried to pop from an empty channel.
ADD(ChannelEmptyError, ChannelError);
ADD(ChannelEmptyError, state->ChannelError);
// An operation tried to close a non-empty channel.
ADD(ChannelNotEmptyError, ChannelError);
ADD(ChannelNotEmptyError, state->ChannelError);
#undef ADD
return 0;
}
static int
handle_channel_error(int err, PyObject *Py_UNUSED(mod), int64_t cid)
handle_channel_error(int err, PyObject *mod, int64_t cid)
{
if (err == 0) {
assert(!PyErr_Occurred());
return 0;
}
assert(err < 0);
module_state *state = get_module_state(mod);
assert(state != NULL);
if (err == ERR_CHANNEL_NOT_FOUND) {
PyErr_Format(ChannelNotFoundError,
PyErr_Format(state->ChannelNotFoundError,
"channel %" PRId64 " not found", cid);
}
else if (err == ERR_CHANNEL_CLOSED) {
PyErr_Format(ChannelClosedError,
PyErr_Format(state->ChannelClosedError,
"channel %" PRId64 " is closed", cid);
}
else if (err == ERR_CHANNEL_INTERP_CLOSED) {
PyErr_Format(ChannelClosedError,
PyErr_Format(state->ChannelClosedError,
"channel %" PRId64 " is already closed", cid);
}
else if (err == ERR_CHANNEL_EMPTY) {
PyErr_Format(ChannelEmptyError,
PyErr_Format(state->ChannelEmptyError,
"channel %" PRId64 " is empty", cid);
}
else if (err == ERR_CHANNEL_NOT_EMPTY) {
PyErr_Format(ChannelNotEmptyError,
PyErr_Format(state->ChannelNotEmptyError,
"channel %" PRId64 " may not be closed "
"if not empty (try force=True)",
cid);
}
else if (err == ERR_CHANNEL_MUTEX_INIT) {
PyErr_SetString(ChannelError,
PyErr_SetString(state->ChannelError,
"can't initialize mutex for new channel");
}
else if (err == ERR_CHANNELS_MUTEX_INIT) {
PyErr_SetString(ChannelError,
PyErr_SetString(state->ChannelError,
"can't initialize mutex for channel management");
}
else if (err == ERR_NO_NEXT_CHANNEL_ID) {
PyErr_SetString(ChannelError,
PyErr_SetString(state->ChannelError,
"failed to get a channel ID");
}
else {
@ -1604,8 +1662,6 @@ _channel_is_associated(_channels *channels, int64_t cid, int64_t interp,
/* ChannelID class */
static PyTypeObject ChannelIDType;
typedef struct channelid {
PyObject_HEAD
int64_t id;
@ -1624,7 +1680,9 @@ channel_id_converter(PyObject *arg, void *ptr)
{
int64_t cid;
struct channel_id_converter_data *data = ptr;
if (PyObject_TypeCheck(arg, &ChannelIDType)) {
module_state *state = get_module_state(data->module);
assert(state != NULL);
if (PyObject_TypeCheck(arg, state->ChannelIDType)) {
cid = ((channelid *)arg)->id;
}
else if (PyIndex_Check(arg)) {
@ -1731,11 +1789,20 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
}
static void
channelid_dealloc(PyObject *v)
channelid_dealloc(PyObject *self)
{
int64_t cid = ((channelid *)v)->id;
_channels *channels = ((channelid *)v)->channels;
Py_TYPE(v)->tp_free(v);
int64_t cid = ((channelid *)self)->id;
_channels *channels = ((channelid *)self)->channels;
PyTypeObject *tp = Py_TYPE(self);
tp->tp_free(self);
/* "Instances of heap-allocated types hold a reference to their type."
* See: https://docs.python.org/3.11/howto/isolating-extensions.html#garbage-collection-protocol
* See: https://docs.python.org/3.11/c-api/typeobj.html#c.PyTypeObject.tp_traverse
*/
// XXX Why don't we implement Py_TPFLAGS_HAVE_GC, e.g. Py_tp_traverse,
// like we do for _abc._abc_data?
Py_DECREF(tp);
_channels_drop_id_object(channels, cid);
}
@ -1774,11 +1841,6 @@ channelid_int(PyObject *self)
return PyLong_FromLongLong(cid->id);
}
static PyNumberMethods channelid_as_number = {
.nb_int = (unaryfunc)channelid_int, /* nb_int */
.nb_index = (unaryfunc)channelid_int, /* nb_index */
};
static Py_hash_t
channelid_hash(PyObject *self)
{
@ -1804,15 +1866,19 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
if (mod == NULL) {
return NULL;
}
module_state *state = get_module_state(mod);
if (state == NULL) {
goto done;
}
if (!PyObject_TypeCheck(self, &ChannelIDType)) {
if (!PyObject_TypeCheck(self, state->ChannelIDType)) {
res = Py_NewRef(Py_NotImplemented);
goto done;
}
channelid *cid = (channelid *)self;
int equal;
if (PyObject_TypeCheck(other, &ChannelIDType)) {
if (PyObject_TypeCheck(other, state->ChannelIDType)) {
channelid *othercid = (channelid *)other;
equal = (cid->end == othercid->end) && (cid->id == othercid->id);
}
@ -1892,10 +1958,14 @@ _channelid_from_xid(_PyCrossInterpreterData *data)
if (mod == NULL) {
return NULL;
}
module_state *state = get_module_state(mod);
if (state == NULL) {
return NULL;
}
// Note that we do not preserve the "resolve" flag.
PyObject *cid = NULL;
int err = newchannelid(&ChannelIDType, xid->id, xid->end,
int err = newchannelid(state->ChannelIDType, xid->id, xid->end,
_global_channels(), 0, 0,
(channelid **)&cid);
if (err != 0) {
@ -1926,20 +1996,20 @@ done:
}
static int
_channelid_shared(PyObject *obj, _PyCrossInterpreterData *data)
_channelid_shared(PyThreadState *tstate, PyObject *obj,
_PyCrossInterpreterData *data)
{
struct _channelid_xid *xid = PyMem_NEW(struct _channelid_xid, 1);
if (xid == NULL) {
if (_PyCrossInterpreterData_InitWithSize(
data, tstate->interp, sizeof(struct _channelid_xid), obj,
_channelid_from_xid
) < 0)
{
return -1;
}
struct _channelid_xid *xid = (struct _channelid_xid *)data->data;
xid->id = ((channelid *)obj)->id;
xid->end = ((channelid *)obj)->end;
xid->resolve = ((channelid *)obj)->resolve;
data->data = xid;
data->obj = Py_NewRef(obj);
data->new_object = _channelid_from_xid;
data->free = PyMem_Free;
return 0;
}
@ -1992,61 +2062,45 @@ static PyGetSetDef channelid_getsets[] = {
PyDoc_STRVAR(channelid_doc,
"A channel ID identifies a channel and may be used as an int.");
static PyTypeObject ChannelIDType = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"_xxsubinterpreters.ChannelID", /* tp_name */
sizeof(channelid), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)channelid_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_as_async */
(reprfunc)channelid_repr, /* tp_repr */
&channelid_as_number, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
channelid_hash, /* tp_hash */
0, /* tp_call */
(reprfunc)channelid_str, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
// Use Py_TPFLAGS_DISALLOW_INSTANTIATION so the type cannot be instantiated
// from Python code. We do this because there is a strong relationship
// between channel IDs and the channel lifecycle, so this limitation avoids
// related complications. Use the _channel_id() function instead.
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE
| Py_TPFLAGS_DISALLOW_INSTANTIATION, /* tp_flags */
channelid_doc, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
channelid_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
channelid_getsets, /* tp_getset */
static PyType_Slot ChannelIDType_slots[] = {
{Py_tp_dealloc, (destructor)channelid_dealloc},
{Py_tp_doc, (void *)channelid_doc},
{Py_tp_repr, (reprfunc)channelid_repr},
{Py_tp_str, (reprfunc)channelid_str},
{Py_tp_hash, channelid_hash},
{Py_tp_richcompare, channelid_richcompare},
{Py_tp_getset, channelid_getsets},
// number slots
{Py_nb_int, (unaryfunc)channelid_int},
{Py_nb_index, (unaryfunc)channelid_int},
{0, NULL},
};
static PyType_Spec ChannelIDType_spec = {
.name = "_xxsubinterpreters.ChannelID",
.basicsize = sizeof(channelid),
.flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_DISALLOW_INSTANTIATION | Py_TPFLAGS_IMMUTABLETYPE),
.slots = ChannelIDType_slots,
};
/* interpreter-specific code ************************************************/
static PyObject * RunFailedError = NULL;
static int
interp_exceptions_init(PyObject *mod)
{
// XXX Move the exceptions into per-module memory?
module_state *state = get_module_state(mod);
if (state == NULL) {
return -1;
}
#define ADD(NAME, BASE) \
do { \
if (NAME == NULL) { \
NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \
if (NAME == NULL) { \
return -1; \
} \
assert(state->NAME == NULL); \
state->NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \
if (state->NAME == NULL) { \
return -1; \
} \
} while (0)
@ -2167,9 +2221,10 @@ _run_script_in_interpreter(PyObject *mod, PyInterpreterState *interp,
if (_ensure_not_running(interp) < 0) {
return -1;
}
module_state *state = get_module_state(mod);
int needs_import = 0;
_sharedns *shared = _get_shared_ns(shareables, &ChannelIDType,
_sharedns *shared = _get_shared_ns(shareables, state->ChannelIDType,
&needs_import);
if (shared == NULL && PyErr_Occurred()) {
return -1;
@ -2195,7 +2250,8 @@ _run_script_in_interpreter(PyObject *mod, PyInterpreterState *interp,
// Propagate any exception out to the caller.
if (exc != NULL) {
_sharedexception_apply(exc, RunFailedError);
assert(state != NULL);
_sharedexception_apply(exc, state->RunFailedError);
_sharedexception_free(exc);
}
else if (result != 0) {
@ -2530,8 +2586,12 @@ channel_create(PyObject *self, PyObject *Py_UNUSED(ignored))
(void)handle_channel_error(cid, self, -1);
return NULL;
}
module_state *state = get_module_state(self);
if (state == NULL) {
return NULL;
}
PyObject *id = NULL;
int err = newchannelid(&ChannelIDType, cid, 0,
int err = newchannelid(state->ChannelIDType, cid, 0,
&_globals.channels, 0, 0,
(channelid **)&id);
if (handle_channel_error(err, self, cid)) {
@ -2594,10 +2654,16 @@ channel_list_all(PyObject *self, PyObject *Py_UNUSED(ignored))
if (ids == NULL) {
goto finally;
}
module_state *state = get_module_state(self);
if (state == NULL) {
Py_DECREF(ids);
ids = NULL;
goto finally;
}
int64_t *cur = cids;
for (int64_t i=0; i < count; cur++, i++) {
PyObject *id = NULL;
int err = newchannelid(&ChannelIDType, *cur, 0,
int err = newchannelid(state->ChannelIDType, *cur, 0,
&_globals.channels, 0, 0,
(channelid **)&id);
if (handle_channel_error(err, self, *cur)) {
@ -2850,7 +2916,11 @@ ends are closed. Closing an already closed end is a noop.");
static PyObject *
channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
{
PyTypeObject *cls = &ChannelIDType;
module_state *state = get_module_state(self);
if (state == NULL) {
return NULL;
}
PyTypeObject *cls = state->ChannelIDType;
PyObject *mod = get_module_from_owned_type(cls);
if (mod == NULL) {
return NULL;
@ -2924,9 +2994,16 @@ module_exec(PyObject *mod)
}
/* Add other types */
if (add_new_type(mod, &ChannelIDType, _channelid_shared) == NULL) {
module_state *state = get_module_state(mod);
// ChannelID
state->ChannelIDType = add_new_type(
mod, &ChannelIDType_spec, _channelid_shared);
if (state->ChannelIDType == NULL) {
goto error;
}
// PyInterpreterID
if (PyModule_AddType(mod, &_PyInterpreterID_Type) < 0) {
goto error;
}
@ -2934,31 +3011,57 @@ module_exec(PyObject *mod)
return 0;
error:
(void)_PyCrossInterpreterData_UnregisterClass(&ChannelIDType);
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
_globals_fini();
return -1;
}
static struct PyModuleDef_Slot module_slots[] = {
{Py_mod_exec, module_exec},
{0, NULL},
};
static int
module_traverse(PyObject *mod, visitproc visit, void *arg)
{
module_state *state = get_module_state(mod);
assert(state != NULL);
traverse_module_state(state, visit, arg);
return 0;
}
static int
module_clear(PyObject *mod)
{
module_state *state = get_module_state(mod);
assert(state != NULL);
clear_module_state(state);
return 0;
}
static void
module_free(void *mod)
{
module_state *state = get_module_state(mod);
assert(state != NULL);
clear_module_state(state);
_globals_fini();
}
static struct PyModuleDef moduledef = {
.m_base = PyModuleDef_HEAD_INIT,
.m_name = MODULE_NAME,
.m_doc = module_doc,
.m_size = -1,
.m_size = sizeof(module_state),
.m_methods = module_functions,
.m_slots = module_slots,
.m_traverse = module_traverse,
.m_clear = module_clear,
.m_free = (freefunc)module_free,
};
PyMODINIT_FUNC
PyInit__xxsubinterpreters(void)
{
/* Create the module */
PyObject *mod = PyModule_Create(&moduledef);
if (mod == NULL) {
return NULL;
}
if (module_exec(mod) < 0) {
Py_DECREF(mod);
return NULL;
}
return mod;
return PyModuleDef_Init(&moduledef);
}