GH-91079: Decouple C stack overflow checks from Python recursion checks. (GH-96510)

This commit is contained in:
Mark Shannon 2022-10-05 01:34:03 +01:00 committed by GitHub
parent 0ff8fd6583
commit 76449350b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 165 additions and 99 deletions

9
Python/Python-ast.c generated
View file

@ -12315,7 +12315,6 @@ PyObject* PyAST_mod2obj(mod_ty t)
return NULL;
}
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;
/* Be careful here to prevent overflow. */
int COMPILER_STACK_FRAME_SCALE = 3;
@ -12323,11 +12322,9 @@ PyObject* PyAST_mod2obj(mod_ty t)
if (!tstate) {
return 0;
}
state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;
PyObject *result = ast2obj_mod(state, t);

View file

@ -975,7 +975,6 @@ _PyAST_Validate(mod_ty mod)
int res = -1;
struct validator state;
PyThreadState *tstate;
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;
/* Setup recursion depth check counters */
@ -984,12 +983,10 @@ _PyAST_Validate(mod_ty mod)
return 0;
}
/* Be careful here to prevent overflow. */
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth< INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state.recursion_depth = starting_recursion_depth;
state.recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
state.recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
switch (mod->kind) {
case Module_kind:

View file

@ -1080,7 +1080,6 @@ int
_PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
{
PyThreadState *tstate;
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;
/* Setup recursion depth check counters */
@ -1089,12 +1088,10 @@ _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
return 0;
}
/* Be careful here to prevent overflow. */
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;
state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int ret = astfold_mod(mod, arena, state);
assert(ret || PyErr_Occurred());

View file

@ -257,9 +257,9 @@ Py_SetRecursionLimit(int new_limit)
PyInterpreterState *interp = _PyInterpreterState_GET();
interp->ceval.recursion_limit = new_limit;
for (PyThreadState *p = interp->threads.head; p != NULL; p = p->next) {
int depth = p->recursion_limit - p->recursion_remaining;
p->recursion_limit = new_limit;
p->recursion_remaining = new_limit - depth;
int depth = p->py_recursion_limit - p->py_recursion_remaining;
p->py_recursion_limit = new_limit;
p->py_recursion_remaining = new_limit - depth;
}
}
@ -268,35 +268,27 @@ Py_SetRecursionLimit(int new_limit)
int
_Py_CheckRecursiveCall(PyThreadState *tstate, const char *where)
{
/* Check against global limit first. */
int depth = tstate->recursion_limit - tstate->recursion_remaining;
if (depth < tstate->interp->ceval.recursion_limit) {
tstate->recursion_limit = tstate->interp->ceval.recursion_limit;
tstate->recursion_remaining = tstate->recursion_limit - depth;
assert(tstate->recursion_remaining > 0);
return 0;
}
#ifdef USE_STACKCHECK
if (PyOS_CheckStack()) {
++tstate->recursion_remaining;
++tstate->c_recursion_remaining;
_PyErr_SetString(tstate, PyExc_MemoryError, "Stack overflow");
return -1;
}
#endif
if (tstate->recursion_headroom) {
if (tstate->recursion_remaining < -50) {
if (tstate->c_recursion_remaining < -50) {
/* Overflowing while handling an overflow. Give up. */
Py_FatalError("Cannot recover from stack overflow.");
}
}
else {
if (tstate->recursion_remaining <= 0) {
if (tstate->c_recursion_remaining <= 0) {
tstate->recursion_headroom++;
_PyErr_Format(tstate, PyExc_RecursionError,
"maximum recursion depth exceeded%s",
where);
tstate->recursion_headroom--;
++tstate->recursion_remaining;
++tstate->c_recursion_remaining;
return -1;
}
}
@ -983,6 +975,39 @@ pop_frame(PyThreadState *tstate, _PyInterpreterFrame *frame)
return prev_frame;
}
int _Py_CheckRecursiveCallPy(
PyThreadState *tstate)
{
if (tstate->recursion_headroom) {
if (tstate->py_recursion_remaining < -50) {
/* Overflowing while handling an overflow. Give up. */
Py_FatalError("Cannot recover from Python stack overflow.");
}
}
else {
if (tstate->py_recursion_remaining <= 0) {
tstate->recursion_headroom++;
_PyErr_Format(tstate, PyExc_RecursionError,
"maximum recursion depth exceeded");
tstate->recursion_headroom--;
return -1;
}
}
return 0;
}
static inline int _Py_EnterRecursivePy(PyThreadState *tstate) {
return (tstate->py_recursion_remaining-- <= 0) &&
_Py_CheckRecursiveCallPy(tstate);
}
static inline void _Py_LeaveRecursiveCallPy(PyThreadState *tstate) {
tstate->py_recursion_remaining++;
}
/* It is only between the KW_NAMES instruction and the following CALL,
* that this has any meaning.
*/
@ -1037,10 +1062,15 @@ _PyEval_EvalFrameDefault(PyThreadState *tstate, _PyInterpreterFrame *frame, int
frame->previous = prev_cframe->current_frame;
cframe.current_frame = frame;
if (_Py_EnterRecursiveCallTstate(tstate, "")) {
tstate->c_recursion_remaining--;
tstate->py_recursion_remaining--;
goto exit_unwind;
}
/* support for generator.throw() */
if (throwflag) {
if (_Py_EnterRecursiveCallTstate(tstate, "")) {
tstate->recursion_remaining--;
if (_Py_EnterRecursivePy(tstate)) {
goto exit_unwind;
}
TRACE_FUNCTION_THROW_ENTRY();
@ -1079,8 +1109,7 @@ _PyEval_EvalFrameDefault(PyThreadState *tstate, _PyInterpreterFrame *frame, int
start_frame:
if (_Py_EnterRecursiveCallTstate(tstate, "")) {
tstate->recursion_remaining--;
if (_Py_EnterRecursivePy(tstate)) {
goto exit_unwind;
}
@ -1830,12 +1859,13 @@ handle_eval_breaker:
_PyFrame_SetStackPointer(frame, stack_pointer);
TRACE_FUNCTION_EXIT();
DTRACE_FUNCTION_EXIT();
_Py_LeaveRecursiveCallTstate(tstate);
_Py_LeaveRecursiveCallPy(tstate);
if (!frame->is_entry) {
frame = cframe.current_frame = pop_frame(tstate, frame);
_PyFrame_StackPush(frame, retval);
goto resume_frame;
}
_Py_LeaveRecursiveCallTstate(tstate);
/* Restore previous cframe and return. */
tstate->cframe = cframe.previous;
tstate->cframe->use_tracing = cframe.use_tracing;
@ -2046,6 +2076,7 @@ handle_eval_breaker:
_PyFrame_SetStackPointer(frame, stack_pointer);
TRACE_FUNCTION_EXIT();
DTRACE_FUNCTION_EXIT();
_Py_LeaveRecursiveCallPy(tstate);
_Py_LeaveRecursiveCallTstate(tstate);
/* Restore previous cframe and return. */
tstate->cframe = cframe.previous;
@ -4800,7 +4831,7 @@ handle_eval_breaker:
assert(frame->frame_obj == NULL);
gen->gi_frame_state = FRAME_CREATED;
gen_frame->owner = FRAME_OWNED_BY_GENERATOR;
_Py_LeaveRecursiveCallTstate(tstate);
_Py_LeaveRecursiveCallPy(tstate);
if (!frame->is_entry) {
_PyInterpreterFrame *prev = frame->previous;
_PyThreadState_PopFrame(tstate, frame);
@ -4808,6 +4839,7 @@ handle_eval_breaker:
_PyFrame_StackPush(frame, (PyObject *)gen);
goto resume_frame;
}
_Py_LeaveRecursiveCallTstate(tstate);
/* Make sure that frame is in a valid state */
frame->stacktop = 0;
frame->f_locals = NULL;
@ -5178,12 +5210,13 @@ exception_unwind:
exit_unwind:
assert(_PyErr_Occurred(tstate));
_Py_LeaveRecursiveCallTstate(tstate);
_Py_LeaveRecursiveCallPy(tstate);
if (frame->is_entry) {
/* Restore previous cframe and exit */
tstate->cframe = cframe.previous;
tstate->cframe->use_tracing = cframe.use_tracing;
assert(tstate->cframe->current_frame == frame->previous);
_Py_LeaveRecursiveCallTstate(tstate);
return NULL;
}
frame = cframe.current_frame = pop_frame(tstate, frame);
@ -5755,11 +5788,11 @@ _PyEvalFrameClearAndPop(PyThreadState *tstate, _PyInterpreterFrame * frame)
// _PyThreadState_PopFrame, since f_code is already cleared at that point:
assert((PyObject **)frame + frame->f_code->co_framesize ==
tstate->datastack_top);
tstate->recursion_remaining--;
tstate->c_recursion_remaining--;
assert(frame->frame_obj == NULL || frame->frame_obj->f_frame == frame);
assert(frame->owner == FRAME_OWNED_BY_THREAD);
_PyFrame_Clear(frame);
tstate->recursion_remaining++;
tstate->c_recursion_remaining++;
_PyThreadState_PopFrame(tstate, frame);
}

View file

@ -792,8 +792,9 @@ init_threadstate(PyThreadState *tstate,
tstate->native_thread_id = PyThread_get_thread_native_id();
#endif
tstate->recursion_limit = interp->ceval.recursion_limit,
tstate->recursion_remaining = interp->ceval.recursion_limit,
tstate->py_recursion_limit = interp->ceval.recursion_limit,
tstate->py_recursion_remaining = interp->ceval.recursion_limit,
tstate->c_recursion_remaining = C_RECURSION_LIMIT;
tstate->exc_info = &tstate->exc_state;

View file

@ -278,7 +278,6 @@ _PySymtable_Build(mod_ty mod, PyObject *filename, PyFutureFeatures *future)
asdl_stmt_seq *seq;
int i;
PyThreadState *tstate;
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;
if (st == NULL)
@ -298,12 +297,10 @@ _PySymtable_Build(mod_ty mod, PyObject *filename, PyFutureFeatures *future)
return NULL;
}
/* Be careful here to prevent overflow. */
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
st->recursion_depth = starting_recursion_depth;
st->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
st->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
/* Make the initial symbol information gathering pass */
if (!symtable_enter_block(st, &_Py_ID(top), ModuleBlock, (void *)mod, 0, 0, 0, 0)) {

View file

@ -1218,7 +1218,7 @@ sys_setrecursionlimit_impl(PyObject *module, int new_limit)
/* Reject too low new limit if the current recursion depth is higher than
the new low-water mark. */
int depth = tstate->recursion_limit - tstate->recursion_remaining;
int depth = tstate->py_recursion_limit - tstate->py_recursion_remaining;
if (depth >= new_limit) {
_PyErr_Format(tstate, PyExc_RecursionError,
"cannot set the recursion limit to %i at "