mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
bpo-43897: AST validation for pattern matching nodes (GH24771)
This commit is contained in:
parent
53b9458f2e
commit
31bec6f1b1
2 changed files with 265 additions and 32 deletions
|
@ -696,7 +696,7 @@ class AST_Tests(unittest.TestCase):
|
|||
for constant in "True", "False", "None":
|
||||
expr = ast.Expression(ast.Name(constant, ast.Load()))
|
||||
ast.fix_missing_locations(expr)
|
||||
with self.assertRaisesRegex(ValueError, f"Name node can't be used with '{constant}' constant"):
|
||||
with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"):
|
||||
compile(expr, "<test>", "eval")
|
||||
|
||||
def test_precedence_enum(self):
|
||||
|
@ -1507,6 +1507,147 @@ class ASTValidatorTests(unittest.TestCase):
|
|||
mod = ast.parse(source, fn)
|
||||
compile(mod, fn, "exec")
|
||||
|
||||
constant_1 = ast.Constant(1)
|
||||
pattern_1 = ast.MatchValue(constant_1)
|
||||
|
||||
constant_x = ast.Constant('x')
|
||||
pattern_x = ast.MatchValue(constant_x)
|
||||
|
||||
constant_true = ast.Constant(True)
|
||||
pattern_true = ast.MatchSingleton(True)
|
||||
|
||||
name_carter = ast.Name('carter', ast.Load())
|
||||
|
||||
_MATCH_PATTERNS = [
|
||||
ast.MatchValue(
|
||||
ast.Attribute(
|
||||
ast.Attribute(
|
||||
ast.Name('x', ast.Store()),
|
||||
'y', ast.Load()
|
||||
),
|
||||
'z', ast.Load()
|
||||
)
|
||||
),
|
||||
ast.MatchValue(
|
||||
ast.Attribute(
|
||||
ast.Attribute(
|
||||
ast.Name('x', ast.Load()),
|
||||
'y', ast.Store()
|
||||
),
|
||||
'z', ast.Load()
|
||||
)
|
||||
),
|
||||
ast.MatchValue(
|
||||
ast.Constant(...)
|
||||
),
|
||||
ast.MatchValue(
|
||||
ast.Constant(True)
|
||||
),
|
||||
ast.MatchValue(
|
||||
ast.Constant((1,2,3))
|
||||
),
|
||||
ast.MatchSingleton('string'),
|
||||
ast.MatchSequence([
|
||||
ast.MatchSingleton('string')
|
||||
]),
|
||||
ast.MatchSequence(
|
||||
[
|
||||
ast.MatchSequence(
|
||||
[
|
||||
ast.MatchSingleton('string')
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
ast.MatchMapping(
|
||||
[constant_1, constant_true],
|
||||
[pattern_x]
|
||||
),
|
||||
ast.MatchMapping(
|
||||
[constant_true, constant_1],
|
||||
[pattern_x, pattern_1],
|
||||
rest='True'
|
||||
),
|
||||
ast.MatchMapping(
|
||||
[constant_true, ast.Starred(ast.Name('lol', ast.Load()), ast.Load())],
|
||||
[pattern_x, pattern_1],
|
||||
rest='legit'
|
||||
),
|
||||
ast.MatchClass(
|
||||
ast.Attribute(
|
||||
ast.Attribute(
|
||||
constant_x,
|
||||
'y', ast.Load()),
|
||||
'z', ast.Load()),
|
||||
patterns=[], kwd_attrs=[], kwd_patterns=[]
|
||||
),
|
||||
ast.MatchClass(
|
||||
name_carter,
|
||||
patterns=[],
|
||||
kwd_attrs=['True'],
|
||||
kwd_patterns=[pattern_1]
|
||||
),
|
||||
ast.MatchClass(
|
||||
name_carter,
|
||||
patterns=[],
|
||||
kwd_attrs=[],
|
||||
kwd_patterns=[pattern_1]
|
||||
),
|
||||
ast.MatchClass(
|
||||
name_carter,
|
||||
patterns=[ast.MatchSingleton('string')],
|
||||
kwd_attrs=[],
|
||||
kwd_patterns=[]
|
||||
),
|
||||
ast.MatchClass(
|
||||
name_carter,
|
||||
patterns=[ast.MatchStar()],
|
||||
kwd_attrs=[],
|
||||
kwd_patterns=[]
|
||||
),
|
||||
ast.MatchClass(
|
||||
name_carter,
|
||||
patterns=[],
|
||||
kwd_attrs=[],
|
||||
kwd_patterns=[ast.MatchStar()]
|
||||
),
|
||||
ast.MatchSequence(
|
||||
[
|
||||
ast.MatchStar("True")
|
||||
]
|
||||
),
|
||||
ast.MatchAs(
|
||||
name='False'
|
||||
),
|
||||
ast.MatchOr(
|
||||
[]
|
||||
),
|
||||
ast.MatchOr(
|
||||
[pattern_1]
|
||||
),
|
||||
ast.MatchOr(
|
||||
[pattern_1, pattern_x, ast.MatchSingleton('xxx')]
|
||||
)
|
||||
]
|
||||
|
||||
def test_match_validation_pattern(self):
|
||||
name_x = ast.Name('x', ast.Load())
|
||||
for pattern in self._MATCH_PATTERNS:
|
||||
with self.subTest(ast.dump(pattern, indent=4)):
|
||||
node = ast.Match(
|
||||
subject=name_x,
|
||||
cases = [
|
||||
ast.match_case(
|
||||
pattern=pattern,
|
||||
body = [ast.Pass()]
|
||||
)
|
||||
]
|
||||
)
|
||||
node = ast.fix_missing_locations(node)
|
||||
module = ast.Module([node], [])
|
||||
with self.assertRaises(ValueError):
|
||||
compile(module, "<test>", "exec")
|
||||
|
||||
|
||||
class ConstantTests(unittest.TestCase):
|
||||
"""Tests on the ast.Constant node type."""
|
||||
|
|
148
Python/ast.c
148
Python/ast.c
|
@ -16,6 +16,7 @@ struct validator {
|
|||
|
||||
static int validate_stmts(struct validator *, asdl_stmt_seq *);
|
||||
static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
|
||||
static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
|
||||
static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
|
||||
static int validate_stmt(struct validator *, stmt_ty);
|
||||
static int validate_expr(struct validator *, expr_ty, expr_context_ty);
|
||||
|
@ -33,7 +34,7 @@ validate_name(PyObject *name)
|
|||
};
|
||||
for (int i = 0; forbidden[i] != NULL; i++) {
|
||||
if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
|
||||
PyErr_Format(PyExc_ValueError, "Name node can't be used with '%s' constant", forbidden[i]);
|
||||
PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
@ -448,6 +449,21 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
|
|||
switch (exp->kind)
|
||||
{
|
||||
case Constant_kind:
|
||||
/* Ellipsis and immutable sequences are not allowed.
|
||||
For True, False and None, MatchSingleton() should
|
||||
be used */
|
||||
if (!validate_expr(state, exp, Load)) {
|
||||
return 0;
|
||||
}
|
||||
PyObject *literal = exp->v.Constant.value;
|
||||
if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
|
||||
PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
|
||||
PyUnicode_CheckExact(literal)) {
|
||||
return 1;
|
||||
}
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"unexpected constant inside of a literal pattern");
|
||||
return 0;
|
||||
case Attribute_kind:
|
||||
// Constants and attribute lookups are always permitted
|
||||
return 1;
|
||||
|
@ -465,10 +481,13 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
|
|||
return 1;
|
||||
}
|
||||
break;
|
||||
case JoinedStr_kind:
|
||||
// Handled in the later stages
|
||||
return 1;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
PyErr_SetString(PyExc_SyntaxError,
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"patterns may only match literals and attribute lookups");
|
||||
return 0;
|
||||
}
|
||||
|
@ -489,51 +508,101 @@ validate_pattern(struct validator *state, pattern_ty p)
|
|||
ret = validate_pattern_match_value(state, p->v.MatchValue.value);
|
||||
break;
|
||||
case MatchSingleton_kind:
|
||||
// TODO: Check constant is specifically None, True, or False
|
||||
ret = validate_constant(state, p->v.MatchSingleton.value);
|
||||
ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
|
||||
if (!ret) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchSingleton can only contain True, False and None");
|
||||
}
|
||||
break;
|
||||
case MatchSequence_kind:
|
||||
// TODO: Validate all subpatterns
|
||||
// return validate_patterns(state, p->v.MatchSequence.patterns);
|
||||
ret = 1;
|
||||
ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
|
||||
break;
|
||||
case MatchMapping_kind:
|
||||
// TODO: check "rest" target name is valid
|
||||
if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchMapping doesn't have the same number of keys as patterns");
|
||||
return 0;
|
||||
ret = 0;
|
||||
break;
|
||||
}
|
||||
// null_ok=0 for key expressions, as rest-of-mapping is captured in "rest"
|
||||
// TODO: replace with more restrictive expression validator, as per MatchValue above
|
||||
if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) {
|
||||
return 0;
|
||||
|
||||
if (p->v.MatchMapping.rest && !validate_name(p->v.MatchMapping.rest)) {
|
||||
ret = 0;
|
||||
break;
|
||||
}
|
||||
// TODO: Validate all subpatterns
|
||||
// ret = validate_patterns(state, p->v.MatchMapping.patterns);
|
||||
ret = 1;
|
||||
|
||||
asdl_expr_seq *keys = p->v.MatchMapping.keys;
|
||||
for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
|
||||
expr_ty key = asdl_seq_GET(keys, i);
|
||||
if (key->kind == Constant_kind) {
|
||||
PyObject *literal = key->v.Constant.value;
|
||||
if (literal == Py_None || PyBool_Check(literal)) {
|
||||
/* validate_pattern_match_value will ensure the key
|
||||
doesn't contain True, False and None but it is
|
||||
syntactically valid, so we will pass those on in
|
||||
a special case. */
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!validate_pattern_match_value(state, key)) {
|
||||
ret = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
|
||||
break;
|
||||
case MatchClass_kind:
|
||||
if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchClass doesn't have the same number of keyword attributes as patterns");
|
||||
return 0;
|
||||
ret = 0;
|
||||
break;
|
||||
}
|
||||
// TODO: Restrict cls lookup to being a name or attribute
|
||||
if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
|
||||
ret = 0;
|
||||
break;
|
||||
}
|
||||
|
||||
expr_ty cls = p->v.MatchClass.cls;
|
||||
while (1) {
|
||||
if (cls->kind == Name_kind) {
|
||||
break;
|
||||
}
|
||||
else if (cls->kind == Attribute_kind) {
|
||||
cls = cls->v.Attribute.value;
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchClass cls field can only contain Name or Attribute nodes.");
|
||||
state->recursion_depth--;
|
||||
return 0;
|
||||
}
|
||||
// TODO: Validate all subpatterns
|
||||
// return validate_patterns(state, p->v.MatchClass.patterns) &&
|
||||
// validate_patterns(state, p->v.MatchClass.kwd_patterns);
|
||||
ret = 1;
|
||||
}
|
||||
|
||||
for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
|
||||
PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
|
||||
if (!validate_name(identifier)) {
|
||||
state->recursion_depth--;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
|
||||
ret = 0;
|
||||
break;
|
||||
}
|
||||
|
||||
ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
|
||||
break;
|
||||
case MatchStar_kind:
|
||||
// TODO: check target name is valid
|
||||
ret = 1;
|
||||
ret = p->v.MatchStar.name == NULL || validate_name(p->v.MatchStar.name);
|
||||
break;
|
||||
case MatchAs_kind:
|
||||
// TODO: check target name is valid
|
||||
if (p->v.MatchAs.name && !validate_name(p->v.MatchAs.name)) {
|
||||
ret = 0;
|
||||
break;
|
||||
}
|
||||
if (p->v.MatchAs.pattern == NULL) {
|
||||
ret = 1;
|
||||
}
|
||||
|
@ -547,9 +616,13 @@ validate_pattern(struct validator *state, pattern_ty p)
|
|||
}
|
||||
break;
|
||||
case MatchOr_kind:
|
||||
// TODO: Validate all subpatterns
|
||||
// return validate_patterns(state, p->v.MatchOr.patterns);
|
||||
ret = 1;
|
||||
if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"MatchOr requires at least 2 patterns");
|
||||
ret = 0;
|
||||
break;
|
||||
}
|
||||
ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
|
||||
break;
|
||||
// No default case, so the compiler will emit a warning if new pattern
|
||||
// kinds are added without being handled here
|
||||
|
@ -815,6 +888,25 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct
|
|||
return 1;
|
||||
}
|
||||
|
||||
static int
|
||||
validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
|
||||
{
|
||||
Py_ssize_t i;
|
||||
for (i = 0; i < asdl_seq_LEN(patterns); i++) {
|
||||
pattern_ty pattern = asdl_seq_GET(patterns, i);
|
||||
if (pattern->kind == MatchStar_kind && !star_ok) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Can't use MatchStar within this sequence of patterns");
|
||||
return 0;
|
||||
}
|
||||
if (!validate_pattern(state, pattern)) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
/* See comments in symtable.c. */
|
||||
#define COMPILER_STACK_FRAME_SCALE 3
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue