mirror of
https://github.com/python/cpython.git
synced 2025-08-19 00:00:48 +00:00
[3.12] gh-106905: Use separate structs to track recursion depth in each PyAST_mod2obj call. (GH-113035) (GH-113472)
(cherry picked from commit 48c49739f5
)
Co-authored-by: Yilei Yang <yileiyang@google.com>
Co-authored-by: Gregory P. Smith [Google LLC] <greg@krypto.org>
This commit is contained in:
parent
2c07540e7d
commit
d58a5f453f
4 changed files with 420 additions and 328 deletions
4
Include/internal/pycore_ast_state.h
generated
4
Include/internal/pycore_ast_state.h
generated
|
@ -12,8 +12,8 @@ extern "C" {
|
||||||
|
|
||||||
struct ast_state {
|
struct ast_state {
|
||||||
int initialized;
|
int initialized;
|
||||||
int recursion_depth;
|
int unused_recursion_depth;
|
||||||
int recursion_limit;
|
int unused_recursion_limit;
|
||||||
PyObject *AST_type;
|
PyObject *AST_type;
|
||||||
PyObject *Add_singleton;
|
PyObject *Add_singleton;
|
||||||
PyObject *Add_type;
|
PyObject *Add_type;
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
Use per AST-parser state rather than global state to track recursion depth
|
||||||
|
within the AST parser to prevent potential race condition due to
|
||||||
|
simultaneous parsing.
|
||||||
|
|
||||||
|
The issue primarily showed up in 3.11 by multithreaded users of
|
||||||
|
:func:`ast.parse`. In 3.12 a change to when garbage collection can be
|
||||||
|
triggered prevented the race condition from occurring.
|
|
@ -731,7 +731,7 @@ class SequenceConstructorVisitor(EmitVisitor):
|
||||||
class PyTypesDeclareVisitor(PickleVisitor):
|
class PyTypesDeclareVisitor(PickleVisitor):
|
||||||
|
|
||||||
def visitProduct(self, prod, name):
|
def visitProduct(self, prod, name):
|
||||||
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
|
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
|
||||||
if prod.attributes:
|
if prod.attributes:
|
||||||
self.emit("static const char * const %s_attributes[] = {" % name, 0)
|
self.emit("static const char * const %s_attributes[] = {" % name, 0)
|
||||||
for a in prod.attributes:
|
for a in prod.attributes:
|
||||||
|
@ -752,7 +752,7 @@ class PyTypesDeclareVisitor(PickleVisitor):
|
||||||
ptype = "void*"
|
ptype = "void*"
|
||||||
if is_simple(sum):
|
if is_simple(sum):
|
||||||
ptype = get_c_type(name)
|
ptype = get_c_type(name)
|
||||||
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
|
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
|
||||||
for t in sum.types:
|
for t in sum.types:
|
||||||
self.visitConstructor(t, name)
|
self.visitConstructor(t, name)
|
||||||
|
|
||||||
|
@ -984,7 +984,8 @@ add_attributes(struct ast_state *state, PyObject *type, const char * const *attr
|
||||||
|
|
||||||
/* Conversion AST -> Python */
|
/* Conversion AST -> Python */
|
||||||
|
|
||||||
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
|
static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
|
||||||
|
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
|
||||||
{
|
{
|
||||||
Py_ssize_t i, n = asdl_seq_LEN(seq);
|
Py_ssize_t i, n = asdl_seq_LEN(seq);
|
||||||
PyObject *result = PyList_New(n);
|
PyObject *result = PyList_New(n);
|
||||||
|
@ -992,7 +993,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject*
|
||||||
if (!result)
|
if (!result)
|
||||||
return NULL;
|
return NULL;
|
||||||
for (i = 0; i < n; i++) {
|
for (i = 0; i < n; i++) {
|
||||||
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
|
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
|
||||||
if (!value) {
|
if (!value) {
|
||||||
Py_DECREF(result);
|
Py_DECREF(result);
|
||||||
return NULL;
|
return NULL;
|
||||||
|
@ -1002,7 +1003,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject*
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
|
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
|
||||||
{
|
{
|
||||||
PyObject *op = (PyObject*)o;
|
PyObject *op = (PyObject*)o;
|
||||||
if (!op) {
|
if (!op) {
|
||||||
|
@ -1014,7 +1015,7 @@ static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
|
||||||
#define ast2obj_identifier ast2obj_object
|
#define ast2obj_identifier ast2obj_object
|
||||||
#define ast2obj_string ast2obj_object
|
#define ast2obj_string ast2obj_object
|
||||||
|
|
||||||
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
|
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
|
||||||
{
|
{
|
||||||
return PyLong_FromLong(b);
|
return PyLong_FromLong(b);
|
||||||
}
|
}
|
||||||
|
@ -1123,8 +1124,6 @@ static int add_ast_fields(struct ast_state *state)
|
||||||
for dfn in mod.dfns:
|
for dfn in mod.dfns:
|
||||||
self.visit(dfn)
|
self.visit(dfn)
|
||||||
self.file.write(textwrap.dedent('''
|
self.file.write(textwrap.dedent('''
|
||||||
state->recursion_depth = 0;
|
|
||||||
state->recursion_limit = 0;
|
|
||||||
state->initialized = 1;
|
state->initialized = 1;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -1265,7 +1264,7 @@ class ObjVisitor(PickleVisitor):
|
||||||
def func_begin(self, name):
|
def func_begin(self, name):
|
||||||
ctype = get_c_type(name)
|
ctype = get_c_type(name)
|
||||||
self.emit("PyObject*", 0)
|
self.emit("PyObject*", 0)
|
||||||
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
|
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
|
||||||
self.emit("{", 0)
|
self.emit("{", 0)
|
||||||
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
|
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
|
||||||
self.emit("PyObject *result = NULL, *value = NULL;", 1)
|
self.emit("PyObject *result = NULL, *value = NULL;", 1)
|
||||||
|
@ -1273,16 +1272,17 @@ class ObjVisitor(PickleVisitor):
|
||||||
self.emit('if (!o) {', 1)
|
self.emit('if (!o) {', 1)
|
||||||
self.emit("Py_RETURN_NONE;", 2)
|
self.emit("Py_RETURN_NONE;", 2)
|
||||||
self.emit("}", 1)
|
self.emit("}", 1)
|
||||||
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
|
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
|
||||||
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
|
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
|
||||||
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
|
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
|
||||||
self.emit("return 0;", 2)
|
self.emit("return 0;", 2)
|
||||||
self.emit("}", 1)
|
self.emit("}", 1)
|
||||||
|
|
||||||
def func_end(self):
|
def func_end(self):
|
||||||
self.emit("state->recursion_depth--;", 1)
|
self.emit("vstate->recursion_depth--;", 1)
|
||||||
self.emit("return result;", 1)
|
self.emit("return result;", 1)
|
||||||
self.emit("failed:", 0)
|
self.emit("failed:", 0)
|
||||||
|
self.emit("vstate->recursion_depth--;", 1)
|
||||||
self.emit("Py_XDECREF(value);", 1)
|
self.emit("Py_XDECREF(value);", 1)
|
||||||
self.emit("Py_XDECREF(result);", 1)
|
self.emit("Py_XDECREF(result);", 1)
|
||||||
self.emit("return NULL;", 1)
|
self.emit("return NULL;", 1)
|
||||||
|
@ -1300,7 +1300,7 @@ class ObjVisitor(PickleVisitor):
|
||||||
self.visitConstructor(t, i + 1, name)
|
self.visitConstructor(t, i + 1, name)
|
||||||
self.emit("}", 1)
|
self.emit("}", 1)
|
||||||
for a in sum.attributes:
|
for a in sum.attributes:
|
||||||
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
|
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
|
||||||
self.emit("if (!value) goto failed;", 1)
|
self.emit("if (!value) goto failed;", 1)
|
||||||
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
|
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
|
||||||
self.emit('goto failed;', 2)
|
self.emit('goto failed;', 2)
|
||||||
|
@ -1308,7 +1308,7 @@ class ObjVisitor(PickleVisitor):
|
||||||
self.func_end()
|
self.func_end()
|
||||||
|
|
||||||
def simpleSum(self, sum, name):
|
def simpleSum(self, sum, name):
|
||||||
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
|
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
|
||||||
self.emit("{", 0)
|
self.emit("{", 0)
|
||||||
self.emit("switch(o) {", 1)
|
self.emit("switch(o) {", 1)
|
||||||
for t in sum.types:
|
for t in sum.types:
|
||||||
|
@ -1326,7 +1326,7 @@ class ObjVisitor(PickleVisitor):
|
||||||
for field in prod.fields:
|
for field in prod.fields:
|
||||||
self.visitField(field, name, 1, True)
|
self.visitField(field, name, 1, True)
|
||||||
for a in prod.attributes:
|
for a in prod.attributes:
|
||||||
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
|
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
|
||||||
self.emit("if (!value) goto failed;", 1)
|
self.emit("if (!value) goto failed;", 1)
|
||||||
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
|
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
|
||||||
self.emit('goto failed;', 2)
|
self.emit('goto failed;', 2)
|
||||||
|
@ -1367,7 +1367,7 @@ class ObjVisitor(PickleVisitor):
|
||||||
self.emit("for(i = 0; i < n; i++)", depth+1)
|
self.emit("for(i = 0; i < n; i++)", depth+1)
|
||||||
# This cannot fail, so no need for error handling
|
# This cannot fail, so no need for error handling
|
||||||
self.emit(
|
self.emit(
|
||||||
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
|
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
|
||||||
field.type,
|
field.type,
|
||||||
value
|
value
|
||||||
),
|
),
|
||||||
|
@ -1376,9 +1376,9 @@ class ObjVisitor(PickleVisitor):
|
||||||
)
|
)
|
||||||
self.emit("}", depth)
|
self.emit("}", depth)
|
||||||
else:
|
else:
|
||||||
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
|
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
|
||||||
else:
|
else:
|
||||||
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
|
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)
|
||||||
|
|
||||||
|
|
||||||
class PartingShots(StaticVisitor):
|
class PartingShots(StaticVisitor):
|
||||||
|
@ -1396,21 +1396,22 @@ PyObject* PyAST_mod2obj(mod_ty t)
|
||||||
int COMPILER_STACK_FRAME_SCALE = 2;
|
int COMPILER_STACK_FRAME_SCALE = 2;
|
||||||
PyThreadState *tstate = _PyThreadState_GET();
|
PyThreadState *tstate = _PyThreadState_GET();
|
||||||
if (!tstate) {
|
if (!tstate) {
|
||||||
return 0;
|
return NULL;
|
||||||
}
|
}
|
||||||
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
|
struct validator vstate;
|
||||||
|
vstate.recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
|
||||||
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
|
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
|
||||||
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
|
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
|
||||||
state->recursion_depth = starting_recursion_depth;
|
vstate.recursion_depth = starting_recursion_depth;
|
||||||
|
|
||||||
PyObject *result = ast2obj_mod(state, t);
|
PyObject *result = ast2obj_mod(state, &vstate, t);
|
||||||
|
|
||||||
/* Check that the recursion depth counting balanced correctly */
|
/* Check that the recursion depth counting balanced correctly */
|
||||||
if (result && state->recursion_depth != starting_recursion_depth) {
|
if (result && vstate.recursion_depth != starting_recursion_depth) {
|
||||||
PyErr_Format(PyExc_SystemError,
|
PyErr_Format(PyExc_SystemError,
|
||||||
"AST constructor recursion depth mismatch (before=%d, after=%d)",
|
"AST constructor recursion depth mismatch (before=%d, after=%d)",
|
||||||
starting_recursion_depth, state->recursion_depth);
|
starting_recursion_depth, vstate.recursion_depth);
|
||||||
return 0;
|
return NULL;
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -1478,8 +1479,8 @@ class ChainOfVisitors:
|
||||||
def generate_ast_state(module_state, f):
|
def generate_ast_state(module_state, f):
|
||||||
f.write('struct ast_state {\n')
|
f.write('struct ast_state {\n')
|
||||||
f.write(' int initialized;\n')
|
f.write(' int initialized;\n')
|
||||||
f.write(' int recursion_depth;\n')
|
f.write(' int unused_recursion_depth;\n')
|
||||||
f.write(' int recursion_limit;\n')
|
f.write(' int unused_recursion_limit;\n')
|
||||||
for s in module_state:
|
for s in module_state:
|
||||||
f.write(' PyObject *' + s + ';\n')
|
f.write(' PyObject *' + s + ';\n')
|
||||||
f.write('};')
|
f.write('};')
|
||||||
|
@ -1545,6 +1546,11 @@ def generate_module_def(mod, metadata, f, internal_h):
|
||||||
#include "structmember.h"
|
#include "structmember.h"
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
struct validator {
|
||||||
|
int recursion_depth; /* current recursion depth */
|
||||||
|
int recursion_limit; /* recursion limit */
|
||||||
|
};
|
||||||
|
|
||||||
// Forward declaration
|
// Forward declaration
|
||||||
static int init_types(struct ast_state *state);
|
static int init_types(struct ast_state *state);
|
||||||
|
|
||||||
|
|
679
Python/Python-ast.c
generated
679
Python/Python-ast.c
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue