diff --git a/crates/ruff/src/checkers/imports.rs b/crates/ruff/src/checkers/imports.rs index 964e720acd..a74bfee724 100644 --- a/crates/ruff/src/checkers/imports.rs +++ b/crates/ruff/src/checkers/imports.rs @@ -8,7 +8,7 @@ use ruff_diagnostics::Diagnostic; use ruff_python_ast::helpers::to_module_path; use ruff_python_ast::imports::{ImportMap, ModuleImport}; use ruff_python_ast::source_code::{Indexer, Locator, Stylist}; -use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::statement_visitor::StatementVisitor; use ruff_python_stdlib::path::is_python_stub_file; use crate::directives::IsortDirectives; diff --git a/crates/ruff/src/doc_lines.rs b/crates/ruff/src/doc_lines.rs index ab6b0cb752..a170f184ed 100644 --- a/crates/ruff/src/doc_lines.rs +++ b/crates/ruff/src/doc_lines.rs @@ -10,8 +10,7 @@ use rustpython_parser::Tok; use ruff_python_ast::newlines::UniversalNewlineIterator; use ruff_python_ast::source_code::Locator; -use ruff_python_ast::visitor; -use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor}; /// Extract doc lines (standalone comments) from a token sequence. pub fn doc_lines_from_tokens<'a>(lxr: &'a [LexResult], locator: &'a Locator<'a>) -> DocLines<'a> { @@ -75,7 +74,7 @@ struct StringLinesVisitor<'a> { locator: &'a Locator<'a>, } -impl Visitor<'_> for StringLinesVisitor<'_> { +impl StatementVisitor<'_> for StringLinesVisitor<'_> { fn visit_stmt(&mut self, stmt: &Stmt) { if let StmtKind::Expr { value } = &stmt.node { if let ExprKind::Constant { @@ -91,7 +90,7 @@ impl Visitor<'_> for StringLinesVisitor<'_> { } } } - visitor::walk_stmt(self, stmt); + walk_stmt(self, stmt); } } diff --git a/crates/ruff/src/rules/flake8_annotations/rules.rs b/crates/ruff/src/rules/flake8_annotations/rules.rs index 5b2c5a0f69..5b9dcbbce5 100644 --- a/crates/ruff/src/rules/flake8_annotations/rules.rs +++ b/crates/ruff/src/rules/flake8_annotations/rules.rs @@ -3,7 +3,7 @@ use rustpython_parser::ast::{Constant, Expr, ExprKind, Stmt}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::ReturnStatementVisitor; -use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::statement_visitor::StatementVisitor; use ruff_python_ast::{cast, helpers}; use ruff_python_semantic::analyze::visibility; use ruff_python_semantic::analyze::visibility::Visibility; @@ -416,9 +416,7 @@ impl Violation for AnyType { fn is_none_returning(body: &[Stmt]) -> bool { let mut visitor = ReturnStatementVisitor::default(); - for stmt in body { - visitor.visit_stmt(stmt); - } + visitor.visit_body(body); for expr in visitor.returns.into_iter().flatten() { if !matches!( expr.node, diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/raise_without_from_inside_except.rs b/crates/ruff/src/rules/flake8_bugbear/rules/raise_without_from_inside_except.rs index 60ff0b6b99..4f90f46b86 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/raise_without_from_inside_except.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/raise_without_from_inside_except.rs @@ -3,7 +3,7 @@ use rustpython_parser::ast::{ExprKind, Stmt}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::RaiseStatementVisitor; -use ruff_python_ast::visitor; +use ruff_python_ast::statement_visitor::StatementVisitor; use ruff_python_stdlib::str::is_lower; use crate::checkers::ast::Checker; @@ -25,7 +25,7 @@ impl Violation for RaiseWithoutFromInsideExcept { pub fn raise_without_from_inside_except(checker: &mut Checker, body: &[Stmt]) { let raises = { let mut visitor = RaiseStatementVisitor::default(); - visitor::walk_body(&mut visitor, body); + visitor.visit_body(body); visitor.raises }; diff --git a/crates/ruff/src/rules/isort/track.rs b/crates/ruff/src/rules/isort/track.rs index cb590ef5c6..6b04a85cf4 100644 --- a/crates/ruff/src/rules/isort/track.rs +++ b/crates/ruff/src/rules/isort/track.rs @@ -1,12 +1,8 @@ use ruff_text_size::{TextRange, TextSize}; -use rustpython_parser::ast::{ - Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, Excepthandler, - ExcepthandlerKind, Expr, ExprContext, Keyword, MatchCase, Operator, Pattern, Stmt, StmtKind, - Unaryop, Withitem, -}; +use rustpython_parser::ast::{Excepthandler, ExcepthandlerKind, MatchCase, Stmt, StmtKind}; use ruff_python_ast::source_code::Locator; -use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::statement_visitor::StatementVisitor; use crate::directives::IsortDirectives; use crate::rules::isort::helpers; @@ -111,7 +107,7 @@ impl<'a> ImportTracker<'a> { } } -impl<'a, 'b> Visitor<'b> for ImportTracker<'a> +impl<'a, 'b> StatementVisitor<'b> for ImportTracker<'a> where 'b: 'a, { @@ -226,8 +222,7 @@ where } self.finalize(None); } - StmtKind::Match { subject, cases } => { - self.visit_expr(subject); + StmtKind::Match { cases, .. } => { for match_case in cases { self.visit_match_case(match_case); } @@ -268,24 +263,6 @@ where self.nested = prev_nested; } - fn visit_annotation(&mut self, _: &'b Expr) {} - - fn visit_expr(&mut self, _: &'b Expr) {} - - fn visit_constant(&mut self, _: &'b Constant) {} - - fn visit_expr_context(&mut self, _: &'b ExprContext) {} - - fn visit_boolop(&mut self, _: &'b Boolop) {} - - fn visit_operator(&mut self, _: &'b Operator) {} - - fn visit_unaryop(&mut self, _: &'b Unaryop) {} - - fn visit_cmpop(&mut self, _: &'b Cmpop) {} - - fn visit_comprehension(&mut self, _: &'b Comprehension) {} - fn visit_excepthandler(&mut self, excepthandler: &'b Excepthandler) { let prev_nested = self.nested; self.nested = true; @@ -299,22 +276,10 @@ where self.nested = prev_nested; } - fn visit_arguments(&mut self, _: &'b Arguments) {} - - fn visit_arg(&mut self, _: &'b Arg) {} - - fn visit_keyword(&mut self, _: &'b Keyword) {} - - fn visit_alias(&mut self, _: &'b Alias) {} - - fn visit_withitem(&mut self, _: &'b Withitem) {} - fn visit_match_case(&mut self, match_case: &'b MatchCase) { for stmt in &match_case.body { self.visit_stmt(stmt); } self.finalize(None); } - - fn visit_pattern(&mut self, _: &'b Pattern) {} } diff --git a/crates/ruff/src/rules/pylint/rules/redefined_loop_name.rs b/crates/ruff/src/rules/pylint/rules/redefined_loop_name.rs index 50d486b0ec..4fbcd3866c 100644 --- a/crates/ruff/src/rules/pylint/rules/redefined_loop_name.rs +++ b/crates/ruff/src/rules/pylint/rules/redefined_loop_name.rs @@ -7,9 +7,8 @@ use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::comparable::ComparableExpr; use ruff_python_ast::helpers::unparse_expr; +use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor}; use ruff_python_ast::types::Node; -use ruff_python_ast::visitor; -use ruff_python_ast::visitor::Visitor; use ruff_python_semantic::context::Context; use crate::checkers::ast::Checker; @@ -146,7 +145,7 @@ struct InnerForWithAssignTargetsVisitor<'a> { assignment_targets: Vec>, } -impl<'a, 'b> Visitor<'b> for InnerForWithAssignTargetsVisitor<'a> +impl<'a, 'b> StatementVisitor<'b> for InnerForWithAssignTargetsVisitor<'a> where 'b: 'a, { @@ -225,7 +224,7 @@ where StmtKind::FunctionDef { .. } => {} // Otherwise, do recurse. _ => { - visitor::walk_stmt(self, stmt); + walk_stmt(self, stmt); } } } diff --git a/crates/ruff/src/rules/pylint/rules/too_many_return_statements.rs b/crates/ruff/src/rules/pylint/rules/too_many_return_statements.rs index 013ce6f9f4..5b6d43dd2b 100644 --- a/crates/ruff/src/rules/pylint/rules/too_many_return_statements.rs +++ b/crates/ruff/src/rules/pylint/rules/too_many_return_statements.rs @@ -4,7 +4,7 @@ use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::{identifier_range, ReturnStatementVisitor}; use ruff_python_ast::source_code::Locator; -use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::statement_visitor::StatementVisitor; #[violation] pub struct TooManyReturnStatements { diff --git a/crates/ruff/src/rules/pylint/rules/useless_return.rs b/crates/ruff/src/rules/pylint/rules/useless_return.rs index e53289c5f8..65f87cfb92 100644 --- a/crates/ruff/src/rules/pylint/rules/useless_return.rs +++ b/crates/ruff/src/rules/pylint/rules/useless_return.rs @@ -4,8 +4,8 @@ use rustpython_parser::ast::{Constant, Expr, ExprKind, Stmt, StmtKind}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Fix}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::{is_const_none, ReturnStatementVisitor}; +use ruff_python_ast::statement_visitor::StatementVisitor; use ruff_python_ast::types::RefEquality; -use ruff_python_ast::visitor::Visitor; use crate::autofix::actions::delete_stmt; use crate::checkers::ast::Checker; diff --git a/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs b/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs index ab847b146c..7c5e9690c3 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs @@ -3,9 +3,10 @@ use rustpython_parser::ast::{Expr, ExprContext, ExprKind, Stmt, StmtKind}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::statement_visitor::StatementVisitor; use ruff_python_ast::types::RefEquality; -use ruff_python_ast::visitor; use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::{statement_visitor, visitor}; use crate::checkers::ast::Checker; use crate::registry::AsRule; @@ -59,7 +60,7 @@ struct YieldFromVisitor<'a> { yields: Vec>, } -impl<'a> Visitor<'a> for YieldFromVisitor<'a> { +impl<'a> StatementVisitor<'a> for YieldFromVisitor<'a> { fn visit_stmt(&mut self, stmt: &'a Stmt) { match &stmt.node { StmtKind::For { @@ -97,20 +98,7 @@ impl<'a> Visitor<'a> for YieldFromVisitor<'a> { | StmtKind::ClassDef { .. } => { // Don't recurse into anything that defines a new scope. } - _ => visitor::walk_stmt(self, stmt), - } - } - - fn visit_expr(&mut self, expr: &'a Expr) { - match &expr.node { - ExprKind::ListComp { .. } - | ExprKind::SetComp { .. } - | ExprKind::DictComp { .. } - | ExprKind::GeneratorExp { .. } - | ExprKind::Lambda { .. } => { - // Don't recurse into anything that defines a new scope. - } - _ => visitor::walk_expr(self, expr), + _ => statement_visitor::walk_stmt(self, stmt), } } } diff --git a/crates/ruff/src/rules/tryceratops/rules/raise_within_try.rs b/crates/ruff/src/rules/tryceratops/rules/raise_within_try.rs index 53c4731906..3e88717162 100644 --- a/crates/ruff/src/rules/tryceratops/rules/raise_within_try.rs +++ b/crates/ruff/src/rules/tryceratops/rules/raise_within_try.rs @@ -2,7 +2,7 @@ use rustpython_parser::ast::{Excepthandler, Stmt, StmtKind}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::visitor::{self, Visitor}; +use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor}; use crate::checkers::ast::Checker; @@ -58,7 +58,7 @@ struct RaiseStatementVisitor<'a> { raises: Vec<&'a Stmt>, } -impl<'a, 'b> Visitor<'b> for RaiseStatementVisitor<'a> +impl<'a, 'b> StatementVisitor<'b> for RaiseStatementVisitor<'a> where 'b: 'a, { @@ -66,7 +66,7 @@ where match stmt.node { StmtKind::Raise { .. } => self.raises.push(stmt), StmtKind::Try { .. } | StmtKind::TryStar { .. } => (), - _ => visitor::walk_stmt(self, stmt), + _ => walk_stmt(self, stmt), } } } @@ -79,9 +79,7 @@ pub fn raise_within_try(checker: &mut Checker, body: &[Stmt], handlers: &[Except let raises = { let mut visitor = RaiseStatementVisitor::default(); - for stmt in body { - visitor.visit_stmt(stmt); - } + visitor.visit_body(body); visitor.raises }; diff --git a/crates/ruff/src/rules/tryceratops/rules/reraise_no_cause.rs b/crates/ruff/src/rules/tryceratops/rules/reraise_no_cause.rs index c18b01fb9a..6d17cb0a66 100644 --- a/crates/ruff/src/rules/tryceratops/rules/reraise_no_cause.rs +++ b/crates/ruff/src/rules/tryceratops/rules/reraise_no_cause.rs @@ -3,7 +3,7 @@ use rustpython_parser::ast::{ExprKind, Stmt}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::RaiseStatementVisitor; -use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::statement_visitor::StatementVisitor; use crate::checkers::ast::Checker; @@ -50,9 +50,7 @@ impl Violation for ReraiseNoCause { pub fn reraise_no_cause(checker: &mut Checker, body: &[Stmt]) { let raises = { let mut visitor = RaiseStatementVisitor::default(); - for stmt in body { - visitor.visit_stmt(stmt); - } + visitor.visit_body(body); visitor.raises }; diff --git a/crates/ruff/src/rules/tryceratops/rules/type_check_without_type_error.rs b/crates/ruff/src/rules/tryceratops/rules/type_check_without_type_error.rs index a1fc850401..baf85fa206 100644 --- a/crates/ruff/src/rules/tryceratops/rules/type_check_without_type_error.rs +++ b/crates/ruff/src/rules/tryceratops/rules/type_check_without_type_error.rs @@ -2,8 +2,7 @@ use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::visitor; -use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor}; use crate::checkers::ast::Checker; @@ -51,7 +50,7 @@ struct ControlFlowVisitor<'a> { continues: Vec<&'a Stmt>, } -impl<'a, 'b> Visitor<'b> for ControlFlowVisitor<'a> +impl<'a, 'b> StatementVisitor<'b> for ControlFlowVisitor<'a> where 'b: 'a, { @@ -65,19 +64,7 @@ where StmtKind::Return { .. } => self.returns.push(stmt), StmtKind::Break => self.breaks.push(stmt), StmtKind::Continue => self.continues.push(stmt), - _ => visitor::walk_stmt(self, stmt), - } - } - - fn visit_expr(&mut self, expr: &'b Expr) { - match &expr.node { - ExprKind::ListComp { .. } - | ExprKind::DictComp { .. } - | ExprKind::SetComp { .. } - | ExprKind::GeneratorExp { .. } => { - // Don't recurse. - } - _ => visitor::walk_expr(self, expr), + _ => walk_stmt(self, stmt), } } } diff --git a/crates/ruff/src/rules/tryceratops/rules/verbose_raise.rs b/crates/ruff/src/rules/tryceratops/rules/verbose_raise.rs index 426e359ae8..672d730895 100644 --- a/crates/ruff/src/rules/tryceratops/rules/verbose_raise.rs +++ b/crates/ruff/src/rules/tryceratops/rules/verbose_raise.rs @@ -2,8 +2,7 @@ use rustpython_parser::ast::{Excepthandler, ExcepthandlerKind, Expr, ExprKind, S use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::visitor; -use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor}; use crate::checkers::ast::Checker; @@ -46,7 +45,7 @@ struct RaiseStatementVisitor<'a> { raises: Vec<(Option<&'a Expr>, Option<&'a Expr>)>, } -impl<'a, 'b> Visitor<'b> for RaiseStatementVisitor<'a> +impl<'a, 'b> StatementVisitor<'b> for RaiseStatementVisitor<'a> where 'b: 'a, { @@ -60,10 +59,10 @@ where body, finalbody, .. } => { for stmt in body.iter().chain(finalbody.iter()) { - visitor::walk_stmt(self, stmt); + walk_stmt(self, stmt); } } - _ => visitor::walk_stmt(self, stmt), + _ => walk_stmt(self, stmt), } } } @@ -80,9 +79,7 @@ pub fn verbose_raise(checker: &mut Checker, handlers: &[Excepthandler]) { { let raises = { let mut visitor = RaiseStatementVisitor::default(); - for stmt in body { - visitor.visit_stmt(stmt); - } + visitor.visit_body(body); visitor.raises }; for (exc, cause) in raises { diff --git a/crates/ruff_benchmark/benches/parser.rs b/crates/ruff_benchmark/benches/parser.rs index e339e91966..afdcc59832 100644 --- a/crates/ruff_benchmark/benches/parser.rs +++ b/crates/ruff_benchmark/benches/parser.rs @@ -1,7 +1,7 @@ use criterion::measurement::WallTime; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use ruff_benchmark::{TestCase, TestCaseSpeed, TestFile, TestFileDownloadError}; -use ruff_python_ast::visitor::{walk_stmt, Visitor}; +use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor}; use rustpython_parser::ast::Stmt; use std::time::Duration; @@ -40,7 +40,7 @@ struct CountVisitor { count: usize, } -impl<'a> Visitor<'a> for CountVisitor { +impl<'a> StatementVisitor<'a> for CountVisitor { fn visit_stmt(&mut self, stmt: &'a Stmt) { walk_stmt(self, stmt); self.count += 1; diff --git a/crates/ruff_python_ast/src/helpers.rs b/crates/ruff_python_ast/src/helpers.rs index 213ae1e770..e6915dccc1 100644 --- a/crates/ruff_python_ast/src/helpers.rs +++ b/crates/ruff_python_ast/src/helpers.rs @@ -18,8 +18,7 @@ use smallvec::SmallVec; use crate::call_path::CallPath; use crate::newlines::UniversalNewlineIterator; use crate::source_code::{Generator, Indexer, Locator, Stylist}; -use crate::visitor; -use crate::visitor::Visitor; +use crate::statement_visitor::{walk_body, walk_stmt, StatementVisitor}; /// Create an `Expr` with default location from an `ExprKind`. pub fn create_expr(node: ExprKind) -> Expr { @@ -816,13 +815,13 @@ pub fn resolve_imported_module_path<'a>( Some(Cow::Owned(qualified_path)) } -/// A [`Visitor`] that collects all `return` statements in a function or method. +/// A [`StatementVisitor`] that collects all `return` statements in a function or method. #[derive(Default)] pub struct ReturnStatementVisitor<'a> { pub returns: Vec>, } -impl<'a, 'b> Visitor<'b> for ReturnStatementVisitor<'a> +impl<'a, 'b> StatementVisitor<'b> for ReturnStatementVisitor<'a> where 'b: 'a, { @@ -832,18 +831,18 @@ where // Don't recurse. } StmtKind::Return { value } => self.returns.push(value.as_deref()), - _ => visitor::walk_stmt(self, stmt), + _ => walk_stmt(self, stmt), } } } -/// A [`Visitor`] that collects all `raise` statements in a function or method. +/// A [`StatementVisitor`] that collects all `raise` statements in a function or method. #[derive(Default)] pub struct RaiseStatementVisitor<'a> { pub raises: Vec<(TextRange, Option<&'a Expr>, Option<&'a Expr>)>, } -impl<'a, 'b> Visitor<'b> for RaiseStatementVisitor<'b> +impl<'a, 'b> StatementVisitor<'b> for RaiseStatementVisitor<'b> where 'b: 'a, { @@ -859,19 +858,19 @@ where | StmtKind::Try { .. } | StmtKind::TryStar { .. } => {} StmtKind::If { body, orelse, .. } => { - visitor::walk_body(self, body); - visitor::walk_body(self, orelse); + walk_body(self, body); + walk_body(self, orelse); } StmtKind::While { body, .. } | StmtKind::With { body, .. } | StmtKind::AsyncWith { body, .. } | StmtKind::For { body, .. } | StmtKind::AsyncFor { body, .. } => { - visitor::walk_body(self, body); + walk_body(self, body); } StmtKind::Match { cases, .. } => { for case in cases { - visitor::walk_body(self, &case.body); + walk_body(self, &case.body); } } _ => {} @@ -884,7 +883,7 @@ struct GlobalStatementVisitor<'a> { globals: FxHashMap<&'a str, &'a Stmt>, } -impl<'a> Visitor<'a> for GlobalStatementVisitor<'a> { +impl<'a> StatementVisitor<'a> for GlobalStatementVisitor<'a> { fn visit_stmt(&mut self, stmt: &'a Stmt) { match &stmt.node { StmtKind::Global { names } => { @@ -897,7 +896,7 @@ impl<'a> Visitor<'a> for GlobalStatementVisitor<'a> { | StmtKind::ClassDef { .. } => { // Don't recurse. } - _ => visitor::walk_stmt(self, stmt), + _ => walk_stmt(self, stmt), } } } diff --git a/crates/ruff_python_ast/src/lib.rs b/crates/ruff_python_ast/src/lib.rs index f57158ff73..1e5dab1cf8 100644 --- a/crates/ruff_python_ast/src/lib.rs +++ b/crates/ruff_python_ast/src/lib.rs @@ -8,6 +8,7 @@ pub mod imports; pub mod newlines; pub mod relocate; pub mod source_code; +pub mod statement_visitor; pub mod str; pub mod token_kind; pub mod types; diff --git a/crates/ruff_python_ast/src/statement_visitor.rs b/crates/ruff_python_ast/src/statement_visitor.rs new file mode 100644 index 0000000000..9b666a602f --- /dev/null +++ b/crates/ruff_python_ast/src/statement_visitor.rs @@ -0,0 +1,111 @@ +//! Specialized AST visitor trait and walk functions that only visit statements. + +use rustpython_parser::ast::{Excepthandler, ExcepthandlerKind, MatchCase, Stmt, StmtKind}; + +/// A trait for AST visitors that only need to visit statements. +pub trait StatementVisitor<'a> { + fn visit_body(&mut self, body: &'a [Stmt]) { + walk_body(self, body); + } + fn visit_stmt(&mut self, stmt: &'a Stmt) { + walk_stmt(self, stmt); + } + fn visit_excepthandler(&mut self, excepthandler: &'a Excepthandler) { + walk_excepthandler(self, excepthandler); + } + fn visit_match_case(&mut self, match_case: &'a MatchCase) { + walk_match_case(self, match_case); + } +} + +pub fn walk_body<'a, V: StatementVisitor<'a> + ?Sized>(visitor: &mut V, body: &'a [Stmt]) { + for stmt in body { + visitor.visit_stmt(stmt); + } +} + +pub fn walk_stmt<'a, V: StatementVisitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) { + match &stmt.node { + StmtKind::FunctionDef { body, .. } => { + visitor.visit_body(body); + } + StmtKind::AsyncFunctionDef { body, .. } => { + visitor.visit_body(body); + } + StmtKind::For { body, orelse, .. } => { + visitor.visit_body(body); + visitor.visit_body(orelse); + } + StmtKind::ClassDef { body, .. } => { + visitor.visit_body(body); + } + StmtKind::AsyncFor { body, orelse, .. } => { + visitor.visit_body(body); + visitor.visit_body(orelse); + } + StmtKind::While { body, orelse, .. } => { + visitor.visit_body(body); + visitor.visit_body(orelse); + } + StmtKind::If { body, orelse, .. } => { + visitor.visit_body(body); + visitor.visit_body(orelse); + } + StmtKind::With { body, .. } => { + visitor.visit_body(body); + } + StmtKind::AsyncWith { body, .. } => { + visitor.visit_body(body); + } + StmtKind::Match { cases, .. } => { + for match_case in cases { + visitor.visit_match_case(match_case); + } + } + StmtKind::Try { + body, + handlers, + orelse, + finalbody, + } => { + visitor.visit_body(body); + for excepthandler in handlers { + visitor.visit_excepthandler(excepthandler); + } + visitor.visit_body(orelse); + visitor.visit_body(finalbody); + } + StmtKind::TryStar { + body, + handlers, + orelse, + finalbody, + } => { + visitor.visit_body(body); + for excepthandler in handlers { + visitor.visit_excepthandler(excepthandler); + } + visitor.visit_body(orelse); + visitor.visit_body(finalbody); + } + _ => {} + } +} + +pub fn walk_excepthandler<'a, V: StatementVisitor<'a> + ?Sized>( + visitor: &mut V, + excepthandler: &'a Excepthandler, +) { + match &excepthandler.node { + ExcepthandlerKind::ExceptHandler { body, .. } => { + visitor.visit_body(body); + } + } +} + +pub fn walk_match_case<'a, V: StatementVisitor<'a> + ?Sized>( + visitor: &mut V, + match_case: &'a MatchCase, +) { + visitor.visit_body(&match_case.body); +} diff --git a/crates/ruff_python_ast/src/visitor.rs b/crates/ruff_python_ast/src/visitor.rs index f82ee38ea8..d64e1c5745 100644 --- a/crates/ruff_python_ast/src/visitor.rs +++ b/crates/ruff_python_ast/src/visitor.rs @@ -1,9 +1,15 @@ +//! AST visitor trait and walk functions. + use rustpython_parser::ast::{ Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, Excepthandler, ExcepthandlerKind, Expr, ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern, PatternKind, Stmt, StmtKind, Unaryop, Withitem, }; +/// A trait for AST visitors. Visits all nodes in the AST recursively. +/// +/// Prefer [`crate::statement_visitor::StatementVisitor`] for visitors that only need to visit +/// statements. pub trait Visitor<'a> { fn visit_stmt(&mut self, stmt: &'a Stmt) { walk_stmt(self, stmt);