[3.12] gh-106905: Use separate structs to track recursion depth in each PyAST_mod2obj call. (GH-113035) (GH-113472)

(cherry picked from commit 48c49739f5)

Co-authored-by: Yilei Yang <yileiyang@google.com>
Co-authored-by: Gregory P. Smith [Google LLC] <greg@krypto.org>
This commit is contained in:
Serhiy Storchaka 2023-12-25 21:20:07 +02:00 committed by GitHub
parent 2c07540e7d
commit d58a5f453f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 420 additions and 328 deletions

View file

@ -12,8 +12,8 @@ extern "C" {
struct ast_state { struct ast_state {
int initialized; int initialized;
int recursion_depth; int unused_recursion_depth;
int recursion_limit; int unused_recursion_limit;
PyObject *AST_type; PyObject *AST_type;
PyObject *Add_singleton; PyObject *Add_singleton;
PyObject *Add_type; PyObject *Add_type;

View file

@ -0,0 +1,7 @@
Use per AST-parser state rather than global state to track recursion depth
within the AST parser to prevent potential race condition due to
simultaneous parsing.
The issue primarily showed up in 3.11 by multithreaded users of
:func:`ast.parse`. In 3.12 a change to when garbage collection can be
triggered prevented the race condition from occurring.

View file

@ -731,7 +731,7 @@ class SequenceConstructorVisitor(EmitVisitor):
class PyTypesDeclareVisitor(PickleVisitor): class PyTypesDeclareVisitor(PickleVisitor):
def visitProduct(self, prod, name): def visitProduct(self, prod, name):
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0) self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
if prod.attributes: if prod.attributes:
self.emit("static const char * const %s_attributes[] = {" % name, 0) self.emit("static const char * const %s_attributes[] = {" % name, 0)
for a in prod.attributes: for a in prod.attributes:
@ -752,7 +752,7 @@ class PyTypesDeclareVisitor(PickleVisitor):
ptype = "void*" ptype = "void*"
if is_simple(sum): if is_simple(sum):
ptype = get_c_type(name) ptype = get_c_type(name)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0) self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
for t in sum.types: for t in sum.types:
self.visitConstructor(t, name) self.visitConstructor(t, name)
@ -984,7 +984,8 @@ add_attributes(struct ast_state *state, PyObject *type, const char * const *attr
/* Conversion AST -> Python */ /* Conversion AST -> Python */
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*)) static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
{ {
Py_ssize_t i, n = asdl_seq_LEN(seq); Py_ssize_t i, n = asdl_seq_LEN(seq);
PyObject *result = PyList_New(n); PyObject *result = PyList_New(n);
@ -992,7 +993,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject*
if (!result) if (!result)
return NULL; return NULL;
for (i = 0; i < n; i++) { for (i = 0; i < n; i++) {
value = func(state, asdl_seq_GET_UNTYPED(seq, i)); value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
if (!value) { if (!value) {
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
@ -1002,7 +1003,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject*
return result; return result;
} }
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o) static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
{ {
PyObject *op = (PyObject*)o; PyObject *op = (PyObject*)o;
if (!op) { if (!op) {
@ -1014,7 +1015,7 @@ static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
#define ast2obj_identifier ast2obj_object #define ast2obj_identifier ast2obj_object
#define ast2obj_string ast2obj_object #define ast2obj_string ast2obj_object
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b) static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
{ {
return PyLong_FromLong(b); return PyLong_FromLong(b);
} }
@ -1123,8 +1124,6 @@ static int add_ast_fields(struct ast_state *state)
for dfn in mod.dfns: for dfn in mod.dfns:
self.visit(dfn) self.visit(dfn)
self.file.write(textwrap.dedent(''' self.file.write(textwrap.dedent('''
state->recursion_depth = 0;
state->recursion_limit = 0;
state->initialized = 1; state->initialized = 1;
return 1; return 1;
} }
@ -1265,7 +1264,7 @@ class ObjVisitor(PickleVisitor):
def func_begin(self, name): def func_begin(self, name):
ctype = get_c_type(name) ctype = get_c_type(name)
self.emit("PyObject*", 0) self.emit("PyObject*", 0)
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0) self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
self.emit("{", 0) self.emit("{", 0)
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1) self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
self.emit("PyObject *result = NULL, *value = NULL;", 1) self.emit("PyObject *result = NULL, *value = NULL;", 1)
@ -1273,16 +1272,17 @@ class ObjVisitor(PickleVisitor):
self.emit('if (!o) {', 1) self.emit('if (!o) {', 1)
self.emit("Py_RETURN_NONE;", 2) self.emit("Py_RETURN_NONE;", 2)
self.emit("}", 1) self.emit("}", 1)
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1) self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2) self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3) self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit("return 0;", 2) self.emit("return 0;", 2)
self.emit("}", 1) self.emit("}", 1)
def func_end(self): def func_end(self):
self.emit("state->recursion_depth--;", 1) self.emit("vstate->recursion_depth--;", 1)
self.emit("return result;", 1) self.emit("return result;", 1)
self.emit("failed:", 0) self.emit("failed:", 0)
self.emit("vstate->recursion_depth--;", 1)
self.emit("Py_XDECREF(value);", 1) self.emit("Py_XDECREF(value);", 1)
self.emit("Py_XDECREF(result);", 1) self.emit("Py_XDECREF(result);", 1)
self.emit("return NULL;", 1) self.emit("return NULL;", 1)
@ -1300,7 +1300,7 @@ class ObjVisitor(PickleVisitor):
self.visitConstructor(t, i + 1, name) self.visitConstructor(t, i + 1, name)
self.emit("}", 1) self.emit("}", 1)
for a in sum.attributes: for a in sum.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1) self.emit("if (!value) goto failed;", 1)
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1) self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
self.emit('goto failed;', 2) self.emit('goto failed;', 2)
@ -1308,7 +1308,7 @@ class ObjVisitor(PickleVisitor):
self.func_end() self.func_end()
def simpleSum(self, sum, name): def simpleSum(self, sum, name):
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0) self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
self.emit("{", 0) self.emit("{", 0)
self.emit("switch(o) {", 1) self.emit("switch(o) {", 1)
for t in sum.types: for t in sum.types:
@ -1326,7 +1326,7 @@ class ObjVisitor(PickleVisitor):
for field in prod.fields: for field in prod.fields:
self.visitField(field, name, 1, True) self.visitField(field, name, 1, True)
for a in prod.attributes: for a in prod.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1) self.emit("if (!value) goto failed;", 1)
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1) self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
self.emit('goto failed;', 2) self.emit('goto failed;', 2)
@ -1367,7 +1367,7 @@ class ObjVisitor(PickleVisitor):
self.emit("for(i = 0; i < n; i++)", depth+1) self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling # This cannot fail, so no need for error handling
self.emit( self.emit(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format( "PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
field.type, field.type,
value value
), ),
@ -1376,9 +1376,9 @@ class ObjVisitor(PickleVisitor):
) )
self.emit("}", depth) self.emit("}", depth)
else: else:
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
else: else:
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False) self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)
class PartingShots(StaticVisitor): class PartingShots(StaticVisitor):
@ -1396,21 +1396,22 @@ PyObject* PyAST_mod2obj(mod_ty t)
int COMPILER_STACK_FRAME_SCALE = 2; int COMPILER_STACK_FRAME_SCALE = 2;
PyThreadState *tstate = _PyThreadState_GET(); PyThreadState *tstate = _PyThreadState_GET();
if (!tstate) { if (!tstate) {
return 0; return NULL;
} }
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE; struct validator vstate;
vstate.recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining; int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE; starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth; vstate.recursion_depth = starting_recursion_depth;
PyObject *result = ast2obj_mod(state, t); PyObject *result = ast2obj_mod(state, &vstate, t);
/* Check that the recursion depth counting balanced correctly */ /* Check that the recursion depth counting balanced correctly */
if (result && state->recursion_depth != starting_recursion_depth) { if (result && vstate.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError, PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)", "AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth); starting_recursion_depth, vstate.recursion_depth);
return 0; return NULL;
} }
return result; return result;
} }
@ -1478,8 +1479,8 @@ class ChainOfVisitors:
def generate_ast_state(module_state, f): def generate_ast_state(module_state, f):
f.write('struct ast_state {\n') f.write('struct ast_state {\n')
f.write(' int initialized;\n') f.write(' int initialized;\n')
f.write(' int recursion_depth;\n') f.write(' int unused_recursion_depth;\n')
f.write(' int recursion_limit;\n') f.write(' int unused_recursion_limit;\n')
for s in module_state: for s in module_state:
f.write(' PyObject *' + s + ';\n') f.write(' PyObject *' + s + ';\n')
f.write('};') f.write('};')
@ -1545,6 +1546,11 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "structmember.h" #include "structmember.h"
#include <stddef.h> #include <stddef.h>
struct validator {
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};
// Forward declaration // Forward declaration
static int init_types(struct ast_state *state); static int init_types(struct ast_state *state);

679
Python/Python-ast.c generated

File diff suppressed because it is too large Load diff