From 144484d46c270fe386df4a7efd11abb17b927ee1 Mon Sep 17 00:00:00 2001 From: Brent Westbrook <36778786+ntBre@users.noreply.github.com> Date: Wed, 9 Apr 2025 14:23:29 -0400 Subject: [PATCH] Refactor semantic syntax error scope handling (#17314) ## Summary Based on the discussion in https://github.com/astral-sh/ruff/pull/17298#discussion_r2033975460, we decided to move the scope handling out of the `SemanticSyntaxChecker` and into the `SemanticSyntaxContext` trait. This PR implements that refactor by: - Reverting all of the `Checkpoint` and `in_async_context` code in the `SemanticSyntaxChecker` - Adding four new methods to the `SemanticSyntaxContext` trait - `in_async_context`: matches `SemanticModel::in_async_context` and only detects the nearest enclosing function - `in_sync_comprehension`: uses the new `is_async` tracking on `Generator` scopes to detect any enclosing sync comprehension - `in_module_scope`: reports whether we're at the top-level scope - `in_notebook`: reports whether we're in a Jupyter notebook - In-lining the `TestContext` directly into the `SemanticSyntaxCheckerVisitor` - This allows modifying the context as the visitor traverses the AST, which wasn't possible before One potential question here is "why not add a single method returning a `Scope` or `Scopes` to the context?" The main reason is that the `Scope` type is defined in the `ruff_python_semantic` crate, which is not currently a dependency of the parser. It also doesn't appear to be used in red-knot. So it seemed best to use these more granular methods instead of trying to access `Scope` in `ruff_python_parser` (and red-knot). ## Test Plan Existing parser and linter tests. --- crates/ruff_linter/src/checkers/ast/mod.rs | 68 +++--- .../rules/pylint/rules/await_outside_async.rs | 5 +- .../ok/nested_async_comprehension_py310.py | 2 - .../ruff_python_parser/src/semantic_errors.rs | 221 ++++++------------ crates/ruff_python_parser/tests/fixtures.rs | 114 +++++++-- ...x@nested_async_comprehension_py310.py.snap | 78 +++---- crates/ruff_python_semantic/src/scope.rs | 5 +- 7 files changed, 255 insertions(+), 238 deletions(-) diff --git a/crates/ruff_linter/src/checkers/ast/mod.rs b/crates/ruff_linter/src/checkers/ast/mod.rs index 04d261bc7a..1fbb4f78fe 100644 --- a/crates/ruff_linter/src/checkers/ast/mod.rs +++ b/crates/ruff_linter/src/checkers/ast/mod.rs @@ -27,8 +27,7 @@ use std::path::Path; use itertools::Itertools; use log::debug; use ruff_python_parser::semantic_errors::{ - Checkpoint, SemanticSyntaxChecker, SemanticSyntaxContext, SemanticSyntaxError, - SemanticSyntaxErrorKind, + SemanticSyntaxChecker, SemanticSyntaxContext, SemanticSyntaxError, SemanticSyntaxErrorKind, }; use rustc_hash::{FxHashMap, FxHashSet}; @@ -283,7 +282,7 @@ impl<'a> Checker<'a> { last_stmt_end: TextSize::default(), docstring_state: DocstringState::default(), target_version, - semantic_checker: SemanticSyntaxChecker::new(source_type), + semantic_checker: SemanticSyntaxChecker::new(), semantic_errors: RefCell::default(), } } @@ -526,14 +525,10 @@ impl<'a> Checker<'a> { self.target_version } - fn with_semantic_checker( - &mut self, - f: impl FnOnce(&mut SemanticSyntaxChecker, &Checker) -> Checkpoint, - ) -> Checkpoint { + fn with_semantic_checker(&mut self, f: impl FnOnce(&mut SemanticSyntaxChecker, &Checker)) { let mut checker = std::mem::take(&mut self.semantic_checker); - let checkpoint = f(&mut checker, self); + f(&mut checker, self); self.semantic_checker = checker; - checkpoint } } @@ -597,17 +592,43 @@ impl SemanticSyntaxContext for Checker<'_> { fn future_annotations_or_stub(&self) -> bool { self.semantic.future_annotations_or_stub() } + + fn in_async_context(&self) -> bool { + self.semantic.in_async_context() + } + + fn in_sync_comprehension(&self) -> bool { + for scope in self.semantic.current_scopes() { + if let ScopeKind::Generator { + kind: + GeneratorKind::ListComprehension + | GeneratorKind::DictComprehension + | GeneratorKind::SetComprehension, + is_async: false, + } = scope.kind + { + return true; + } + } + false + } + + fn in_module_scope(&self) -> bool { + self.semantic.current_scope().kind.is_module() + } + + fn in_notebook(&self) -> bool { + self.source_type.is_ipynb() + } } impl<'a> Visitor<'a> for Checker<'a> { fn visit_stmt(&mut self, stmt: &'a Stmt) { // For functions, defer semantic syntax error checks until the body of the function is // visited - let checkpoint = if stmt.is_function_def_stmt() { - None - } else { - Some(self.with_semantic_checker(|semantic, context| semantic.enter_stmt(stmt, context))) - }; + if !stmt.is_function_def_stmt() { + self.with_semantic_checker(|semantic, context| semantic.visit_stmt(stmt, context)); + } // Step 0: Pre-processing self.semantic.push_node(stmt); @@ -1210,10 +1231,6 @@ impl<'a> Visitor<'a> for Checker<'a> { self.semantic.flags = flags_snapshot; self.semantic.pop_node(); self.last_stmt_end = stmt.end(); - - if let Some(checkpoint) = checkpoint { - self.semantic_checker.exit_stmt(checkpoint); - } } fn visit_annotation(&mut self, expr: &'a Expr) { @@ -1224,8 +1241,7 @@ impl<'a> Visitor<'a> for Checker<'a> { } fn visit_expr(&mut self, expr: &'a Expr) { - let checkpoint = - self.with_semantic_checker(|semantic, context| semantic.enter_expr(expr, context)); + self.with_semantic_checker(|semantic, context| semantic.visit_expr(expr, context)); // Step 0: Pre-processing if self.source_type.is_stub() @@ -1772,8 +1788,6 @@ impl<'a> Visitor<'a> for Checker<'a> { self.semantic.flags = flags_snapshot; analyze::expression(expr, self); self.semantic.pop_node(); - - self.semantic_checker.exit_expr(checkpoint); } fn visit_except_handler(&mut self, except_handler: &'a ExceptHandler) { @@ -2012,7 +2026,10 @@ impl<'a> Checker<'a> { // while all subsequent reads and writes are evaluated in the inner scope. In particular, // `x` is local to `foo`, and the `T` in `y=T` skips the class scope when resolving. self.visit_expr(&generator.iter); - self.semantic.push_scope(ScopeKind::Generator(kind)); + self.semantic.push_scope(ScopeKind::Generator { + kind, + is_async: generators.iter().any(|gen| gen.is_async), + }); self.visit_expr(&generator.target); self.semantic.flags = flags; @@ -2618,15 +2635,12 @@ impl<'a> Checker<'a> { unreachable!("Expected Stmt::FunctionDef") }; - let checkpoint = self - .with_semantic_checker(|semantic, context| semantic.enter_stmt(stmt, context)); + self.with_semantic_checker(|semantic, context| semantic.visit_stmt(stmt, context)); self.visit_parameters(parameters); // Set the docstring state before visiting the function body. self.docstring_state = DocstringState::Expected(ExpectedDocstringKind::Function); self.visit_body(body); - - self.semantic_checker.exit_stmt(checkpoint); } } self.semantic.restore(snapshot); diff --git a/crates/ruff_linter/src/rules/pylint/rules/await_outside_async.rs b/crates/ruff_linter/src/rules/pylint/rules/await_outside_async.rs index 98a4c26160..5d4646d0e3 100644 --- a/crates/ruff_linter/src/rules/pylint/rules/await_outside_async.rs +++ b/crates/ruff_linter/src/rules/pylint/rules/await_outside_async.rs @@ -73,7 +73,10 @@ pub(crate) fn await_outside_async(checker: &Checker, node: T) { // ``` if matches!( checker.semantic().current_scope().kind, - ScopeKind::Generator(GeneratorKind::Generator) + ScopeKind::Generator { + kind: GeneratorKind::Generator, + .. + } ) { return; } diff --git a/crates/ruff_python_parser/resources/inline/ok/nested_async_comprehension_py310.py b/crates/ruff_python_parser/resources/inline/ok/nested_async_comprehension_py310.py index 3df095c1b4..9d4441be94 100644 --- a/crates/ruff_python_parser/resources/inline/ok/nested_async_comprehension_py310.py +++ b/crates/ruff_python_parser/resources/inline/ok/nested_async_comprehension_py310.py @@ -1,9 +1,7 @@ # parse_options: {"target-version": "3.10"} -# this case fails if exit_expr doesn't run async def f(): [_ for n in range(3)] [_ async for n in range(3)] -# and this fails without exit_stmt async def f(): def g(): ... [_ async for n in range(3)] diff --git a/crates/ruff_python_parser/src/semantic_errors.rs b/crates/ruff_python_parser/src/semantic_errors.rs index 4bed95c45c..c21f6c93e2 100644 --- a/crates/ruff_python_parser/src/semantic_errors.rs +++ b/crates/ruff_python_parser/src/semantic_errors.rs @@ -1,28 +1,20 @@ //! [`SemanticSyntaxChecker`] for AST-based syntax errors. //! //! This checker is not responsible for traversing the AST itself. Instead, its -//! [`SemanticSyntaxChecker::enter_stmt`] and [`SemanticSyntaxChecker::enter_expr`] methods should -//! be called in a parent `Visitor`'s `visit_stmt` and `visit_expr` methods, respectively, and -//! followed by matching calls to [`SemanticSyntaxChecker::exit_stmt`] and -//! [`SemanticSyntaxChecker::exit_expr`]. - +//! [`SemanticSyntaxChecker::visit_stmt`] and [`SemanticSyntaxChecker::visit_expr`] methods should +//! be called in a parent `Visitor`'s `visit_stmt` and `visit_expr` methods, respectively. use std::fmt::Display; use ruff_python_ast::{ self as ast, comparable::ComparableExpr, visitor::{walk_expr, Visitor}, - Expr, ExprContext, IrrefutablePatternKind, Pattern, PySourceType, PythonVersion, Stmt, - StmtExpr, StmtImportFrom, + Expr, ExprContext, IrrefutablePatternKind, Pattern, PythonVersion, Stmt, StmtExpr, + StmtImportFrom, }; use ruff_text_size::{Ranged, TextRange, TextSize}; use rustc_hash::FxHashSet; -#[derive(Debug)] -pub struct Checkpoint { - in_async_context: bool, -} - #[derive(Debug, Default)] pub struct SemanticSyntaxChecker { /// The checker has traversed past the `__future__` import boundary. @@ -40,21 +32,11 @@ pub struct SemanticSyntaxChecker { /// Python considers it a syntax error to import from `__future__` after any other /// non-`__future__`-importing statements. seen_futures_boundary: bool, - - /// The checker is currently in an `async` context: either the body of an `async` function or an - /// `async` comprehension. - /// - /// Note that this should be updated *after* checking the current statement or expression - /// because the parent context is what matters. - in_async_context: bool, } impl SemanticSyntaxChecker { - pub fn new(source_type: PySourceType) -> Self { - Self { - seen_futures_boundary: false, - in_async_context: source_type.is_ipynb(), - } + pub fn new() -> Self { + Self::default() } } @@ -490,25 +472,15 @@ impl SemanticSyntaxChecker { /// Check `stmt` for semantic syntax errors and update the checker's internal state. /// - /// This should be followed by a call to [`SemanticSyntaxChecker::exit_stmt`] to reset any state - /// specific to scopes introduced by `stmt`, such as whether the body of a function is async. - /// /// Note that this method should only be called when traversing `stmt` *and* its children. For - /// example, if traversal of function bodies needs to be deferred, avoid calling `enter_stmt` on - /// the function itself until the deferred body is visited too. Failing to defer `enter_stmt` in + /// example, if traversal of function bodies needs to be deferred, avoid calling `visit_stmt` on + /// the function itself until the deferred body is visited too. Failing to defer `visit_stmt` in /// this case will break any internal state that depends on function scopes, such as `async` /// context detection. - #[must_use] - pub fn enter_stmt( - &mut self, - stmt: &ast::Stmt, - ctx: &Ctx, - ) -> Checkpoint { + pub fn visit_stmt(&mut self, stmt: &ast::Stmt, ctx: &Ctx) { // check for errors self.check_stmt(stmt, ctx); - let checkpoint = self.checkpoint(); - // update internal state match stmt { Stmt::Expr(StmtExpr { value, .. }) @@ -519,48 +491,17 @@ impl SemanticSyntaxChecker { self.seen_futures_boundary = true; } } - Stmt::FunctionDef(ast::StmtFunctionDef { is_async, .. }) => { - self.in_async_context = *is_async; + Stmt::FunctionDef(_) => { self.seen_futures_boundary = true; } _ => { self.seen_futures_boundary = true; } } - - checkpoint - } - - pub fn exit_stmt(&mut self, checkpoint: Checkpoint) { - self.restore_checkpoint(checkpoint); } /// Check `expr` for semantic syntax errors and update the checker's internal state. - /// - /// This should be followed by a call to [`SemanticSyntaxChecker::exit_expr`] to reset any state - /// specific to scopes introduced by `expr`, such as whether the body of a comprehension is - /// async. - #[must_use] - pub fn enter_expr(&mut self, expr: &Expr, ctx: &Ctx) -> Checkpoint { - self.check_expr(expr, ctx); - let checkpoint = self.checkpoint(); - match expr { - Expr::ListComp(ast::ExprListComp { generators, .. }) - | Expr::SetComp(ast::ExprSetComp { generators, .. }) - | Expr::DictComp(ast::ExprDictComp { generators, .. }) => { - self.in_async_context = generators.iter().any(|g| g.is_async); - } - _ => {} - } - - checkpoint - } - - pub fn exit_expr(&mut self, checkpoint: Checkpoint) { - self.restore_checkpoint(checkpoint); - } - - fn check_expr(&mut self, expr: &Expr, ctx: &Ctx) { + pub fn visit_expr(&mut self, expr: &Expr, ctx: &Ctx) { match expr { Expr::ListComp(ast::ExprListComp { elt, generators, .. @@ -569,7 +510,7 @@ impl SemanticSyntaxChecker { elt, generators, .. }) => { Self::check_generator_expr(elt, generators, ctx); - self.async_comprehension_outside_async_function(ctx, generators); + Self::async_comprehension_outside_async_function(ctx, generators); } Expr::Generator(ast::ExprGenerator { elt, generators, .. @@ -584,7 +525,7 @@ impl SemanticSyntaxChecker { }) => { Self::check_generator_expr(key, generators, ctx); Self::check_generator_expr(value, generators, ctx); - self.async_comprehension_outside_async_function(ctx, generators); + Self::async_comprehension_outside_async_function(ctx, generators); } Expr::Name(ast::ExprName { range, @@ -698,7 +639,6 @@ impl SemanticSyntaxChecker { } fn async_comprehension_outside_async_function( - &self, ctx: &Ctx, generators: &[ast::Comprehension], ) { @@ -706,54 +646,46 @@ impl SemanticSyntaxChecker { if python_version >= PythonVersion::PY311 { return; } - for generator in generators { - if generator.is_async && !self.in_async_context { - // test_ok nested_async_comprehension_py311 - // # parse_options: {"target-version": "3.11"} - // async def f(): return [[x async for x in foo(n)] for n in range(3)] # list - // async def g(): return [{x: 1 async for x in foo(n)} for n in range(3)] # dict - // async def h(): return [{x async for x in foo(n)} for n in range(3)] # set - - // test_ok nested_async_comprehension_py310 - // # parse_options: {"target-version": "3.10"} - // # this case fails if exit_expr doesn't run - // async def f(): - // [_ for n in range(3)] - // [_ async for n in range(3)] - // # and this fails without exit_stmt - // async def f(): - // def g(): ... - // [_ async for n in range(3)] - - // test_ok all_async_comprehension_py310 - // # parse_options: {"target-version": "3.10"} - // async def test(): return [[x async for x in elements(n)] async for n in range(3)] - - // test_err nested_async_comprehension_py310 - // # parse_options: {"target-version": "3.10"} - // async def f(): return [[x async for x in foo(n)] for n in range(3)] # list - // async def g(): return [{x: 1 async for x in foo(n)} for n in range(3)] # dict - // async def h(): return [{x async for x in foo(n)} for n in range(3)] # set - // async def i(): return [([y async for y in range(1)], [z for z in range(2)]) for x in range(5)] - // async def j(): return [([y for y in range(1)], [z async for z in range(2)]) for x in range(5)] - Self::add_error( - ctx, - SemanticSyntaxErrorKind::AsyncComprehensionOutsideAsyncFunction(python_version), - generator.range, - ); - } + // async allowed at notebook top-level + if ctx.in_notebook() && ctx.in_module_scope() { + return; } - } - - fn checkpoint(&self) -> Checkpoint { - Checkpoint { - in_async_context: self.in_async_context, + if ctx.in_async_context() && !ctx.in_sync_comprehension() { + return; } - } + for generator in generators.iter().filter(|gen| gen.is_async) { + // test_ok nested_async_comprehension_py311 + // # parse_options: {"target-version": "3.11"} + // async def f(): return [[x async for x in foo(n)] for n in range(3)] # list + // async def g(): return [{x: 1 async for x in foo(n)} for n in range(3)] # dict + // async def h(): return [{x async for x in foo(n)} for n in range(3)] # set - #[allow(clippy::needless_pass_by_value)] - fn restore_checkpoint(&mut self, checkpoint: Checkpoint) { - self.in_async_context = checkpoint.in_async_context; + // test_ok nested_async_comprehension_py310 + // # parse_options: {"target-version": "3.10"} + // async def f(): + // [_ for n in range(3)] + // [_ async for n in range(3)] + // async def f(): + // def g(): ... + // [_ async for n in range(3)] + + // test_ok all_async_comprehension_py310 + // # parse_options: {"target-version": "3.10"} + // async def test(): return [[x async for x in elements(n)] async for n in range(3)] + + // test_err nested_async_comprehension_py310 + // # parse_options: {"target-version": "3.10"} + // async def f(): return [[x async for x in foo(n)] for n in range(3)] # list + // async def g(): return [{x: 1 async for x in foo(n)} for n in range(3)] # dict + // async def h(): return [{x async for x in foo(n)} for n in range(3)] # set + // async def i(): return [([y async for y in range(1)], [z for z in range(2)]) for x in range(5)] + // async def j(): return [([y for y in range(1)], [z async for z in range(2)]) for x in range(5)] + Self::add_error( + ctx, + SemanticSyntaxErrorKind::AsyncComprehensionOutsideAsyncFunction(python_version), + generator.range, + ); + } } } @@ -1410,45 +1342,26 @@ pub trait SemanticSyntaxContext { /// Return the [`TextRange`] at which a name is declared as `global` in the current scope. fn global(&self, name: &str) -> Option; + /// Returns `true` if the visitor is currently in an async context, i.e. an async function. + fn in_async_context(&self) -> bool; + + /// Returns `true` if the visitor is currently inside of a synchronous comprehension. + /// + /// This method is necessary because `in_async_context` only checks for the nearest, enclosing + /// function to determine the (a)sync context. Instead, this method will search all enclosing + /// scopes until it finds a sync comprehension. As a result, the two methods will typically be + /// used together. + fn in_sync_comprehension(&self) -> bool; + + /// Returns `true` if the visitor is at the top-level module scope. + fn in_module_scope(&self) -> bool; + + /// Returns `true` if the source file is a Jupyter notebook. + fn in_notebook(&self) -> bool; + fn report_semantic_error(&self, error: SemanticSyntaxError); } -#[derive(Default)] -pub struct SemanticSyntaxCheckerVisitor { - checker: SemanticSyntaxChecker, - context: Ctx, -} - -impl SemanticSyntaxCheckerVisitor { - pub fn new(context: Ctx) -> Self { - Self { - checker: SemanticSyntaxChecker::new(PySourceType::Python), - context, - } - } - - pub fn into_context(self) -> Ctx { - self.context - } -} - -impl Visitor<'_> for SemanticSyntaxCheckerVisitor -where - Ctx: SemanticSyntaxContext, -{ - fn visit_stmt(&mut self, stmt: &'_ Stmt) { - let checkpoint = self.checker.enter_stmt(stmt, &self.context); - ruff_python_ast::visitor::walk_stmt(self, stmt); - self.checker.exit_stmt(checkpoint); - } - - fn visit_expr(&mut self, expr: &'_ Expr) { - let checkpoint = self.checker.enter_expr(expr, &self.context); - ruff_python_ast::visitor::walk_expr(self, expr); - self.checker.exit_expr(checkpoint); - } -} - /// Modified version of [`std::str::EscapeDefault`] that does not escape single or double quotes. struct EscapeDefault<'a>(&'a str); diff --git a/crates/ruff_python_parser/tests/fixtures.rs b/crates/ruff_python_parser/tests/fixtures.rs index 0fbe6597d1..5c9534f53c 100644 --- a/crates/ruff_python_parser/tests/fixtures.rs +++ b/crates/ruff_python_parser/tests/fixtures.rs @@ -7,9 +7,9 @@ use std::path::Path; use ruff_annotate_snippets::{Level, Renderer, Snippet}; use ruff_python_ast::visitor::source_order::{walk_module, SourceOrderVisitor, TraversalSignal}; use ruff_python_ast::visitor::Visitor; -use ruff_python_ast::{AnyNodeRef, Mod, PythonVersion}; +use ruff_python_ast::{self as ast, AnyNodeRef, Mod, PythonVersion}; use ruff_python_parser::semantic_errors::{ - SemanticSyntaxCheckerVisitor, SemanticSyntaxContext, SemanticSyntaxError, + SemanticSyntaxChecker, SemanticSyntaxContext, SemanticSyntaxError, }; use ruff_python_parser::{parse_unchecked, Mode, ParseErrorType, ParseOptions, Token}; use ruff_source_file::{LineIndex, OneIndexed, SourceCode}; @@ -88,15 +88,14 @@ fn test_valid_syntax(input_path: &Path) { let parsed = parsed.try_into_module().expect("Parsed with Mode::Module"); - let mut visitor = SemanticSyntaxCheckerVisitor::new( - TestContext::new(&source).with_python_version(options.target_version()), - ); + let mut visitor = + SemanticSyntaxCheckerVisitor::new(&source).with_python_version(options.target_version()); for stmt in parsed.suite() { visitor.visit_stmt(stmt); } - let semantic_syntax_errors = visitor.into_context().diagnostics.into_inner(); + let semantic_syntax_errors = visitor.into_diagnostics(); if !semantic_syntax_errors.is_empty() { let mut message = "Expected no semantic syntax errors for a valid program:\n".to_string(); @@ -184,15 +183,14 @@ fn test_invalid_syntax(input_path: &Path) { let parsed = parsed.try_into_module().expect("Parsed with Mode::Module"); - let mut visitor = SemanticSyntaxCheckerVisitor::new( - TestContext::new(&source).with_python_version(options.target_version()), - ); + let mut visitor = + SemanticSyntaxCheckerVisitor::new(&source).with_python_version(options.target_version()); for stmt in parsed.suite() { visitor.visit_stmt(stmt); } - let semantic_syntax_errors = visitor.into_context().diagnostics.into_inner(); + let semantic_syntax_errors = visitor.into_diagnostics(); assert!( parsed.has_syntax_errors() || !semantic_syntax_errors.is_empty(), @@ -462,19 +460,28 @@ impl<'ast> SourceOrderVisitor<'ast> for ValidateAstVisitor<'ast> { } } -#[derive(Debug)] -struct TestContext<'a> { +enum Scope { + Module, + Function { is_async: bool }, + Comprehension { is_async: bool }, +} + +struct SemanticSyntaxCheckerVisitor<'a> { + checker: SemanticSyntaxChecker, diagnostics: RefCell>, python_version: PythonVersion, source: &'a str, + scopes: Vec, } -impl<'a> TestContext<'a> { +impl<'a> SemanticSyntaxCheckerVisitor<'a> { fn new(source: &'a str) -> Self { Self { + checker: SemanticSyntaxChecker::new(), diagnostics: RefCell::default(), python_version: PythonVersion::default(), source, + scopes: vec![Scope::Module], } } @@ -483,9 +490,19 @@ impl<'a> TestContext<'a> { self.python_version = python_version; self } + + fn into_diagnostics(self) -> Vec { + self.diagnostics.into_inner() + } + + fn with_semantic_checker(&mut self, f: impl FnOnce(&mut SemanticSyntaxChecker, &Self)) { + let mut checker = std::mem::take(&mut self.checker); + f(&mut checker, self); + self.checker = checker; + } } -impl SemanticSyntaxContext for TestContext<'_> { +impl SemanticSyntaxContext for SemanticSyntaxCheckerVisitor<'_> { fn seen_docstring_boundary(&self) -> bool { false } @@ -509,4 +526,73 @@ impl SemanticSyntaxContext for TestContext<'_> { fn global(&self, _name: &str) -> Option { None } + + fn in_async_context(&self) -> bool { + for scope in &self.scopes { + if let Scope::Function { is_async } = scope { + return *is_async; + } + } + false + } + + fn in_sync_comprehension(&self) -> bool { + for scope in &self.scopes { + if let Scope::Comprehension { is_async: false } = scope { + return true; + } + } + false + } + + fn in_module_scope(&self) -> bool { + self.scopes + .last() + .is_some_and(|scope| matches!(scope, Scope::Module)) + } + + fn in_notebook(&self) -> bool { + false + } +} + +impl Visitor<'_> for SemanticSyntaxCheckerVisitor<'_> { + fn visit_stmt(&mut self, stmt: &ast::Stmt) { + self.with_semantic_checker(|semantic, context| semantic.visit_stmt(stmt, context)); + match stmt { + ast::Stmt::FunctionDef(ast::StmtFunctionDef { is_async, .. }) => { + self.scopes.push(Scope::Function { + is_async: *is_async, + }); + ast::visitor::walk_stmt(self, stmt); + self.scopes.pop().unwrap(); + } + _ => { + ast::visitor::walk_stmt(self, stmt); + } + } + } + + fn visit_expr(&mut self, expr: &ast::Expr) { + self.with_semantic_checker(|semantic, context| semantic.visit_expr(expr, context)); + match expr { + ast::Expr::Lambda(_) => { + self.scopes.push(Scope::Function { is_async: false }); + ast::visitor::walk_expr(self, expr); + self.scopes.pop().unwrap(); + } + ast::Expr::ListComp(ast::ExprListComp { generators, .. }) + | ast::Expr::SetComp(ast::ExprSetComp { generators, .. }) + | ast::Expr::DictComp(ast::ExprDictComp { generators, .. }) => { + self.scopes.push(Scope::Comprehension { + is_async: generators.iter().any(|gen| gen.is_async), + }); + ast::visitor::walk_expr(self, expr); + self.scopes.pop().unwrap(); + } + _ => { + ast::visitor::walk_expr(self, expr); + } + } + } } diff --git a/crates/ruff_python_parser/tests/snapshots/valid_syntax@nested_async_comprehension_py310.py.snap b/crates/ruff_python_parser/tests/snapshots/valid_syntax@nested_async_comprehension_py310.py.snap index e143cc966b..2fb0f626fe 100644 --- a/crates/ruff_python_parser/tests/snapshots/valid_syntax@nested_async_comprehension_py310.py.snap +++ b/crates/ruff_python_parser/tests/snapshots/valid_syntax@nested_async_comprehension_py310.py.snap @@ -7,20 +7,20 @@ input_file: crates/ruff_python_parser/resources/inline/ok/nested_async_comprehen ``` Module( ModModule { - range: 0..259, + range: 0..181, body: [ FunctionDef( StmtFunctionDef { - range: 87..159, + range: 44..116, is_async: true, decorator_list: [], name: Identifier { id: Name("f"), - range: 97..98, + range: 54..55, }, type_params: None, parameters: Parameters { - range: 98..100, + range: 55..57, posonlyargs: [], args: [], vararg: None, @@ -31,43 +31,43 @@ Module( body: [ Expr( StmtExpr { - range: 106..127, + range: 63..84, value: ListComp( ExprListComp { - range: 106..127, + range: 63..84, elt: Name( ExprName { - range: 107..108, + range: 64..65, id: Name("_"), ctx: Load, }, ), generators: [ Comprehension { - range: 109..126, + range: 66..83, target: Name( ExprName { - range: 113..114, + range: 70..71, id: Name("n"), ctx: Store, }, ), iter: Call( ExprCall { - range: 118..126, + range: 75..83, func: Name( ExprName { - range: 118..123, + range: 75..80, id: Name("range"), ctx: Load, }, ), arguments: Arguments { - range: 123..126, + range: 80..83, args: [ NumberLiteral( ExprNumberLiteral { - range: 124..125, + range: 81..82, value: Int( 3, ), @@ -88,43 +88,43 @@ Module( ), Expr( StmtExpr { - range: 132..159, + range: 89..116, value: ListComp( ExprListComp { - range: 132..159, + range: 89..116, elt: Name( ExprName { - range: 133..134, + range: 90..91, id: Name("_"), ctx: Load, }, ), generators: [ Comprehension { - range: 135..158, + range: 92..115, target: Name( ExprName { - range: 145..146, + range: 102..103, id: Name("n"), ctx: Store, }, ), iter: Call( ExprCall { - range: 150..158, + range: 107..115, func: Name( ExprName { - range: 150..155, + range: 107..112, id: Name("range"), ctx: Load, }, ), arguments: Arguments { - range: 155..158, + range: 112..115, args: [ NumberLiteral( ExprNumberLiteral { - range: 156..157, + range: 113..114, value: Int( 3, ), @@ -148,16 +148,16 @@ Module( ), FunctionDef( StmtFunctionDef { - range: 195..258, + range: 117..180, is_async: true, decorator_list: [], name: Identifier { id: Name("f"), - range: 205..206, + range: 127..128, }, type_params: None, parameters: Parameters { - range: 206..208, + range: 128..130, posonlyargs: [], args: [], vararg: None, @@ -168,16 +168,16 @@ Module( body: [ FunctionDef( StmtFunctionDef { - range: 214..226, + range: 136..148, is_async: false, decorator_list: [], name: Identifier { id: Name("g"), - range: 218..219, + range: 140..141, }, type_params: None, parameters: Parameters { - range: 219..221, + range: 141..143, posonlyargs: [], args: [], vararg: None, @@ -188,10 +188,10 @@ Module( body: [ Expr( StmtExpr { - range: 223..226, + range: 145..148, value: EllipsisLiteral( ExprEllipsisLiteral { - range: 223..226, + range: 145..148, }, ), }, @@ -201,43 +201,43 @@ Module( ), Expr( StmtExpr { - range: 231..258, + range: 153..180, value: ListComp( ExprListComp { - range: 231..258, + range: 153..180, elt: Name( ExprName { - range: 232..233, + range: 154..155, id: Name("_"), ctx: Load, }, ), generators: [ Comprehension { - range: 234..257, + range: 156..179, target: Name( ExprName { - range: 244..245, + range: 166..167, id: Name("n"), ctx: Store, }, ), iter: Call( ExprCall { - range: 249..257, + range: 171..179, func: Name( ExprName { - range: 249..254, + range: 171..176, id: Name("range"), ctx: Load, }, ), arguments: Arguments { - range: 254..257, + range: 176..179, args: [ NumberLiteral( ExprNumberLiteral { - range: 255..256, + range: 177..178, value: Int( 3, ), diff --git a/crates/ruff_python_semantic/src/scope.rs b/crates/ruff_python_semantic/src/scope.rs index 4ea1b52429..93cf49bba6 100644 --- a/crates/ruff_python_semantic/src/scope.rs +++ b/crates/ruff_python_semantic/src/scope.rs @@ -170,7 +170,10 @@ bitflags! { pub enum ScopeKind<'a> { Class(&'a ast::StmtClassDef), Function(&'a ast::StmtFunctionDef), - Generator(GeneratorKind), + Generator { + kind: GeneratorKind, + is_async: bool, + }, Module, /// A Python 3.12+ [annotation scope](https://docs.python.org/3/reference/executionmodel.html#annotation-scopes) Type,