[syntax-errors] Extend annotation checks to await (#17282)

Summary
--

This PR extends the changes in #17101 to include `await` in the same
positions.

I also renamed the `valid_annotation_function` test to include `_py313`
and explicitly passed a Python version to contrast it with the `_py314`
version.

Test Plan
--

New test cases added to existing files.
This commit is contained in:
Brent Westbrook 2025-04-08 08:55:43 -04:00 committed by GitHub
parent b662c3ff7e
commit 127a45622f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1399 additions and 314 deletions

View file

@ -125,7 +125,8 @@ impl SemanticSyntaxChecker {
returns,
..
}) => {
// test_ok valid_annotation_function
// test_ok valid_annotation_function_py313
// # parse_options: {"target-version": "3.13"}
// def f() -> (y := 3): ...
// def g(arg: (x := 1)): ...
// def outer():
@ -133,6 +134,9 @@ impl SemanticSyntaxChecker {
// def k() -> (yield 1): ...
// def m(x: (yield from 1)): ...
// def o() -> (yield from 1): ...
// async def outer():
// def f() -> (await 1): ...
// def g(arg: (await 1)): ...
// test_err invalid_annotation_function_py314
// # parse_options: {"target-version": "3.14"}
@ -143,8 +147,13 @@ impl SemanticSyntaxChecker {
// def k() -> (yield 1): ...
// def m(x: (yield from 1)): ...
// def o() -> (yield from 1): ...
// async def outer():
// def f() -> (await 1): ...
// def g(arg: (await 1)): ...
// test_err invalid_annotation_function
// def d[T]() -> (await 1): ...
// def e[T](arg: (await 1)): ...
// def f[T]() -> (y := 3): ...
// def g[T](arg: (x := 1)): ...
// def h[T](x: (yield 1)): ...
@ -159,6 +168,10 @@ impl SemanticSyntaxChecker {
// def u[T = (x := 1)](): ... # named expr in TypeVar default
// def v[*Ts = (x := 1)](): ... # named expr in TypeVarTuple default
// def w[**Ts = (x := 1)](): ... # named expr in ParamSpec default
// def t[T: (await 1)](): ... # await in TypeVar bound
// def u[T = (await 1)](): ... # await in TypeVar default
// def v[*Ts = (await 1)](): ... # await in TypeVarTuple default
// def w[**Ts = (await 1)](): ... # await in ParamSpec default
let mut visitor = InvalidExpressionVisitor {
position: InvalidExpressionPosition::TypeAnnotation,
ctx,
@ -194,6 +207,8 @@ impl SemanticSyntaxChecker {
// def f():
// class G((yield 1)): ...
// class H((yield from 1)): ...
// async def f():
// class G((await 1)): ...
// test_err invalid_annotation_class
// class F[T](y := list): ...
@ -201,6 +216,8 @@ impl SemanticSyntaxChecker {
// class J[T]((yield from 1)): ...
// class K[T: (yield 1)]: ... # yield in TypeVar
// class L[T: (x := 1)]: ... # named expr in TypeVar
// class M[T]((await 1)): ...
// class N[T: (await 1)]: ...
let mut visitor = InvalidExpressionVisitor {
position: InvalidExpressionPosition::TypeAnnotation,
ctx,
@ -221,6 +238,8 @@ impl SemanticSyntaxChecker {
// type X[**Ts = (yield 1)] = int # ParamSpec default
// type Y = (yield 1) # yield in value
// type Y = (x := 1) # named expr in value
// type Y[T: (await 1)] = int # await in bound
// type Y = (await 1) # await in value
let mut visitor = InvalidExpressionVisitor {
position: InvalidExpressionPosition::TypeAlias,
ctx,
@ -878,6 +897,7 @@ impl Display for InvalidExpressionPosition {
pub enum InvalidExpressionKind {
Yield,
NamedExpr,
Await,
}
impl Display for InvalidExpressionKind {
@ -885,6 +905,7 @@ impl Display for InvalidExpressionKind {
f.write_str(match self {
InvalidExpressionKind::Yield => "yield expression",
InvalidExpressionKind::NamedExpr => "named expression",
InvalidExpressionKind::Await => "await expression",
})
}
}
@ -1115,6 +1136,16 @@ where
*range,
);
}
Expr::Await(ast::ExprAwait { range, .. }) => {
SemanticSyntaxChecker::add_error(
self.ctx,
SemanticSyntaxErrorKind::InvalidExpression(
InvalidExpressionKind::Await,
self.position,
),
*range,
);
}
_ => {}
}
ast::visitor::walk_expr(self, expr);