Add a specialized StatementVisitor (#4349)

This commit is contained in:
Charlie Marsh 2023-05-10 12:42:20 -04:00 committed by GitHub
parent 6532455672
commit fd34797d0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 167 additions and 121 deletions

View file

@ -8,7 +8,7 @@ use ruff_diagnostics::Diagnostic;
use ruff_python_ast::helpers::to_module_path; use ruff_python_ast::helpers::to_module_path;
use ruff_python_ast::imports::{ImportMap, ModuleImport}; use ruff_python_ast::imports::{ImportMap, ModuleImport};
use ruff_python_ast::source_code::{Indexer, Locator, Stylist}; use ruff_python_ast::source_code::{Indexer, Locator, Stylist};
use ruff_python_ast::visitor::Visitor; use ruff_python_ast::statement_visitor::StatementVisitor;
use ruff_python_stdlib::path::is_python_stub_file; use ruff_python_stdlib::path::is_python_stub_file;
use crate::directives::IsortDirectives; use crate::directives::IsortDirectives;

View file

@ -10,8 +10,7 @@ use rustpython_parser::Tok;
use ruff_python_ast::newlines::UniversalNewlineIterator; use ruff_python_ast::newlines::UniversalNewlineIterator;
use ruff_python_ast::source_code::Locator; use ruff_python_ast::source_code::Locator;
use ruff_python_ast::visitor; use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use ruff_python_ast::visitor::Visitor;
/// Extract doc lines (standalone comments) from a token sequence. /// Extract doc lines (standalone comments) from a token sequence.
pub fn doc_lines_from_tokens<'a>(lxr: &'a [LexResult], locator: &'a Locator<'a>) -> DocLines<'a> { pub fn doc_lines_from_tokens<'a>(lxr: &'a [LexResult], locator: &'a Locator<'a>) -> DocLines<'a> {
@ -75,7 +74,7 @@ struct StringLinesVisitor<'a> {
locator: &'a Locator<'a>, locator: &'a Locator<'a>,
} }
impl Visitor<'_> for StringLinesVisitor<'_> { impl StatementVisitor<'_> for StringLinesVisitor<'_> {
fn visit_stmt(&mut self, stmt: &Stmt) { fn visit_stmt(&mut self, stmt: &Stmt) {
if let StmtKind::Expr { value } = &stmt.node { if let StmtKind::Expr { value } = &stmt.node {
if let ExprKind::Constant { if let ExprKind::Constant {
@ -91,7 +90,7 @@ impl Visitor<'_> for StringLinesVisitor<'_> {
} }
} }
} }
visitor::walk_stmt(self, stmt); walk_stmt(self, stmt);
} }
} }

View file

