GH-130396: Use computed stack limits on linux (GH-130398)

* Implement C recursion protection with limit pointers for Linux, MacOS and Windows

* Remove calls to PyOS_CheckStack

* Add stack protection to parser

* Make tests more robust to low stacks

* Improve error messages for stack overflow
This commit is contained in:
Mark Shannon 2025-02-25 09:24:48 +00:00 committed by GitHub
parent 99088ab081
commit 014223649c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
58 changed files with 1295 additions and 1482 deletions

View file

@ -921,11 +921,7 @@ because the :ref:`call protocol <call>` takes care of recursion handling.
Marks a point where a recursive C-level call is about to be performed. Marks a point where a recursive C-level call is about to be performed.
If :c:macro:`!USE_STACKCHECK` is defined, this function checks if the OS The function then checks if the stack limit is reached. If this is the
stack overflowed using :c:func:`PyOS_CheckStack`. If this is the case, it
sets a :exc:`MemoryError` and returns a nonzero value.
The function then checks if the recursion limit is reached. If this is the
case, a :exc:`RecursionError` is set and a nonzero value is returned. case, a :exc:`RecursionError` is set and a nonzero value is returned.
Otherwise, zero is returned. Otherwise, zero is returned.

View file

@ -487,18 +487,19 @@ PyAPI_FUNC(void) _PyTrash_thread_destroy_chain(PyThreadState *tstate);
* we have headroom above the trigger limit */ * we have headroom above the trigger limit */
#define Py_TRASHCAN_HEADROOM 50 #define Py_TRASHCAN_HEADROOM 50
/* Helper function for Py_TRASHCAN_BEGIN */
PyAPI_FUNC(int) _Py_ReachedRecursionLimitWithMargin(PyThreadState *tstate, int margin_count);
#define Py_TRASHCAN_BEGIN(op, dealloc) \ #define Py_TRASHCAN_BEGIN(op, dealloc) \
do { \ do { \
PyThreadState *tstate = PyThreadState_Get(); \ PyThreadState *tstate = PyThreadState_Get(); \
if (tstate->c_recursion_remaining <= Py_TRASHCAN_HEADROOM && Py_TYPE(op)->tp_dealloc == (destructor)dealloc) { \ if (_Py_ReachedRecursionLimitWithMargin(tstate, 1) && Py_TYPE(op)->tp_dealloc == (destructor)dealloc) { \
_PyTrash_thread_deposit_object(tstate, (PyObject *)op); \ _PyTrash_thread_deposit_object(tstate, (PyObject *)op); \
break; \ break; \
} \ }
tstate->c_recursion_remaining--;
/* The body of the deallocator is here. */ /* The body of the deallocator is here. */
#define Py_TRASHCAN_END \ #define Py_TRASHCAN_END \
tstate->c_recursion_remaining++; \ if (tstate->delete_later && !_Py_ReachedRecursionLimitWithMargin(tstate, 2)) { \
if (tstate->delete_later && tstate->c_recursion_remaining > (Py_TRASHCAN_HEADROOM*2)) { \
_PyTrash_thread_destroy_chain(tstate); \ _PyTrash_thread_destroy_chain(tstate); \
} \ } \
} while (0); } while (0);

View file

@ -112,7 +112,7 @@ struct _ts {
int py_recursion_remaining; int py_recursion_remaining;
int py_recursion_limit; int py_recursion_limit;
int c_recursion_remaining; int c_recursion_remaining; /* Retained for backwards compatibility. Do not use */
int recursion_headroom; /* Allow 50 more calls to handle any errors. */ int recursion_headroom; /* Allow 50 more calls to handle any errors. */
/* 'tracing' keeps track of the execution depth when tracing/profiling. /* 'tracing' keeps track of the execution depth when tracing/profiling.
@ -202,36 +202,7 @@ struct _ts {
PyObject *threading_local_sentinel; PyObject *threading_local_sentinel;
}; };
#ifdef Py_DEBUG
// A debug build is likely built with low optimization level which implies
// higher stack memory usage than a release build: use a lower limit.
# define Py_C_RECURSION_LIMIT 500
#elif defined(__s390x__)
# define Py_C_RECURSION_LIMIT 800
#elif defined(_WIN32) && defined(_M_ARM64)
# define Py_C_RECURSION_LIMIT 1000
#elif defined(_WIN32)
# define Py_C_RECURSION_LIMIT 3000
#elif defined(__ANDROID__)
// On an ARM64 emulator, API level 34 was OK with 10000, but API level 21
// crashed in test_compiler_recursion_limit.
# define Py_C_RECURSION_LIMIT 3000
#elif defined(_Py_ADDRESS_SANITIZER)
# define Py_C_RECURSION_LIMIT 4000
#elif defined(__sparc__)
// test_descr crashed on sparc64 with >7000 but let's keep a margin of error.
# define Py_C_RECURSION_LIMIT 4000
#elif defined(__wasi__)
// Based on wasmtime 16.
# define Py_C_RECURSION_LIMIT 5000 # define Py_C_RECURSION_LIMIT 5000
#elif defined(__hppa__) || defined(__powerpc64__)
// test_descr crashed with >8000 but let's keep a margin of error.
# define Py_C_RECURSION_LIMIT 5000
#else
// This value is duplicated in Lib/test/support/__init__.py
# define Py_C_RECURSION_LIMIT 10000
#endif
/* other API */ /* other API */
@ -246,7 +217,6 @@ _PyThreadState_UncheckedGet(void)
return PyThreadState_GetUnchecked(); return PyThreadState_GetUnchecked();
} }
// Disable tracing and profiling. // Disable tracing and profiling.
PyAPI_FUNC(void) PyThreadState_EnterTracing(PyThreadState *tstate); PyAPI_FUNC(void) PyThreadState_EnterTracing(PyThreadState *tstate);

View file

@ -193,19 +193,29 @@ extern void _PyEval_DeactivateOpCache(void);
/* --- _Py_EnterRecursiveCall() ----------------------------------------- */ /* --- _Py_EnterRecursiveCall() ----------------------------------------- */
#ifdef USE_STACKCHECK #if !_Py__has_builtin(__builtin_frame_address)
/* With USE_STACKCHECK macro defined, trigger stack checks in static uintptr_t return_pointer_as_int(char* p) {
_Py_CheckRecursiveCall() on every 64th call to _Py_EnterRecursiveCall. */ return (uintptr_t)p;
static inline int _Py_MakeRecCheck(PyThreadState *tstate) {
return (tstate->c_recursion_remaining-- < 0
|| (tstate->c_recursion_remaining & 63) == 0);
}
#else
static inline int _Py_MakeRecCheck(PyThreadState *tstate) {
return tstate->c_recursion_remaining-- < 0;
} }
#endif #endif
static inline uintptr_t
_Py_get_machine_stack_pointer(void) {
#if _Py__has_builtin(__builtin_frame_address)
return (uintptr_t)__builtin_frame_address(0);
#else
char here;
/* Avoid compiler warning about returning stack address */
return return_pointer_as_int(&here);
#endif
}
static inline int _Py_MakeRecCheck(PyThreadState *tstate) {
uintptr_t here_addr = _Py_get_machine_stack_pointer();
_PyThreadStateImpl *_tstate = (_PyThreadStateImpl *)tstate;
return here_addr < _tstate->c_stack_soft_limit;
}
// Export for '_json' shared extension, used via _Py_EnterRecursiveCall() // Export for '_json' shared extension, used via _Py_EnterRecursiveCall()
// static inline function. // static inline function.
PyAPI_FUNC(int) _Py_CheckRecursiveCall( PyAPI_FUNC(int) _Py_CheckRecursiveCall(
@ -220,23 +230,30 @@ static inline int _Py_EnterRecursiveCallTstate(PyThreadState *tstate,
return (_Py_MakeRecCheck(tstate) && _Py_CheckRecursiveCall(tstate, where)); return (_Py_MakeRecCheck(tstate) && _Py_CheckRecursiveCall(tstate, where));
} }
static inline void _Py_EnterRecursiveCallTstateUnchecked(PyThreadState *tstate) {
assert(tstate->c_recursion_remaining > 0);
tstate->c_recursion_remaining--;
}
static inline int _Py_EnterRecursiveCall(const char *where) { static inline int _Py_EnterRecursiveCall(const char *where) {
PyThreadState *tstate = _PyThreadState_GET(); PyThreadState *tstate = _PyThreadState_GET();
return _Py_EnterRecursiveCallTstate(tstate, where); return _Py_EnterRecursiveCallTstate(tstate, where);
} }
static inline void _Py_LeaveRecursiveCallTstate(PyThreadState *tstate) { static inline void _Py_LeaveRecursiveCallTstate(PyThreadState *tstate) {
tstate->c_recursion_remaining++; (void)tstate;
}
PyAPI_FUNC(void) _Py_InitializeRecursionLimits(PyThreadState *tstate);
static inline int _Py_ReachedRecursionLimit(PyThreadState *tstate) {
uintptr_t here_addr = _Py_get_machine_stack_pointer();
_PyThreadStateImpl *_tstate = (_PyThreadStateImpl *)tstate;
if (here_addr > _tstate->c_stack_soft_limit) {
return 0;
}
if (_tstate->c_stack_hard_limit == 0) {
_Py_InitializeRecursionLimits(tstate);
}
return here_addr <= _tstate->c_stack_soft_limit;
} }
static inline void _Py_LeaveRecursiveCall(void) { static inline void _Py_LeaveRecursiveCall(void) {
PyThreadState *tstate = _PyThreadState_GET();
_Py_LeaveRecursiveCallTstate(tstate);
} }
extern struct _PyInterpreterFrame* _PyEval_GetFrame(void); extern struct _PyInterpreterFrame* _PyEval_GetFrame(void);
@ -327,7 +344,6 @@ void _Py_unset_eval_breaker_bit_all(PyInterpreterState *interp, uintptr_t bit);
PyAPI_FUNC(PyObject *) _PyFloat_FromDouble_ConsumeInputs(_PyStackRef left, _PyStackRef right, double value); PyAPI_FUNC(PyObject *) _PyFloat_FromDouble_ConsumeInputs(_PyStackRef left, _PyStackRef right, double value);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -82,8 +82,6 @@ struct symtable {
PyObject *st_private; /* name of current class or NULL */ PyObject *st_private; /* name of current class or NULL */
_PyFutureFeatures *st_future; /* module's future features that affect _PyFutureFeatures *st_future; /* module's future features that affect
the symbol table */ the symbol table */
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
}; };
typedef struct _symtable_entry { typedef struct _symtable_entry {

View file

@ -21,6 +21,11 @@ typedef struct _PyThreadStateImpl {
// semi-public fields are in PyThreadState. // semi-public fields are in PyThreadState.
PyThreadState base; PyThreadState base;
// These are addresses, but we need to convert to ints to avoid UB.
uintptr_t c_stack_top;
uintptr_t c_stack_soft_limit;
uintptr_t c_stack_hard_limit;
PyObject *asyncio_running_loop; // Strong reference PyObject *asyncio_running_loop; // Strong reference
PyObject *asyncio_running_task; // Strong reference PyObject *asyncio_running_task; // Strong reference

View file

@ -21,14 +21,23 @@ PyAPI_FUNC(void) PyErr_DisplayException(PyObject *);
/* Stuff with no proper home (yet) */ /* Stuff with no proper home (yet) */
PyAPI_DATA(int) (*PyOS_InputHook)(void); PyAPI_DATA(int) (*PyOS_InputHook)(void);
/* Stack size, in "pointers" (so we get extra safety margins /* Stack size, in "pointers". This must be large enough, so
on 64-bit platforms). On a 32-bit platform, this translates * no two calls to check recursion depth are more than this far
to an 8k margin. */ * apart. In practice, that means it must be larger than the C
#define PYOS_STACK_MARGIN 2048 * stack consumption of PyEval_EvalDefault */
#if defined(_Py_ADDRESS_SANITIZER) || defined(_Py_THREAD_SANITIZER)
# define PYOS_STACK_MARGIN 4096
#elif defined(Py_DEBUG) && defined(WIN32)
# define PYOS_STACK_MARGIN 3072
#elif defined(__wasi__)
/* Web assembly has two stacks, so this isn't really a size */
# define PYOS_STACK_MARGIN 500
#else
# define PYOS_STACK_MARGIN 2048
#endif
#define PYOS_STACK_MARGIN_BYTES (PYOS_STACK_MARGIN * sizeof(void *))
#if defined(WIN32) && !defined(MS_WIN64) && !defined(_M_ARM) && defined(_MSC_VER) && _MSC_VER >= 1300 #if defined(WIN32)
/* Enable stack checking under Microsoft C */
// When changing the platforms, ensure PyOS_CheckStack() docs are still correct
#define USE_STACKCHECK #define USE_STACKCHECK
#endif #endif

View file

@ -6,7 +6,8 @@ import sys
from functools import cmp_to_key from functools import cmp_to_key
from test import seq_tests from test import seq_tests
from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit, skip_emscripten_stack_overflow from test.support import ALWAYS_EQ, NEVER_EQ
from test.support import skip_emscripten_stack_overflow, skip_wasi_stack_overflow
class CommonTest(seq_tests.CommonTest): class CommonTest(seq_tests.CommonTest):
@ -59,10 +60,11 @@ class CommonTest(seq_tests.CommonTest):
self.assertEqual(str(a2), "[0, 1, 2, [...], 3]") self.assertEqual(str(a2), "[0, 1, 2, [...], 3]")
self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]") self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]")
@skip_wasi_stack_overflow()
@skip_emscripten_stack_overflow() @skip_emscripten_stack_overflow()
def test_repr_deep(self): def test_repr_deep(self):
a = self.type2test([]) a = self.type2test([])
for i in range(get_c_recursion_limit() + 1): for i in range(200_000):
a = self.type2test([a]) a = self.type2test([a])
self.assertRaises(RecursionError, repr, a) self.assertRaises(RecursionError, repr, a)

View file

@ -1,7 +1,7 @@
# tests common to dict and UserDict # tests common to dict and UserDict
import unittest import unittest
import collections import collections
from test.support import get_c_recursion_limit, skip_emscripten_stack_overflow from test import support
class BasicTestMappingProtocol(unittest.TestCase): class BasicTestMappingProtocol(unittest.TestCase):
@ -622,10 +622,11 @@ class TestHashMappingProtocol(TestMappingProtocol):
d = self._full_mapping({1: BadRepr()}) d = self._full_mapping({1: BadRepr()})
self.assertRaises(Exc, repr, d) self.assertRaises(Exc, repr, d)
@skip_emscripten_stack_overflow() @support.skip_wasi_stack_overflow()
@support.skip_emscripten_stack_overflow()
def test_repr_deep(self): def test_repr_deep(self):
d = self._empty_mapping() d = self._empty_mapping()
for i in range(get_c_recursion_limit() + 1): for i in range(support.exceeds_recursion_limit()):
d0 = d d0 = d
d = self._empty_mapping() d = self._empty_mapping()
d[1] = d0 d[1] = d0

View file

@ -684,7 +684,6 @@ def collect_testcapi(info_add):
for name in ( for name in (
'LONG_MAX', # always 32-bit on Windows, 64-bit on 64-bit Unix 'LONG_MAX', # always 32-bit on Windows, 64-bit on 64-bit Unix
'PY_SSIZE_T_MAX', 'PY_SSIZE_T_MAX',
'Py_C_RECURSION_LIMIT',
'SIZEOF_TIME_T', # 32-bit or 64-bit depending on the platform 'SIZEOF_TIME_T', # 32-bit or 64-bit depending on the platform
'SIZEOF_WCHAR_T', # 16-bit or 32-bit depending on the platform 'SIZEOF_WCHAR_T', # 16-bit or 32-bit depending on the platform
): ):

View file

@ -56,8 +56,7 @@ __all__ = [
"run_with_tz", "PGO", "missing_compiler_executable", "run_with_tz", "PGO", "missing_compiler_executable",
"ALWAYS_EQ", "NEVER_EQ", "LARGEST", "SMALLEST", "ALWAYS_EQ", "NEVER_EQ", "LARGEST", "SMALLEST",
"LOOPBACK_TIMEOUT", "INTERNET_TIMEOUT", "SHORT_TIMEOUT", "LONG_TIMEOUT", "LOOPBACK_TIMEOUT", "INTERNET_TIMEOUT", "SHORT_TIMEOUT", "LONG_TIMEOUT",
"Py_DEBUG", "exceeds_recursion_limit", "get_c_recursion_limit", "Py_DEBUG", "exceeds_recursion_limit", "skip_on_s390x",
"skip_on_s390x",
"requires_jit_enabled", "requires_jit_enabled",
"requires_jit_disabled", "requires_jit_disabled",
"force_not_colorized", "force_not_colorized",
@ -558,6 +557,9 @@ is_wasi = sys.platform == "wasi"
def skip_emscripten_stack_overflow(): def skip_emscripten_stack_overflow():
return unittest.skipIf(is_emscripten, "Exhausts limited stack on Emscripten") return unittest.skipIf(is_emscripten, "Exhausts limited stack on Emscripten")
def skip_wasi_stack_overflow():
return unittest.skipIf(is_wasi, "Exhausts stack on WASI")
is_apple_mobile = sys.platform in {"ios", "tvos", "watchos"} is_apple_mobile = sys.platform in {"ios", "tvos", "watchos"}
is_apple = is_apple_mobile or sys.platform == "darwin" is_apple = is_apple_mobile or sys.platform == "darwin"
@ -2624,17 +2626,9 @@ def adjust_int_max_str_digits(max_digits):
sys.set_int_max_str_digits(current) sys.set_int_max_str_digits(current)
def get_c_recursion_limit():
try:
import _testcapi
return _testcapi.Py_C_RECURSION_LIMIT
except ImportError:
raise unittest.SkipTest('requires _testcapi')
def exceeds_recursion_limit(): def exceeds_recursion_limit():
"""For recursion tests, easily exceeds default recursion limit.""" """For recursion tests, easily exceeds default recursion limit."""
return get_c_recursion_limit() * 3 return 150_000
# Windows doesn't have os.uname() but it doesn't support s390x. # Windows doesn't have os.uname() but it doesn't support s390x.

View file

@ -18,7 +18,8 @@ except ImportError:
_testinternalcapi = None _testinternalcapi = None
from test import support from test import support
from test.support import os_helper, script_helper, skip_emscripten_stack_overflow from test.support import os_helper, script_helper
from test.support import skip_emscripten_stack_overflow, skip_wasi_stack_overflow
from test.support.ast_helper import ASTTestMixin from test.support.ast_helper import ASTTestMixin
from test.test_ast.utils import to_tuple from test.test_ast.utils import to_tuple
from test.test_ast.snippets import ( from test.test_ast.snippets import (
@ -750,11 +751,11 @@ class AST_Tests(unittest.TestCase):
enum._test_simple_enum(_Precedence, ast._Precedence) enum._test_simple_enum(_Precedence, ast._Precedence)
@support.cpython_only @support.cpython_only
@skip_wasi_stack_overflow()
@skip_emscripten_stack_overflow() @skip_emscripten_stack_overflow()
def test_ast_recursion_limit(self): def test_ast_recursion_limit(self):
fail_depth = support.exceeds_recursion_limit() crash_depth = 500_000
crash_depth = 100_000 success_depth = 200
success_depth = int(support.get_c_recursion_limit() * 0.8)
if _testinternalcapi is not None: if _testinternalcapi is not None:
remaining = _testinternalcapi.get_c_recursion_remaining() remaining = _testinternalcapi.get_c_recursion_remaining()
success_depth = min(success_depth, remaining) success_depth = min(success_depth, remaining)
@ -762,10 +763,10 @@ class AST_Tests(unittest.TestCase):
def check_limit(prefix, repeated): def check_limit(prefix, repeated):
expect_ok = prefix + repeated * success_depth expect_ok = prefix + repeated * success_depth
ast.parse(expect_ok) ast.parse(expect_ok)
for depth in (fail_depth, crash_depth):
broken = prefix + repeated * depth broken = prefix + repeated * crash_depth
details = "Compiling ({!r} + {!r} * {})".format( details = "Compiling ({!r} + {!r} * {})".format(
prefix, repeated, depth) prefix, repeated, crash_depth)
with self.assertRaises(RecursionError, msg=details): with self.assertRaises(RecursionError, msg=details):
with support.infinite_recursion(): with support.infinite_recursion():
ast.parse(broken) ast.parse(broken)

View file

@ -1052,6 +1052,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
f2 = filter(filter_char, "abcdeabcde") f2 = filter(filter_char, "abcdeabcde")
self.check_iter_pickle(f1, list(f2), proto) self.check_iter_pickle(f1, list(f2), proto)
@support.skip_wasi_stack_overflow()
@support.requires_resource('cpu') @support.requires_resource('cpu')
def test_filter_dealloc(self): def test_filter_dealloc(self):
# Tests recursive deallocation of nested filter objects using the # Tests recursive deallocation of nested filter objects using the

View file

@ -1,6 +1,6 @@
import unittest import unittest
from test.support import (cpython_only, is_wasi, requires_limited_api, Py_DEBUG, from test.support import (cpython_only, is_wasi, requires_limited_api, Py_DEBUG,
set_recursion_limit, skip_on_s390x, skip_emscripten_stack_overflow, set_recursion_limit, skip_on_s390x, exceeds_recursion_limit, skip_emscripten_stack_overflow, skip_wasi_stack_overflow,
skip_if_sanitizer, import_helper) skip_if_sanitizer, import_helper)
try: try:
import _testcapi import _testcapi
@ -1040,6 +1040,7 @@ class TestRecursion(unittest.TestCase):
@skip_if_sanitizer("requires deep stack", thread=True) @skip_if_sanitizer("requires deep stack", thread=True)
@unittest.skipIf(_testcapi is None, "requires _testcapi") @unittest.skipIf(_testcapi is None, "requires _testcapi")
@skip_emscripten_stack_overflow() @skip_emscripten_stack_overflow()
@skip_wasi_stack_overflow()
def test_super_deep(self): def test_super_deep(self):
def recurse(n): def recurse(n):
@ -1064,10 +1065,10 @@ class TestRecursion(unittest.TestCase):
recurse(90_000) recurse(90_000)
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
recurse(101_000) recurse(101_000)
c_recurse(100) c_recurse(50)
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
c_recurse(90_000) c_recurse(90_000)
c_py_recurse(90) c_py_recurse(50)
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
c_py_recurse(100_000) c_py_recurse(100_000)

View file

@ -408,7 +408,7 @@ class CAPITest(unittest.TestCase):
# activated when its tp_dealloc is being called by a subclass # activated when its tp_dealloc is being called by a subclass
from _testcapi import MyList from _testcapi import MyList
L = None L = None
for i in range(1000): for i in range(100):
L = MyList((L,)) L = MyList((L,))
@support.requires_resource('cpu') @support.requires_resource('cpu')

View file

@ -2,7 +2,7 @@
import unittest import unittest
from test import support from test import support
from test.support import cpython_only, import_helper, script_helper, skip_emscripten_stack_overflow from test.support import cpython_only, import_helper, script_helper
testmeths = [ testmeths = [
@ -556,7 +556,8 @@ class ClassTests(unittest.TestCase):
self.assertFalse(hasattr(o, "__call__")) self.assertFalse(hasattr(o, "__call__"))
self.assertFalse(hasattr(c, "__call__")) self.assertFalse(hasattr(c, "__call__"))
@skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
@support.skip_wasi_stack_overflow()
def testSFBug532646(self): def testSFBug532646(self):
# Test for SF bug 532646 # Test for SF bug 532646

View file

@ -21,7 +21,7 @@ except ImportError:
from test import support from test import support
from test.support import (script_helper, requires_debug_ranges, run_code, from test.support import (script_helper, requires_debug_ranges, run_code,
requires_specialization, get_c_recursion_limit) requires_specialization)
from test.support.bytecode_helper import instructions_with_positions from test.support.bytecode_helper import instructions_with_positions
from test.support.os_helper import FakePath from test.support.os_helper import FakePath
@ -123,7 +123,7 @@ class TestSpecifics(unittest.TestCase):
@unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI") @unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI")
@support.skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
def test_extended_arg(self): def test_extended_arg(self):
repeat = int(get_c_recursion_limit() * 0.9) repeat = 100
longexpr = 'x = x or ' + '-x' * repeat longexpr = 'x = x or ' + '-x' * repeat
g = {} g = {}
code = textwrap.dedent(''' code = textwrap.dedent('''
@ -712,19 +712,16 @@ class TestSpecifics(unittest.TestCase):
@unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI") @unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI")
@support.skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
def test_compiler_recursion_limit(self): def test_compiler_recursion_limit(self):
# Expected limit is Py_C_RECURSION_LIMIT # Compiler frames are small
limit = get_c_recursion_limit() limit = 100
fail_depth = limit + 1 crash_depth = limit * 5000
crash_depth = limit * 100 success_depth = limit
success_depth = int(limit * 0.8)
def check_limit(prefix, repeated, mode="single"): def check_limit(prefix, repeated, mode="single"):
expect_ok = prefix + repeated * success_depth expect_ok = prefix + repeated * success_depth
compile(expect_ok, '<test>', mode) compile(expect_ok, '<test>', mode)
for depth in (fail_depth, crash_depth): broken = prefix + repeated * crash_depth
broken = prefix + repeated * depth details = f"Compiling ({prefix!r} + {repeated!r} * {crash_depth})"
details = "Compiling ({!r} + {!r} * {})".format(
prefix, repeated, depth)
with self.assertRaises(RecursionError, msg=details): with self.assertRaises(RecursionError, msg=details):
compile(broken, '<test>', mode) compile(broken, '<test>', mode)

View file

@ -3658,6 +3658,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
encoding='latin1', errors='replace') encoding='latin1', errors='replace')
self.assertEqual(ba, b'abc\xbd?') self.assertEqual(ba, b'abc\xbd?')
@support.skip_wasi_stack_overflow()
@support.skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
def test_recursive_call(self): def test_recursive_call(self):
# Testing recursive __call__() by setting to instance of class... # Testing recursive __call__() by setting to instance of class...
@ -4518,6 +4519,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
o.whatever = Provoker(o) o.whatever = Provoker(o)
del o del o
@support.skip_wasi_stack_overflow()
@support.requires_resource('cpu') @support.requires_resource('cpu')
def test_wrapper_segfault(self): def test_wrapper_segfault(self):
# SF 927248: deeply nested wrappers could cause stack overflow # SF 927248: deeply nested wrappers could cause stack overflow

View file

@ -8,7 +8,7 @@ import sys
import unittest import unittest
import weakref import weakref
from test import support from test import support
from test.support import import_helper, get_c_recursion_limit from test.support import import_helper
class DictTest(unittest.TestCase): class DictTest(unittest.TestCase):
@ -594,10 +594,11 @@ class DictTest(unittest.TestCase):
d = {1: BadRepr()} d = {1: BadRepr()}
self.assertRaises(Exc, repr, d) self.assertRaises(Exc, repr, d)
@support.skip_wasi_stack_overflow()
@support.skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
def test_repr_deep(self): def test_repr_deep(self):
d = {} d = {}
for i in range(get_c_recursion_limit() + 1): for i in range(support.exceeds_recursion_limit()):
d = {1: d} d = {1: d}
self.assertRaises(RecursionError, repr, d) self.assertRaises(RecursionError, repr, d)

View file

@ -2,7 +2,7 @@ import collections.abc
import copy import copy
import pickle import pickle
import unittest import unittest
from test.support import get_c_recursion_limit, skip_emscripten_stack_overflow from test.support import skip_emscripten_stack_overflow, skip_wasi_stack_overflow, exceeds_recursion_limit
class DictSetTest(unittest.TestCase): class DictSetTest(unittest.TestCase):
@ -277,10 +277,11 @@ class DictSetTest(unittest.TestCase):
# Again. # Again.
self.assertIsInstance(r, str) self.assertIsInstance(r, str)
@skip_wasi_stack_overflow()
@skip_emscripten_stack_overflow() @skip_emscripten_stack_overflow()
def test_deeply_nested_repr(self): def test_deeply_nested_repr(self):
d = {} d = {}
for i in range(get_c_recursion_limit()//2 + 100): for i in range(exceeds_recursion_limit()):
d = {42: d.values()} d = {42: d.values()}
self.assertRaises(RecursionError, repr, d) self.assertRaises(RecursionError, repr, d)

View file

@ -4,7 +4,7 @@ import builtins
import sys import sys
import unittest import unittest
from test.support import swap_item, swap_attr, is_wasi, Py_DEBUG from test.support import swap_item, swap_attr, skip_wasi_stack_overflow, Py_DEBUG
class RebindBuiltinsTests(unittest.TestCase): class RebindBuiltinsTests(unittest.TestCase):
@ -134,7 +134,8 @@ class RebindBuiltinsTests(unittest.TestCase):
self.assertEqual(foo(), 7) self.assertEqual(foo(), 7)
@unittest.skipIf(is_wasi and Py_DEBUG, "requires too much stack")
@skip_wasi_stack_overflow()
def test_load_global_specialization_failure_keeps_oparg(self): def test_load_global_specialization_failure_keeps_oparg(self):
# https://github.com/python/cpython/issues/91625 # https://github.com/python/cpython/issues/91625
class MyGlobals(dict): class MyGlobals(dict):

View file

@ -1,7 +1,7 @@
import collections.abc import collections.abc
import types import types
import unittest import unittest
from test.support import get_c_recursion_limit, skip_emscripten_stack_overflow from test.support import skip_emscripten_stack_overflow, exceeds_recursion_limit
class TestExceptionGroupTypeHierarchy(unittest.TestCase): class TestExceptionGroupTypeHierarchy(unittest.TestCase):
def test_exception_group_types(self): def test_exception_group_types(self):
@ -460,7 +460,7 @@ class ExceptionGroupSplitTests(ExceptionGroupTestBase):
class DeepRecursionInSplitAndSubgroup(unittest.TestCase): class DeepRecursionInSplitAndSubgroup(unittest.TestCase):
def make_deep_eg(self): def make_deep_eg(self):
e = TypeError(1) e = TypeError(1)
for i in range(get_c_recursion_limit() + 1): for i in range(exceeds_recursion_limit()):
e = ExceptionGroup('eg', [e]) e = ExceptionGroup('eg', [e])
return e return e

View file

@ -1391,7 +1391,7 @@ class ExceptionTests(unittest.TestCase):
self.assertIsInstance(exc, RecursionError, type(exc)) self.assertIsInstance(exc, RecursionError, type(exc))
self.assertIn("maximum recursion depth exceeded", str(exc)) self.assertIn("maximum recursion depth exceeded", str(exc))
@support.skip_wasi_stack_overflow()
@cpython_only @cpython_only
@support.requires_resource('cpu') @support.requires_resource('cpu')
def test_trashcan_recursion(self): def test_trashcan_recursion(self):
@ -1479,7 +1479,7 @@ class ExceptionTests(unittest.TestCase):
""" """
rc, out, err = script_helper.assert_python_failure("-c", code) rc, out, err = script_helper.assert_python_failure("-c", code)
self.assertEqual(rc, 1) self.assertEqual(rc, 1)
expected = b'RecursionError: maximum recursion depth exceeded' expected = b'RecursionError'
self.assertTrue(expected in err, msg=f"{expected!r} not found in {err[:3_000]!r}... (truncated)") self.assertTrue(expected in err, msg=f"{expected!r} not found in {err[:3_000]!r}... (truncated)")
self.assertIn(b'Done.', out) self.assertIn(b'Done.', out)

View file

@ -628,13 +628,23 @@ x = (
r"does not match opening parenthesis '\('", r"does not match opening parenthesis '\('",
["f'{a(4}'", ["f'{a(4}'",
]) ])
self.assertRaises(SyntaxError, eval, "f'{" + "("*500 + "}'") self.assertRaises(SyntaxError, eval, "f'{" + "("*20 + "}'")
@unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI") @unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI")
def test_fstring_nested_too_deeply(self): def test_fstring_nested_too_deeply(self):
self.assertAllRaise(SyntaxError, def raises_syntax_or_memory_error(txt):
"f-string: expressions nested too deeply", try:
['f"{1+2:{1+2:{1+1:{1}}}}"']) eval(txt)
except SyntaxError:
pass
except MemoryError:
pass
except Exception as ex:
self.fail(f"Should raise SyntaxError or MemoryError, not {type(ex)}")
else:
self.fail("No exception raised")
raises_syntax_or_memory_error('f"{1+2:{1+2:{1+1:{1}}}}"')
def create_nested_fstring(n): def create_nested_fstring(n):
if n == 0: if n == 0:
@ -642,9 +652,10 @@ x = (
prev = create_nested_fstring(n-1) prev = create_nested_fstring(n-1)
return f'f"{{{prev}}}"' return f'f"{{{prev}}}"'
self.assertAllRaise(SyntaxError, raises_syntax_or_memory_error(create_nested_fstring(160))
"too many nested f-strings", raises_syntax_or_memory_error("f'{" + "("*100 + "}'")
[create_nested_fstring(160)]) raises_syntax_or_memory_error("f'{" + "("*1000 + "}'")
raises_syntax_or_memory_error("f'{" + "("*10_000 + "}'")
def test_syntax_error_in_nested_fstring(self): def test_syntax_error_in_nested_fstring(self):
# See gh-104016 for more information on this crash # See gh-104016 for more information on this crash

View file

@ -404,6 +404,7 @@ class TestPartial:
self.assertEqual(r, ((1, 2), {})) self.assertEqual(r, ((1, 2), {}))
self.assertIs(type(r[0]), tuple) self.assertIs(type(r[0]), tuple)
@support.skip_if_sanitizer("thread sanitizer crashes in __tsan::FuncEntry", thread=True)
@support.skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
def test_recursive_pickle(self): def test_recursive_pickle(self):
with replaced_module('functools', self.module): with replaced_module('functools', self.module):
@ -2087,15 +2088,12 @@ class TestLRU:
return n return n
return fib(n-1) + fib(n-2) return fib(n-1) + fib(n-2)
if not support.Py_DEBUG: fib(100)
depth = support.get_c_recursion_limit()*2//7
with support.infinite_recursion():
fib(depth)
if self.module == c_functools: if self.module == c_functools:
fib.cache_clear() fib.cache_clear()
with support.infinite_recursion(): with support.infinite_recursion():
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
fib(10000) fib(support.exceeds_recursion_limit())
@py_functools.lru_cache() @py_functools.lru_cache()

View file

@ -263,18 +263,18 @@ class TestIsInstanceIsSubclass(unittest.TestCase):
self.assertEqual(True, issubclass(int, (int, (float, int)))) self.assertEqual(True, issubclass(int, (int, (float, int))))
self.assertEqual(True, issubclass(str, (str, (Child, str)))) self.assertEqual(True, issubclass(str, (str, (Child, str))))
@support.skip_wasi_stack_overflow()
@support.skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
def test_subclass_recursion_limit(self): def test_subclass_recursion_limit(self):
# make sure that issubclass raises RecursionError before the C stack is # make sure that issubclass raises RecursionError before the C stack is
# blown # blown
with support.infinite_recursion():
self.assertRaises(RecursionError, blowstack, issubclass, str, str) self.assertRaises(RecursionError, blowstack, issubclass, str, str)
@support.skip_wasi_stack_overflow()
@support.skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
def test_isinstance_recursion_limit(self): def test_isinstance_recursion_limit(self):
# make sure that issubclass raises RecursionError before the C stack is # make sure that issubclass raises RecursionError before the C stack is
# blown # blown
with support.infinite_recursion():
self.assertRaises(RecursionError, blowstack, isinstance, '', str) self.assertRaises(RecursionError, blowstack, isinstance, '', str)
def test_subclass_with_union(self): def test_subclass_with_union(self):
@ -355,7 +355,8 @@ def blowstack(fxn, arg, compare_to):
# Make sure that calling isinstance with a deeply nested tuple for its # Make sure that calling isinstance with a deeply nested tuple for its
# argument will raise RecursionError eventually. # argument will raise RecursionError eventually.
tuple_arg = (compare_to,) tuple_arg = (compare_to,)
for cnt in range(support.exceeds_recursion_limit()): while True:
for _ in range(100):
tuple_arg = (tuple_arg,) tuple_arg = (tuple_arg,)
fxn(arg, tuple_arg) fxn(arg, tuple_arg)

View file

@ -70,23 +70,25 @@ class TestRecursion:
@support.skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
def test_highly_nested_objects_decoding(self): def test_highly_nested_objects_decoding(self):
very_deep = 200000
# test that loading highly-nested objects doesn't segfault when C # test that loading highly-nested objects doesn't segfault when C
# accelerations are used. See #12017 # accelerations are used. See #12017
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
with support.infinite_recursion(): with support.infinite_recursion():
self.loads('{"a":' * 100000 + '1' + '}' * 100000) self.loads('{"a":' * very_deep + '1' + '}' * very_deep)
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
with support.infinite_recursion(): with support.infinite_recursion():
self.loads('{"a":' * 100000 + '[1]' + '}' * 100000) self.loads('{"a":' * very_deep + '[1]' + '}' * very_deep)
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
with support.infinite_recursion(): with support.infinite_recursion():
self.loads('[' * 100000 + '1' + ']' * 100000) self.loads('[' * very_deep + '1' + ']' * very_deep)
@support.skip_wasi_stack_overflow()
@support.skip_emscripten_stack_overflow() @support.skip_emscripten_stack_overflow()
def test_highly_nested_objects_encoding(self): def test_highly_nested_objects_encoding(self):
# See #12051 # See #12051
l, d = [], {} l, d = [], {}
for x in range(100000): for x in range(200_000):
l, d = [l], {'k':d} l, d = [l], {'k':d}
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
with support.infinite_recursion(5000): with support.infinite_recursion(5000):

View file

@ -125,8 +125,7 @@ class CodeTestCase(unittest.TestCase):
def test_many_codeobjects(self): def test_many_codeobjects(self):
# Issue2957: bad recursion count on code objects # Issue2957: bad recursion count on code objects
# more than MAX_MARSHAL_STACK_DEPTH # more than MAX_MARSHAL_STACK_DEPTH
count = support.exceeds_recursion_limit() codes = (ExceptionTestCase.test_exceptions.__code__,) * 10_000
codes = (ExceptionTestCase.test_exceptions.__code__,) * count
marshal.loads(marshal.dumps(codes)) marshal.loads(marshal.dumps(codes))
def test_different_filenames(self): def test_different_filenames(self):

View file

@ -6,6 +6,7 @@ import enum
import inspect import inspect
import sys import sys
import unittest import unittest
from test import support
@dataclasses.dataclass @dataclasses.dataclass
@ -3498,6 +3499,7 @@ class TestTracing(unittest.TestCase):
self.assertListEqual(self._trace(f, 1), [1, 2, 3]) self.assertListEqual(self._trace(f, 1), [1, 2, 3])
self.assertListEqual(self._trace(f, 0), [1, 2, 5, 6]) self.assertListEqual(self._trace(f, 0), [1, 2, 5, 6])
@support.skip_wasi_stack_overflow()
def test_parser_deeply_nested_patterns(self): def test_parser_deeply_nested_patterns(self):
# Deeply nested patterns can cause exponential backtracking when parsing. # Deeply nested patterns can cause exponential backtracking when parsing.
# See gh-93671 for more information. # See gh-93671 for more information.

View file

@ -1090,6 +1090,9 @@ class SysModuleTest(unittest.TestCase):
# about the underlying implementation: the function might # about the underlying implementation: the function might
# return 0 or something greater. # return 0 or something greater.
self.assertGreaterEqual(a, 0) self.assertGreaterEqual(a, 0)
gc.collect()
b = sys.getallocatedblocks()
self.assertLessEqual(b, a)
try: try:
# While we could imagine a Python session where the number of # While we could imagine a Python session where the number of
# multiple buffer objects would exceed the sharing of references, # multiple buffer objects would exceed the sharing of references,
@ -1112,9 +1115,6 @@ class SysModuleTest(unittest.TestCase):
# gettotalrefcount() not available # gettotalrefcount() not available
pass pass
gc.collect() gc.collect()
b = sys.getallocatedblocks()
self.assertLessEqual(b, a)
gc.collect()
c = sys.getallocatedblocks() c = sys.getallocatedblocks()
self.assertIn(c, range(b - 50, b + 50)) self.assertIn(c, range(b - 50, b + 50))

View file

@ -3035,18 +3035,18 @@ class TestExtendedArgs(unittest.TestCase):
def test_trace_lots_of_globals(self): def test_trace_lots_of_globals(self):
count = min(1000, int(support.get_c_recursion_limit() * 0.8)) count = 1000
code = """if 1: code = """if 1:
def f(): def f():
return ( return (
{} {}
) )
""".format("\n+\n".join(f"var{i}\n" for i in range(count))) """.format("\n,\n".join(f"var{i}\n" for i in range(count)))
ns = {f"var{i}": i for i in range(count)} ns = {f"var{i}": i for i in range(count)}
exec(code, ns) exec(code, ns)
counts = self.count_traces(ns["f"]) counts = self.count_traces(ns["f"])
self.assertEqual(counts, {'call': 1, 'line': count * 2, 'return': 1}) self.assertEqual(counts, {'call': 1, 'line': count * 2 + 1, 'return': 1})
class TestEdgeCases(unittest.TestCase): class TestEdgeCases(unittest.TestCase):

View file

@ -3040,6 +3040,7 @@ async def f():
with self.subTest(case=case): with self.subTest(case=case):
self.assertRaises(tokenize.TokenError, get_tokens, case) self.assertRaises(tokenize.TokenError, get_tokens, case)
@support.skip_wasi_stack_overflow()
def test_max_indent(self): def test_max_indent(self):
MAXINDENT = 100 MAXINDENT = 100

View file

@ -213,11 +213,7 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
else: else:
self.fail("g[42] didn't raise KeyError") self.fail("g[42] didn't raise KeyError")
# Decorate existing test with recursion limit, because test_repr_deep = mapping_tests.TestHashMappingProtocol.test_repr_deep
# the test is for C structure, but `UserDict` is a Python structure.
test_repr_deep = support.infinite_recursion(25)(
mapping_tests.TestHashMappingProtocol.test_repr_deep,
)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -69,9 +69,7 @@ class UserListTest(list_tests.CommonTest):
# Decorate existing test with recursion limit, because # Decorate existing test with recursion limit, because
# the test is for C structure, but `UserList` is a Python structure. # the test is for C structure, but `UserList` is a Python structure.
test_repr_deep = support.infinite_recursion(25)( test_repr_deep = list_tests.CommonTest.test_repr_deep
list_tests.CommonTest.test_repr_deep,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -57,6 +57,7 @@ class MiscTests(unittest.TestCase):
del element.attrib del element.attrib
self.assertEqual(element.attrib, {'A': 'B', 'C': 'D'}) self.assertEqual(element.attrib, {'A': 'B', 'C': 'D'})
@support.skip_wasi_stack_overflow()
@unittest.skipIf(support.is_emscripten, "segfaults") @unittest.skipIf(support.is_emscripten, "segfaults")
def test_trashcan(self): def test_trashcan(self):
# If this test fails, it will most likely die via segfault. # If this test fails, it will most likely die via segfault.

View file

@ -0,0 +1,3 @@
Change C stack overflow protection to consider the amount of stack consumed,
rather than a counter. This allows deeper recursion in many cases, but
remains safe.

View file

@ -0,0 +1,3 @@
Use actual stack limits (from :manpage:`pthread_getattr_np(3)`) for linux, and other
systems with ``_GNU_SOURCE`` defined, when determining limits for C stack
protection.

View file

@ -3239,7 +3239,6 @@ PyInit__testcapi(void)
PyModule_AddObject(m, "instancemethod", (PyObject *)&PyInstanceMethod_Type); PyModule_AddObject(m, "instancemethod", (PyObject *)&PyInstanceMethod_Type);
PyModule_AddIntConstant(m, "the_number_three", 3); PyModule_AddIntConstant(m, "the_number_three", 3);
PyModule_AddIntMacro(m, Py_C_RECURSION_LIMIT);
PyModule_AddObject(m, "INT32_MIN", PyLong_FromInt32(INT32_MIN)); PyModule_AddObject(m, "INT32_MIN", PyLong_FromInt32(INT32_MIN));
PyModule_AddObject(m, "INT32_MAX", PyLong_FromInt32(INT32_MAX)); PyModule_AddObject(m, "INT32_MAX", PyLong_FromInt32(INT32_MAX));
PyModule_AddObject(m, "UINT32_MAX", PyLong_FromUInt32(UINT32_MAX)); PyModule_AddObject(m, "UINT32_MAX", PyLong_FromUInt32(UINT32_MAX));

View file

@ -115,7 +115,10 @@ static PyObject*
get_c_recursion_remaining(PyObject *self, PyObject *Py_UNUSED(args)) get_c_recursion_remaining(PyObject *self, PyObject *Py_UNUSED(args))
{ {
PyThreadState *tstate = _PyThreadState_GET(); PyThreadState *tstate = _PyThreadState_GET();
return PyLong_FromLong(tstate->c_recursion_remaining); uintptr_t here_addr = _Py_get_machine_stack_pointer();
_PyThreadStateImpl *_tstate = (_PyThreadStateImpl *)tstate;
int remaining = (int)((here_addr - _tstate->c_stack_soft_limit)/PYOS_STACK_MARGIN_BYTES * 50);
return PyLong_FromLong(remaining);
} }

View file

@ -612,12 +612,9 @@ PyObject_Print(PyObject *op, FILE *fp, int flags)
int write_error = 0; int write_error = 0;
if (PyErr_CheckSignals()) if (PyErr_CheckSignals())
return -1; return -1;
#ifdef USE_STACKCHECK if (_Py_EnterRecursiveCall(" printing an object")) {
if (PyOS_CheckStack()) {
PyErr_SetString(PyExc_MemoryError, "stack overflow");
return -1; return -1;
} }
#endif
clearerr(fp); /* Clear any previous error condition */ clearerr(fp); /* Clear any previous error condition */
if (op == NULL) { if (op == NULL) {
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
@ -738,12 +735,6 @@ PyObject_Repr(PyObject *v)
PyObject *res; PyObject *res;
if (PyErr_CheckSignals()) if (PyErr_CheckSignals())
return NULL; return NULL;
#ifdef USE_STACKCHECK
if (PyOS_CheckStack()) {
PyErr_SetString(PyExc_MemoryError, "stack overflow");
return NULL;
}
#endif
if (v == NULL) if (v == NULL)
return PyUnicode_FromString("<NULL>"); return PyUnicode_FromString("<NULL>");
if (Py_TYPE(v)->tp_repr == NULL) if (Py_TYPE(v)->tp_repr == NULL)
@ -786,12 +777,6 @@ PyObject_Str(PyObject *v)
PyObject *res; PyObject *res;
if (PyErr_CheckSignals()) if (PyErr_CheckSignals())
return NULL; return NULL;
#ifdef USE_STACKCHECK
if (PyOS_CheckStack()) {
PyErr_SetString(PyExc_MemoryError, "stack overflow");
return NULL;
}
#endif
if (v == NULL) if (v == NULL)
return PyUnicode_FromString("<NULL>"); return PyUnicode_FromString("<NULL>");
if (PyUnicode_CheckExact(v)) { if (PyUnicode_CheckExact(v)) {
@ -2900,19 +2885,6 @@ _PyTrash_thread_deposit_object(PyThreadState *tstate, PyObject *op)
void void
_PyTrash_thread_destroy_chain(PyThreadState *tstate) _PyTrash_thread_destroy_chain(PyThreadState *tstate)
{ {
/* We need to increase c_recursion_remaining here, otherwise,
_PyTrash_thread_destroy_chain will be called recursively
and then possibly crash. An example that may crash without
increase:
N = 500000 # need to be large enough
ob = object()
tups = [(ob,) for i in range(N)]
for i in range(49):
tups = [(tup,) for tup in tups]
del tups
*/
assert(tstate->c_recursion_remaining > Py_TRASHCAN_HEADROOM);
tstate->c_recursion_remaining--;
while (tstate->delete_later) { while (tstate->delete_later) {
PyObject *op = tstate->delete_later; PyObject *op = tstate->delete_later;
destructor dealloc = Py_TYPE(op)->tp_dealloc; destructor dealloc = Py_TYPE(op)->tp_dealloc;
@ -2934,7 +2906,6 @@ _PyTrash_thread_destroy_chain(PyThreadState *tstate)
_PyObject_ASSERT(op, Py_REFCNT(op) == 0); _PyObject_ASSERT(op, Py_REFCNT(op) == 0);
(*dealloc)(op); (*dealloc)(op);
} }
tstate->c_recursion_remaining++;
} }
void _Py_NO_RETURN void _Py_NO_RETURN

View file

@ -738,7 +738,7 @@ class SequenceConstructorVisitor(EmitVisitor):
class PyTypesDeclareVisitor(PickleVisitor): class PyTypesDeclareVisitor(PickleVisitor):
def visitProduct(self, prod, name): def visitProduct(self, prod, name):
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0) self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
if prod.attributes: if prod.attributes:
self.emit("static const char * const %s_attributes[] = {" % name, 0) self.emit("static const char * const %s_attributes[] = {" % name, 0)
for a in prod.attributes: for a in prod.attributes:
@ -759,7 +759,7 @@ class PyTypesDeclareVisitor(PickleVisitor):
ptype = "void*" ptype = "void*"
if is_simple(sum): if is_simple(sum):
ptype = get_c_type(name) ptype = get_c_type(name)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0) self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
for t in sum.types: for t in sum.types:
self.visitConstructor(t, name) self.visitConstructor(t, name)
@ -1734,8 +1734,8 @@ add_attributes(struct ast_state *state, PyObject *type, const char * const *attr
/* Conversion AST -> Python */ /* Conversion AST -> Python */
static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq, static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq,
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*)) PyObject* (*func)(struct ast_state *state, void*))
{ {
Py_ssize_t i, n = asdl_seq_LEN(seq); Py_ssize_t i, n = asdl_seq_LEN(seq);
PyObject *result = PyList_New(n); PyObject *result = PyList_New(n);
@ -1743,7 +1743,7 @@ static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate,
if (!result) if (!result)
return NULL; return NULL;
for (i = 0; i < n; i++) { for (i = 0; i < n; i++) {
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i)); value = func(state, asdl_seq_GET_UNTYPED(seq, i));
if (!value) { if (!value) {
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
@ -1753,7 +1753,7 @@ static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate,
return result; return result;
} }
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o) static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
{ {
PyObject *op = (PyObject*)o; PyObject *op = (PyObject*)o;
if (!op) { if (!op) {
@ -1765,7 +1765,7 @@ static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct valid
#define ast2obj_identifier ast2obj_object #define ast2obj_identifier ast2obj_object
#define ast2obj_string ast2obj_object #define ast2obj_string ast2obj_object
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b) static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
{ {
return PyLong_FromLong(b); return PyLong_FromLong(b);
} }
@ -2014,7 +2014,7 @@ class ObjVisitor(PickleVisitor):
def func_begin(self, name): def func_begin(self, name):
ctype = get_c_type(name) ctype = get_c_type(name)
self.emit("PyObject*", 0) self.emit("PyObject*", 0)
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0) self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
self.emit("{", 0) self.emit("{", 0)
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1) self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
self.emit("PyObject *result = NULL, *value = NULL;", 1) self.emit("PyObject *result = NULL, *value = NULL;", 1)
@ -2022,17 +2022,15 @@ class ObjVisitor(PickleVisitor):
self.emit('if (!o) {', 1) self.emit('if (!o) {', 1)
self.emit("Py_RETURN_NONE;", 2) self.emit("Py_RETURN_NONE;", 2)
self.emit("}", 1) self.emit("}", 1)
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1) self.emit('if (Py_EnterRecursiveCall("during ast construction")) {', 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit("return NULL;", 2) self.emit("return NULL;", 2)
self.emit("}", 1) self.emit("}", 1)
def func_end(self): def func_end(self):
self.emit("vstate->recursion_depth--;", 1) self.emit("Py_LeaveRecursiveCall();", 1)
self.emit("return result;", 1) self.emit("return result;", 1)
self.emit("failed:", 0) self.emit("failed:", 0)
self.emit("vstate->recursion_depth--;", 1) self.emit("Py_LeaveRecursiveCall();", 1)
self.emit("Py_XDECREF(value);", 1) self.emit("Py_XDECREF(value);", 1)
self.emit("Py_XDECREF(result);", 1) self.emit("Py_XDECREF(result);", 1)
self.emit("return NULL;", 1) self.emit("return NULL;", 1)
@ -2050,7 +2048,7 @@ class ObjVisitor(PickleVisitor):
self.visitConstructor(t, i + 1, name) self.visitConstructor(t, i + 1, name)
self.emit("}", 1) self.emit("}", 1)
for a in sum.attributes: for a in sum.attributes:
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1) self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1) self.emit("if (!value) goto failed;", 1)
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1) self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
self.emit('goto failed;', 2) self.emit('goto failed;', 2)
@ -2058,7 +2056,7 @@ class ObjVisitor(PickleVisitor):
self.func_end() self.func_end()
def simpleSum(self, sum, name): def simpleSum(self, sum, name):
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0) self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
self.emit("{", 0) self.emit("{", 0)
self.emit("switch(o) {", 1) self.emit("switch(o) {", 1)
for t in sum.types: for t in sum.types:
@ -2076,7 +2074,7 @@ class ObjVisitor(PickleVisitor):
for field in prod.fields: for field in prod.fields:
self.visitField(field, name, 1, True) self.visitField(field, name, 1, True)
for a in prod.attributes: for a in prod.attributes:
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1) self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1) self.emit("if (!value) goto failed;", 1)
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1) self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
self.emit('goto failed;', 2) self.emit('goto failed;', 2)
@ -2117,7 +2115,7 @@ class ObjVisitor(PickleVisitor):
self.emit("for(i = 0; i < n; i++)", depth+1) self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling # This cannot fail, so no need for error handling
self.emit( self.emit(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format( "PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
field.type, field.type,
value value
), ),
@ -2126,9 +2124,9 @@ class ObjVisitor(PickleVisitor):
) )
self.emit("}", depth) self.emit("}", depth)
else: else:
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
else: else:
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False) self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
class PartingShots(StaticVisitor): class PartingShots(StaticVisitor):
@ -2140,28 +2138,8 @@ PyObject* PyAST_mod2obj(mod_ty t)
if (state == NULL) { if (state == NULL) {
return NULL; return NULL;
} }
PyObject *result = ast2obj_mod(state, t);
int starting_recursion_depth;
/* Be careful here to prevent overflow. */
PyThreadState *tstate = _PyThreadState_GET();
if (!tstate) {
return NULL;
}
struct validator vstate;
vstate.recursion_limit = Py_C_RECURSION_LIMIT;
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth;
vstate.recursion_depth = starting_recursion_depth;
PyObject *result = ast2obj_mod(state, &vstate, t);
/* Check that the recursion depth counting balanced correctly */
if (result && vstate.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, vstate.recursion_depth);
return NULL;
}
return result; return result;
} }
@ -2305,11 +2283,6 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "structmember.h" #include "structmember.h"
#include <stddef.h> #include <stddef.h>
struct validator {
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};
// Forward declaration // Forward declaration
static int init_types(void *arg); static int init_types(void *arg);

