gh-122313: Clean up deep recursion guarding code in the compiler (GH-122640)

Add ENTER_RECURSIVE and LEAVE_RECURSIVE macros in ast.c, ast_opt.c and
symtable.c. Remove VISIT_QUIT macro in symtable.c.

The current recursion depth counter only needs to be updated during
normal execution -- all functions should just return an error code
if an error occurs.
This commit is contained in:
Serhiy Storchaka 2024-08-03 12:45:45 +03:00 committed by GitHub
parent fe0a28d850
commit efcd65cd84
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 164 additions and 162 deletions

View file

@ -14,6 +14,20 @@ struct validator {
int recursion_limit; /* recursion limit */
};
#define ENTER_RECURSIVE(ST) \
do { \
if (++(ST)->recursion_depth > (ST)->recursion_limit) { \
PyErr_SetString(PyExc_RecursionError, \
"maximum recursion depth exceeded during compilation"); \
return 0; \
} \
} while(0)
#define LEAVE_RECURSIVE(ST) \
do { \
--(ST)->recursion_depth; \
} while(0)
static int validate_stmts(struct validator *, asdl_stmt_seq *);
static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
@ -166,11 +180,7 @@ validate_constant(struct validator *state, PyObject *value)
return 1;
if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) {
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
PyObject *it = PyObject_GetIter(value);
if (it == NULL)
@ -195,7 +205,7 @@ validate_constant(struct validator *state, PyObject *value)
}
Py_DECREF(it);
--state->recursion_depth;
LEAVE_RECURSIVE(state);
return 1;
}
@ -213,11 +223,7 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
assert(!PyErr_Occurred());
VALIDATE_POSITIONS(exp);
int ret = -1;
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
int check_ctx = 1;
expr_context_ty actual_ctx;
@ -398,7 +404,7 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
PyErr_SetString(PyExc_SystemError, "unexpected expression");
ret = 0;
}
state->recursion_depth--;
LEAVE_RECURSIVE(state);
return ret;
}
@ -544,11 +550,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
assert(!PyErr_Occurred());
VALIDATE_POSITIONS(p);
int ret = -1;
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
switch (p->kind) {
case MatchValue_kind:
ret = validate_pattern_match_value(state, p->v.MatchValue.value);
@ -690,7 +692,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
PyErr_SetString(PyExc_SystemError, "unexpected pattern");
ret = 0;
}
state->recursion_depth--;
LEAVE_RECURSIVE(state);
return ret;
}
@ -725,11 +727,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
assert(!PyErr_Occurred());
VALIDATE_POSITIONS(stmt);
int ret = -1;
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
ENTER_RECURSIVE(state);
switch (stmt->kind) {
case FunctionDef_kind:
ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") &&
@ -946,7 +944,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
PyErr_SetString(PyExc_SystemError, "unexpected statement");
ret = 0;
}
state->recursion_depth--;
LEAVE_RECURSIVE(state);
return ret;
}