@ -3,7 +3,7 @@ use rustpython_parser::ast::{Constant, Expr, ExprKind, Stmt};
use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Violation}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::ReturnStatementVisitor; use ruff_python_ast::helpers::ReturnStatementVisitor;
use ruff_python_ast::visitor::Visitor; use ruff_python_ast::statement_visitor::StatementVisitor;
use ruff_python_ast::{cast, helpers}; use ruff_python_ast::{cast, helpers};
use ruff_python_semantic::analyze::visibility; use ruff_python_semantic::analyze::visibility;
use ruff_python_semantic::analyze::visibility::Visibility; use ruff_python_semantic::analyze::visibility::Visibility;
@ -416,9 +416,7 @@ impl Violation for AnyType {
fn is_none_returning(body: &[Stmt]) -> bool { fn is_none_returning(body: &[Stmt]) -> bool {
let mut visitor = ReturnStatementVisitor::default(); let mut visitor = ReturnStatementVisitor::default();
for stmt in body { visitor.visit_body(body);
visitor.visit_stmt(stmt);
}
for expr in visitor.returns.into_iter().flatten() { for expr in visitor.returns.into_iter().flatten() {
if !matches!( if !matches!(
expr.node, expr.node,

View file

@ -3,7 +3,7 @@ use rustpython_parser::ast::{ExprKind, Stmt};
use ruff_diagnostics::{Diagnostic, Violation}; use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::RaiseStatementVisitor; use ruff_python_ast::helpers::RaiseStatementVisitor;
use ruff_python_ast::visitor; use ruff_python_ast::statement_visitor::StatementVisitor;
use ruff_python_stdlib::str::is_lower; use ruff_python_stdlib::str::is_lower;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -25,7 +25,7 @@ impl Violation for RaiseWithoutFromInsideExcept {
pub fn raise_without_from_inside_except(checker: &mut Checker, body: &[Stmt]) { pub fn raise_without_from_inside_except(checker: &mut Checker, body: &[Stmt]) {
let raises = { let raises = {
let mut visitor = RaiseStatementVisitor::default(); let mut visitor = RaiseStatementVisitor::default();
visitor::walk_body(&mut visitor, body); visitor.visit_body(body);
visitor.raises visitor.raises
}; };

View file

@ -1,12 +1,8 @@
use ruff_text_size::{TextRange, TextSize}; use ruff_text_size::{TextRange, TextSize};
use rustpython_parser::ast::{ use rustpython_parser::ast::{Excepthandler, ExcepthandlerKind, MatchCase, Stmt, StmtKind};
Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, Excepthandler,
ExcepthandlerKind, Expr, ExprContext, Keyword, MatchCase, Operator, Pattern, Stmt, StmtKind,
Unaryop, Withitem,
};
use ruff_python_ast::source_code::Locator; use ruff_python_ast::source_code::Locator;
use ruff_python_ast::visitor::Visitor; use ruff_python_ast::statement_visitor::StatementVisitor;
use crate::directives::IsortDirectives; use crate::directives::IsortDirectives;
use crate::rules::isort::helpers; use crate::rules::isort::helpers;
@ -111,7 +107,7 @@ impl<'a> ImportTracker<'a> {
} }
} }
impl<'a, 'b> Visitor<'b> for ImportTracker<'a> impl<'a, 'b> StatementVisitor<'b> for ImportTracker<'a>
where where
'b: 'a, 'b: 'a,
{ {
@ -226,8 +222,7 @@ where
} }
self.finalize(None); self.finalize(None);
} }
StmtKind::Match { subject, cases } => { StmtKind::Match { cases, .. } => {
self.visit_expr(subject);
for match_case in cases { for match_case in cases {
self.visit_match_case(match_case); self.visit_match_case(match_case);
} }
@ -268,24 +263,6 @@ where
self.nested = prev_nested; self.nested = prev_nested;
} }
fn visit_annotation(&mut self, _: &'b Expr) {}
fn visit_expr(&mut self, _: &'b Expr) {}
fn visit_constant(&mut self, _: &'b Constant) {}
fn visit_expr_context(&mut self, _: &'b ExprContext) {}
fn visit_boolop(&mut self, _: &'b Boolop) {}
fn visit_operator(&mut self, _: &'b Operator) {}
fn visit_unaryop(&mut self, _: &'b Unaryop) {}
fn visit_cmpop(&mut self, _: &'b Cmpop) {}
fn visit_comprehension(&mut self, _: &'b Comprehension) {}
fn visit_excepthandler(&mut self, excepthandler: &'b Excepthandler) { fn visit_excepthandler(&mut self, excepthandler: &'b Excepthandler) {
let prev_nested = self.nested; let prev_nested = self.nested;
self.nested = true; self.nested = true;
@ -299,22 +276,10 @@ where
self.nested = prev_nested; self.nested = prev_nested;
} }
fn visit_arguments(&mut self, _: &'b Arguments) {}
fn visit_arg(&mut self, _: &'b Arg) {}
fn visit_keyword(&mut self, _: &'b Keyword) {}
fn visit_alias(&mut self, _: &'b Alias) {}
fn visit_withitem(&mut self, _: &'b Withitem) {}
fn visit_match_case(&mut self, match_case: &'b MatchCase) { fn visit_match_case(&mut self, match_case: &'b MatchCase) {
for stmt in &match_case.body { for stmt in &match_case.body {
self.visit_stmt(stmt); self.visit_stmt(stmt);
} }
self.finalize(None); self.finalize(None);
} }
fn visit_pattern(&mut self, _: &'b Pattern) {}
} }

View file

@ -7,9 +7,8 @@ use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::comparable::ComparableExpr; use ruff_python_ast::comparable::ComparableExpr;
use ruff_python_ast::helpers::unparse_expr; use ruff_python_ast::helpers::unparse_expr;
use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use ruff_python_ast::types::Node; use ruff_python_ast::types::Node;
use ruff_python_ast::visitor;
use ruff_python_ast::visitor::Visitor;
use ruff_python_semantic::context::Context; use ruff_python_semantic::context::Context;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -146,7 +145,7 @@ struct InnerForWithAssignTargetsVisitor<'a> {
assignment_targets: Vec<ExprWithInnerBindingKind<'a>>, assignment_targets: Vec<ExprWithInnerBindingKind<'a>>,
} }
impl<'a, 'b> Visitor<'b> for InnerForWithAssignTargetsVisitor<'a> impl<'a, 'b> StatementVisitor<'b> for InnerForWithAssignTargetsVisitor<'a>
where where
'b: 'a, 'b: 'a,
{ {
@ -225,7 +224,7 @@ where
StmtKind::FunctionDef { .. } => {} StmtKind::FunctionDef { .. } => {}
// Otherwise, do recurse. // Otherwise, do recurse.
_ => { _ => {
visitor::walk_stmt(self, stmt); walk_stmt(self, stmt);
} }
} }
} }

View file

@ -4,7 +4,7 @@ use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::{identifier_range, ReturnStatementVisitor}; use ruff_python_ast::helpers::{identifier_range, ReturnStatementVisitor};
use ruff_python_ast::source_code::Locator; use ruff_python_ast::source_code::Locator;
use ruff_python_ast::visitor::Visitor; use ruff_python_ast::statement_visitor::StatementVisitor;
#[violation] #[violation]
pub struct TooManyReturnStatements { pub struct TooManyReturnStatements {

View file

@ -4,8 +4,8 @@ use rustpython_parser::ast::{Constant, Expr, ExprKind, Stmt, StmtKind};
use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Fix}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Fix};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::{is_const_none, ReturnStatementVisitor}; use ruff_python_ast::helpers::{is_const_none, ReturnStatementVisitor};
use ruff_python_ast::statement_visitor::StatementVisitor;
use ruff_python_ast::types::RefEquality; use ruff_python_ast::types::RefEquality;
use ruff_python_ast::visitor::Visitor;
use crate::autofix::actions::delete_stmt; use crate::autofix::actions::delete_stmt;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;

View file

@ -3,9 +3,10 @@ use rustpython_parser::ast::{Expr, ExprContext, ExprKind, Stmt, StmtKind};
use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::statement_visitor::StatementVisitor;
use ruff_python_ast::types::RefEquality; use ruff_python_ast::types::RefEquality;
use ruff_python_ast::visitor;
use ruff_python_ast::visitor::Visitor; use ruff_python_ast::visitor::Visitor;
use ruff_python_ast::{statement_visitor, visitor};
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
use crate::registry::AsRule; use crate::registry::AsRule;
@ -59,7 +60,7 @@ struct YieldFromVisitor<'a> {
yields: Vec<YieldFrom<'a>>, yields: Vec<YieldFrom<'a>>,
} }
impl<'a> Visitor<'a> for YieldFromVisitor<'a> { impl<'a> StatementVisitor<'a> for YieldFromVisitor<'a> {
fn visit_stmt(&mut self, stmt: &'a Stmt) { fn visit_stmt(&mut self, stmt: &'a Stmt) {
match &stmt.node { match &stmt.node {
StmtKind::For { StmtKind::For {
@ -97,20 +98,7 @@ impl<'a> Visitor<'a> for YieldFromVisitor<'a> {
| StmtKind::ClassDef { .. } => { | StmtKind::ClassDef { .. } => {
// Don't recurse into anything that defines a new scope. // Don't recurse into anything that defines a new scope.
} }
_ => visitor::walk_stmt(self, stmt), _ => statement_visitor::walk_stmt(self, stmt),
}
}
fn visit_expr(&mut self, expr: &'a Expr) {
match &expr.node {
ExprKind::ListComp { .. }
| ExprKind::SetComp { .. }
| ExprKind::DictComp { .. }
| ExprKind::GeneratorExp { .. }
| ExprKind::Lambda { .. } => {
// Don't recurse into anything that defines a new scope.
}
_ => visitor::walk_expr(self, expr),
} }
} }
} }

View file

@ -2,7 +2,7 @@ use rustpython_parser::ast::{Excepthandler, Stmt, StmtKind};
use ruff_diagnostics::{Diagnostic, Violation}; use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::visitor::{self, Visitor}; use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -58,7 +58,7 @@ struct RaiseStatementVisitor<'a> {
raises: Vec<&'a Stmt>, raises: Vec<&'a Stmt>,
} }
impl<'a, 'b> Visitor<'b> for RaiseStatementVisitor<'a> impl<'a, 'b> StatementVisitor<'b> for RaiseStatementVisitor<'a>
where where
'b: 'a, 'b: 'a,
{ {
@ -66,7 +66,7 @@ where
match stmt.node { match stmt.node {
StmtKind::Raise { .. } => self.raises.push(stmt), StmtKind::Raise { .. } => self.raises.push(stmt),
StmtKind::Try { .. } | StmtKind::TryStar { .. } => (), StmtKind::Try { .. } | StmtKind::TryStar { .. } => (),
_ => visitor::walk_stmt(self, stmt), _ => walk_stmt(self, stmt),
} }
} }
} }
@ -79,9 +79,7 @@ pub fn raise_within_try(checker: &mut Checker, body: &[Stmt], handlers: &[Except
let raises = { let raises = {
let mut visitor = RaiseStatementVisitor::default(); let mut visitor = RaiseStatementVisitor::default();
for stmt in body { visitor.visit_body(body);
visitor.visit_stmt(stmt);
}
visitor.raises visitor.raises
}; };

View file

@ -3,7 +3,7 @@ use rustpython_parser::ast::{ExprKind, Stmt};
use ruff_diagnostics::{Diagnostic, Violation}; use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::RaiseStatementVisitor; use ruff_python_ast::helpers::RaiseStatementVisitor;
use ruff_python_ast::visitor::Visitor; use ruff_python_ast::statement_visitor::StatementVisitor;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -50,9 +50,7 @@ impl Violation for ReraiseNoCause {
pub fn reraise_no_cause(checker: &mut Checker, body: &[Stmt]) { pub fn reraise_no_cause(checker: &mut Checker, body: &[Stmt]) {
let raises = { let raises = {
let mut visitor = RaiseStatementVisitor::default(); let mut visitor = RaiseStatementVisitor::default();
for stmt in body { visitor.visit_body(body);
visitor.visit_stmt(stmt);
}
visitor.raises visitor.raises
}; };

View file

@ -2,8 +2,7 @@ use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind};
use ruff_diagnostics::{Diagnostic, Violation}; use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::visitor; use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use ruff_python_ast::visitor::Visitor;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -51,7 +50,7 @@ struct ControlFlowVisitor<'a> {
continues: Vec<&'a Stmt>, continues: Vec<&'a Stmt>,
} }
impl<'a, 'b> Visitor<'b> for ControlFlowVisitor<'a> impl<'a, 'b> StatementVisitor<'b> for ControlFlowVisitor<'a>
where where
'b: 'a, 'b: 'a,
{ {
@ -65,19 +64,7 @@ where
StmtKind::Return { .. } => self.returns.push(stmt), StmtKind::Return { .. } => self.returns.push(stmt),
StmtKind::Break => self.breaks.push(stmt), StmtKind::Break => self.breaks.push(stmt),
StmtKind::Continue => self.continues.push(stmt), StmtKind::Continue => self.continues.push(stmt),
_ => visitor::walk_stmt(self, stmt), _ => walk_stmt(self, stmt),
}
}
fn visit_expr(&mut self, expr: &'b Expr) {
match &expr.node {
ExprKind::ListComp { .. }
| ExprKind::DictComp { .. }
| ExprKind::SetComp { .. }
| ExprKind::GeneratorExp { .. } => {
// Don't recurse.
}
_ => visitor::walk_expr(self, expr),
} }
} }
} }

View file

@ -2,8 +2,7 @@ use rustpython_parser::ast::{Excepthandler, ExcepthandlerKind, Expr, ExprKind, S
use ruff_diagnostics::{Diagnostic, Violation}; use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::visitor; use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use ruff_python_ast::visitor::Visitor;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -46,7 +45,7 @@ struct RaiseStatementVisitor<'a> {
raises: Vec<(Option<&'a Expr>, Option<&'a Expr>)>, raises: Vec<(Option<&'a Expr>, Option<&'a Expr>)>,
} }
impl<'a, 'b> Visitor<'b> for RaiseStatementVisitor<'a> impl<'a, 'b> StatementVisitor<'b> for RaiseStatementVisitor<'a>
where where
'b: 'a, 'b: 'a,
{ {
@ -60,10 +59,10 @@ where
body, finalbody, .. body, finalbody, ..
} => { } => {
for stmt in body.iter().chain(finalbody.iter()) { for stmt in body.iter().chain(finalbody.iter()) {
visitor::walk_stmt(self, stmt); walk_stmt(self, stmt);
} }
} }
_ => visitor::walk_stmt(self, stmt), _ => walk_stmt(self, stmt),
} }
} }
} }
@ -80,9 +79,7 @@ pub fn verbose_raise(checker: &mut Checker, handlers: &[Excepthandler]) {
{ {
let raises = { let raises = {
let mut visitor = RaiseStatementVisitor::default(); let mut visitor = RaiseStatementVisitor::default();
for stmt in body { visitor.visit_body(body);
visitor.visit_stmt(stmt);
}
visitor.raises visitor.raises
}; };
for (exc, cause) in raises { for (exc, cause) in raises {

View file

@ -1,7 +1,7 @@
use criterion::measurement::WallTime; use criterion::measurement::WallTime;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ruff_benchmark::{TestCase, TestCaseSpeed, TestFile, TestFileDownloadError}; use ruff_benchmark::{TestCase, TestCaseSpeed, TestFile, TestFileDownloadError};
use ruff_python_ast::visitor::{walk_stmt, Visitor}; use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use rustpython_parser::ast::Stmt; use rustpython_parser::ast::Stmt;
use std::time::Duration; use std::time::Duration;
@ -40,7 +40,7 @@ struct CountVisitor {
count: usize, count: usize,
} }
impl<'a> Visitor<'a> for CountVisitor { impl<'a> StatementVisitor<'a> for CountVisitor {
fn visit_stmt(&mut self, stmt: &'a Stmt) { fn visit_stmt(&mut self, stmt: &'a Stmt) {
walk_stmt(self, stmt); walk_stmt(self, stmt);
self.count += 1; self.count += 1;

View file

@ -18,8 +18,7 @@ use smallvec::SmallVec;
use crate::call_path::CallPath; use crate::call_path::CallPath;
use crate::newlines::UniversalNewlineIterator; use crate::newlines::UniversalNewlineIterator;
use crate::source_code::{Generator, Indexer, Locator, Stylist}; use crate::source_code::{Generator, Indexer, Locator, Stylist};
use crate::visitor; use crate::statement_visitor::{walk_body, walk_stmt, StatementVisitor};
use crate::visitor::Visitor;
/// Create an `Expr` with default location from an `ExprKind`. /// Create an `Expr` with default location from an `ExprKind`.
pub fn create_expr(node: ExprKind) -> Expr { pub fn create_expr(node: ExprKind) -> Expr {
@ -816,13 +815,13 @@ pub fn resolve_imported_module_path<'a>(
Some(Cow::Owned(qualified_path)) Some(Cow::Owned(qualified_path))
} }
/// A [`Visitor`] that collects all `return` statements in a function or method. /// A [`StatementVisitor`] that collects all `return` statements in a function or method.
#[derive(Default)] #[derive(Default)]
pub struct ReturnStatementVisitor<'a> { pub struct ReturnStatementVisitor<'a> {
pub returns: Vec<Option<&'a Expr>>, pub returns: Vec<Option<&'a Expr>>,
} }
impl<'a, 'b> Visitor<'b> for ReturnStatementVisitor<'a> impl<'a, 'b> StatementVisitor<'b> for ReturnStatementVisitor<'a>
where where
'b: 'a, 'b: 'a,
{ {
@ -832,18 +831,18 @@ where
// Don't recurse. // Don't recurse.
} }
StmtKind::Return { value } => self.returns.push(value.as_deref()), StmtKind::Return { value } => self.returns.push(value.as_deref()),
_ => visitor::walk_stmt(self, stmt), _ => walk_stmt(self, stmt),
} }
} }
} }
/// A [`Visitor`] that collects all `raise` statements in a function or method. /// A [`StatementVisitor`] that collects all `raise` statements in a function or method.
#[derive(Default)] #[derive(Default)]
pub struct RaiseStatementVisitor<'a> { pub struct RaiseStatementVisitor<'a> {
pub raises: Vec<(TextRange, Option<&'a Expr>, Option<&'a Expr>)>, pub raises: Vec<(TextRange, Option<&'a Expr>, Option<&'a Expr>)>,
} }
impl<'a, 'b> Visitor<'b> for RaiseStatementVisitor<'b> impl<'a, 'b> StatementVisitor<'b> for RaiseStatementVisitor<'b>
where where
'b: 'a, 'b: 'a,
{ {
@ -859,19 +858,19 @@ where
| StmtKind::Try { .. } | StmtKind::Try { .. }
| StmtKind::TryStar { .. } => {} | StmtKind::TryStar { .. } => {}
StmtKind::If { body, orelse, .. } => { StmtKind::If { body, orelse, .. } => {
visitor::walk_body(self, body); walk_body(self, body);
visitor::walk_body(self, orelse); walk_body(self, orelse);
} }
StmtKind::While { body, .. } StmtKind::While { body, .. }
| StmtKind::With { body, .. } | StmtKind::With { body, .. }
| StmtKind::AsyncWith { body, .. } | StmtKind::AsyncWith { body, .. }
| StmtKind::For { body, .. } | StmtKind::For { body, .. }
| StmtKind::AsyncFor { body, .. } => { | StmtKind::AsyncFor { body, .. } => {
visitor::walk_body(self, body); walk_body(self, body);
} }
StmtKind::Match { cases, .. } => { StmtKind::Match { cases, .. } => {
for case in cases { for case in cases {
visitor::walk_body(self, &case.body); walk_body(self, &case.body);
} }
} }
_ => {} _ => {}
@ -884,7 +883,7 @@ struct GlobalStatementVisitor<'a> {
globals: FxHashMap<&'a str, &'a Stmt>, globals: FxHashMap<&'a str, &'a Stmt>,
} }
impl<'a> Visitor<'a> for GlobalStatementVisitor<'a> { impl<'a> StatementVisitor<'a> for GlobalStatementVisitor<'a> {
fn visit_stmt(&mut self, stmt: &'a Stmt) { fn visit_stmt(&mut self, stmt: &'a Stmt) {
match &stmt.node { match &stmt.node {
StmtKind::Global { names } => { StmtKind::Global { names } => {
@ -897,7 +896,7 @@ impl<'a> Visitor<'a> for GlobalStatementVisitor<'a> {
| StmtKind::ClassDef { .. } => { | StmtKind::ClassDef { .. } => {
// Don't recurse. // Don't recurse.
} }
_ => visitor::walk_stmt(self, stmt), _ => walk_stmt(self, stmt),
} }
} }
} }

View file

@ -8,6 +8,7 @@ pub mod imports;
pub mod newlines; pub mod newlines;
pub mod relocate; pub mod relocate;
pub mod source_code; pub mod source_code;
pub mod statement_visitor;
pub mod str; pub mod str;
pub mod token_kind; pub mod token_kind;
pub mod types; pub mod types;

View file

@ -0,0 +1,111 @@
//! Specialized AST visitor trait and walk functions that only visit statements.
use rustpython_parser::ast::{Excepthandler, ExcepthandlerKind, MatchCase, Stmt, StmtKind};
/// A trait for AST visitors that only need to visit statements.
pub trait StatementVisitor<'a> {
fn visit_body(&mut self, body: &'a [Stmt]) {
walk_body(self, body);
}
fn visit_stmt(&mut self, stmt: &'a Stmt) {
walk_stmt(self, stmt);
}
fn visit_excepthandler(&mut self, excepthandler: &'a Excepthandler) {
walk_excepthandler(self, excepthandler);
}
fn visit_match_case(&mut self, match_case: &'a MatchCase) {
walk_match_case(self, match_case);
}
}
pub fn walk_body<'a, V: StatementVisitor<'a> + ?Sized>(visitor: &mut V, body: &'a [Stmt]) {
for stmt in body {
visitor.visit_stmt(stmt);
}
}
pub fn walk_stmt<'a, V: StatementVisitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) {
match &stmt.node {
StmtKind::FunctionDef { body, .. } => {
visitor.visit_body(body);
}
StmtKind::AsyncFunctionDef { body, .. } => {
visitor.visit_body(body);
}
StmtKind::For { body, orelse, .. } => {
visitor.visit_body(body);
visitor.visit_body(orelse);
}
StmtKind::ClassDef { body, .. } => {
visitor.visit_body(body);
}
StmtKind::AsyncFor { body, orelse, .. } => {
visitor.visit_body(body);
visitor.visit_body(orelse);
}
StmtKind::While { body, orelse, .. } => {
visitor.visit_body(body);
visitor.visit_body(orelse);
}
StmtKind::If { body, orelse, .. } => {
visitor.visit_body(body);
visitor.visit_body(orelse);
}
StmtKind::With { body, .. } => {
visitor.visit_body(body);
}
StmtKind::AsyncWith { body, .. } => {
visitor.visit_body(body);
}
StmtKind::Match { cases, .. } => {
for match_case in cases {
visitor.visit_match_case(match_case);
}
}
StmtKind::Try {
body,
handlers,
orelse,
finalbody,
} => {
visitor.visit_body(body);
for excepthandler in handlers {
visitor.visit_excepthandler(excepthandler);
}
visitor.visit_body(orelse);
visitor.visit_body(finalbody);
}
StmtKind::TryStar {
body,
handlers,
orelse,
finalbody,
} => {
visitor.visit_body(body);
for excepthandler in handlers {
visitor.visit_excepthandler(excepthandler);
}
visitor.visit_body(orelse);
visitor.visit_body(finalbody);
}
_ => {}
}
}
pub fn walk_excepthandler<'a, V: StatementVisitor<'a> + ?Sized>(
visitor: &mut V,
excepthandler: &'a Excepthandler,
) {
match &excepthandler.node {
ExcepthandlerKind::ExceptHandler { body, .. } => {
visitor.visit_body(body);
}
}
}
pub fn walk_match_case<'a, V: StatementVisitor<'a> + ?Sized>(
visitor: &mut V,
match_case: &'a MatchCase,
) {
visitor.visit_body(&match_case.body);
}

View file

@ -1,9 +1,15 @@
//! AST visitor trait and walk functions.
use rustpython_parser::ast::{ use rustpython_parser::ast::{
Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, Excepthandler, Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, Excepthandler,
ExcepthandlerKind, Expr, ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern, ExcepthandlerKind, Expr, ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern,
PatternKind, Stmt, StmtKind, Unaryop, Withitem, PatternKind, Stmt, StmtKind, Unaryop, Withitem,
}; };
/// A trait for AST visitors. Visits all nodes in the AST recursively.
///
/// Prefer [`crate::statement_visitor::StatementVisitor`] for visitors that only need to visit
/// statements.
pub trait Visitor<'a> { pub trait Visitor<'a> {
fn visit_stmt(&mut self, stmt: &'a Stmt) { fn visit_stmt(&mut self, stmt: &'a Stmt) {
walk_stmt(self, stmt); walk_stmt(self, stmt);