878
Parser/parser.c generated

File diff suppressed because it is too large Load diff

733
Python/Python-ast.c generated

File diff suppressed because it is too large Load diff

View file

@ -9,34 +9,22 @@
#include <assert.h> #include <assert.h>
#include <stdbool.h> #include <stdbool.h>
struct validator { #define ENTER_RECURSIVE() \
int recursion_depth; /* current recursion depth */ if (Py_EnterRecursiveCall(" during compilation")) { \
int recursion_limit; /* recursion limit */
};
#define ENTER_RECURSIVE(ST) \
do { \
if (++(ST)->recursion_depth > (ST)->recursion_limit) { \
PyErr_SetString(PyExc_RecursionError, \
"maximum recursion depth exceeded during compilation"); \
return 0; \ return 0; \
} \ }
} while(0)
#define LEAVE_RECURSIVE(ST) \ #define LEAVE_RECURSIVE() Py_LeaveRecursiveCall();
do { \
--(ST)->recursion_depth; \
} while(0)
static int validate_stmts(struct validator *, asdl_stmt_seq *); static int validate_stmts(asdl_stmt_seq *);
static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int); static int validate_exprs(asdl_expr_seq *, expr_context_ty, int);
static int validate_patterns(struct validator *, asdl_pattern_seq *, int); static int validate_patterns(asdl_pattern_seq *, int);
static int validate_type_params(struct validator *, asdl_type_param_seq *); static int validate_type_params(asdl_type_param_seq *);
static int _validate_nonempty_seq(asdl_seq *, const char *, const char *); static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
static int validate_stmt(struct validator *, stmt_ty); static int validate_stmt(stmt_ty);
static int validate_expr(struct validator *, expr_ty, expr_context_ty); static int validate_expr(expr_ty, expr_context_ty);
static int validate_pattern(struct validator *, pattern_ty, int); static int validate_pattern(pattern_ty, int);
static int validate_typeparam(struct validator *, type_param_ty); static int validate_typeparam(type_param_ty);
#define VALIDATE_POSITIONS(node) \ #define VALIDATE_POSITIONS(node) \
if (node->lineno > node->end_lineno) { \ if (node->lineno > node->end_lineno) { \
@ -80,7 +68,7 @@ validate_name(PyObject *name)
} }
static int static int
validate_comprehension(struct validator *state, asdl_comprehension_seq *gens) validate_comprehension(asdl_comprehension_seq *gens)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
if (!asdl_seq_LEN(gens)) { if (!asdl_seq_LEN(gens)) {
@ -89,32 +77,32 @@ validate_comprehension(struct validator *state, asdl_comprehension_seq *gens)
} }
for (Py_ssize_t i = 0; i < asdl_seq_LEN(gens); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(gens); i++) {
comprehension_ty comp = asdl_seq_GET(gens, i); comprehension_ty comp = asdl_seq_GET(gens, i);
if (!validate_expr(state, comp->target, Store) || if (!validate_expr(comp->target, Store) ||
!validate_expr(state, comp->iter, Load) || !validate_expr(comp->iter, Load) ||
!validate_exprs(state, comp->ifs, Load, 0)) !validate_exprs(comp->ifs, Load, 0))
return 0; return 0;
} }
return 1; return 1;
} }
static int static int
validate_keywords(struct validator *state, asdl_keyword_seq *keywords) validate_keywords(asdl_keyword_seq *keywords)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
for (Py_ssize_t i = 0; i < asdl_seq_LEN(keywords); i++) for (Py_ssize_t i = 0; i < asdl_seq_LEN(keywords); i++)
if (!validate_expr(state, (asdl_seq_GET(keywords, i))->value, Load)) if (!validate_expr((asdl_seq_GET(keywords, i))->value, Load))
return 0; return 0;
return 1; return 1;
} }
static int static int
validate_args(struct validator *state, asdl_arg_seq *args) validate_args(asdl_arg_seq *args)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) {
arg_ty arg = asdl_seq_GET(args, i); arg_ty arg = asdl_seq_GET(args, i);
VALIDATE_POSITIONS(arg); VALIDATE_POSITIONS(arg);
if (arg->annotation && !validate_expr(state, arg->annotation, Load)) if (arg->annotation && !validate_expr(arg->annotation, Load))
return 0; return 0;
} }
return 1; return 1;
@ -136,20 +124,20 @@ expr_context_name(expr_context_ty ctx)
} }
static int static int
validate_arguments(struct validator *state, arguments_ty args) validate_arguments(arguments_ty args)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
if (!validate_args(state, args->posonlyargs) || !validate_args(state, args->args)) { if (!validate_args(args->posonlyargs) || !validate_args(args->args)) {
return 0; return 0;
} }
if (args->vararg && args->vararg->annotation if (args->vararg && args->vararg->annotation
&& !validate_expr(state, args->vararg->annotation, Load)) { && !validate_expr(args->vararg->annotation, Load)) {
return 0; return 0;
} }
if (!validate_args(state, args->kwonlyargs)) if (!validate_args(args->kwonlyargs))
return 0; return 0;
if (args->kwarg && args->kwarg->annotation if (args->kwarg && args->kwarg->annotation
&& !validate_expr(state, args->kwarg->annotation, Load)) { && !validate_expr(args->kwarg->annotation, Load)) {
return 0; return 0;
} }
if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) { if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) {
@ -161,11 +149,11 @@ validate_arguments(struct validator *state, arguments_ty args)
"kw_defaults on arguments"); "kw_defaults on arguments");
return 0; return 0;
} }
return validate_exprs(state, args->defaults, Load, 0) && validate_exprs(state, args->kw_defaults, Load, 1); return validate_exprs(args->defaults, Load, 0) && validate_exprs(args->kw_defaults, Load, 1);
} }
static int static int
validate_constant(struct validator *state, PyObject *value) validate_constant(PyObject *value)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
if (value == Py_None || value == Py_Ellipsis) if (value == Py_None || value == Py_Ellipsis)
@ -180,7 +168,7 @@ validate_constant(struct validator *state, PyObject *value)
return 1; return 1;
if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) { if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) {
ENTER_RECURSIVE(state); ENTER_RECURSIVE();
PyObject *it = PyObject_GetIter(value); PyObject *it = PyObject_GetIter(value);
if (it == NULL) if (it == NULL)
@ -196,7 +184,7 @@ validate_constant(struct validator *state, PyObject *value)
break; break;
} }
if (!validate_constant(state, item)) { if (!validate_constant(item)) {
Py_DECREF(it); Py_DECREF(it);
Py_DECREF(item); Py_DECREF(item);
return 0; return 0;
@ -205,7 +193,7 @@ validate_constant(struct validator *state, PyObject *value)
} }
Py_DECREF(it); Py_DECREF(it);
LEAVE_RECURSIVE(state); LEAVE_RECURSIVE();
return 1; return 1;
} }
@ -218,12 +206,12 @@ validate_constant(struct validator *state, PyObject *value)
} }
static int static int
validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx) validate_expr(expr_ty exp, expr_context_ty ctx)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
VALIDATE_POSITIONS(exp); VALIDATE_POSITIONS(exp);
int ret = -1; int ret = -1;
ENTER_RECURSIVE(state); ENTER_RECURSIVE();
int check_ctx = 1; int check_ctx = 1;
expr_context_ty actual_ctx; expr_context_ty actual_ctx;
@ -273,23 +261,23 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
PyErr_SetString(PyExc_ValueError, "BoolOp with less than 2 values"); PyErr_SetString(PyExc_ValueError, "BoolOp with less than 2 values");
return 0; return 0;
} }
ret = validate_exprs(state, exp->v.BoolOp.values, Load, 0); ret = validate_exprs(exp->v.BoolOp.values, Load, 0);
break; break;
case BinOp_kind: case BinOp_kind:
ret = validate_expr(state, exp->v.BinOp.left, Load) && ret = validate_expr(exp->v.BinOp.left, Load) &&
validate_expr(state, exp->v.BinOp.right, Load); validate_expr(exp->v.BinOp.right, Load);
break; break;
case UnaryOp_kind: case UnaryOp_kind:
ret = validate_expr(state, exp->v.UnaryOp.operand, Load); ret = validate_expr(exp->v.UnaryOp.operand, Load);
break; break;
case Lambda_kind: case Lambda_kind:
ret = validate_arguments(state, exp->v.Lambda.args) && ret = validate_arguments(exp->v.Lambda.args) &&
validate_expr(state, exp->v.Lambda.body, Load); validate_expr(exp->v.Lambda.body, Load);
break; break;
case IfExp_kind: case IfExp_kind:
ret = validate_expr(state, exp->v.IfExp.test, Load) && ret = validate_expr(exp->v.IfExp.test, Load) &&
validate_expr(state, exp->v.IfExp.body, Load) && validate_expr(exp->v.IfExp.body, Load) &&
validate_expr(state, exp->v.IfExp.orelse, Load); validate_expr(exp->v.IfExp.orelse, Load);
break; break;
case Dict_kind: case Dict_kind:
if (asdl_seq_LEN(exp->v.Dict.keys) != asdl_seq_LEN(exp->v.Dict.values)) { if (asdl_seq_LEN(exp->v.Dict.keys) != asdl_seq_LEN(exp->v.Dict.values)) {
@ -299,34 +287,34 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
} }
/* null_ok=1 for keys expressions to allow dict unpacking to work in /* null_ok=1 for keys expressions to allow dict unpacking to work in
dict literals, i.e. ``{**{a:b}}`` */ dict literals, i.e. ``{**{a:b}}`` */
ret = validate_exprs(state, exp->v.Dict.keys, Load, /*null_ok=*/ 1) && ret = validate_exprs(exp->v.Dict.keys, Load, /*null_ok=*/ 1) &&
validate_exprs(state, exp->v.Dict.values, Load, /*null_ok=*/ 0); validate_exprs(exp->v.Dict.values, Load, /*null_ok=*/ 0);
break; break;
case Set_kind: case Set_kind:
ret = validate_exprs(state, exp->v.Set.elts, Load, 0); ret = validate_exprs(exp->v.Set.elts, Load, 0);
break; break;
#define COMP(NAME) \ #define COMP(NAME) \
case NAME ## _kind: \ case NAME ## _kind: \
ret = validate_comprehension(state, exp->v.NAME.generators) && \ ret = validate_comprehension(exp->v.NAME.generators) && \
validate_expr(state, exp->v.NAME.elt, Load); \ validate_expr(exp->v.NAME.elt, Load); \
break; break;
COMP(ListComp) COMP(ListComp)
COMP(SetComp) COMP(SetComp)
COMP(GeneratorExp) COMP(GeneratorExp)
#undef COMP #undef COMP
case DictComp_kind: case DictComp_kind:
ret = validate_comprehension(state, exp->v.DictComp.generators) && ret = validate_comprehension(exp->v.DictComp.generators) &&
validate_expr(state, exp->v.DictComp.key, Load) && validate_expr(exp->v.DictComp.key, Load) &&
validate_expr(state, exp->v.DictComp.value, Load); validate_expr(exp->v.DictComp.value, Load);
break; break;
case Yield_kind: case Yield_kind:
ret = !exp->v.Yield.value || validate_expr(state, exp->v.Yield.value, Load); ret = !exp->v.Yield.value || validate_expr(exp->v.Yield.value, Load);
break; break;
case YieldFrom_kind: case YieldFrom_kind:
ret = validate_expr(state, exp->v.YieldFrom.value, Load); ret = validate_expr(exp->v.YieldFrom.value, Load);
break; break;
case Await_kind: case Await_kind:
ret = validate_expr(state, exp->v.Await.value, Load); ret = validate_expr(exp->v.Await.value, Load);
break; break;
case Compare_kind: case Compare_kind:
if (!asdl_seq_LEN(exp->v.Compare.comparators)) { if (!asdl_seq_LEN(exp->v.Compare.comparators)) {
@ -339,52 +327,52 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
"of comparators and operands"); "of comparators and operands");
return 0; return 0;
} }
ret = validate_exprs(state, exp->v.Compare.comparators, Load, 0) && ret = validate_exprs(exp->v.Compare.comparators, Load, 0) &&
validate_expr(state, exp->v.Compare.left, Load); validate_expr(exp->v.Compare.left, Load);
break; break;
case Call_kind: case Call_kind:
ret = validate_expr(state, exp->v.Call.func, Load) && ret = validate_expr(exp->v.Call.func, Load) &&
validate_exprs(state, exp->v.Call.args, Load, 0) && validate_exprs(exp->v.Call.args, Load, 0) &&
validate_keywords(state, exp->v.Call.keywords); validate_keywords(exp->v.Call.keywords);
break; break;
case Constant_kind: case Constant_kind:
if (!validate_constant(state, exp->v.Constant.value)) { if (!validate_constant(exp->v.Constant.value)) {
return 0; return 0;
} }
ret = 1; ret = 1;
break; break;
case JoinedStr_kind: case JoinedStr_kind:
ret = validate_exprs(state, exp->v.JoinedStr.values, Load, 0); ret = validate_exprs(exp->v.JoinedStr.values, Load, 0);
break; break;
case FormattedValue_kind: case FormattedValue_kind:
if (validate_expr(state, exp->v.FormattedValue.value, Load) == 0) if (validate_expr(exp->v.FormattedValue.value, Load) == 0)
return 0; return 0;
if (exp->v.FormattedValue.format_spec) { if (exp->v.FormattedValue.format_spec) {
ret = validate_expr(state, exp->v.FormattedValue.format_spec, Load); ret = validate_expr(exp->v.FormattedValue.format_spec, Load);
break; break;
} }
ret = 1; ret = 1;
break; break;
case Attribute_kind: case Attribute_kind:
ret = validate_expr(state, exp->v.Attribute.value, Load); ret = validate_expr(exp->v.Attribute.value, Load);
break; break;
case Subscript_kind: case Subscript_kind:
ret = validate_expr(state, exp->v.Subscript.slice, Load) && ret = validate_expr(exp->v.Subscript.slice, Load) &&
validate_expr(state, exp->v.Subscript.value, Load); validate_expr(exp->v.Subscript.value, Load);
break; break;
case Starred_kind: case Starred_kind:
ret = validate_expr(state, exp->v.Starred.value, ctx); ret = validate_expr(exp->v.Starred.value, ctx);
break; break;
case Slice_kind: case Slice_kind:
ret = (!exp->v.Slice.lower || validate_expr(state, exp->v.Slice.lower, Load)) && ret = (!exp->v.Slice.lower || validate_expr(exp->v.Slice.lower, Load)) &&
(!exp->v.Slice.upper || validate_expr(state, exp->v.Slice.upper, Load)) && (!exp->v.Slice.upper || validate_expr(exp->v.Slice.upper, Load)) &&
(!exp->v.Slice.step || validate_expr(state, exp->v.Slice.step, Load)); (!exp->v.Slice.step || validate_expr(exp->v.Slice.step, Load));
break; break;
case List_kind: case List_kind:
ret = validate_exprs(state, exp->v.List.elts, ctx, 0); ret = validate_exprs(exp->v.List.elts, ctx, 0);
break; break;
case Tuple_kind: case Tuple_kind:
ret = validate_exprs(state, exp->v.Tuple.elts, ctx, 0); ret = validate_exprs(exp->v.Tuple.elts, ctx, 0);
break; break;
case NamedExpr_kind: case NamedExpr_kind:
if (exp->v.NamedExpr.target->kind != Name_kind) { if (exp->v.NamedExpr.target->kind != Name_kind) {
@ -392,7 +380,7 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
"NamedExpr target must be a Name"); "NamedExpr target must be a Name");
return 0; return 0;
} }
ret = validate_expr(state, exp->v.NamedExpr.value, Load); ret = validate_expr(exp->v.NamedExpr.value, Load);
break; break;
/* This last case doesn't have any checking. */ /* This last case doesn't have any checking. */
case Name_kind: case Name_kind:
@ -404,7 +392,7 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
PyErr_SetString(PyExc_SystemError, "unexpected expression"); PyErr_SetString(PyExc_SystemError, "unexpected expression");
ret = 0; ret = 0;
} }
LEAVE_RECURSIVE(state); LEAVE_RECURSIVE();
return ret; return ret;
} }
@ -480,10 +468,10 @@ ensure_literal_complex(expr_ty exp)
} }
static int static int
validate_pattern_match_value(struct validator *state, expr_ty exp) validate_pattern_match_value(expr_ty exp)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
if (!validate_expr(state, exp, Load)) { if (!validate_expr(exp, Load)) {
return 0; return 0;
} }
@ -493,7 +481,7 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
/* Ellipsis and immutable sequences are not allowed. /* Ellipsis and immutable sequences are not allowed.
For True, False and None, MatchSingleton() should For True, False and None, MatchSingleton() should
be used */ be used */
if (!validate_expr(state, exp, Load)) { if (!validate_expr(exp, Load)) {
return 0; return 0;
} }
PyObject *literal = exp->v.Constant.value; PyObject *literal = exp->v.Constant.value;
@ -545,15 +533,15 @@ validate_capture(PyObject *name)
} }
static int static int
validate_pattern(struct validator *state, pattern_ty p, int star_ok) validate_pattern(pattern_ty p, int star_ok)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
VALIDATE_POSITIONS(p); VALIDATE_POSITIONS(p);
int ret = -1; int ret = -1;
ENTER_RECURSIVE(state); ENTER_RECURSIVE();
switch (p->kind) { switch (p->kind) {
case MatchValue_kind: case MatchValue_kind:
ret = validate_pattern_match_value(state, p->v.MatchValue.value); ret = validate_pattern_match_value(p->v.MatchValue.value);
break; break;
case MatchSingleton_kind: case MatchSingleton_kind:
ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value); ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
@ -563,7 +551,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
} }
break; break;
case MatchSequence_kind: case MatchSequence_kind:
ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1); ret = validate_patterns(p->v.MatchSequence.patterns, /*star_ok=*/1);
break; break;
case MatchMapping_kind: case MatchMapping_kind:
if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) { if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
@ -591,7 +579,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
continue; continue;
} }
} }
if (!validate_pattern_match_value(state, key)) { if (!validate_pattern_match_value(key)) {
ret = 0; ret = 0;
break; break;
} }
@ -599,7 +587,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
if (ret == 0) { if (ret == 0) {
break; break;
} }
ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0); ret = validate_patterns(p->v.MatchMapping.patterns, /*star_ok=*/0);
break; break;
case MatchClass_kind: case MatchClass_kind:
if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) { if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
@ -608,7 +596,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
ret = 0; ret = 0;
break; break;
} }
if (!validate_expr(state, p->v.MatchClass.cls, Load)) { if (!validate_expr(p->v.MatchClass.cls, Load)) {
ret = 0; ret = 0;
break; break;
} }
@ -644,12 +632,12 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
break; break;
} }
if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) { if (!validate_patterns(p->v.MatchClass.patterns, /*star_ok=*/0)) {
ret = 0; ret = 0;
break; break;
} }
ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0); ret = validate_patterns(p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
break; break;
case MatchStar_kind: case MatchStar_kind:
if (!star_ok) { if (!star_ok) {
@ -673,7 +661,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
ret = 0; ret = 0;
} }
else { else {
ret = validate_pattern(state, p->v.MatchAs.pattern, /*star_ok=*/0); ret = validate_pattern(p->v.MatchAs.pattern, /*star_ok=*/0);
} }
break; break;
case MatchOr_kind: case MatchOr_kind:
@ -683,7 +671,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
ret = 0; ret = 0;
break; break;
} }
ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0); ret = validate_patterns(p->v.MatchOr.patterns, /*star_ok=*/0);
break; break;
// No default case, so the compiler will emit a warning if new pattern // No default case, so the compiler will emit a warning if new pattern
// kinds are added without being handled here // kinds are added without being handled here
@ -692,7 +680,7 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok)
PyErr_SetString(PyExc_SystemError, "unexpected pattern"); PyErr_SetString(PyExc_SystemError, "unexpected pattern");
ret = 0; ret = 0;
} }
LEAVE_RECURSIVE(state); LEAVE_RECURSIVE();
return ret; return ret;
} }
@ -707,56 +695,56 @@ _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner)
#define validate_nonempty_seq(seq, what, owner) _validate_nonempty_seq((asdl_seq*)seq, what, owner) #define validate_nonempty_seq(seq, what, owner) _validate_nonempty_seq((asdl_seq*)seq, what, owner)
static int static int
validate_assignlist(struct validator *state, asdl_expr_seq *targets, expr_context_ty ctx) validate_assignlist(asdl_expr_seq *targets, expr_context_ty ctx)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
return validate_nonempty_seq(targets, "targets", ctx == Del ? "Delete" : "Assign") && return validate_nonempty_seq(targets, "targets", ctx == Del ? "Delete" : "Assign") &&
validate_exprs(state, targets, ctx, 0); validate_exprs(targets, ctx, 0);
} }
static int static int
validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner) validate_body(asdl_stmt_seq *body, const char *owner)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
return validate_nonempty_seq(body, "body", owner) && validate_stmts(state, body); return validate_nonempty_seq(body, "body", owner) && validate_stmts(body);
} }
static int static int
validate_stmt(struct validator *state, stmt_ty stmt) validate_stmt(stmt_ty stmt)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
VALIDATE_POSITIONS(stmt); VALIDATE_POSITIONS(stmt);
int ret = -1; int ret = -1;
ENTER_RECURSIVE(state); ENTER_RECURSIVE();
switch (stmt->kind) { switch (stmt->kind) {
case FunctionDef_kind: case FunctionDef_kind:
ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") && ret = validate_body(stmt->v.FunctionDef.body, "FunctionDef") &&
validate_type_params(state, stmt->v.FunctionDef.type_params) && validate_type_params(stmt->v.FunctionDef.type_params) &&
validate_arguments(state, stmt->v.FunctionDef.args) && validate_arguments(stmt->v.FunctionDef.args) &&
validate_exprs(state, stmt->v.FunctionDef.decorator_list, Load, 0) && validate_exprs(stmt->v.FunctionDef.decorator_list, Load, 0) &&
(!stmt->v.FunctionDef.returns || (!stmt->v.FunctionDef.returns ||
validate_expr(state, stmt->v.FunctionDef.returns, Load)); validate_expr(stmt->v.FunctionDef.returns, Load));
break; break;
case ClassDef_kind: case ClassDef_kind:
ret = validate_body(state, stmt->v.ClassDef.body, "ClassDef") && ret = validate_body(stmt->v.ClassDef.body, "ClassDef") &&
validate_type_params(state, stmt->v.ClassDef.type_params) && validate_type_params(stmt->v.ClassDef.type_params) &&
validate_exprs(state, stmt->v.ClassDef.bases, Load, 0) && validate_exprs(stmt->v.ClassDef.bases, Load, 0) &&
validate_keywords(state, stmt->v.ClassDef.keywords) && validate_keywords(stmt->v.ClassDef.keywords) &&
validate_exprs(state, stmt->v.ClassDef.decorator_list, Load, 0); validate_exprs(stmt->v.ClassDef.decorator_list, Load, 0);
break; break;
case Return_kind: case Return_kind:
ret = !stmt->v.Return.value || validate_expr(state, stmt->v.Return.value, Load); ret = !stmt->v.Return.value || validate_expr(stmt->v.Return.value, Load);
break; break;
case Delete_kind: case Delete_kind:
ret = validate_assignlist(state, stmt->v.Delete.targets, Del); ret = validate_assignlist(stmt->v.Delete.targets, Del);
break; break;
case Assign_kind: case Assign_kind:
ret = validate_assignlist(state, stmt->v.Assign.targets, Store) && ret = validate_assignlist(stmt->v.Assign.targets, Store) &&
validate_expr(state, stmt->v.Assign.value, Load); validate_expr(stmt->v.Assign.value, Load);
break; break;
case AugAssign_kind: case AugAssign_kind:
ret = validate_expr(state, stmt->v.AugAssign.target, Store) && ret = validate_expr(stmt->v.AugAssign.target, Store) &&
validate_expr(state, stmt->v.AugAssign.value, Load); validate_expr(stmt->v.AugAssign.value, Load);
break; break;
case AnnAssign_kind: case AnnAssign_kind:
if (stmt->v.AnnAssign.target->kind != Name_kind && if (stmt->v.AnnAssign.target->kind != Name_kind &&
@ -765,10 +753,10 @@ validate_stmt(struct validator *state, stmt_ty stmt)
"AnnAssign with simple non-Name target"); "AnnAssign with simple non-Name target");
return 0; return 0;
} }
ret = validate_expr(state, stmt->v.AnnAssign.target, Store) && ret = validate_expr(stmt->v.AnnAssign.target, Store) &&
(!stmt->v.AnnAssign.value || (!stmt->v.AnnAssign.value ||
validate_expr(state, stmt->v.AnnAssign.value, Load)) && validate_expr(stmt->v.AnnAssign.value, Load)) &&
validate_expr(state, stmt->v.AnnAssign.annotation, Load); validate_expr(stmt->v.AnnAssign.annotation, Load);
break; break;
case TypeAlias_kind: case TypeAlias_kind:
if (stmt->v.TypeAlias.name->kind != Name_kind) { if (stmt->v.TypeAlias.name->kind != Name_kind) {
@ -776,64 +764,64 @@ validate_stmt(struct validator *state, stmt_ty stmt)
"TypeAlias with non-Name name"); "TypeAlias with non-Name name");
return 0; return 0;
} }
ret = validate_expr(state, stmt->v.TypeAlias.name, Store) && ret = validate_expr(stmt->v.TypeAlias.name, Store) &&
validate_type_params(state, stmt->v.TypeAlias.type_params) && validate_type_params(stmt->v.TypeAlias.type_params) &&
validate_expr(state, stmt->v.TypeAlias.value, Load); validate_expr(stmt->v.TypeAlias.value, Load);
break; break;
case For_kind: case For_kind:
ret = validate_expr(state, stmt->v.For.target, Store) && ret = validate_expr(stmt->v.For.target, Store) &&
validate_expr(state, stmt->v.For.iter, Load) && validate_expr(stmt->v.For.iter, Load) &&
validate_body(state, stmt->v.For.body, "For") && validate_body(stmt->v.For.body, "For") &&
validate_stmts(state, stmt->v.For.orelse); validate_stmts(stmt->v.For.orelse);
break; break;
case AsyncFor_kind: case AsyncFor_kind:
ret = validate_expr(state, stmt->v.AsyncFor.target, Store) && ret = validate_expr(stmt->v.AsyncFor.target, Store) &&
validate_expr(state, stmt->v.AsyncFor.iter, Load) && validate_expr(stmt->v.AsyncFor.iter, Load) &&
validate_body(state, stmt->v.AsyncFor.body, "AsyncFor") && validate_body(stmt->v.AsyncFor.body, "AsyncFor") &&
validate_stmts(state, stmt->v.AsyncFor.orelse); validate_stmts(stmt->v.AsyncFor.orelse);
break; break;
case While_kind: case While_kind:
ret = validate_expr(state, stmt->v.While.test, Load) && ret = validate_expr(stmt->v.While.test, Load) &&
validate_body(state, stmt->v.While.body, "While") && validate_body(stmt->v.While.body, "While") &&
validate_stmts(state, stmt->v.While.orelse); validate_stmts(stmt->v.While.orelse);
break; break;
case If_kind: case If_kind:
ret = validate_expr(state, stmt->v.If.test, Load) && ret = validate_expr(stmt->v.If.test, Load) &&
validate_body(state, stmt->v.If.body, "If") && validate_body(stmt->v.If.body, "If") &&
validate_stmts(state, stmt->v.If.orelse); validate_stmts(stmt->v.If.orelse);
break; break;
case With_kind: case With_kind:
if (!validate_nonempty_seq(stmt->v.With.items, "items", "With")) if (!validate_nonempty_seq(stmt->v.With.items, "items", "With"))
return 0; return 0;
for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.With.items); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.With.items); i++) {
withitem_ty item = asdl_seq_GET(stmt->v.With.items, i); withitem_ty item = asdl_seq_GET(stmt->v.With.items, i);
if (!validate_expr(state, item->context_expr, Load) || if (!validate_expr(item->context_expr, Load) ||
(item->optional_vars && !validate_expr(state, item->optional_vars, Store))) (item->optional_vars && !validate_expr(item->optional_vars, Store)))
return 0; return 0;
} }
ret = validate_body(state, stmt->v.With.body, "With"); ret = validate_body(stmt->v.With.body, "With");
break; break;
case AsyncWith_kind: case AsyncWith_kind:
if (!validate_nonempty_seq(stmt->v.AsyncWith.items, "items", "AsyncWith")) if (!validate_nonempty_seq(stmt->v.AsyncWith.items, "items", "AsyncWith"))
return 0; return 0;
for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.AsyncWith.items); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.AsyncWith.items); i++) {
withitem_ty item = asdl_seq_GET(stmt->v.AsyncWith.items, i); withitem_ty item = asdl_seq_GET(stmt->v.AsyncWith.items, i);
if (!validate_expr(state, item->context_expr, Load) || if (!validate_expr(item->context_expr, Load) ||
(item->optional_vars && !validate_expr(state, item->optional_vars, Store))) (item->optional_vars && !validate_expr(item->optional_vars, Store)))
return 0; return 0;
} }
ret = validate_body(state, stmt->v.AsyncWith.body, "AsyncWith"); ret = validate_body(stmt->v.AsyncWith.body, "AsyncWith");
break; break;
case Match_kind: case Match_kind:
if (!validate_expr(state, stmt->v.Match.subject, Load) if (!validate_expr(stmt->v.Match.subject, Load)
|| !validate_nonempty_seq(stmt->v.Match.cases, "cases", "Match")) { || !validate_nonempty_seq(stmt->v.Match.cases, "cases", "Match")) {
return 0; return 0;
} }
for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) {
match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i); match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i);
if (!validate_pattern(state, m->pattern, /*star_ok=*/0) if (!validate_pattern(m->pattern, /*star_ok=*/0)
|| (m->guard && !validate_expr(state, m->guard, Load)) || (m->guard && !validate_expr(m->guard, Load))
|| !validate_body(state, m->body, "match_case")) { || !validate_body(m->body, "match_case")) {
return 0; return 0;
} }
} }
@ -841,8 +829,8 @@ validate_stmt(struct validator *state, stmt_ty stmt)
break; break;
case Raise_kind: case Raise_kind:
if (stmt->v.Raise.exc) { if (stmt->v.Raise.exc) {
ret = validate_expr(state, stmt->v.Raise.exc, Load) && ret = validate_expr(stmt->v.Raise.exc, Load) &&
(!stmt->v.Raise.cause || validate_expr(state, stmt->v.Raise.cause, Load)); (!stmt->v.Raise.cause || validate_expr(stmt->v.Raise.cause, Load));
break; break;
} }
if (stmt->v.Raise.cause) { if (stmt->v.Raise.cause) {
@ -852,7 +840,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
ret = 1; ret = 1;
break; break;
case Try_kind: case Try_kind:
if (!validate_body(state, stmt->v.Try.body, "Try")) if (!validate_body(stmt->v.Try.body, "Try"))
return 0; return 0;
if (!asdl_seq_LEN(stmt->v.Try.handlers) && if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
!asdl_seq_LEN(stmt->v.Try.finalbody)) { !asdl_seq_LEN(stmt->v.Try.finalbody)) {
@ -868,17 +856,17 @@ validate_stmt(struct validator *state, stmt_ty stmt)
excepthandler_ty handler = asdl_seq_GET(stmt->v.Try.handlers, i); excepthandler_ty handler = asdl_seq_GET(stmt->v.Try.handlers, i);
VALIDATE_POSITIONS(handler); VALIDATE_POSITIONS(handler);
if ((handler->v.ExceptHandler.type && if ((handler->v.ExceptHandler.type &&
!validate_expr(state, handler->v.ExceptHandler.type, Load)) || !validate_expr(handler->v.ExceptHandler.type, Load)) ||
!validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler")) !validate_body(handler->v.ExceptHandler.body, "ExceptHandler"))
return 0; return 0;
} }
ret = (!asdl_seq_LEN(stmt->v.Try.finalbody) || ret = (!asdl_seq_LEN(stmt->v.Try.finalbody) ||
validate_stmts(state, stmt->v.Try.finalbody)) && validate_stmts(stmt->v.Try.finalbody)) &&
(!asdl_seq_LEN(stmt->v.Try.orelse) || (!asdl_seq_LEN(stmt->v.Try.orelse) ||
validate_stmts(state, stmt->v.Try.orelse)); validate_stmts(stmt->v.Try.orelse));
break; break;
case TryStar_kind: case TryStar_kind:
if (!validate_body(state, stmt->v.TryStar.body, "TryStar")) if (!validate_body(stmt->v.TryStar.body, "TryStar"))
return 0; return 0;
if (!asdl_seq_LEN(stmt->v.TryStar.handlers) && if (!asdl_seq_LEN(stmt->v.TryStar.handlers) &&
!asdl_seq_LEN(stmt->v.TryStar.finalbody)) { !asdl_seq_LEN(stmt->v.TryStar.finalbody)) {
@ -893,18 +881,18 @@ validate_stmt(struct validator *state, stmt_ty stmt)
for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.TryStar.handlers); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.TryStar.handlers); i++) {
excepthandler_ty handler = asdl_seq_GET(stmt->v.TryStar.handlers, i); excepthandler_ty handler = asdl_seq_GET(stmt->v.TryStar.handlers, i);
if ((handler->v.ExceptHandler.type && if ((handler->v.ExceptHandler.type &&
!validate_expr(state, handler->v.ExceptHandler.type, Load)) || !validate_expr(handler->v.ExceptHandler.type, Load)) ||
!validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler")) !validate_body(handler->v.ExceptHandler.body, "ExceptHandler"))
return 0; return 0;
} }
ret = (!asdl_seq_LEN(stmt->v.TryStar.finalbody) || ret = (!asdl_seq_LEN(stmt->v.TryStar.finalbody) ||
validate_stmts(state, stmt->v.TryStar.finalbody)) && validate_stmts(stmt->v.TryStar.finalbody)) &&
(!asdl_seq_LEN(stmt->v.TryStar.orelse) || (!asdl_seq_LEN(stmt->v.TryStar.orelse) ||
validate_stmts(state, stmt->v.TryStar.orelse)); validate_stmts(stmt->v.TryStar.orelse));
break; break;
case Assert_kind: case Assert_kind:
ret = validate_expr(state, stmt->v.Assert.test, Load) && ret = validate_expr(stmt->v.Assert.test, Load) &&
(!stmt->v.Assert.msg || validate_expr(state, stmt->v.Assert.msg, Load)); (!stmt->v.Assert.msg || validate_expr(stmt->v.Assert.msg, Load));
break; break;
case Import_kind: case Import_kind:
ret = validate_nonempty_seq(stmt->v.Import.names, "names", "Import"); ret = validate_nonempty_seq(stmt->v.Import.names, "names", "Import");
@ -923,15 +911,15 @@ validate_stmt(struct validator *state, stmt_ty stmt)
ret = validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal"); ret = validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal");
break; break;
case Expr_kind: case Expr_kind:
ret = validate_expr(state, stmt->v.Expr.value, Load); ret = validate_expr(stmt->v.Expr.value, Load);
break; break;
case AsyncFunctionDef_kind: case AsyncFunctionDef_kind:
ret = validate_body(state, stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") && ret = validate_body(stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") &&
validate_type_params(state, stmt->v.AsyncFunctionDef.type_params) && validate_type_params(stmt->v.AsyncFunctionDef.type_params) &&
validate_arguments(state, stmt->v.AsyncFunctionDef.args) && validate_arguments(stmt->v.AsyncFunctionDef.args) &&
validate_exprs(state, stmt->v.AsyncFunctionDef.decorator_list, Load, 0) && validate_exprs(stmt->v.AsyncFunctionDef.decorator_list, Load, 0) &&
(!stmt->v.AsyncFunctionDef.returns || (!stmt->v.AsyncFunctionDef.returns ||
validate_expr(state, stmt->v.AsyncFunctionDef.returns, Load)); validate_expr(stmt->v.AsyncFunctionDef.returns, Load));
break; break;
case Pass_kind: case Pass_kind:
case Break_kind: case Break_kind:
@ -944,18 +932,18 @@ validate_stmt(struct validator *state, stmt_ty stmt)
PyErr_SetString(PyExc_SystemError, "unexpected statement"); PyErr_SetString(PyExc_SystemError, "unexpected statement");
ret = 0; ret = 0;
} }
LEAVE_RECURSIVE(state); LEAVE_RECURSIVE();
return ret; return ret;
} }
static int static int
validate_stmts(struct validator *state, asdl_stmt_seq *seq) validate_stmts(asdl_stmt_seq *seq)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
for (Py_ssize_t i = 0; i < asdl_seq_LEN(seq); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(seq); i++) {
stmt_ty stmt = asdl_seq_GET(seq, i); stmt_ty stmt = asdl_seq_GET(seq, i);
if (stmt) { if (stmt) {
if (!validate_stmt(state, stmt)) if (!validate_stmt(stmt))
return 0; return 0;
} }
else { else {
@ -968,13 +956,13 @@ validate_stmts(struct validator *state, asdl_stmt_seq *seq)
} }
static int static int
validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok) validate_exprs(asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
for (Py_ssize_t i = 0; i < asdl_seq_LEN(exprs); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(exprs); i++) {
expr_ty expr = asdl_seq_GET(exprs, i); expr_ty expr = asdl_seq_GET(exprs, i);
if (expr) { if (expr) {
if (!validate_expr(state, expr, ctx)) if (!validate_expr(expr, ctx))
return 0; return 0;
} }
else if (!null_ok) { else if (!null_ok) {
@ -988,12 +976,12 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct
} }
static int static int
validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok) validate_patterns(asdl_pattern_seq *patterns, int star_ok)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
for (Py_ssize_t i = 0; i < asdl_seq_LEN(patterns); i++) { for (Py_ssize_t i = 0; i < asdl_seq_LEN(patterns); i++) {
pattern_ty pattern = asdl_seq_GET(patterns, i); pattern_ty pattern = asdl_seq_GET(patterns, i);
if (!validate_pattern(state, pattern, star_ok)) { if (!validate_pattern(pattern, star_ok)) {
return 0; return 0;
} }
} }
@ -1001,7 +989,7 @@ validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_
} }
static int static int
validate_typeparam(struct validator *state, type_param_ty tp) validate_typeparam(type_param_ty tp)
{ {
VALIDATE_POSITIONS(tp); VALIDATE_POSITIONS(tp);
int ret = -1; int ret = -1;
@ -1009,32 +997,32 @@ validate_typeparam(struct validator *state, type_param_ty tp)
case TypeVar_kind: case TypeVar_kind:
ret = validate_name(tp->v.TypeVar.name) && ret = validate_name(tp->v.TypeVar.name) &&
(!tp->v.TypeVar.bound || (!tp->v.TypeVar.bound ||
validate_expr(state, tp->v.TypeVar.bound, Load)) && validate_expr(tp->v.TypeVar.bound, Load)) &&
(!tp->v.TypeVar.default_value || (!tp->v.TypeVar.default_value ||
validate_expr(state, tp->v.TypeVar.default_value, Load)); validate_expr(tp->v.TypeVar.default_value, Load));
break; break;
case ParamSpec_kind: case ParamSpec_kind:
ret = validate_name(tp->v.ParamSpec.name) && ret = validate_name(tp->v.ParamSpec.name) &&
(!tp->v.ParamSpec.default_value || (!tp->v.ParamSpec.default_value ||
validate_expr(state, tp->v.ParamSpec.default_value, Load)); validate_expr(tp->v.ParamSpec.default_value, Load));
break; break;
case TypeVarTuple_kind: case TypeVarTuple_kind:
ret = validate_name(tp->v.TypeVarTuple.name) && ret = validate_name(tp->v.TypeVarTuple.name) &&
(!tp->v.TypeVarTuple.default_value || (!tp->v.TypeVarTuple.default_value ||
validate_expr(state, tp->v.TypeVarTuple.default_value, Load)); validate_expr(tp->v.TypeVarTuple.default_value, Load));
break; break;
} }
return ret; return ret;
} }
static int static int
validate_type_params(struct validator *state, asdl_type_param_seq *tps) validate_type_params(asdl_type_param_seq *tps)
{ {
Py_ssize_t i; Py_ssize_t i;
for (i = 0; i < asdl_seq_LEN(tps); i++) { for (i = 0; i < asdl_seq_LEN(tps); i++) {
type_param_ty tp = asdl_seq_GET(tps, i); type_param_ty tp = asdl_seq_GET(tps, i);
if (tp) { if (tp) {
if (!validate_typeparam(state, tp)) if (!validate_typeparam(tp))
return 0; return 0;
} }
} }
@ -1046,34 +1034,20 @@ _PyAST_Validate(mod_ty mod)
{ {
assert(!PyErr_Occurred()); assert(!PyErr_Occurred());
int res = -1; int res = -1;
struct validator state;
PyThreadState *tstate;
int starting_recursion_depth;
/* Setup recursion depth check counters */
tstate = _PyThreadState_GET();
if (!tstate) {
return 0;
}
/* Be careful here to prevent overflow. */
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth;
state.recursion_depth = starting_recursion_depth;
state.recursion_limit = Py_C_RECURSION_LIMIT;
switch (mod->kind) { switch (mod->kind) {
case Module_kind: case Module_kind:
res = validate_stmts(&state, mod->v.Module.body); res = validate_stmts(mod->v.Module.body);
break; break;
case Interactive_kind: case Interactive_kind:
res = validate_stmts(&state, mod->v.Interactive.body); res = validate_stmts(mod->v.Interactive.body);
break; break;
case Expression_kind: case Expression_kind:
res = validate_expr(&state, mod->v.Expression.body, Load); res = validate_expr(mod->v.Expression.body, Load);
break; break;
case FunctionType_kind: case FunctionType_kind:
res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) && res = validate_exprs(mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) &&
validate_expr(&state, mod->v.FunctionType.returns, Load); validate_expr(mod->v.FunctionType.returns, Load);
break; break;
// No default case so compiler emits warning for unhandled cases // No default case so compiler emits warning for unhandled cases
} }
@ -1082,14 +1056,6 @@ _PyAST_Validate(mod_ty mod)
PyErr_SetString(PyExc_SystemError, "impossible module node"); PyErr_SetString(PyExc_SystemError, "impossible module node");
return 0; return 0;
} }
/* Check that the recursion depth counting balanced correctly */
if (res && state.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST validator recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state.recursion_depth);
return 0;
}
return res; return res;
} }

View file

@ -10,24 +10,14 @@
typedef struct { typedef struct {
int optimize; int optimize;
int ff_features; int ff_features;
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
} _PyASTOptimizeState; } _PyASTOptimizeState;
#define ENTER_RECURSIVE(ST) \ #define ENTER_RECURSIVE() \
do { \ if (Py_EnterRecursiveCall(" during compilation")) { \
if (++(ST)->recursion_depth > (ST)->recursion_limit) { \
PyErr_SetString(PyExc_RecursionError, \
"maximum recursion depth exceeded during compilation"); \
return 0; \ return 0; \
} \ }
} while(0)
#define LEAVE_RECURSIVE(ST) \ #define LEAVE_RECURSIVE() Py_LeaveRecursiveCall();
do { \
--(ST)->recursion_depth; \
} while(0)
static int static int
make_const(expr_ty node, PyObject *val, PyArena *arena) make_const(expr_ty node, PyObject *val, PyArena *arena)
@ -424,7 +414,7 @@ astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
static int static int
astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{ {
ENTER_RECURSIVE(state); ENTER_RECURSIVE();
switch (node_->kind) { switch (node_->kind) {
case BoolOp_kind: case BoolOp_kind:
CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values); CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
@ -520,7 +510,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
case Name_kind: case Name_kind:
if (node_->v.Name.ctx == Load && if (node_->v.Name.ctx == Load &&
_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) { _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
LEAVE_RECURSIVE(state); LEAVE_RECURSIVE();
return make_const(node_, PyBool_FromLong(!state->optimize), ctx_); return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
} }
break; break;
@ -533,7 +523,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// No default case, so the compiler will emit a warning if new expression // No default case, so the compiler will emit a warning if new expression
// kinds are added without being handled here // kinds are added without being handled here
} }
LEAVE_RECURSIVE(state);; LEAVE_RECURSIVE();
return 1; return 1;
} }
@ -578,7 +568,7 @@ astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
static int static int
astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{ {
ENTER_RECURSIVE(state); ENTER_RECURSIVE();
switch (node_->kind) { switch (node_->kind) {
case FunctionDef_kind: case FunctionDef_kind:
CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params); CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params);
@ -700,7 +690,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// No default case, so the compiler will emit a warning if new statement // No default case, so the compiler will emit a warning if new statement
// kinds are added without being handled here // kinds are added without being handled here
} }
LEAVE_RECURSIVE(state); LEAVE_RECURSIVE();
return 1; return 1;
} }
@ -770,7 +760,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// Currently, this is really only used to form complex/negative numeric // Currently, this is really only used to form complex/negative numeric
// constants in MatchValue and MatchMapping nodes // constants in MatchValue and MatchMapping nodes
// We still recurse into all subexpressions and subpatterns anyway // We still recurse into all subexpressions and subpatterns anyway
ENTER_RECURSIVE(state); ENTER_RECURSIVE();
switch (node_->kind) { switch (node_->kind) {
case MatchValue_kind: case MatchValue_kind:
CALL(fold_const_match_patterns, expr_ty, node_->v.MatchValue.value); CALL(fold_const_match_patterns, expr_ty, node_->v.MatchValue.value);
@ -802,7 +792,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// No default case, so the compiler will emit a warning if new pattern // No default case, so the compiler will emit a warning if new pattern
// kinds are added without being handled here // kinds are added without being handled here
} }
LEAVE_RECURSIVE(state); LEAVE_RECURSIVE();
return 1; return 1;
} }
@ -840,34 +830,12 @@ astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *stat
int int
_PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize, int ff_features) _PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize, int ff_features)
{ {
PyThreadState *tstate;
int starting_recursion_depth;
_PyASTOptimizeState state; _PyASTOptimizeState state;
state.optimize = optimize; state.optimize = optimize;
state.ff_features = ff_features; state.ff_features = ff_features;
/* Setup recursion depth check counters */
tstate = _PyThreadState_GET();
if (!tstate) {
return 0;
}
/* Be careful here to prevent overflow. */
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth;
state.recursion_depth = starting_recursion_depth;
state.recursion_limit = Py_C_RECURSION_LIMIT;
int ret = astfold_mod(mod, arena, &state); int ret = astfold_mod(mod, arena, &state);
assert(ret || PyErr_Occurred()); assert(ret || PyErr_Occurred());
/* Check that the recursion depth counting balanced correctly */
if (ret && state.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST optimizer recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state.recursion_depth);
return 0;
}
return ret; return ret;
} }

