gh-91051: allow setting a callback hook on PyType_Modified (GH-97875)

This commit is contained in:
Carl Meyer 2022-10-21 07:41:51 -06:00 committed by GitHub
parent 8367ca136e
commit 82ccbf69a8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 462 additions and 5 deletions

View file

@ -372,6 +372,83 @@ _PyTypes_Fini(PyInterpreterState *interp)
static PyObject * lookup_subclasses(PyTypeObject *);
int
PyType_AddWatcher(PyType_WatchCallback callback)
{
PyInterpreterState *interp = _PyInterpreterState_GET();
for (int i = 0; i < TYPE_MAX_WATCHERS; i++) {
if (!interp->type_watchers[i]) {
interp->type_watchers[i] = callback;
return i;
}
}
PyErr_SetString(PyExc_RuntimeError, "no more type watcher IDs available");
return -1;
}
static inline int
validate_watcher_id(PyInterpreterState *interp, int watcher_id)
{
if (watcher_id < 0 || watcher_id >= TYPE_MAX_WATCHERS) {
PyErr_Format(PyExc_ValueError, "Invalid type watcher ID %d", watcher_id);
return -1;
}
if (!interp->type_watchers[watcher_id]) {
PyErr_Format(PyExc_ValueError, "No type watcher set for ID %d", watcher_id);
return -1;
}
return 0;
}
int
PyType_ClearWatcher(int watcher_id)
{
PyInterpreterState *interp = _PyInterpreterState_GET();
if (validate_watcher_id(interp, watcher_id) < 0) {
return -1;
}
interp->type_watchers[watcher_id] = NULL;
return 0;
}
static int assign_version_tag(PyTypeObject *type);
int
PyType_Watch(int watcher_id, PyObject* obj)
{
if (!PyType_Check(obj)) {
PyErr_SetString(PyExc_ValueError, "Cannot watch non-type");
return -1;
}
PyTypeObject *type = (PyTypeObject *)obj;
PyInterpreterState *interp = _PyInterpreterState_GET();
if (validate_watcher_id(interp, watcher_id) < 0) {
return -1;
}
// ensure we will get a callback on the next modification
assign_version_tag(type);
type->tp_watched |= (1 << watcher_id);
return 0;
}
int
PyType_Unwatch(int watcher_id, PyObject* obj)
{
if (!PyType_Check(obj)) {
PyErr_SetString(PyExc_ValueError, "Cannot watch non-type");
return -1;
}
PyTypeObject *type = (PyTypeObject *)obj;
PyInterpreterState *interp = _PyInterpreterState_GET();
if (validate_watcher_id(interp, watcher_id)) {
return -1;
}
type->tp_watched &= ~(1 << watcher_id);
return 0;
}
void
PyType_Modified(PyTypeObject *type)
{
@ -409,6 +486,23 @@ PyType_Modified(PyTypeObject *type)
}
}
if (type->tp_watched) {
PyInterpreterState *interp = _PyInterpreterState_GET();
int bits = type->tp_watched;
int i = 0;
while(bits && i < TYPE_MAX_WATCHERS) {
if (bits & 1) {
PyType_WatchCallback cb = interp->type_watchers[i];
if (cb && (cb(type) < 0)) {
PyErr_WriteUnraisable((PyObject *)type);
}
}
i += 1;
bits >>= 1;
}
}
type->tp_flags &= ~Py_TPFLAGS_VALID_VERSION_TAG;
type->tp_version_tag = 0; /* 0 is not a valid version tag */
}
@ -467,7 +561,7 @@ type_mro_modified(PyTypeObject *type, PyObject *bases) {
}
static int
assign_version_tag(struct type_cache *cache, PyTypeObject *type)
assign_version_tag(PyTypeObject *type)
{
/* Ensure that the tp_version_tag is valid and set
Py_TPFLAGS_VALID_VERSION_TAG. To respect the invariant, this
@ -492,7 +586,7 @@ assign_version_tag(struct type_cache *cache, PyTypeObject *type)
Py_ssize_t n = PyTuple_GET_SIZE(bases);
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *b = PyTuple_GET_ITEM(bases, i);
if (!assign_version_tag(cache, _PyType_CAST(b)))
if (!assign_version_tag(_PyType_CAST(b)))
return 0;
}
type->tp_flags |= Py_TPFLAGS_VALID_VERSION_TAG;
@ -4111,7 +4205,7 @@ _PyType_Lookup(PyTypeObject *type, PyObject *name)
return NULL;
}
if (MCACHE_CACHEABLE_NAME(name) && assign_version_tag(cache, type)) {
if (MCACHE_CACHEABLE_NAME(name) && assign_version_tag(type)) {
h = MCACHE_HASH_METHOD(type, name);
struct type_cache_entry *entry = &cache->hashtable[h];
entry->version = type->tp_version_tag;