[syntax-errors] Invalid syntax in annotations (#17101)

Summary
--

This PR detects the use of invalid syntax in annotation scopes,
including
`yield` and `yield from` expressions and named expressions. I combined a
few
different types of CPython errors here, but I think the resulting error
messages
still make sense and are even preferable to what CPython gives. For
example, we
report `yield expression cannot be used in a type annotation` for both
of these:

```pycon
>>> def f[T](x: (yield 1)): ...
  File "<python-input-26>", line 1
    def f[T](x: (yield 1)): ...
                 ^^^^^^^
SyntaxError: yield expression cannot be used within the definition of a generic
>>> def foo() -> (yield x): ...
  File "<python-input-28>", line 1
    def foo() -> (yield x): ...
                  ^^^^^^^
SyntaxError: 'yield' outside function
```

Fixes https://github.com/astral-sh/ruff/issues/11118.

Test Plan
--

New inline tests, along with some updates to existing tests.
This commit is contained in:
Brent Westbrook 2025-04-03 17:56:55 -04:00 committed by GitHub
parent 24b1b1d52c
commit c2b2e42ad3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 2866 additions and 141 deletions

View file

@ -114,6 +114,120 @@ impl SemanticSyntaxChecker {
}
Self::debug_shadowing(stmt, ctx);
Self::check_annotation(stmt, ctx);
}
fn check_annotation<Ctx: SemanticSyntaxContext>(stmt: &ast::Stmt, ctx: &Ctx) {
match stmt {
Stmt::FunctionDef(ast::StmtFunctionDef {
type_params,
parameters,
returns,
..
}) => {
// test_ok valid_annotation_function
// def f() -> (y := 3): ...
// def g(arg: (x := 1)): ...
// test_err invalid_annotation_function
// def f[T]() -> (y := 3): ...
// def g[T](arg: (x := 1)): ...
// def h[T](x: (yield 1)): ...
// def i(x: (yield 1)): ...
// def j[T]() -> (yield 1): ...
// def k() -> (yield 1): ...
// def l[T](x: (yield from 1)): ...
// def m(x: (yield from 1)): ...
// def n[T]() -> (yield from 1): ...
// def o() -> (yield from 1): ...
// def p[T: (yield 1)](): ... # yield in TypeVar bound
// def q[T = (yield 1)](): ... # yield in TypeVar default
// def r[*Ts = (yield 1)](): ... # yield in TypeVarTuple default
// def s[**Ts = (yield 1)](): ... # yield in ParamSpec default
// def t[T: (x := 1)](): ... # named expr in TypeVar bound
// def u[T = (x := 1)](): ... # named expr in TypeVar default
// def v[*Ts = (x := 1)](): ... # named expr in TypeVarTuple default
// def w[**Ts = (x := 1)](): ... # named expr in ParamSpec default
let is_generic = type_params.is_some();
let mut visitor = InvalidExpressionVisitor {
allow_named_expr: !is_generic,
position: InvalidExpressionPosition::TypeAnnotation,
ctx,
};
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
if is_generic {
visitor.position = InvalidExpressionPosition::GenericDefinition;
} else {
visitor.position = InvalidExpressionPosition::TypeAnnotation;
}
for param in parameters
.iter()
.filter_map(ast::AnyParameterRef::annotation)
{
visitor.visit_expr(param);
}
if let Some(returns) = returns {
visitor.visit_expr(returns);
}
}
Stmt::ClassDef(ast::StmtClassDef {
type_params,
arguments,
..
}) => {
// test_ok valid_annotation_class
// class F(y := list): ...
// test_err invalid_annotation_class
// class F[T](y := list): ...
// class G((yield 1)): ...
// class H((yield from 1)): ...
// class I[T]((yield 1)): ...
// class J[T]((yield from 1)): ...
// class K[T: (yield 1)]: ... # yield in TypeVar
// class L[T: (x := 1)]: ... # named expr in TypeVar
let is_generic = type_params.is_some();
let mut visitor = InvalidExpressionVisitor {
allow_named_expr: !is_generic,
position: InvalidExpressionPosition::TypeAnnotation,
ctx,
};
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
if is_generic {
visitor.position = InvalidExpressionPosition::GenericDefinition;
} else {
visitor.position = InvalidExpressionPosition::BaseClass;
}
if let Some(arguments) = arguments {
visitor.visit_arguments(arguments);
}
}
Stmt::TypeAlias(ast::StmtTypeAlias {
type_params, value, ..
}) => {
// test_err invalid_annotation_type_alias
// type X[T: (yield 1)] = int # TypeVar bound
// type X[T = (yield 1)] = int # TypeVar default
// type X[*Ts = (yield 1)] = int # TypeVarTuple default
// type X[**Ts = (yield 1)] = int # ParamSpec default
// type Y = (yield 1) # yield in value
// type Y = (x := 1) # named expr in value
let mut visitor = InvalidExpressionVisitor {
allow_named_expr: false,
position: InvalidExpressionPosition::TypeAlias,
ctx,
};
visitor.visit_expr(value);
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
}
_ => {}
}
}
/// Emit a [`SemanticSyntaxErrorKind::InvalidStarExpression`] if `expr` is starred.
@ -511,6 +625,15 @@ impl Display for SemanticSyntaxError {
write!(f, "cannot delete `__debug__` on Python {python_version} (syntax was removed in 3.9)")
}
},
SemanticSyntaxErrorKind::InvalidExpression(
kind,
InvalidExpressionPosition::BaseClass,
) => {
write!(f, "{kind} cannot be used as a base class")
}
SemanticSyntaxErrorKind::InvalidExpression(kind, position) => {
write!(f, "{kind} cannot be used within a {position}")
}
SemanticSyntaxErrorKind::DuplicateMatchKey(key) => {
write!(
f,
@ -641,6 +764,21 @@ pub enum SemanticSyntaxErrorKind {
/// [BPO 45000]: https://github.com/python/cpython/issues/89163
WriteToDebug(WriteToDebugKind),
/// Represents the use of an invalid expression kind in one of several locations.
///
/// The kinds include `yield` and `yield from` expressions and named expressions, and locations
/// include type parameter bounds and defaults, type annotations, type aliases, and base class
/// lists.
///
/// ## Examples
///
/// ```python
/// type X[T: (yield 1)] = int
/// type Y = (yield 1)
/// def f[T](x: int) -> (y := 3): return x
/// ```
InvalidExpression(InvalidExpressionKind, InvalidExpressionPosition),
/// Represents a duplicate key in a `match` mapping pattern.
///
/// The [CPython grammar] allows keys in mapping patterns to be literals or attribute accesses:
@ -713,6 +851,48 @@ pub enum SemanticSyntaxErrorKind {
InvalidStarExpression,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum InvalidExpressionPosition {
TypeVarBound,
TypeVarDefault,
TypeVarTupleDefault,
ParamSpecDefault,
TypeAnnotation,
BaseClass,
GenericDefinition,
TypeAlias,
}
impl Display for InvalidExpressionPosition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
InvalidExpressionPosition::TypeVarBound => "TypeVar bound",
InvalidExpressionPosition::TypeVarDefault => "TypeVar default",
InvalidExpressionPosition::TypeVarTupleDefault => "TypeVarTuple default",
InvalidExpressionPosition::ParamSpecDefault => "ParamSpec default",
InvalidExpressionPosition::TypeAnnotation => "type annotation",
InvalidExpressionPosition::GenericDefinition => "generic definition",
InvalidExpressionPosition::BaseClass => "base class",
InvalidExpressionPosition::TypeAlias => "type alias",
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum InvalidExpressionKind {
Yield,
NamedExpr,
}
impl Display for InvalidExpressionKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
InvalidExpressionKind::Yield => "yield expression",
InvalidExpressionKind::NamedExpr => "named expression",
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WriteToDebugKind {
Store,
@ -905,6 +1085,83 @@ impl<'a, Ctx: SemanticSyntaxContext> MatchPatternVisitor<'a, Ctx> {
}
}
struct InvalidExpressionVisitor<'a, Ctx> {
/// Allow named expressions (`x := ...`) to appear in annotations.
///
/// These are allowed in non-generic functions, for example:
///
/// ```python
/// def foo(arg: (x := int)): ... # ok
/// def foo[T](arg: (x := int)): ... # syntax error
/// ```
allow_named_expr: bool,
/// Context used for emitting errors.
ctx: &'a Ctx,
position: InvalidExpressionPosition,
}
impl<Ctx> Visitor<'_> for InvalidExpressionVisitor<'_, Ctx>
where
Ctx: SemanticSyntaxContext,
{
fn visit_expr(&mut self, expr: &Expr) {
match expr {
Expr::Named(ast::ExprNamed { range, .. }) if !self.allow_named_expr => {
SemanticSyntaxChecker::add_error(
self.ctx,
SemanticSyntaxErrorKind::InvalidExpression(
InvalidExpressionKind::NamedExpr,
self.position,
),
*range,
);
}
Expr::Yield(ast::ExprYield { range, .. })
| Expr::YieldFrom(ast::ExprYieldFrom { range, .. }) => {
SemanticSyntaxChecker::add_error(
self.ctx,
SemanticSyntaxErrorKind::InvalidExpression(
InvalidExpressionKind::Yield,
self.position,
),
*range,
);
}
_ => {}
}
ast::visitor::walk_expr(self, expr);
}
fn visit_type_param(&mut self, type_param: &ast::TypeParam) {
match type_param {
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { bound, default, .. }) => {
if let Some(expr) = bound {
self.position = InvalidExpressionPosition::TypeVarBound;
self.visit_expr(expr);
}
if let Some(expr) = default {
self.position = InvalidExpressionPosition::TypeVarDefault;
self.visit_expr(expr);
}
}
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { default, .. }) => {
if let Some(expr) = default {
self.position = InvalidExpressionPosition::TypeVarTupleDefault;
self.visit_expr(expr);
}
}
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { default, .. }) => {
if let Some(expr) = default {
self.position = InvalidExpressionPosition::ParamSpecDefault;
self.visit_expr(expr);
}
}
};
}
}
pub trait SemanticSyntaxContext {
/// Returns `true` if a module's docstring boundary has been passed.
fn seen_docstring_boundary(&self) -> bool;