View file

@ -1082,7 +1082,6 @@ dummy_func(
/* Restore previous frame and return. */ /* Restore previous frame and return. */
tstate->current_frame = frame->previous; tstate->current_frame = frame->previous;
assert(!_PyErr_Occurred(tstate)); assert(!_PyErr_Occurred(tstate));
tstate->c_recursion_remaining += PY_EVAL_C_STACK_UNITS;
PyObject *result = PyStackRef_AsPyObjectSteal(retval); PyObject *result = PyStackRef_AsPyObjectSteal(retval);
SYNC_SP(); /* Not strictly necessary, but prevents warnings */ SYNC_SP(); /* Not strictly necessary, but prevents warnings */
return result; return result;
@ -3971,11 +3970,10 @@ dummy_func(
EXIT_IF(!PyCFunction_CheckExact(callable_o)); EXIT_IF(!PyCFunction_CheckExact(callable_o));
EXIT_IF(PyCFunction_GET_FLAGS(callable_o) != METH_O); EXIT_IF(PyCFunction_GET_FLAGS(callable_o) != METH_O);
// CPython promises to check all non-vectorcall function calls. // CPython promises to check all non-vectorcall function calls.
EXIT_IF(tstate->c_recursion_remaining <= 0); EXIT_IF(_Py_ReachedRecursionLimit(tstate));
STAT_INC(CALL, hit); STAT_INC(CALL, hit);
PyCFunction cfunc = PyCFunction_GET_FUNCTION(callable_o); PyCFunction cfunc = PyCFunction_GET_FUNCTION(callable_o);
_PyStackRef arg = args[0]; _PyStackRef arg = args[0];
_Py_EnterRecursiveCallTstateUnchecked(tstate);
PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, PyCFunction_GET_SELF(callable_o), PyStackRef_AsPyObjectBorrow(arg)); PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, PyCFunction_GET_SELF(callable_o), PyStackRef_AsPyObjectBorrow(arg));
_Py_LeaveRecursiveCallTstate(tstate); _Py_LeaveRecursiveCallTstate(tstate);
assert((res_o != NULL) ^ (_PyErr_Occurred(tstate) != NULL)); assert((res_o != NULL) ^ (_PyErr_Occurred(tstate) != NULL));
@ -4165,14 +4163,13 @@ dummy_func(
PyMethodDef *meth = method->d_method; PyMethodDef *meth = method->d_method;
EXIT_IF(meth->ml_flags != METH_O); EXIT_IF(meth->ml_flags != METH_O);
// CPython promises to check all non-vectorcall function calls. // CPython promises to check all non-vectorcall function calls.
EXIT_IF(tstate->c_recursion_remaining <= 0); EXIT_IF(_Py_ReachedRecursionLimit(tstate));
_PyStackRef arg_stackref = arguments[1]; _PyStackRef arg_stackref = arguments[1];
_PyStackRef self_stackref = arguments[0]; _PyStackRef self_stackref = arguments[0];
EXIT_IF(!Py_IS_TYPE(PyStackRef_AsPyObjectBorrow(self_stackref), EXIT_IF(!Py_IS_TYPE(PyStackRef_AsPyObjectBorrow(self_stackref),
method->d_common.d_type)); method->d_common.d_type));
STAT_INC(CALL, hit); STAT_INC(CALL, hit);
PyCFunction cfunc = meth->ml_meth; PyCFunction cfunc = meth->ml_meth;
_Py_EnterRecursiveCallTstateUnchecked(tstate);
PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, PyObject *res_o = _PyCFunction_TrampolineCall(cfunc,
PyStackRef_AsPyObjectBorrow(self_stackref), PyStackRef_AsPyObjectBorrow(self_stackref),
PyStackRef_AsPyObjectBorrow(arg_stackref)); PyStackRef_AsPyObjectBorrow(arg_stackref));
@ -4247,10 +4244,9 @@ dummy_func(
EXIT_IF(!Py_IS_TYPE(self, method->d_common.d_type)); EXIT_IF(!Py_IS_TYPE(self, method->d_common.d_type));
EXIT_IF(meth->ml_flags != METH_NOARGS); EXIT_IF(meth->ml_flags != METH_NOARGS);
// CPython promises to check all non-vectorcall function calls. // CPython promises to check all non-vectorcall function calls.
EXIT_IF(tstate->c_recursion_remaining <= 0); EXIT_IF(_Py_ReachedRecursionLimit(tstate));
STAT_INC(CALL, hit); STAT_INC(CALL, hit);
PyCFunction cfunc = meth->ml_meth; PyCFunction cfunc = meth->ml_meth;
_Py_EnterRecursiveCallTstateUnchecked(tstate);
PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, self, NULL); PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, self, NULL);
_Py_LeaveRecursiveCallTstate(tstate); _Py_LeaveRecursiveCallTstate(tstate);
assert((res_o != NULL) ^ (_PyErr_Occurred(tstate) != NULL)); assert((res_o != NULL) ^ (_PyErr_Occurred(tstate) != NULL));
@ -5252,7 +5248,6 @@ dummy_func(
if (frame->owner == FRAME_OWNED_BY_INTERPRETER) { if (frame->owner == FRAME_OWNED_BY_INTERPRETER) {
/* Restore previous frame and exit */ /* Restore previous frame and exit */
tstate->current_frame = frame->previous; tstate->current_frame = frame->previous;
tstate->c_recursion_remaining += PY_EVAL_C_STACK_UNITS;
return NULL; return NULL;
} }
next_instr = frame->instr_ptr; next_instr = frame->instr_ptr;

View file

@ -304,36 +304,122 @@ Py_SetRecursionLimit(int new_limit)
_PyEval_StartTheWorld(interp); _PyEval_StartTheWorld(interp);
} }
int
_Py_ReachedRecursionLimitWithMargin(PyThreadState *tstate, int margin_count)
{
uintptr_t here_addr = _Py_get_machine_stack_pointer();
_PyThreadStateImpl *_tstate = (_PyThreadStateImpl *)tstate;
if (here_addr > _tstate->c_stack_soft_limit + margin_count * PYOS_STACK_MARGIN_BYTES) {
return 0;
}
if (_tstate->c_stack_hard_limit == 0) {
_Py_InitializeRecursionLimits(tstate);
}
return here_addr <= _tstate->c_stack_soft_limit + margin_count * PYOS_STACK_MARGIN_BYTES;
}
void
_Py_EnterRecursiveCallUnchecked(PyThreadState *tstate)
{
uintptr_t here_addr = _Py_get_machine_stack_pointer();
_PyThreadStateImpl *_tstate = (_PyThreadStateImpl *)tstate;
if (here_addr < _tstate->c_stack_hard_limit) {
Py_FatalError("Unchecked stack overflow.");
}
}
#if defined(__s390x__)
# define Py_C_STACK_SIZE 320000
#elif defined(_WIN32)
// Don't define Py_C_STACK_SIZE, ask the O/S
#elif defined(__ANDROID__)
# define Py_C_STACK_SIZE 1200000
#elif defined(__sparc__)
# define Py_C_STACK_SIZE 1600000
#elif defined(__wasi__)
/* Web assembly has two stacks, so this isn't really the stack depth */
# define Py_C_STACK_SIZE 80000
#elif defined(__hppa__) || defined(__powerpc64__)
# define Py_C_STACK_SIZE 2000000
#else
# define Py_C_STACK_SIZE 4000000
#endif
void
_Py_InitializeRecursionLimits(PyThreadState *tstate)
{
_PyThreadStateImpl *_tstate = (_PyThreadStateImpl *)tstate;
#ifdef WIN32
ULONG_PTR low, high;
GetCurrentThreadStackLimits(&low, &high);
_tstate->c_stack_top = (uintptr_t)high;
ULONG guarantee = 0;
SetThreadStackGuarantee(&guarantee);
_tstate->c_stack_hard_limit = ((uintptr_t)low) + guarantee + PYOS_STACK_MARGIN_BYTES;
_tstate->c_stack_soft_limit = _tstate->c_stack_hard_limit + PYOS_STACK_MARGIN_BYTES;
#else
uintptr_t here_addr = _Py_get_machine_stack_pointer();
# if defined(HAVE_PTHREAD_GETATTR_NP) && !defined(_AIX)
size_t stack_size, guard_size;
void *stack_addr;
pthread_attr_t attr;
int err = pthread_getattr_np(pthread_self(), &attr);
if (err == 0) {
err = pthread_attr_getguardsize(&attr, &guard_size);
err |= pthread_attr_getstack(&attr, &stack_addr, &stack_size);
err |= pthread_attr_destroy(&attr);
}
if (err == 0) {
uintptr_t base = ((uintptr_t)stack_addr) + guard_size;
_tstate->c_stack_top = base + stack_size;
_tstate->c_stack_soft_limit = base + PYOS_STACK_MARGIN_BYTES * 2;
_tstate->c_stack_hard_limit = base + PYOS_STACK_MARGIN_BYTES;
assert(_tstate->c_stack_soft_limit < here_addr);
assert(here_addr < _tstate->c_stack_top);
return;
}
# endif
_tstate->c_stack_top = _Py_SIZE_ROUND_UP(here_addr, 4096);
_tstate->c_stack_soft_limit = _tstate->c_stack_top - Py_C_STACK_SIZE;
_tstate->c_stack_hard_limit = _tstate->c_stack_top - (Py_C_STACK_SIZE + PYOS_STACK_MARGIN_BYTES);
#endif
}
/* The function _Py_EnterRecursiveCallTstate() only calls _Py_CheckRecursiveCall() /* The function _Py_EnterRecursiveCallTstate() only calls _Py_CheckRecursiveCall()
if the recursion_depth reaches recursion_limit. */ if the recursion_depth reaches recursion_limit. */
int int
_Py_CheckRecursiveCall(PyThreadState *tstate, const char *where) _Py_CheckRecursiveCall(PyThreadState *tstate, const char *where)
{ {
#ifdef USE_STACKCHECK _PyThreadStateImpl *_tstate = (_PyThreadStateImpl *)tstate;
if (PyOS_CheckStack()) { uintptr_t here_addr = _Py_get_machine_stack_pointer();
++tstate->c_recursion_remaining; assert(_tstate->c_stack_soft_limit != 0);
_PyErr_SetString(tstate, PyExc_MemoryError, "Stack overflow"); if (_tstate->c_stack_hard_limit == 0) {
return -1; _Py_InitializeRecursionLimits(tstate);
} }
#endif if (here_addr >= _tstate->c_stack_soft_limit) {
if (tstate->recursion_headroom) { return 0;
if (tstate->c_recursion_remaining < -50) { }
assert(_tstate->c_stack_hard_limit != 0);
if (here_addr < _tstate->c_stack_hard_limit) {
/* Overflowing while handling an overflow. Give up. */ /* Overflowing while handling an overflow. Give up. */
Py_FatalError("Cannot recover from stack overflow."); int kbytes_used = (int)(_tstate->c_stack_top - here_addr)/1024;
char buffer[80];
snprintf(buffer, 80, "Unrecoverable stack overflow (used %d kB)%s", kbytes_used, where);
Py_FatalError(buffer);
} }
if (tstate->recursion_headroom) {
return 0;
} }
else { else {
if (tstate->c_recursion_remaining <= 0) { int kbytes_used = (int)(_tstate->c_stack_top - here_addr)/1024;
tstate->recursion_headroom++; tstate->recursion_headroom++;
_PyErr_Format(tstate, PyExc_RecursionError, _PyErr_Format(tstate, PyExc_RecursionError,
"maximum recursion depth exceeded%s", "Stack overflow (used %d kB)%s",
kbytes_used,
where); where);
tstate->recursion_headroom--; tstate->recursion_headroom--;
++tstate->c_recursion_remaining;
return -1; return -1;
} }
}
return 0;
} }
@ -761,11 +847,6 @@ _PyObjectArray_Free(PyObject **array, PyObject **scratch)
} }
/* _PyEval_EvalFrameDefault() is a *big* function,
* so consume 3 units of C stack */
#define PY_EVAL_C_STACK_UNITS 2
/* _PyEval_EvalFrameDefault is too large to optimize for speed with PGO on MSVC. /* _PyEval_EvalFrameDefault is too large to optimize for speed with PGO on MSVC.
*/ */
#if (defined(_MSC_VER) && \ #if (defined(_MSC_VER) && \
@ -838,8 +919,6 @@ _PyEval_EvalFrameDefault(PyThreadState *tstate, _PyInterpreterFrame *frame, int
frame->previous = &entry_frame; frame->previous = &entry_frame;
tstate->current_frame = frame; tstate->current_frame = frame;
tstate->c_recursion_remaining -= (PY_EVAL_C_STACK_UNITS - 1);
/* support for generator.throw() */ /* support for generator.throw() */
if (throwflag) { if (throwflag) {
if (_Py_EnterRecursivePy(tstate)) { if (_Py_EnterRecursivePy(tstate)) {
@ -998,7 +1077,6 @@ early_exit:
assert(frame->owner == FRAME_OWNED_BY_INTERPRETER); assert(frame->owner == FRAME_OWNED_BY_INTERPRETER);
/* Restore previous frame and exit */ /* Restore previous frame and exit */
tstate->current_frame = frame->previous; tstate->current_frame = frame->previous;
tstate->c_recursion_remaining += PY_EVAL_C_STACK_UNITS;
return NULL; return NULL;
} }
@ -1562,11 +1640,9 @@ clear_thread_frame(PyThreadState *tstate, _PyInterpreterFrame * frame)
// _PyThreadState_PopFrame, since f_code is already cleared at that point: // _PyThreadState_PopFrame, since f_code is already cleared at that point:
assert((PyObject **)frame + _PyFrame_GetCode(frame)->co_framesize == assert((PyObject **)frame + _PyFrame_GetCode(frame)->co_framesize ==
tstate->datastack_top); tstate->datastack_top);
tstate->c_recursion_remaining--;
assert(frame->frame_obj == NULL || frame->frame_obj->f_frame == frame); assert(frame->frame_obj == NULL || frame->frame_obj->f_frame == frame);
_PyFrame_ClearExceptCode(frame); _PyFrame_ClearExceptCode(frame);
PyStackRef_CLEAR(frame->f_executable); PyStackRef_CLEAR(frame->f_executable);
tstate->c_recursion_remaining++;
_PyThreadState_PopFrame(tstate, frame); _PyThreadState_PopFrame(tstate, frame);
} }
@ -1579,11 +1655,9 @@ clear_gen_frame(PyThreadState *tstate, _PyInterpreterFrame * frame)
assert(tstate->exc_info == &gen->gi_exc_state); assert(tstate->exc_info == &gen->gi_exc_state);
tstate->exc_info = gen->gi_exc_state.previous_item; tstate->exc_info = gen->gi_exc_state.previous_item;
gen->gi_exc_state.previous_item = NULL; gen->gi_exc_state.previous_item = NULL;
tstate->c_recursion_remaining--;
assert(frame->frame_obj == NULL || frame->frame_obj->f_frame == frame); assert(frame->frame_obj == NULL || frame->frame_obj->f_frame == frame);
_PyFrame_ClearExceptCode(frame); _PyFrame_ClearExceptCode(frame);
_PyErr_ClearExcState(&gen->gi_exc_state); _PyErr_ClearExcState(&gen->gi_exc_state);
tstate->c_recursion_remaining++;
frame->previous = NULL; frame->previous = NULL;
} }

