bpo-28411: Support other mappings in PyInterpreterState.modules. (#3593)

The concrete PyDict_* API is used to interact with PyInterpreterState.modules in a number of places. This isn't compatible with all dict subclasses, nor with other Mapping implementations. This patch switches the concrete API usage to the corresponding abstract API calls.

We also add a PyImport_GetModule() function (and some other helpers) to reduce a bunch of code duplication.
This commit is contained in:
Eric Snow 2017-09-15 16:35:20 -06:00 committed by GitHub
parent e82c034496
commit 3f9eee6eb4
11 changed files with 216 additions and 113 deletions

View file

@ -1649,13 +1649,40 @@ getattribute(PyObject *obj, PyObject *name, int allow_qualname)
return attr;
}
static int
_checkmodule(PyObject *module_name, PyObject *module,
PyObject *global, PyObject *dotted_path)
{
if (module == Py_None) {
return -1;
}
if (PyUnicode_Check(module_name) &&
_PyUnicode_EqualToASCIIString(module_name, "__main__")) {
return -1;
}
PyObject *candidate = get_deep_attribute(module, dotted_path, NULL);
if (candidate == NULL) {
if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
PyErr_Clear();
}
return -1;
}
if (candidate != global) {
Py_DECREF(candidate);
return -1;
}
Py_DECREF(candidate);
return 0;
}
static PyObject *
whichmodule(PyObject *global, PyObject *dotted_path)
{
PyObject *module_name;
PyObject *modules_dict;
PyObject *module;
PyObject *module = NULL;
Py_ssize_t i;
PyObject *modules;
_Py_IDENTIFIER(__module__);
_Py_IDENTIFIER(modules);
_Py_IDENTIFIER(__main__);
@ -1678,35 +1705,48 @@ whichmodule(PyObject *global, PyObject *dotted_path)
assert(module_name == NULL);
/* Fallback on walking sys.modules */
modules_dict = _PySys_GetObjectId(&PyId_modules);
if (modules_dict == NULL) {
modules = _PySys_GetObjectId(&PyId_modules);
if (modules == NULL) {
PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules");
return NULL;
}
i = 0;
while (PyDict_Next(modules_dict, &i, &module_name, &module)) {
PyObject *candidate;
if (PyUnicode_Check(module_name) &&
_PyUnicode_EqualToASCIIString(module_name, "__main__"))
continue;
if (module == Py_None)
continue;
candidate = get_deep_attribute(module, dotted_path, NULL);
if (candidate == NULL) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
if (PyDict_CheckExact(modules)) {
i = 0;
while (PyDict_Next(modules, &i, &module_name, &module)) {
if (_checkmodule(module_name, module, global, dotted_path) == 0) {
Py_INCREF(module_name);
return module_name;
}
if (PyErr_Occurred()) {
return NULL;
PyErr_Clear();
continue;
}
}
if (candidate == global) {
Py_INCREF(module_name);
Py_DECREF(candidate);
return module_name;
}
else {
PyObject *iterator = PyObject_GetIter(modules);
if (iterator == NULL) {
return NULL;
}
Py_DECREF(candidate);
while ((module_name = PyIter_Next(iterator))) {
module = PyObject_GetItem(modules, module_name);
if (module == NULL) {
Py_DECREF(module_name);
Py_DECREF(iterator);
return NULL;
}
if (_checkmodule(module_name, module, global, dotted_path) == 0) {
Py_DECREF(module);
Py_DECREF(iterator);
return module_name;
}
Py_DECREF(module);
Py_DECREF(module_name);
if (PyErr_Occurred()) {
Py_DECREF(iterator);
return NULL;
}
}
Py_DECREF(iterator);
}
/* If no module is found, use __main__. */
@ -6424,9 +6464,7 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self,
/*[clinic end generated code: output=becc08d7f9ed41e3 input=e2e6a865de093ef4]*/
{
PyObject *global;
PyObject *modules_dict;
PyObject *module;
_Py_IDENTIFIER(modules);
/* Try to map the old names used in Python 2.x to the new ones used in
Python 3.x. We do this only with old pickle protocols and when the
@ -6483,25 +6521,16 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self,
}
}
modules_dict = _PySys_GetObjectId(&PyId_modules);
if (modules_dict == NULL) {
PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules");
return NULL;
}
module = PyDict_GetItemWithError(modules_dict, module_name);
module = PyImport_GetModule(module_name);
if (module == NULL) {
if (PyErr_Occurred())
return NULL;
module = PyImport_Import(module_name);
if (module == NULL)
return NULL;
global = getattribute(module, global_name, self->proto >= 4);
Py_DECREF(module);
}
else {
global = getattribute(module, global_name, self->proto >= 4);
}
global = getattribute(module, global_name, self->proto >= 4);
Py_DECREF(module);
return global;
}