[red-knot] Detect semantic syntax errors (#17463)

Summary
--

This PR extends semantic syntax error detection to red-knot. The main
changes here are:

1. Adding `SemanticSyntaxChecker` and `Vec<SemanticSyntaxError>` fields
to the `SemanticIndexBuilder`
2. Calling `SemanticSyntaxChecker::visit_stmt` and `visit_expr` in the
`SemanticIndexBuilder`'s `visit_stmt` and `visit_expr` methods
3. Implementing `SemanticSyntaxContext` for `SemanticIndexBuilder`
4. Adding new mdtests to test the context implementation and show
diagnostics

(3) is definitely the trickiest and required (I think) a minor addition
to the `SemanticIndexBuilder`. I tried to look around for existing code
performing the necessary checks, but I definitely could have missed
something or misused the existing code even when I found it.

There's still one TODO around `global` statement handling. I don't think
there's an existing way to look this up, but I'm happy to work on that
here or in a separate PR. This currently only affects detection of one
error (`LoadBeforeGlobalDeclaration` or
[PLE0118](https://docs.astral.sh/ruff/rules/load-before-global-declaration/)
in ruff), so it's not too big of a problem even if we leave the TODO.

Test Plan
--

New mdtests, as well as new errors for existing mdtests

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Brent Westbrook 2025-04-23 09:52:58 -04:00 committed by GitHub
parent 624f5c6c22
commit e7f38fe74b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 431 additions and 66 deletions

View file

@ -56,40 +56,41 @@ def _(
def bar() -> None:
return None
def _(
a: 1, # error: [invalid-type-form] "Int literals are not allowed in this context in a type expression"
b: 2.3, # error: [invalid-type-form] "Float literals are not allowed in type expressions"
c: 4j, # error: [invalid-type-form] "Complex literals are not allowed in type expressions"
d: True, # error: [invalid-type-form] "Boolean literals are not allowed in this context in a type expression"
e: int | b"foo", # error: [invalid-type-form] "Bytes literals are not allowed in this context in a type expression"
f: 1 and 2, # error: [invalid-type-form] "Boolean operations are not allowed in type expressions"
g: 1 or 2, # error: [invalid-type-form] "Boolean operations are not allowed in type expressions"
h: (foo := 1), # error: [invalid-type-form] "Named expressions are not allowed in type expressions"
i: not 1, # error: [invalid-type-form] "Unary operations are not allowed in type expressions"
j: lambda: 1, # error: [invalid-type-form] "`lambda` expressions are not allowed in type expressions"
k: 1 if True else 2, # error: [invalid-type-form] "`if` expressions are not allowed in type expressions"
l: await 1, # error: [invalid-type-form] "`await` expressions are not allowed in type expressions"
m: (yield 1), # error: [invalid-type-form] "`yield` expressions are not allowed in type expressions"
n: (yield from [1]), # error: [invalid-type-form] "`yield from` expressions are not allowed in type expressions"
o: 1 < 2, # error: [invalid-type-form] "Comparison expressions are not allowed in type expressions"
p: bar(), # error: [invalid-type-form] "Function calls are not allowed in type expressions"
q: int | f"foo", # error: [invalid-type-form] "F-strings are not allowed in type expressions"
r: [1, 2, 3][1:2], # error: [invalid-type-form] "Slices are not allowed in type expressions"
):
reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: Unknown
reveal_type(c) # revealed: Unknown
reveal_type(d) # revealed: Unknown
reveal_type(e) # revealed: int | Unknown
reveal_type(f) # revealed: Unknown
reveal_type(g) # revealed: Unknown
reveal_type(h) # revealed: Unknown
reveal_type(i) # revealed: Unknown
reveal_type(j) # revealed: Unknown
reveal_type(k) # revealed: Unknown
reveal_type(p) # revealed: Unknown
reveal_type(q) # revealed: int | Unknown
reveal_type(r) # revealed: @Todo(unknown type subscript)
async def outer(): # avoid unrelated syntax errors on yield, yield from, and await
def _(
a: 1, # error: [invalid-type-form] "Int literals are not allowed in this context in a type expression"
b: 2.3, # error: [invalid-type-form] "Float literals are not allowed in type expressions"
c: 4j, # error: [invalid-type-form] "Complex literals are not allowed in type expressions"
d: True, # error: [invalid-type-form] "Boolean literals are not allowed in this context in a type expression"
e: int | b"foo", # error: [invalid-type-form] "Bytes literals are not allowed in this context in a type expression"
f: 1 and 2, # error: [invalid-type-form] "Boolean operations are not allowed in type expressions"
g: 1 or 2, # error: [invalid-type-form] "Boolean operations are not allowed in type expressions"
h: (foo := 1), # error: [invalid-type-form] "Named expressions are not allowed in type expressions"
i: not 1, # error: [invalid-type-form] "Unary operations are not allowed in type expressions"
j: lambda: 1, # error: [invalid-type-form] "`lambda` expressions are not allowed in type expressions"
k: 1 if True else 2, # error: [invalid-type-form] "`if` expressions are not allowed in type expressions"
l: await 1, # error: [invalid-type-form] "`await` expressions are not allowed in type expressions"
m: (yield 1), # error: [invalid-type-form] "`yield` expressions are not allowed in type expressions"
n: (yield from [1]), # error: [invalid-type-form] "`yield from` expressions are not allowed in type expressions"
o: 1 < 2, # error: [invalid-type-form] "Comparison expressions are not allowed in type expressions"
p: bar(), # error: [invalid-type-form] "Function calls are not allowed in type expressions"
q: int | f"foo", # error: [invalid-type-form] "F-strings are not allowed in type expressions"
r: [1, 2, 3][1:2], # error: [invalid-type-form] "Slices are not allowed in type expressions"
):
reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: Unknown
reveal_type(c) # revealed: Unknown
reveal_type(d) # revealed: Unknown
reveal_type(e) # revealed: int | Unknown
reveal_type(f) # revealed: Unknown
reveal_type(g) # revealed: Unknown
reveal_type(h) # revealed: Unknown
reveal_type(i) # revealed: Unknown
reveal_type(j) # revealed: Unknown
reveal_type(k) # revealed: Unknown
reveal_type(p) # revealed: Unknown
reveal_type(q) # revealed: int | Unknown
reveal_type(r) # revealed: @Todo(unknown type subscript)
```
## Invalid Collection based AST nodes

View file

@ -127,8 +127,9 @@ class AsyncIterable:
def __aiter__(self) -> AsyncIterator:
return AsyncIterator()
# revealed: @Todo(async iterables/iterators)
[reveal_type(x) async for x in AsyncIterable()]
async def _():
# revealed: @Todo(async iterables/iterators)
[reveal_type(x) async for x in AsyncIterable()]
```
### Invalid async comprehension
@ -145,6 +146,7 @@ class Iterable:
def __iter__(self) -> Iterator:
return Iterator()
# revealed: @Todo(async iterables/iterators)
[reveal_type(x) async for x in Iterable()]
async def _():
# revealed: @Todo(async iterables/iterators)
[reveal_type(x) async for x in Iterable()]
```

View file

@ -0,0 +1,165 @@
# Semantic syntax error diagnostics
## `async` comprehensions in synchronous comprehensions
### Python 3.10
<!-- snapshot-diagnostics -->
Before Python 3.11, `async` comprehensions could not be used within outer sync comprehensions, even
within an `async` function ([CPython issue](https://github.com/python/cpython/issues/77527)):
```toml
[environment]
python-version = "3.10"
```
```py
async def elements(n):
yield n
async def f():
# error: 19 [invalid-syntax] "cannot use an asynchronous comprehension outside of an asynchronous function on Python 3.10 (syntax was added in 3.11)"
return {n: [x async for x in elements(n)] for n in range(3)}
```
If all of the comprehensions are `async`, on the other hand, the code was still valid:
```py
async def test():
return [[x async for x in elements(n)] async for n in range(3)]
```
These are a couple of tricky but valid cases to check that nested scope handling is wired up
correctly in the `SemanticSyntaxContext` trait:
```py
async def f():
[x for x in [1]] and [x async for x in elements(1)]
async def f():
def g():
pass
[x async for x in elements(1)]
```
### Python 3.11
All of these same examples are valid after Python 3.11:
```toml
[environment]
python-version = "3.11"
```
```py
async def elements(n):
yield n
async def f():
return {n: [x async for x in elements(n)] for n in range(3)}
```
## Late `__future__` import
```py
from collections import namedtuple
# error: [invalid-syntax] "__future__ imports must be at the top of the file"
from __future__ import print_function
```
## Invalid annotation
This one might be a bit redundant with the `invalid-type-form` error.
```toml
[environment]
python-version = "3.12"
```
```py
from __future__ import annotations
# error: [invalid-type-form] "Named expressions are not allowed in type expressions"
# error: [invalid-syntax] "named expression cannot be used within a type annotation"
def f() -> (y := 3): ...
```
## Duplicate `match` key
```toml
[environment]
python-version = "3.10"
```
```py
match 2:
# error: [invalid-syntax] "mapping pattern checks duplicate key `"x"`"
case {"x": 1, "x": 2}:
...
```
## `return`, `yield`, `yield from`, and `await` outside function
```py
# error: [invalid-syntax] "`return` statement outside of a function"
return
# error: [invalid-syntax] "`yield` statement outside of a function"
yield
# error: [invalid-syntax] "`yield from` statement outside of a function"
yield from []
# error: [invalid-syntax] "`await` statement outside of a function"
# error: [invalid-syntax] "`await` outside of an asynchronous function"
await 1
def f():
# error: [invalid-syntax] "`await` outside of an asynchronous function"
await 1
```
Generators are evaluated lazily, so `await` is allowed, even outside of a function.
```py
async def g():
yield 1
(x async for x in g())
```
## `await` outside async function
This error includes `await`, `async for`, `async with`, and `async` comprehensions.
```python
async def elements(n):
yield n
def _():
# error: [invalid-syntax] "`await` outside of an asynchronous function"
await 1
# error: [invalid-syntax] "`async for` outside of an asynchronous function"
async for _ in elements(1):
...
# error: [invalid-syntax] "`async with` outside of an asynchronous function"
async with elements(1) as x:
...
# error: [invalid-syntax] "cannot use an asynchronous comprehension outside of an asynchronous function on Python 3.9 (syntax was added in 3.11)"
# error: [invalid-syntax] "asynchronous comprehension outside of an asynchronous function"
[x async for x in elements(1)]
```
## Load before `global` declaration
This should be an error, but it's not yet.
TODO implement `SemanticSyntaxContext::global`
```py
def f():
x = 1
global x
```

View file

@ -189,7 +189,7 @@ match 42:
...
case [O]:
...
case P | Q:
case P | Q: # error: [invalid-syntax] "name capture `P` makes remaining patterns unreachable"
...
case object(foo=R):
...
@ -289,7 +289,7 @@ match 42:
...
case [D]:
...
case E | F:
case E | F: # error: [invalid-syntax] "name capture `E` makes remaining patterns unreachable"
...
case object(foo=G):
...
@ -357,7 +357,7 @@ match 42:
...
case [D]:
...
case E | F:
case E | F: # error: [invalid-syntax] "name capture `E` makes remaining patterns unreachable"
...
case object(foo=G):
...

View file

@ -0,0 +1,46 @@
---
source: crates/red_knot_test/src/lib.rs
expression: snapshot
---
---
mdtest name: semantic_syntax_errors.md - Semantic syntax error diagnostics - `async` comprehensions in synchronous comprehensions - Python 3.10
mdtest path: crates/red_knot_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md
---
# Python source files
## mdtest_snippet.py
```
1 | async def elements(n):
2 | yield n
3 |
4 | async def f():
5 | # error: 19 [invalid-syntax] "cannot use an asynchronous comprehension outside of an asynchronous function on Python 3.10 (syntax was added in 3.11)"
6 | return {n: [x async for x in elements(n)] for n in range(3)}
7 | async def test():
8 | return [[x async for x in elements(n)] async for n in range(3)]
9 | async def f():
10 | [x for x in [1]] and [x async for x in elements(1)]
11 |
12 | async def f():
13 | def g():
14 | pass
15 | [x async for x in elements(1)]
```
# Diagnostics
```
error: invalid-syntax
--> /src/mdtest_snippet.py:6:19
|
4 | async def f():
5 | # error: 19 [invalid-syntax] "cannot use an asynchronous comprehension outside of an asynchronous function on Python 3.10 (syntax...
6 | return {n: [x async for x in elements(n)] for n in range(3)}
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot use an asynchronous comprehension outside of an asynchronous function on Python 3.10 (syntax was added in 3.11)
7 | async def test():
8 | return [[x async for x in elements(n)] async for n in range(3)]
|
```

View file

@ -5,6 +5,7 @@ use ruff_db::files::File;
use ruff_db::parsed::parsed_module;
use ruff_index::{IndexSlice, IndexVec};
use ruff_python_parser::semantic_errors::SemanticSyntaxError;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use salsa::plumbing::AsId;
use salsa::Update;
@ -175,6 +176,9 @@ pub(crate) struct SemanticIndex<'db> {
/// Map of all of the eager bindings that appear in this file.
eager_bindings: FxHashMap<EagerBindingsKey, ScopedEagerBindingsId>,
/// List of all semantic syntax errors in this file.
semantic_syntax_errors: Vec<SemanticSyntaxError>,
}
impl<'db> SemanticIndex<'db> {
@ -399,6 +403,10 @@ impl<'db> SemanticIndex<'db> {
None => EagerBindingsResult::NotFound,
}
}
pub(crate) fn semantic_syntax_errors(&self) -> &[SemanticSyntaxError] {
&self.semantic_syntax_errors
}
}
pub struct AncestorsIter<'a> {

View file

@ -1,3 +1,4 @@
use std::cell::{OnceCell, RefCell};
use std::sync::Arc;
use except_handlers::TryNodeContextStackManager;
@ -5,10 +6,15 @@ use rustc_hash::{FxHashMap, FxHashSet};
use ruff_db::files::File;
use ruff_db::parsed::ParsedModule;
use ruff_db::source::{source_text, SourceText};
use ruff_index::IndexVec;
use ruff_python_ast::name::Name;
use ruff_python_ast::visitor::{walk_expr, walk_pattern, walk_stmt, Visitor};
use ruff_python_ast::{self as ast};
use ruff_python_ast::{self as ast, PythonVersion};
use ruff_python_parser::semantic_errors::{
SemanticSyntaxChecker, SemanticSyntaxContext, SemanticSyntaxError,
};
use ruff_text_size::TextRange;
use crate::ast_node_ref::AstNodeRef;
use crate::module_name::ModuleName;
@ -32,8 +38,8 @@ use crate::semantic_index::predicate::{
};
use crate::semantic_index::re_exports::exported_names;
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopeKind, ScopedSymbolId,
SymbolTableBuilder,
FileScopeId, NodeWithScopeKey, NodeWithScopeKind, NodeWithScopeRef, Scope, ScopeId, ScopeKind,
ScopedSymbolId, SymbolTableBuilder,
};
use crate::semantic_index::use_def::{
EagerBindingsKey, FlowSnapshot, ScopedEagerBindingsId, UseDefMapBuilder,
@ -43,7 +49,7 @@ use crate::semantic_index::visibility_constraints::{
};
use crate::semantic_index::SemanticIndex;
use crate::unpack::{Unpack, UnpackKind, UnpackPosition, UnpackValue};
use crate::Db;
use crate::{Db, Program};
mod except_handlers;
@ -85,6 +91,11 @@ pub(super) struct SemanticIndexBuilder<'db> {
/// Flags about the file's global scope
has_future_annotations: bool,
// Used for checking semantic syntax errors
python_version: PythonVersion,
source_text: OnceCell<SourceText>,
semantic_checker: SemanticSyntaxChecker,
// Semantic Index fields
scopes: IndexVec<FileScopeId, Scope>,
scope_ids_by_scope: IndexVec<FileScopeId, ScopeId<'db>>,
@ -98,6 +109,8 @@ pub(super) struct SemanticIndexBuilder<'db> {
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
imported_modules: FxHashSet<ModuleName>,
eager_bindings: FxHashMap<EagerBindingsKey, ScopedEagerBindingsId>,
/// Errors collected by the `semantic_checker`.
semantic_syntax_errors: RefCell<Vec<SemanticSyntaxError>>,
}
impl<'db> SemanticIndexBuilder<'db> {
@ -129,6 +142,11 @@ impl<'db> SemanticIndexBuilder<'db> {
imported_modules: FxHashSet::default(),
eager_bindings: FxHashMap::default(),
python_version: Program::get(db).python_version(db),
source_text: OnceCell::new(),
semantic_checker: SemanticSyntaxChecker::default(),
semantic_syntax_errors: RefCell::default(),
};
builder.push_scope_with_parent(
@ -156,10 +174,6 @@ impl<'db> SemanticIndexBuilder<'db> {
self.current_scope_info().file_scope_id
}
fn current_scope_is_global_scope(&self) -> bool {
self.scope_stack.len() == 1
}
/// Returns the scope ID of the surrounding class body scope if the current scope
/// is a method inside a class body. Returns `None` otherwise, e.g. if the current
/// scope is a function body outside of a class, or if the current scope is not a
@ -1050,8 +1064,20 @@ impl<'db> SemanticIndexBuilder<'db> {
imported_modules: Arc::new(self.imported_modules),
has_future_annotations: self.has_future_annotations,
eager_bindings: self.eager_bindings,
semantic_syntax_errors: self.semantic_syntax_errors.into_inner(),
}
}
fn with_semantic_checker(&mut self, f: impl FnOnce(&mut SemanticSyntaxChecker, &Self)) {
let mut checker = std::mem::take(&mut self.semantic_checker);
f(&mut checker, self);
self.semantic_checker = checker;
}
fn source_text(&self) -> &SourceText {
self.source_text
.get_or_init(|| source_text(self.db.upcast(), self.file))
}
}
impl<'db, 'ast> Visitor<'ast> for SemanticIndexBuilder<'db>
@ -1059,6 +1085,8 @@ where
'ast: 'db,
{
fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) {
self.with_semantic_checker(|semantic, context| semantic.visit_stmt(stmt, context));
match stmt {
ast::Stmt::FunctionDef(function_def) => {
let ast::StmtFunctionDef {
@ -1254,7 +1282,7 @@ where
// Wildcard imports are invalid syntax everywhere except the top-level scope,
// and thus do not bind any definitions anywhere else
if !self.current_scope_is_global_scope() {
if !self.in_module_scope() {
continue;
}
@ -1809,6 +1837,8 @@ where
}
fn visit_expr(&mut self, expr: &'ast ast::Expr) {
self.with_semantic_checker(|semantic, context| semantic.visit_expr(expr, context));
self.scopes_by_expression
.insert(expr.into(), self.current_scope());
self.current_ast_ids().record_expression(expr);
@ -2268,6 +2298,99 @@ where
}
}
impl SemanticSyntaxContext for SemanticIndexBuilder<'_> {
fn future_annotations_or_stub(&self) -> bool {
self.has_future_annotations
}
fn python_version(&self) -> PythonVersion {
self.python_version
}
fn source(&self) -> &str {
self.source_text().as_str()
}
// TODO(brent) handle looking up `global` bindings
fn global(&self, _name: &str) -> Option<TextRange> {
None
}
fn in_async_context(&self) -> bool {
for scope_info in self.scope_stack.iter().rev() {
let scope = &self.scopes[scope_info.file_scope_id];
match scope.kind() {
ScopeKind::Class | ScopeKind::Lambda => return false,
ScopeKind::Function => {
return scope.node().expect_function().is_async;
}
ScopeKind::Comprehension
| ScopeKind::Module
| ScopeKind::TypeAlias
| ScopeKind::Annotation => {}
}
}
false
}
fn in_await_allowed_context(&self) -> bool {
for scope_info in self.scope_stack.iter().rev() {
let scope = &self.scopes[scope_info.file_scope_id];
match scope.kind() {
ScopeKind::Class => return false,
ScopeKind::Function | ScopeKind::Lambda => return true,
ScopeKind::Comprehension
| ScopeKind::Module
| ScopeKind::TypeAlias
| ScopeKind::Annotation => {}
}
}
false
}
fn in_sync_comprehension(&self) -> bool {
for scope_info in self.scope_stack.iter().rev() {
let scope = &self.scopes[scope_info.file_scope_id];
let generators = match scope.node() {
NodeWithScopeKind::ListComprehension(node) => &node.generators,
NodeWithScopeKind::SetComprehension(node) => &node.generators,
NodeWithScopeKind::DictComprehension(node) => &node.generators,
_ => continue,
};
if generators.iter().all(|gen| !gen.is_async) {
return true;
}
}
false
}
fn in_module_scope(&self) -> bool {
self.scope_stack.len() == 1
}
fn in_function_scope(&self) -> bool {
let kind = self.scopes[self.current_scope()].kind();
matches!(kind, ScopeKind::Function | ScopeKind::Lambda)
}
fn in_generator_scope(&self) -> bool {
matches!(
self.scopes[self.current_scope()].node(),
NodeWithScopeKind::GeneratorExpression(_)
)
}
fn in_notebook(&self) -> bool {
self.source_text().is_notebook()
}
fn report_semantic_error(&self, error: SemanticSyntaxError) {
if self.db.is_file_open(self.file) {
self.semantic_syntax_errors.borrow_mut().push(error);
}
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
enum CurrentAssignment<'a> {
Assign {

View file

@ -10,6 +10,7 @@ use diagnostic::{
CALL_POSSIBLY_UNBOUND_METHOD, INVALID_CONTEXT_MANAGER, INVALID_SUPER_ARGUMENT, NOT_ITERABLE,
UNAVAILABLE_IMPLICIT_SUPER_ARGUMENTS,
};
use ruff_db::diagnostic::create_semantic_syntax_diagnostic;
use ruff_db::files::{File, FileRange};
use ruff_python_ast::name::Name;
use ruff_python_ast::{self as ast, AnyNodeRef};
@ -90,6 +91,13 @@ pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics {
diagnostics.extend(result.diagnostics());
}
diagnostics.extend_diagnostics(
index
.semantic_syntax_errors()
.iter()
.map(|error| create_semantic_syntax_diagnostic(file, error)),
);
check_suppressions(db, file, &mut diagnostics);
diagnostics

View file

@ -1021,6 +1021,10 @@ impl TypeCheckDiagnostics {
self.used_suppressions.extend(&other.used_suppressions);
}
pub(super) fn extend_diagnostics(&mut self, diagnostics: impl IntoIterator<Item = Diagnostic>) {
self.diagnostics.extend(diagnostics);
}
pub(crate) fn mark_used(&mut self, suppression_id: FileSuppressionId) {
self.used_suppressions.insert(suppression_id);
}

View file

@ -845,3 +845,16 @@ pub fn create_unsupported_syntax_diagnostic(
diag.annotate(Annotation::primary(span).message(err.to_string()));
diag
}
/// Creates a `Diagnostic` from a semantic syntax error.
///
/// See [`create_parse_diagnostic`] for more details.
pub fn create_semantic_syntax_diagnostic(
file: File,
err: &ruff_python_parser::semantic_errors::SemanticSyntaxError,
) -> Diagnostic {
let mut diag = Diagnostic::new(DiagnosticId::InvalidSyntax, Severity::Error, "");
let span = Span::from(file).with_range(err.range);
diag.annotate(Annotation::primary(span).message(err.to_string()));
diag
}

View file

@ -557,10 +557,6 @@ impl<'a> Checker<'a> {
}
impl SemanticSyntaxContext for Checker<'_> {
fn seen_docstring_boundary(&self) -> bool {
self.semantic.seen_module_docstring_boundary()
}
fn python_version(&self) -> PythonVersion {
self.target_version
}

View file

@ -32,6 +32,10 @@ pub struct SemanticSyntaxChecker {
/// Python considers it a syntax error to import from `__future__` after any other
/// non-`__future__`-importing statements.
seen_futures_boundary: bool,
/// The checker has traversed past the module docstring boundary (i.e. seen any statement in the
/// module).
seen_module_docstring_boundary: bool,
}
impl SemanticSyntaxChecker {
@ -506,7 +510,7 @@ impl SemanticSyntaxChecker {
// update internal state
match stmt {
Stmt::Expr(StmtExpr { value, .. })
if !ctx.seen_docstring_boundary() && value.is_string_literal_expr() => {}
if !self.seen_module_docstring_boundary && value.is_string_literal_expr() => {}
Stmt::ImportFrom(StmtImportFrom { module, .. }) => {
// Allow __future__ imports until we see a non-__future__ import.
if !matches!(module.as_deref(), Some("__future__")) {
@ -520,6 +524,8 @@ impl SemanticSyntaxChecker {
self.seen_futures_boundary = true;
}
}
self.seen_module_docstring_boundary = true;
}
/// Check `expr` for semantic syntax errors and update the checker's internal state.
@ -881,7 +887,7 @@ impl Display for SemanticSyntaxError {
f.write_str("`return` statement outside of a function")
}
SemanticSyntaxErrorKind::AwaitOutsideAsyncFunction(kind) => {
write!(f, "`{kind}` outside of an asynchronous function")
write!(f, "{kind} outside of an asynchronous function")
}
}
}
@ -1207,9 +1213,9 @@ pub enum AwaitOutsideAsyncFunctionKind {
impl Display for AwaitOutsideAsyncFunctionKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
AwaitOutsideAsyncFunctionKind::Await => "await",
AwaitOutsideAsyncFunctionKind::AsyncFor => "async for",
AwaitOutsideAsyncFunctionKind::AsyncWith => "async with",
AwaitOutsideAsyncFunctionKind::Await => "`await`",
AwaitOutsideAsyncFunctionKind::AsyncFor => "`async for`",
AwaitOutsideAsyncFunctionKind::AsyncWith => "`async with`",
AwaitOutsideAsyncFunctionKind::AsyncComprehension => "asynchronous comprehension",
})
}
@ -1584,9 +1590,6 @@ where
/// x # here, classes break function scopes
/// ```
pub trait SemanticSyntaxContext {
/// Returns `true` if a module's docstring boundary has been passed.
fn seen_docstring_boundary(&self) -> bool;
/// Returns `true` if `__future__`-style type annotations are enabled.
fn future_annotations_or_stub(&self) -> bool;

View file

@ -504,10 +504,6 @@ impl<'a> SemanticSyntaxCheckerVisitor<'a> {
}
impl SemanticSyntaxContext for SemanticSyntaxCheckerVisitor<'_> {
fn seen_docstring_boundary(&self) -> bool {
false
}
fn future_annotations_or_stub(&self) -> bool {
false
}