View file

@ -4885,6 +4885,9 @@ codegen_with(compiler *c, stmt_ty s, int pos)
static int static int
codegen_visit_expr(compiler *c, expr_ty e) codegen_visit_expr(compiler *c, expr_ty e)
{ {
if (Py_EnterRecursiveCall(" during compilation")) {
return ERROR;
}
location loc = LOC(e); location loc = LOC(e);
switch (e->kind) { switch (e->kind) {
case NamedExpr_kind: case NamedExpr_kind:

View file

@ -5426,14 +5426,13 @@
JUMP_TO_JUMP_TARGET(); JUMP_TO_JUMP_TARGET();
} }
// CPython promises to check all non-vectorcall function calls. // CPython promises to check all non-vectorcall function calls.
if (tstate->c_recursion_remaining <= 0) { if (_Py_ReachedRecursionLimit(tstate)) {
UOP_STAT_INC(uopcode, miss); UOP_STAT_INC(uopcode, miss);
JUMP_TO_JUMP_TARGET(); JUMP_TO_JUMP_TARGET();
} }
STAT_INC(CALL, hit); STAT_INC(CALL, hit);
PyCFunction cfunc = PyCFunction_GET_FUNCTION(callable_o); PyCFunction cfunc = PyCFunction_GET_FUNCTION(callable_o);
_PyStackRef arg = args[0]; _PyStackRef arg = args[0];
_Py_EnterRecursiveCallTstateUnchecked(tstate);
_PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_SetStackPointer(frame, stack_pointer);
PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, PyCFunction_GET_SELF(callable_o), PyStackRef_AsPyObjectBorrow(arg)); PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, PyCFunction_GET_SELF(callable_o), PyStackRef_AsPyObjectBorrow(arg));
stack_pointer = _PyFrame_GetStackPointer(frame); stack_pointer = _PyFrame_GetStackPointer(frame);
@ -5813,7 +5812,7 @@
JUMP_TO_JUMP_TARGET(); JUMP_TO_JUMP_TARGET();
} }
// CPython promises to check all non-vectorcall function calls. // CPython promises to check all non-vectorcall function calls.
if (tstate->c_recursion_remaining <= 0) { if (_Py_ReachedRecursionLimit(tstate)) {
UOP_STAT_INC(uopcode, miss); UOP_STAT_INC(uopcode, miss);
JUMP_TO_JUMP_TARGET(); JUMP_TO_JUMP_TARGET();
} }
@ -5826,7 +5825,6 @@
} }
STAT_INC(CALL, hit); STAT_INC(CALL, hit);
PyCFunction cfunc = meth->ml_meth; PyCFunction cfunc = meth->ml_meth;
_Py_EnterRecursiveCallTstateUnchecked(tstate);
_PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_SetStackPointer(frame, stack_pointer);
PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, PyObject *res_o = _PyCFunction_TrampolineCall(cfunc,
PyStackRef_AsPyObjectBorrow(self_stackref), PyStackRef_AsPyObjectBorrow(self_stackref),
@ -5984,13 +5982,12 @@
JUMP_TO_JUMP_TARGET(); JUMP_TO_JUMP_TARGET();
} }
// CPython promises to check all non-vectorcall function calls. // CPython promises to check all non-vectorcall function calls.
if (tstate->c_recursion_remaining <= 0) { if (_Py_ReachedRecursionLimit(tstate)) {
UOP_STAT_INC(uopcode, miss); UOP_STAT_INC(uopcode, miss);
JUMP_TO_JUMP_TARGET(); JUMP_TO_JUMP_TARGET();
} }
STAT_INC(CALL, hit); STAT_INC(CALL, hit);
PyCFunction cfunc = meth->ml_meth; PyCFunction cfunc = meth->ml_meth;
_Py_EnterRecursiveCallTstateUnchecked(tstate);
_PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_SetStackPointer(frame, stack_pointer);
PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, self, NULL); PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, self, NULL);
stack_pointer = _PyFrame_GetStackPointer(frame); stack_pointer = _PyFrame_GetStackPointer(frame);

View file

@ -2222,7 +2222,7 @@
JUMP_TO_PREDICTED(CALL); JUMP_TO_PREDICTED(CALL);
} }
// CPython promises to check all non-vectorcall function calls. // CPython promises to check all non-vectorcall function calls.
if (tstate->c_recursion_remaining <= 0) { if (_Py_ReachedRecursionLimit(tstate)) {
UPDATE_MISS_STATS(CALL); UPDATE_MISS_STATS(CALL);
assert(_PyOpcode_Deopt[opcode] == (CALL)); assert(_PyOpcode_Deopt[opcode] == (CALL));
JUMP_TO_PREDICTED(CALL); JUMP_TO_PREDICTED(CALL);
@ -2230,7 +2230,6 @@
STAT_INC(CALL, hit); STAT_INC(CALL, hit);
PyCFunction cfunc = PyCFunction_GET_FUNCTION(callable_o); PyCFunction cfunc = PyCFunction_GET_FUNCTION(callable_o);
_PyStackRef arg = args[0]; _PyStackRef arg = args[0];
_Py_EnterRecursiveCallTstateUnchecked(tstate);
_PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_SetStackPointer(frame, stack_pointer);
PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, PyCFunction_GET_SELF(callable_o), PyStackRef_AsPyObjectBorrow(arg)); PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, PyCFunction_GET_SELF(callable_o), PyStackRef_AsPyObjectBorrow(arg));
stack_pointer = _PyFrame_GetStackPointer(frame); stack_pointer = _PyFrame_GetStackPointer(frame);
@ -3599,14 +3598,13 @@
JUMP_TO_PREDICTED(CALL); JUMP_TO_PREDICTED(CALL);
} }
// CPython promises to check all non-vectorcall function calls. // CPython promises to check all non-vectorcall function calls.
if (tstate->c_recursion_remaining <= 0) { if (_Py_ReachedRecursionLimit(tstate)) {
UPDATE_MISS_STATS(CALL); UPDATE_MISS_STATS(CALL);
assert(_PyOpcode_Deopt[opcode] == (CALL)); assert(_PyOpcode_Deopt[opcode] == (CALL));
JUMP_TO_PREDICTED(CALL); JUMP_TO_PREDICTED(CALL);
} }
STAT_INC(CALL, hit); STAT_INC(CALL, hit);
PyCFunction cfunc = meth->ml_meth; PyCFunction cfunc = meth->ml_meth;
_Py_EnterRecursiveCallTstateUnchecked(tstate);
_PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_SetStackPointer(frame, stack_pointer);
PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, self, NULL); PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, self, NULL);
stack_pointer = _PyFrame_GetStackPointer(frame); stack_pointer = _PyFrame_GetStackPointer(frame);
@ -3696,7 +3694,7 @@
JUMP_TO_PREDICTED(CALL); JUMP_TO_PREDICTED(CALL);
} }
// CPython promises to check all non-vectorcall function calls. // CPython promises to check all non-vectorcall function calls.
if (tstate->c_recursion_remaining <= 0) { if (_Py_ReachedRecursionLimit(tstate)) {
UPDATE_MISS_STATS(CALL); UPDATE_MISS_STATS(CALL);
assert(_PyOpcode_Deopt[opcode] == (CALL)); assert(_PyOpcode_Deopt[opcode] == (CALL));
JUMP_TO_PREDICTED(CALL); JUMP_TO_PREDICTED(CALL);
@ -3711,7 +3709,6 @@
} }
STAT_INC(CALL, hit); STAT_INC(CALL, hit);
PyCFunction cfunc = meth->ml_meth; PyCFunction cfunc = meth->ml_meth;
_Py_EnterRecursiveCallTstateUnchecked(tstate);
_PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_SetStackPointer(frame, stack_pointer);
PyObject *res_o = _PyCFunction_TrampolineCall(cfunc, PyObject *res_o = _PyCFunction_TrampolineCall(cfunc,
PyStackRef_AsPyObjectBorrow(self_stackref), PyStackRef_AsPyObjectBorrow(self_stackref),
@ -7292,7 +7289,6 @@
/* Restore previous frame and return. */ /* Restore previous frame and return. */
tstate->current_frame = frame->previous; tstate->current_frame = frame->previous;
assert(!_PyErr_Occurred(tstate)); assert(!_PyErr_Occurred(tstate));
tstate->c_recursion_remaining += PY_EVAL_C_STACK_UNITS;
PyObject *result = PyStackRef_AsPyObjectSteal(retval); PyObject *result = PyStackRef_AsPyObjectSteal(retval);
stack_pointer += -1; stack_pointer += -1;
assert(WITHIN_STACK_BOUNDS()); assert(WITHIN_STACK_BOUNDS());
@ -12081,7 +12077,6 @@ JUMP_TO_LABEL(error);
if (frame->owner == FRAME_OWNED_BY_INTERPRETER) { if (frame->owner == FRAME_OWNED_BY_INTERPRETER) {
/* Restore previous frame and exit */ /* Restore previous frame and exit */
tstate->current_frame = frame->previous; tstate->current_frame = frame->previous;
tstate->c_recursion_remaining += PY_EVAL_C_STACK_UNITS;
return NULL; return NULL;
} }
next_instr = frame->instr_ptr; next_instr = frame->instr_ptr;

View file

@ -1490,10 +1490,9 @@ init_threadstate(_PyThreadStateImpl *_tstate,
// thread_id and native_thread_id are set in bind_tstate(). // thread_id and native_thread_id are set in bind_tstate().
tstate->py_recursion_limit = interp->ceval.recursion_limit, tstate->py_recursion_limit = interp->ceval.recursion_limit;
tstate->py_recursion_remaining = interp->ceval.recursion_limit, tstate->py_recursion_remaining = interp->ceval.recursion_limit;
tstate->c_recursion_remaining = Py_C_RECURSION_LIMIT; tstate->c_recursion_remaining = 2;
tstate->exc_info = &tstate->exc_state; tstate->exc_info = &tstate->exc_state;
// PyGILState_Release must not try to delete this thread state. // PyGILState_Release must not try to delete this thread state.
@ -1508,6 +1507,10 @@ init_threadstate(_PyThreadStateImpl *_tstate,
tstate->previous_executor = NULL; tstate->previous_executor = NULL;
tstate->dict_global_version = 0; tstate->dict_global_version = 0;
_tstate->c_stack_soft_limit = UINTPTR_MAX;
_tstate->c_stack_top = 0;
_tstate->c_stack_hard_limit = 0;
_tstate->asyncio_running_loop = NULL; _tstate->asyncio_running_loop = NULL;
_tstate->asyncio_running_task = NULL; _tstate->asyncio_running_task = NULL;

View file

@ -1528,12 +1528,8 @@ _Py_SourceAsString(PyObject *cmd, const char *funcname, const char *what, PyComp
} }
#if defined(USE_STACKCHECK) #if defined(USE_STACKCHECK)
#if defined(WIN32) && defined(_MSC_VER)
/* Stack checking for Microsoft C */ /* Stack checking */
#include <malloc.h>
#include <excpt.h>
/* /*
* Return non-zero when we run out of memory on the stack; zero otherwise. * Return non-zero when we run out of memory on the stack; zero otherwise.
@ -1541,27 +1537,10 @@ _Py_SourceAsString(PyObject *cmd, const char *funcname, const char *what, PyComp
int int
PyOS_CheckStack(void) PyOS_CheckStack(void)
{ {
__try { PyThreadState *tstate = _PyThreadState_GET();
/* alloca throws a stack overflow exception if there's return _Py_ReachedRecursionLimit(tstate);
not enough space left on the stack */
alloca(PYOS_STACK_MARGIN * sizeof(void*));
return 0;
} __except (GetExceptionCode() == STATUS_STACK_OVERFLOW ?
EXCEPTION_EXECUTE_HANDLER :
EXCEPTION_CONTINUE_SEARCH) {
int errcode = _resetstkoflw();
if (errcode == 0)
{
Py_FatalError("Could not reset the stack!");
}
}
return 1;
} }
#endif /* WIN32 && _MSC_VER */
/* Alternate implementations can be added here... */
#endif /* USE_STACKCHECK */ #endif /* USE_STACKCHECK */
/* Deprecated C API functions still provided for binary compatibility */ /* Deprecated C API functions still provided for binary compatibility */

