mirror of
https://github.com/python/cpython.git
synced 2025-07-08 03:45:36 +00:00
bpo-43892: Make match patterns explicit in the AST (GH-25585)
Co-authored-by: Brandt Bucher <brandtbucher@gmail.com>
This commit is contained in:
parent
e52ab42ced
commit
1e7b858575
20 changed files with 3460 additions and 1377 deletions
238
Python/ast.c
238
Python/ast.c
|
@ -7,6 +7,7 @@
|
|||
#include "pycore_pystate.h" // _PyThreadState_GET()
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
struct validator {
|
||||
int recursion_depth; /* current recursion depth */
|
||||
|
@ -18,6 +19,7 @@ static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, i
|
|||
static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
|
||||
static int validate_stmt(struct validator *, stmt_ty);
|
||||
static int validate_expr(struct validator *, expr_ty, expr_context_ty);
|
||||
static int validate_pattern(struct validator *, pattern_ty);
|
||||
|
||||
static int
|
||||
validate_name(PyObject *name)
|
||||
|
@ -88,9 +90,9 @@ expr_context_name(expr_context_ty ctx)
|
|||
return "Store";
|
||||
case Del:
|
||||
return "Del";
|
||||
default:
|
||||
Py_UNREACHABLE();
|
||||
// No default case so compiler emits warning for unhandled cases
|
||||
}
|
||||
Py_UNREACHABLE();
|
||||
}
|
||||
|
||||
static int
|
||||
|
@ -180,7 +182,7 @@ validate_constant(struct validator *state, PyObject *value)
|
|||
static int
|
||||
validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
|
||||
{
|
||||
int ret;
|
||||
int ret = -1;
|
||||
if (++state->recursion_depth > state->recursion_limit) {
|
||||
PyErr_SetString(PyExc_RecursionError,
|
||||
"maximum recursion depth exceeded during compilation");
|
||||
|
@ -351,33 +353,215 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
|
|||
case NamedExpr_kind:
|
||||
ret = validate_expr(state, exp->v.NamedExpr.value, Load);
|
||||
break;
|
||||
case MatchAs_kind:
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchAs is only valid in match_case patterns");
|
||||
return 0;
|
||||
case MatchOr_kind:
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchOr is only valid in match_case patterns");
|
||||
return 0;
|
||||
/* This last case doesn't have any checking. */
|
||||
case Name_kind:
|
||||
ret = 1;
|
||||
break;
|
||||
default:
|
||||
// No default case so compiler emits warning for unhandled cases
|
||||
}
|
||||
if (ret < 0) {
|
||||
PyErr_SetString(PyExc_SystemError, "unexpected expression");
|
||||
return 0;
|
||||
ret = 0;
|
||||
}
|
||||
state->recursion_depth--;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
// Note: the ensure_literal_* functions are only used to validate a restricted
|
||||
// set of non-recursive literals that have already been checked with
|
||||
// validate_expr, so they don't accept the validator state
|
||||
static int
|
||||
validate_pattern(expr_ty p)
|
||||
ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary)
|
||||
{
|
||||
// Coming soon (thanks Batuhan)!
|
||||
assert(exp->kind == Constant_kind);
|
||||
PyObject *value = exp->v.Constant.value;
|
||||
return (allow_real && PyFloat_CheckExact(value)) ||
|
||||
(allow_real && PyLong_CheckExact(value)) ||
|
||||
(allow_imaginary && PyComplex_CheckExact(value));
|
||||
}
|
||||
|
||||
static int
|
||||
ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary)
|
||||
{
|
||||
assert(exp->kind == UnaryOp_kind);
|
||||
// Must be negation ...
|
||||
if (exp->v.UnaryOp.op != USub) {
|
||||
return 0;
|
||||
}
|
||||
// ... of a constant ...
|
||||
expr_ty operand = exp->v.UnaryOp.operand;
|
||||
if (operand->kind != Constant_kind) {
|
||||
return 0;
|
||||
}
|
||||
// ... number
|
||||
return ensure_literal_number(operand, allow_real, allow_imaginary);
|
||||
}
|
||||
|
||||
static int
|
||||
ensure_literal_complex(expr_ty exp)
|
||||
{
|
||||
assert(exp->kind == BinOp_kind);
|
||||
expr_ty left = exp->v.BinOp.left;
|
||||
expr_ty right = exp->v.BinOp.right;
|
||||
// Ensure op is addition or subtraction
|
||||
if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) {
|
||||
return 0;
|
||||
}
|
||||
// Check LHS is a real number (potentially signed)
|
||||
switch (left->kind)
|
||||
{
|
||||
case Constant_kind:
|
||||
if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) {
|
||||
return 0;
|
||||
}
|
||||
break;
|
||||
case UnaryOp_kind:
|
||||
if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) {
|
||||
return 0;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
// Check RHS is an imaginary number (no separate sign allowed)
|
||||
switch (right->kind)
|
||||
{
|
||||
case Constant_kind:
|
||||
if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) {
|
||||
return 0;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int
|
||||
validate_pattern_match_value(struct validator *state, expr_ty exp)
|
||||
{
|
||||
if (!validate_expr(state, exp, Load)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
switch (exp->kind)
|
||||
{
|
||||
case Constant_kind:
|
||||
case Attribute_kind:
|
||||
// Constants and attribute lookups are always permitted
|
||||
return 1;
|
||||
case UnaryOp_kind:
|
||||
// Negated numbers are permitted (whether real or imaginary)
|
||||
// Compiler will complain if AST folding doesn't create a constant
|
||||
if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) {
|
||||
return 1;
|
||||
}
|
||||
break;
|
||||
case BinOp_kind:
|
||||
// Complex literals are permitted
|
||||
// Compiler will complain if AST folding doesn't create a constant
|
||||
if (ensure_literal_complex(exp)) {
|
||||
return 1;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
PyErr_SetString(PyExc_SyntaxError,
|
||||
"patterns may only match literals and attribute lookups");
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int
|
||||
validate_pattern(struct validator *state, pattern_ty p)
|
||||
{
|
||||
int ret = -1;
|
||||
if (++state->recursion_depth > state->recursion_limit) {
|
||||
PyErr_SetString(PyExc_RecursionError,
|
||||
"maximum recursion depth exceeded during compilation");
|
||||
return 0;
|
||||
}
|
||||
// Coming soon: https://bugs.python.org/issue43897 (thanks Batuhan)!
|
||||
// TODO: Ensure no subnodes use "_" as an ordinary identifier
|
||||
switch (p->kind) {
|
||||
case MatchValue_kind:
|
||||
ret = validate_pattern_match_value(state, p->v.MatchValue.value);
|
||||
break;
|
||||
case MatchSingleton_kind:
|
||||
// TODO: Check constant is specifically None, True, or False
|
||||
ret = validate_constant(state, p->v.MatchSingleton.value);
|
||||
break;
|
||||
case MatchSequence_kind:
|
||||
// TODO: Validate all subpatterns
|
||||
// return validate_patterns(state, p->v.MatchSequence.patterns);
|
||||
ret = 1;
|
||||
break;
|
||||
case MatchMapping_kind:
|
||||
// TODO: check "rest" target name is valid
|
||||
if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchMapping doesn't have the same number of keys as patterns");
|
||||
return 0;
|
||||
}
|
||||
// null_ok=0 for key expressions, as rest-of-mapping is captured in "rest"
|
||||
// TODO: replace with more restrictive expression validator, as per MatchValue above
|
||||
if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) {
|
||||
return 0;
|
||||
}
|
||||
// TODO: Validate all subpatterns
|
||||
// ret = validate_patterns(state, p->v.MatchMapping.patterns);
|
||||
ret = 1;
|
||||
break;
|
||||
case MatchClass_kind:
|
||||
if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchClass doesn't have the same number of keyword attributes as patterns");
|
||||
return 0;
|
||||
}
|
||||
// TODO: Restrict cls lookup to being a name or attribute
|
||||
if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
|
||||
return 0;
|
||||
}
|
||||
// TODO: Validate all subpatterns
|
||||
// return validate_patterns(state, p->v.MatchClass.patterns) &&
|
||||
// validate_patterns(state, p->v.MatchClass.kwd_patterns);
|
||||
ret = 1;
|
||||
break;
|
||||
case MatchStar_kind:
|
||||
// TODO: check target name is valid
|
||||
ret = 1;
|
||||
break;
|
||||
case MatchAs_kind:
|
||||
// TODO: check target name is valid
|
||||
if (p->v.MatchAs.pattern == NULL) {
|
||||
ret = 1;
|
||||
}
|
||||
else if (p->v.MatchAs.name == NULL) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchAs must specify a target name if a pattern is given");
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
ret = validate_pattern(state, p->v.MatchAs.pattern);
|
||||
}
|
||||
break;
|
||||
case MatchOr_kind:
|
||||
// TODO: Validate all subpatterns
|
||||
// return validate_patterns(state, p->v.MatchOr.patterns);
|
||||
ret = 1;
|
||||
break;
|
||||
// No default case, so the compiler will emit a warning if new pattern
|
||||
// kinds are added without being handled here
|
||||
}
|
||||
if (ret < 0) {
|
||||
PyErr_SetString(PyExc_SystemError, "unexpected pattern");
|
||||
ret = 0;
|
||||
}
|
||||
state->recursion_depth--;
|
||||
return ret;
|
||||
}
|
||||
|
||||
static int
|
||||
_validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner)
|
||||
{
|
||||
|
@ -404,7 +588,7 @@ validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner)
|
|||
static int
|
||||
validate_stmt(struct validator *state, stmt_ty stmt)
|
||||
{
|
||||
int ret;
|
||||
int ret = -1;
|
||||
Py_ssize_t i;
|
||||
if (++state->recursion_depth > state->recursion_limit) {
|
||||
PyErr_SetString(PyExc_RecursionError,
|
||||
|
@ -502,7 +686,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
|
|||
}
|
||||
for (i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) {
|
||||
match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i);
|
||||
if (!validate_pattern(m->pattern)
|
||||
if (!validate_pattern(state, m->pattern)
|
||||
|| (m->guard && !validate_expr(state, m->guard, Load))
|
||||
|| !validate_body(state, m->body, "match_case")) {
|
||||
return 0;
|
||||
|
@ -582,9 +766,11 @@ validate_stmt(struct validator *state, stmt_ty stmt)
|
|||
case Continue_kind:
|
||||
ret = 1;
|
||||
break;
|
||||
default:
|
||||
// No default case so compiler emits warning for unhandled cases
|
||||
}
|
||||
if (ret < 0) {
|
||||
PyErr_SetString(PyExc_SystemError, "unexpected statement");
|
||||
return 0;
|
||||
ret = 0;
|
||||
}
|
||||
state->recursion_depth--;
|
||||
return ret;
|
||||
|
@ -635,7 +821,7 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct
|
|||
int
|
||||
_PyAST_Validate(mod_ty mod)
|
||||
{
|
||||
int res = 0;
|
||||
int res = -1;
|
||||
struct validator state;
|
||||
PyThreadState *tstate;
|
||||
int recursion_limit = Py_GetRecursionLimit();
|
||||
|
@ -663,10 +849,16 @@ _PyAST_Validate(mod_ty mod)
|
|||
case Expression_kind:
|
||||
res = validate_expr(&state, mod->v.Expression.body, Load);
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(PyExc_SystemError, "impossible module node");
|
||||
res = 0;
|
||||
case FunctionType_kind:
|
||||
res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) &&
|
||||
validate_expr(&state, mod->v.FunctionType.returns, Load);
|
||||
break;
|
||||
// No default case so compiler emits warning for unhandled cases
|
||||
}
|
||||
|
||||
if (res < 0) {
|
||||
PyErr_SetString(PyExc_SystemError, "impossible module node");
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Check that the recursion depth counting balanced correctly */
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue