[syntax-errors] Allow yield in base classes and annotations (#17206)

Summary
--

This PR fixes the issue pointed out by @JelleZijlstra in
https://github.com/astral-sh/ruff/pull/17101#issuecomment-2777480204.
Namely, I conflated two very different errors from CPython:

```pycon
>>> def m[T](x: (yield from 1)): ...
  File "<python-input-310>", line 1
    def m[T](x: (yield from 1)): ...
                 ^^^^^^^^^^^^
SyntaxError: yield expression cannot be used within the definition of a generic
>>> def m(x: (yield from 1)): ...
  File "<python-input-311>", line 1
    def m(x: (yield from 1)): ...
              ^^^^^^^^^^^^
SyntaxError: 'yield from' outside function
>>> def outer():
...     def m(x: (yield from 1)): ...
...
>>>
```

I thought the second error was the same as the first, but `yield` (and
`yield from`) is actually valid in this position when inside a function
scope. The same is true for base classes, as pointed out in the original
comment.

We don't currently raise an error for `yield` outside of a function, but
that should be handled separately.

On the upside, this had the benefit of removing the
`InvalidExpressionPosition::BaseClass` variant and the
`allow_named_expr` field from the visitor because they were both no
longer used.

Test Plan
--

Updated inline tests.
This commit is contained in:
Brent Westbrook 2025-04-04 13:48:28 -04:00 committed by GitHub
parent 33a56f198b
commit acc5662e8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1081 additions and 686 deletions

View file

@ -128,18 +128,29 @@ impl SemanticSyntaxChecker {
// test_ok valid_annotation_function
// def f() -> (y := 3): ...
// def g(arg: (x := 1)): ...
// def outer():
// def i(x: (yield 1)): ...
// def k() -> (yield 1): ...
// def m(x: (yield from 1)): ...
// def o() -> (yield from 1): ...
// test_err invalid_annotation_function_py314
// # parse_options: {"target-version": "3.14"}
// def f() -> (y := 3): ...
// def g(arg: (x := 1)): ...
// def outer():
// def i(x: (yield 1)): ...
// def k() -> (yield 1): ...
// def m(x: (yield from 1)): ...
// def o() -> (yield from 1): ...
// test_err invalid_annotation_function
// def f[T]() -> (y := 3): ...
// def g[T](arg: (x := 1)): ...
// def h[T](x: (yield 1)): ...
// def i(x: (yield 1)): ...
// def j[T]() -> (yield 1): ...
// def k() -> (yield 1): ...
// def l[T](x: (yield from 1)): ...
// def m(x: (yield from 1)): ...
// def n[T]() -> (yield from 1): ...
// def o() -> (yield from 1): ...
// def p[T: (yield 1)](): ... # yield in TypeVar bound
// def q[T = (yield 1)](): ... # yield in TypeVar default
// def r[*Ts = (yield 1)](): ... # yield in TypeVarTuple default
@ -148,19 +159,20 @@ impl SemanticSyntaxChecker {
// def u[T = (x := 1)](): ... # named expr in TypeVar default
// def v[*Ts = (x := 1)](): ... # named expr in TypeVarTuple default
// def w[**Ts = (x := 1)](): ... # named expr in ParamSpec default
let is_generic = type_params.is_some();
let mut visitor = InvalidExpressionVisitor {
allow_named_expr: !is_generic,
position: InvalidExpressionPosition::TypeAnnotation,
ctx,
};
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
if is_generic {
// the __future__ annotation error takes precedence over the generic error
if ctx.future_annotations_or_stub() || ctx.python_version() > PythonVersion::PY313 {
visitor.position = InvalidExpressionPosition::TypeAnnotation;
} else if type_params.is_some() {
visitor.position = InvalidExpressionPosition::GenericDefinition;
} else {
visitor.position = InvalidExpressionPosition::TypeAnnotation;
return;
}
for param in parameters
.iter()
@ -173,36 +185,29 @@ impl SemanticSyntaxChecker {
}
}
Stmt::ClassDef(ast::StmtClassDef {
type_params,
type_params: Some(type_params),
arguments,
..
}) => {
// test_ok valid_annotation_class
// class F(y := list): ...
// def f():
// class G((yield 1)): ...
// class H((yield from 1)): ...
// test_err invalid_annotation_class
// class F[T](y := list): ...
// class G((yield 1)): ...
// class H((yield from 1)): ...
// class I[T]((yield 1)): ...
// class J[T]((yield from 1)): ...
// class K[T: (yield 1)]: ... # yield in TypeVar
// class L[T: (x := 1)]: ... # named expr in TypeVar
let is_generic = type_params.is_some();
let mut visitor = InvalidExpressionVisitor {
allow_named_expr: !is_generic,
position: InvalidExpressionPosition::TypeAnnotation,
ctx,
};
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
if is_generic {
visitor.position = InvalidExpressionPosition::GenericDefinition;
} else {
visitor.position = InvalidExpressionPosition::BaseClass;
}
visitor.visit_type_params(type_params);
if let Some(arguments) = arguments {
visitor.position = InvalidExpressionPosition::GenericDefinition;
visitor.visit_arguments(arguments);
}
}
@ -217,7 +222,6 @@ impl SemanticSyntaxChecker {
// type Y = (yield 1) # yield in value
// type Y = (x := 1) # named expr in value
let mut visitor = InvalidExpressionVisitor {
allow_named_expr: false,
position: InvalidExpressionPosition::TypeAlias,
ctx,
};
@ -625,12 +629,6 @@ impl Display for SemanticSyntaxError {
write!(f, "cannot delete `__debug__` on Python {python_version} (syntax was removed in 3.9)")
}
},
SemanticSyntaxErrorKind::InvalidExpression(
kind,
InvalidExpressionPosition::BaseClass,
) => {
write!(f, "{kind} cannot be used as a base class")
}
SemanticSyntaxErrorKind::InvalidExpression(kind, position) => {
write!(f, "{kind} cannot be used within a {position}")
}
@ -858,7 +856,6 @@ pub enum InvalidExpressionPosition {
TypeVarTupleDefault,
ParamSpecDefault,
TypeAnnotation,
BaseClass,
GenericDefinition,
TypeAlias,
}
@ -872,7 +869,6 @@ impl Display for InvalidExpressionPosition {
InvalidExpressionPosition::ParamSpecDefault => "ParamSpec default",
InvalidExpressionPosition::TypeAnnotation => "type annotation",
InvalidExpressionPosition::GenericDefinition => "generic definition",
InvalidExpressionPosition::BaseClass => "base class",
InvalidExpressionPosition::TypeAlias => "type alias",
})
}
@ -1086,16 +1082,6 @@ impl<'a, Ctx: SemanticSyntaxContext> MatchPatternVisitor<'a, Ctx> {
}
struct InvalidExpressionVisitor<'a, Ctx> {
/// Allow named expressions (`x := ...`) to appear in annotations.
///
/// These are allowed in non-generic functions, for example:
///
/// ```python
/// def foo(arg: (x := int)): ... # ok
/// def foo[T](arg: (x := int)): ... # syntax error
/// ```
allow_named_expr: bool,
/// Context used for emitting errors.
ctx: &'a Ctx,
@ -1108,7 +1094,7 @@ where
{
fn visit_expr(&mut self, expr: &Expr) {
match expr {
Expr::Named(ast::ExprNamed { range, .. }) if !self.allow_named_expr => {
Expr::Named(ast::ExprNamed { range, .. }) => {
SemanticSyntaxChecker::add_error(
self.ctx,
SemanticSyntaxErrorKind::InvalidExpression(
@ -1166,6 +1152,9 @@ pub trait SemanticSyntaxContext {
/// Returns `true` if a module's docstring boundary has been passed.
fn seen_docstring_boundary(&self) -> bool;
/// Returns `true` if `__future__`-style type annotations are enabled.
fn future_annotations_or_stub(&self) -> bool;
/// The target Python version for detecting backwards-incompatible syntax changes.
fn python_version(&self) -> PythonVersion;