View file

@ -406,7 +406,6 @@ _PySymtable_Build(mod_ty mod, PyObject *filename, _PyFutureFeatures *future)
asdl_stmt_seq *seq; asdl_stmt_seq *seq;
Py_ssize_t i; Py_ssize_t i;
PyThreadState *tstate; PyThreadState *tstate;
int starting_recursion_depth;
if (st == NULL) if (st == NULL)
return NULL; return NULL;
@ -423,11 +422,6 @@ _PySymtable_Build(mod_ty mod, PyObject *filename, _PyFutureFeatures *future)
_PySymtable_Free(st); _PySymtable_Free(st);
return NULL; return NULL;
} }
/* Be careful here to prevent overflow. */
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth;
st->recursion_depth = starting_recursion_depth;
st->recursion_limit = Py_C_RECURSION_LIMIT;
/* Make the initial symbol information gathering pass */ /* Make the initial symbol information gathering pass */
@ -469,14 +463,6 @@ _PySymtable_Build(mod_ty mod, PyObject *filename, _PyFutureFeatures *future)
_PySymtable_Free(st); _PySymtable_Free(st);
return NULL; return NULL;
} }
/* Check that the recursion depth counting balanced correctly */
if (st->recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"symtable analysis recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, st->recursion_depth);
_PySymtable_Free(st);
return NULL;
}
/* Make the second symbol analysis pass */ /* Make the second symbol analysis pass */
if (symtable_analyze(st)) { if (symtable_analyze(st)) {
#if _PY_DUMP_SYMTABLE #if _PY_DUMP_SYMTABLE
@ -1736,19 +1722,12 @@ symtable_enter_type_param_block(struct symtable *st, identifier name,
} \ } \
} while(0) } while(0)
#define ENTER_RECURSIVE(ST) \ #define ENTER_RECURSIVE() \
do { \ if (Py_EnterRecursiveCall(" during compilation")) { \
if (++(ST)->recursion_depth > (ST)->recursion_limit) { \
PyErr_SetString(PyExc_RecursionError, \
"maximum recursion depth exceeded during compilation"); \
return 0; \ return 0; \
} \ }
} while(0)
#define LEAVE_RECURSIVE(ST) \ #define LEAVE_RECURSIVE() Py_LeaveRecursiveCall();
do { \
--(ST)->recursion_depth; \
} while(0)
static int static int
@ -1823,7 +1802,7 @@ maybe_set_ste_coroutine_for_module(struct symtable *st, stmt_ty s)
static int static int
symtable_visit_stmt(struct symtable *st, stmt_ty s) symtable_visit_stmt(struct symtable *st, stmt_ty s)
{ {
ENTER_RECURSIVE(st); ENTER_RECURSIVE();
switch (s->kind) { switch (s->kind) {
case FunctionDef_kind: { case FunctionDef_kind: {
if (!symtable_add_def(st, s->v.FunctionDef.name, DEF_LOCAL, LOCATION(s))) if (!symtable_add_def(st, s->v.FunctionDef.name, DEF_LOCAL, LOCATION(s)))
@ -2235,7 +2214,7 @@ symtable_visit_stmt(struct symtable *st, stmt_ty s)
VISIT_SEQ(st, stmt, s->v.AsyncFor.orelse); VISIT_SEQ(st, stmt, s->v.AsyncFor.orelse);
break; break;
} }
LEAVE_RECURSIVE(st); LEAVE_RECURSIVE();
return 1; return 1;
} }
@ -2358,7 +2337,7 @@ symtable_handle_namedexpr(struct symtable *st, expr_ty e)
static int static int
symtable_visit_expr(struct symtable *st, expr_ty e) symtable_visit_expr(struct symtable *st, expr_ty e)
{ {
ENTER_RECURSIVE(st); ENTER_RECURSIVE();
switch (e->kind) { switch (e->kind) {
case NamedExpr_kind: case NamedExpr_kind:
if (!symtable_raise_if_annotation_block(st, "named expression", e)) { if (!symtable_raise_if_annotation_block(st, "named expression", e)) {
@ -2529,7 +2508,7 @@ symtable_visit_expr(struct symtable *st, expr_ty e)
VISIT_SEQ(st, expr, e->v.Tuple.elts); VISIT_SEQ(st, expr, e->v.Tuple.elts);
break; break;
} }
LEAVE_RECURSIVE(st); LEAVE_RECURSIVE();
return 1; return 1;
} }
@ -2563,7 +2542,7 @@ symtable_visit_type_param_bound_or_default(
static int static int
symtable_visit_type_param(struct symtable *st, type_param_ty tp) symtable_visit_type_param(struct symtable *st, type_param_ty tp)
{ {
ENTER_RECURSIVE(st); ENTER_RECURSIVE();
switch(tp->kind) { switch(tp->kind) {
case TypeVar_kind: case TypeVar_kind:
if (!symtable_add_def(st, tp->v.TypeVar.name, DEF_TYPE_PARAM | DEF_LOCAL, LOCATION(tp))) if (!symtable_add_def(st, tp->v.TypeVar.name, DEF_TYPE_PARAM | DEF_LOCAL, LOCATION(tp)))
@ -2612,14 +2591,14 @@ symtable_visit_type_param(struct symtable *st, type_param_ty tp)
} }
break; break;
} }
LEAVE_RECURSIVE(st); LEAVE_RECURSIVE();
return 1; return 1;
} }
static int static int
symtable_visit_pattern(struct symtable *st, pattern_ty p) symtable_visit_pattern(struct symtable *st, pattern_ty p)
{ {
ENTER_RECURSIVE(st); ENTER_RECURSIVE();
switch (p->kind) { switch (p->kind) {
case MatchValue_kind: case MatchValue_kind:
VISIT(st, expr, p->v.MatchValue.value); VISIT(st, expr, p->v.MatchValue.value);
@ -2668,7 +2647,7 @@ symtable_visit_pattern(struct symtable *st, pattern_ty p)
VISIT_SEQ(st, pattern, p->v.MatchOr.patterns); VISIT_SEQ(st, pattern, p->v.MatchOr.patterns);
break; break;
} }
LEAVE_RECURSIVE(st); LEAVE_RECURSIVE();
return 1; return 1;
} }

View file

@ -653,7 +653,6 @@ NON_ESCAPING_FUNCTIONS = (
"_PyUnicode_JoinArray", "_PyUnicode_JoinArray",
"_Py_CHECK_EMSCRIPTEN_SIGNALS_PERIODICALLY", "_Py_CHECK_EMSCRIPTEN_SIGNALS_PERIODICALLY",
"_Py_DECREF_NO_DEALLOC", "_Py_DECREF_NO_DEALLOC",
"_Py_EnterRecursiveCallTstateUnchecked",
"_Py_ID", "_Py_ID",
"_Py_IsImmortal", "_Py_IsImmortal",
"_Py_LeaveRecursiveCallPy", "_Py_LeaveRecursiveCallPy",
@ -673,6 +672,7 @@ NON_ESCAPING_FUNCTIONS = (
"initial_temperature_backoff_counter", "initial_temperature_backoff_counter",
"JUMP_TO_LABEL", "JUMP_TO_LABEL",
"restart_backoff_counter", "restart_backoff_counter",
"_Py_ReachedRecursionLimit",
) )
def find_stmt_start(node: parser.CodeDef, idx: int) -> lexer.Token: def find_stmt_start(node: parser.CodeDef, idx: int) -> lexer.Token:

View file

@ -44,7 +44,7 @@ EXTENSION_PREFIX = """\
# define MAXSTACK 4000 # define MAXSTACK 4000
# endif # endif
#else #else
# define MAXSTACK 6000 # define MAXSTACK 4000
#endif #endif
""" """
@ -380,7 +380,7 @@ class CParserGenerator(ParserGenerator, GrammarVisitor):
self.cleanup_statements: List[str] = [] self.cleanup_statements: List[str] = []
def add_level(self) -> None: def add_level(self) -> None:
self.print("if (p->level++ == MAXSTACK) {") self.print("if (p->level++ == MAXSTACK || _Py_ReachedRecursionLimitWithMargin(PyThreadState_Get(), 1)) {")
with self.indent(): with self.indent():
self.print("_Pypegen_stack_overflow(p);") self.print("_Pypegen_stack_overflow(p);")
self.print("}") self.print("}")

6
configure generated vendored
View file

@ -19620,6 +19620,12 @@ if test "x$ac_cv_func_pthread_setname_np" = xyes
then : then :
printf "%s\n" "#define HAVE_PTHREAD_SETNAME_NP 1" >>confdefs.h printf "%s\n" "#define HAVE_PTHREAD_SETNAME_NP 1" >>confdefs.h
fi
ac_fn_c_check_func "$LINENO" "pthread_getattr_np" "ac_cv_func_pthread_getattr_np"
if test "x$ac_cv_func_pthread_getattr_np" = xyes
then :
printf "%s\n" "#define HAVE_PTHREAD_GETATTR_NP 1" >>confdefs.h
fi fi
ac_fn_c_check_func "$LINENO" "ptsname" "ac_cv_func_ptsname" ac_fn_c_check_func "$LINENO" "ptsname" "ac_cv_func_ptsname"
if test "x$ac_cv_func_ptsname" = xyes if test "x$ac_cv_func_ptsname" = xyes

View file

@ -5147,7 +5147,7 @@ AC_CHECK_FUNCS([ \
posix_spawn_file_actions_addclosefrom_np \ posix_spawn_file_actions_addclosefrom_np \
pread preadv preadv2 process_vm_readv \ pread preadv preadv2 process_vm_readv \
pthread_cond_timedwait_relative_np pthread_condattr_setclock pthread_init \ pthread_cond_timedwait_relative_np pthread_condattr_setclock pthread_init \
pthread_kill pthread_getname_np pthread_setname_np \ pthread_kill pthread_getname_np pthread_setname_np pthread_getattr_np \
ptsname ptsname_r pwrite pwritev pwritev2 readlink readlinkat readv realpath renameat \ ptsname ptsname_r pwrite pwritev pwritev2 readlink readlinkat readv realpath renameat \
rtpSpawn sched_get_priority_max sched_rr_get_interval sched_setaffinity \ rtpSpawn sched_get_priority_max sched_rr_get_interval sched_setaffinity \
sched_setparam sched_setscheduler sem_clockwait sem_getvalue sem_open \ sched_setparam sched_setscheduler sem_clockwait sem_getvalue sem_open \

View file

@ -984,6 +984,9 @@
/* Defined for Solaris 2.6 bug in pthread header. */ /* Defined for Solaris 2.6 bug in pthread header. */
#undef HAVE_PTHREAD_DESTRUCTOR #undef HAVE_PTHREAD_DESTRUCTOR
/* Define to 1 if you have the 'pthread_getattr_np' function. */
#undef HAVE_PTHREAD_GETATTR_NP
/* Define to 1 if you have the 'pthread_getcpuclockid' function. */ /* Define to 1 if you have the 'pthread_getcpuclockid' function. */
#undef HAVE_PTHREAD_GETCPUCLOCKID #undef HAVE_PTHREAD_GETCPUCLOCKID