Add AnyNodeRef.visit_preorder

<!--
Thank you for contributing to Ruff! To help us out with reviewing, please consider the following:

- Does this pull request include a summary of the change? (See below.)
- Does this pull request include a descriptive title?
- Does this pull request include references to any relevant issues?
-->

## Summary

This PR adds the `AnyNodeRef.visit_preorder` method. I'll need this method to mark all comments of a suppressed node's children as formatted (in debug builds). 

I'm not super happy with this because it now requires a double-dispatch where the `walk_*` methods call into `node.visit_preorder` and the `visit_preorder` then calls back into the visitor. Meaning,
the new implementation now probably results in way more function calls. The other downside is that `AnyNodeRef` now contains code that is difficult to auto-generate. This could be mitigated by extracting the `visit_preorder` method into its own `VisitPreorder` trait. 

Anyway, this approach solves the need and avoids duplicating the visiting code once more. 

<!-- What's the purpose of the change? What does it do, and why? -->

## Test Plan

`cargo test`

<!-- How was it tested? -->
This commit is contained in:
Micha Reiser 2023-08-10 08:35:09 +02:00 committed by GitHub
parent c1bc67686c
commit ac5c8bb3b6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 1321 additions and 640 deletions

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,8 @@
use crate::node::AnyNodeRef;
use crate::node::{AnyNodeRef, AstNode};
use crate::{
self as ast, Alias, Arguments, BoolOp, CmpOp, Comprehension, Constant, Decorator,
ElifElseClause, ExceptHandler, Expr, Keyword, MatchCase, Mod, Operator, Parameter,
ParameterWithDefault, Parameters, Pattern, Stmt, TypeParam, TypeParamTypeVar, TypeParams,
UnaryOp, WithItem,
Alias, Arguments, BoolOp, CmpOp, Comprehension, Constant, Decorator, ElifElseClause,
ExceptHandler, Expr, Keyword, MatchCase, Mod, Operator, Parameter, ParameterWithDefault,
Parameters, Pattern, Stmt, TypeParam, TypeParams, UnaryOp, WithItem,
};
/// Visitor that traverses all nodes recursively in pre-order.
@ -152,10 +151,8 @@ where
let node = AnyNodeRef::from(module);
if visitor.enter_node(node).is_traverse() {
match module {
Mod::Module(ast::ModModule { body, range: _ }) => {
visitor.visit_body(body);
}
Mod::Expression(ast::ModExpression { body, range: _ }) => visitor.visit_expr(body),
Mod::Module(module) => module.visit_preorder(visitor),
Mod::Expression(module) => module.visit_preorder(visitor),
}
}
@ -179,246 +176,32 @@ where
if visitor.enter_node(node).is_traverse() {
match stmt {
Stmt::Expr(ast::StmtExpr { value, range: _ }) => visitor.visit_expr(value),
Stmt::FunctionDef(ast::StmtFunctionDef {
parameters,
body,
decorator_list,
returns,
type_params,
..
}) => {
for decorator in decorator_list {
visitor.visit_decorator(decorator);
}
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
visitor.visit_parameters(parameters);
for expr in returns {
visitor.visit_annotation(expr);
}
visitor.visit_body(body);
}
Stmt::ClassDef(ast::StmtClassDef {
arguments,
body,
decorator_list,
type_params,
..
}) => {
for decorator in decorator_list {
visitor.visit_decorator(decorator);
}
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
if let Some(arguments) = arguments {
visitor.visit_arguments(arguments);
}
visitor.visit_body(body);
}
Stmt::Return(ast::StmtReturn { value, range: _ }) => {
if let Some(expr) = value {
visitor.visit_expr(expr);
}
}
Stmt::Delete(ast::StmtDelete { targets, range: _ }) => {
for expr in targets {
visitor.visit_expr(expr);
}
}
Stmt::TypeAlias(ast::StmtTypeAlias {
range: _,
name,
type_params,
value,
}) => {
visitor.visit_expr(name);
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
visitor.visit_expr(value);
}
Stmt::Assign(ast::StmtAssign {
targets,
value,
range: _,
}) => {
for expr in targets {
visitor.visit_expr(expr);
}
visitor.visit_expr(value);
}
Stmt::AugAssign(ast::StmtAugAssign {
target,
op,
value,
range: _,
}) => {
visitor.visit_expr(target);
visitor.visit_operator(op);
visitor.visit_expr(value);
}
Stmt::AnnAssign(ast::StmtAnnAssign {
target,
annotation,
value,
range: _,
simple: _,
}) => {
visitor.visit_expr(target);
visitor.visit_annotation(annotation);
if let Some(expr) = value {
visitor.visit_expr(expr);
}
}
Stmt::For(ast::StmtFor {
target,
iter,
body,
orelse,
..
}) => {
visitor.visit_expr(target);
visitor.visit_expr(iter);
visitor.visit_body(body);
visitor.visit_body(orelse);
}
Stmt::While(ast::StmtWhile {
test,
body,
orelse,
range: _,
}) => {
visitor.visit_expr(test);
visitor.visit_body(body);
visitor.visit_body(orelse);
}
Stmt::If(ast::StmtIf {
test,
body,
elif_else_clauses,
range: _,
}) => {
visitor.visit_expr(test);
visitor.visit_body(body);
for clause in elif_else_clauses {
visitor.visit_elif_else_clause(clause);
}
}
Stmt::With(ast::StmtWith {
items,
body,
is_async: _,
range: _,
}) => {
for with_item in items {
visitor.visit_with_item(with_item);
}
visitor.visit_body(body);
}
Stmt::Match(ast::StmtMatch {
subject,
cases,
range: _,
}) => {
visitor.visit_expr(subject);
for match_case in cases {
visitor.visit_match_case(match_case);
}
}
Stmt::Raise(ast::StmtRaise {
exc,
cause,
range: _,
}) => {
if let Some(expr) = exc {
visitor.visit_expr(expr);
};
if let Some(expr) = cause {
visitor.visit_expr(expr);
};
}
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
range: _,
})
| Stmt::TryStar(ast::StmtTryStar {
body,
handlers,
orelse,
finalbody,
range: _,
}) => {
visitor.visit_body(body);
for except_handler in handlers {
visitor.visit_except_handler(except_handler);
}
visitor.visit_body(orelse);
visitor.visit_body(finalbody);
}
Stmt::Assert(ast::StmtAssert {
test,
msg,
range: _,
}) => {
visitor.visit_expr(test);
if let Some(expr) = msg {
visitor.visit_expr(expr);
}
}
Stmt::Import(ast::StmtImport { names, range: _ }) => {
for alias in names {
visitor.visit_alias(alias);
}
}
Stmt::ImportFrom(ast::StmtImportFrom {
range: _,
module: _,
names,
level: _,
}) => {
for alias in names {
visitor.visit_alias(alias);
}
}
Stmt::Pass(_)
| Stmt::Break(_)
| Stmt::Continue(_)
| Stmt::Global(_)
| Stmt::Nonlocal(_)
| Stmt::IpyEscapeCommand(_) => {}
Stmt::Expr(stmt) => stmt.visit_preorder(visitor),
Stmt::FunctionDef(stmt) => stmt.visit_preorder(visitor),
Stmt::ClassDef(stmt) => stmt.visit_preorder(visitor),
Stmt::Return(stmt) => stmt.visit_preorder(visitor),
Stmt::Delete(stmt) => stmt.visit_preorder(visitor),
Stmt::TypeAlias(stmt) => stmt.visit_preorder(visitor),
Stmt::Assign(stmt) => stmt.visit_preorder(visitor),
Stmt::AugAssign(stmt) => stmt.visit_preorder(visitor),
Stmt::AnnAssign(stmt) => stmt.visit_preorder(visitor),
Stmt::For(stmt) => stmt.visit_preorder(visitor),
Stmt::While(stmt) => stmt.visit_preorder(visitor),
Stmt::If(stmt) => stmt.visit_preorder(visitor),
Stmt::With(stmt) => stmt.visit_preorder(visitor),
Stmt::Match(stmt) => stmt.visit_preorder(visitor),
Stmt::Raise(stmt) => stmt.visit_preorder(visitor),
Stmt::Try(stmt) => stmt.visit_preorder(visitor),
Stmt::TryStar(stmt) => stmt.visit_preorder(visitor),
Stmt::Assert(stmt) => stmt.visit_preorder(visitor),
Stmt::Import(stmt) => stmt.visit_preorder(visitor),
Stmt::ImportFrom(stmt) => stmt.visit_preorder(visitor),
Stmt::Pass(stmt) => stmt.visit_preorder(visitor),
Stmt::Break(stmt) => stmt.visit_preorder(visitor),
Stmt::Continue(stmt) => stmt.visit_preorder(visitor),
Stmt::Global(stmt) => stmt.visit_preorder(visitor),
Stmt::Nonlocal(stmt) => stmt.visit_preorder(visitor),
Stmt::IpyEscapeCommand(stmt) => stmt.visit_preorder(visitor),
}
}
@ -452,7 +235,7 @@ where
{
let node = AnyNodeRef::from(decorator);
if visitor.enter_node(node).is_traverse() {
visitor.visit_expr(&decorator.expression);
decorator.visit_preorder(visitor);
}
visitor.leave_node(node);
@ -465,261 +248,34 @@ where
let node = AnyNodeRef::from(expr);
if visitor.enter_node(node).is_traverse() {
match expr {
Expr::BoolOp(ast::ExprBoolOp {
op,
values,
range: _,
}) => match values.as_slice() {
[left, rest @ ..] => {
visitor.visit_expr(left);
visitor.visit_bool_op(op);
for expr in rest {
visitor.visit_expr(expr);
}
}
[] => {
visitor.visit_bool_op(op);
}
},
Expr::NamedExpr(ast::ExprNamedExpr {
target,
value,
range: _,
}) => {
visitor.visit_expr(target);
visitor.visit_expr(value);
}
Expr::BinOp(ast::ExprBinOp {
left,
op,
right,
range: _,
}) => {
visitor.visit_expr(left);
visitor.visit_operator(op);
visitor.visit_expr(right);
}
Expr::UnaryOp(ast::ExprUnaryOp {
op,
operand,
range: _,
}) => {
visitor.visit_unary_op(op);
visitor.visit_expr(operand);
}
Expr::Lambda(ast::ExprLambda {
parameters,
body,
range: _,
}) => {
visitor.visit_parameters(parameters);
visitor.visit_expr(body);
}
Expr::IfExp(ast::ExprIfExp {
test,
body,
orelse,
range: _,
}) => {
// `body if test else orelse`
visitor.visit_expr(body);
visitor.visit_expr(test);
visitor.visit_expr(orelse);
}
Expr::Dict(ast::ExprDict {
keys,
values,
range: _,
}) => {
for (key, value) in keys.iter().zip(values) {
if let Some(key) = key {
visitor.visit_expr(key);
}
visitor.visit_expr(value);
}
}
Expr::Set(ast::ExprSet { elts, range: _ }) => {
for expr in elts {
visitor.visit_expr(expr);
}
}
Expr::ListComp(ast::ExprListComp {
elt,
generators,
range: _,
}) => {
visitor.visit_expr(elt);
for comprehension in generators {
visitor.visit_comprehension(comprehension);
}
}
Expr::SetComp(ast::ExprSetComp {
elt,
generators,
range: _,
}) => {
visitor.visit_expr(elt);
for comprehension in generators {
visitor.visit_comprehension(comprehension);
}
}
Expr::DictComp(ast::ExprDictComp {
key,
value,
generators,
range: _,
}) => {
visitor.visit_expr(key);
visitor.visit_expr(value);
for comprehension in generators {
visitor.visit_comprehension(comprehension);
}
}
Expr::GeneratorExp(ast::ExprGeneratorExp {
elt,
generators,
range: _,
}) => {
visitor.visit_expr(elt);
for comprehension in generators {
visitor.visit_comprehension(comprehension);
}
}
Expr::Await(ast::ExprAwait { value, range: _ })
| Expr::YieldFrom(ast::ExprYieldFrom { value, range: _ }) => visitor.visit_expr(value),
Expr::Yield(ast::ExprYield { value, range: _ }) => {
if let Some(expr) = value {
visitor.visit_expr(expr);
}
}
Expr::Compare(ast::ExprCompare {
left,
ops,
comparators,
range: _,
}) => {
visitor.visit_expr(left);
for (op, comparator) in ops.iter().zip(comparators) {
visitor.visit_cmp_op(op);
visitor.visit_expr(comparator);
}
}
Expr::Call(ast::ExprCall {
func,
arguments,
range: _,
}) => {
visitor.visit_expr(func);
visitor.visit_arguments(arguments);
}
Expr::FormattedValue(ast::ExprFormattedValue {
value, format_spec, ..
}) => {
visitor.visit_expr(value);
if let Some(expr) = format_spec {
visitor.visit_format_spec(expr);
}
}
Expr::FString(ast::ExprFString { values, range: _ }) => {
for expr in values {
visitor.visit_expr(expr);
}
}
Expr::Constant(ast::ExprConstant {
value,
range: _,
kind: _,
}) => visitor.visit_constant(value),
Expr::Attribute(ast::ExprAttribute {
value,
attr: _,
ctx: _,
range: _,
}) => {
visitor.visit_expr(value);
}
Expr::Subscript(ast::ExprSubscript {
value,
slice,
ctx: _,
range: _,
}) => {
visitor.visit_expr(value);
visitor.visit_expr(slice);
}
Expr::Starred(ast::ExprStarred {
value,
ctx: _,
range: _,
}) => {
visitor.visit_expr(value);
}
Expr::Name(ast::ExprName {
id: _,
ctx: _,
range: _,
}) => {}
Expr::List(ast::ExprList {
elts,
ctx: _,
range: _,
}) => {
for expr in elts {
visitor.visit_expr(expr);
}
}
Expr::Tuple(ast::ExprTuple {
elts,
ctx: _,
range: _,
}) => {
for expr in elts {
visitor.visit_expr(expr);
}
}
Expr::Slice(ast::ExprSlice {
lower,
upper,
step,
range: _,
}) => {
if let Some(expr) = lower {
visitor.visit_expr(expr);
}
if let Some(expr) = upper {
visitor.visit_expr(expr);
}
if let Some(expr) = step {
visitor.visit_expr(expr);
}
}
Expr::IpyEscapeCommand(_) => (),
Expr::BoolOp(expr) => expr.visit_preorder(visitor),
Expr::NamedExpr(expr) => expr.visit_preorder(visitor),
Expr::BinOp(expr) => expr.visit_preorder(visitor),
Expr::UnaryOp(expr) => expr.visit_preorder(visitor),
Expr::Lambda(expr) => expr.visit_preorder(visitor),
Expr::IfExp(expr) => expr.visit_preorder(visitor),
Expr::Dict(expr) => expr.visit_preorder(visitor),
Expr::Set(expr) => expr.visit_preorder(visitor),
Expr::ListComp(expr) => expr.visit_preorder(visitor),
Expr::SetComp(expr) => expr.visit_preorder(visitor),
Expr::DictComp(expr) => expr.visit_preorder(visitor),
Expr::GeneratorExp(expr) => expr.visit_preorder(visitor),
Expr::Await(expr) => expr.visit_preorder(visitor),
Expr::Yield(expr) => expr.visit_preorder(visitor),
Expr::YieldFrom(expr) => expr.visit_preorder(visitor),
Expr::Compare(expr) => expr.visit_preorder(visitor),
Expr::Call(expr) => expr.visit_preorder(visitor),
Expr::FormattedValue(expr) => expr.visit_preorder(visitor),
Expr::FString(expr) => expr.visit_preorder(visitor),
Expr::Constant(expr) => expr.visit_preorder(visitor),
Expr::Attribute(expr) => expr.visit_preorder(visitor),
Expr::Subscript(expr) => expr.visit_preorder(visitor),
Expr::Starred(expr) => expr.visit_preorder(visitor),
Expr::Name(expr) => expr.visit_preorder(visitor),
Expr::List(expr) => expr.visit_preorder(visitor),
Expr::Tuple(expr) => expr.visit_preorder(visitor),
Expr::Slice(expr) => expr.visit_preorder(visitor),
Expr::IpyEscapeCommand(expr) => expr.visit_preorder(visitor),
}
}
@ -732,12 +288,7 @@ where
{
let node = AnyNodeRef::from(comprehension);
if visitor.enter_node(node).is_traverse() {
visitor.visit_expr(&comprehension.target);
visitor.visit_expr(&comprehension.iter);
for expr in &comprehension.ifs {
visitor.visit_expr(expr);
}
comprehension.visit_preorder(visitor);
}
visitor.leave_node(node);
@ -749,10 +300,7 @@ where
{
let node = AnyNodeRef::from(elif_else_clause);
if visitor.enter_node(node).is_traverse() {
if let Some(test) = &elif_else_clause.test {
visitor.visit_expr(test);
}
visitor.visit_body(&elif_else_clause.body);
elif_else_clause.visit_preorder(visitor);
}
visitor.leave_node(node);
@ -765,17 +313,7 @@ where
let node = AnyNodeRef::from(except_handler);
if visitor.enter_node(node).is_traverse() {
match except_handler {
ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler {
range: _,
type_,
name: _,
body,
}) => {
if let Some(expr) = type_ {
visitor.visit_expr(expr);
}
visitor.visit_body(body);
}
ExceptHandler::ExceptHandler(except_handler) => except_handler.visit_preorder(visitor),
}
}
visitor.leave_node(node);
@ -799,13 +337,7 @@ where
{
let node = AnyNodeRef::from(arguments);
if visitor.enter_node(node).is_traverse() {
for arg in &arguments.args {
visitor.visit_expr(arg);
}
for keyword in &arguments.keywords {
visitor.visit_keyword(keyword);
}
arguments.visit_preorder(visitor);
}
visitor.leave_node(node);
@ -817,21 +349,7 @@ where
{
let node = AnyNodeRef::from(parameters);
if visitor.enter_node(node).is_traverse() {
for arg in parameters.posonlyargs.iter().chain(&parameters.args) {
visitor.visit_parameter_with_default(arg);
}
if let Some(arg) = &parameters.vararg {
visitor.visit_parameter(arg);
}
for arg in &parameters.kwonlyargs {
visitor.visit_parameter_with_default(arg);
}
if let Some(arg) = &parameters.kwarg {
visitor.visit_parameter(arg);
}
parameters.visit_preorder(visitor);
}
visitor.leave_node(node);
@ -844,9 +362,7 @@ where
let node = AnyNodeRef::from(parameter);
if visitor.enter_node(node).is_traverse() {
if let Some(expr) = &parameter.annotation {
visitor.visit_annotation(expr);
}
parameter.visit_preorder(visitor);
}
visitor.leave_node(node);
}
@ -859,10 +375,7 @@ pub fn walk_parameter_with_default<'a, V>(
{
let node = AnyNodeRef::from(parameter_with_default);
if visitor.enter_node(node).is_traverse() {
visitor.visit_parameter(&parameter_with_default.parameter);
if let Some(expr) = &parameter_with_default.default {
visitor.visit_expr(expr);
}
parameter_with_default.visit_preorder(visitor);
}
visitor.leave_node(node);
@ -876,7 +389,7 @@ where
let node = AnyNodeRef::from(keyword);
if visitor.enter_node(node).is_traverse() {
visitor.visit_expr(&keyword.value);
keyword.visit_preorder(visitor);
}
visitor.leave_node(node);
}
@ -887,11 +400,7 @@ where
{
let node = AnyNodeRef::from(with_item);
if visitor.enter_node(node).is_traverse() {
visitor.visit_expr(&with_item.context_expr);
if let Some(expr) = &with_item.optional_vars {
visitor.visit_expr(expr);
}
with_item.visit_preorder(visitor);
}
visitor.leave_node(node);
}
@ -902,9 +411,7 @@ where
{
let node = AnyNodeRef::from(type_params);
if visitor.enter_node(node).is_traverse() {
for type_param in &type_params.type_params {
visitor.visit_type_param(type_param);
}
type_params.visit_preorder(visitor);
}
visitor.leave_node(node);
}
@ -916,16 +423,9 @@ where
let node = AnyNodeRef::from(type_param);
if visitor.enter_node(node).is_traverse() {
match type_param {
TypeParam::TypeVar(TypeParamTypeVar {
bound,
name: _,
range: _,
}) => {
if let Some(expr) = bound {
visitor.visit_expr(expr);
}
}
TypeParam::TypeVarTuple(_) | TypeParam::ParamSpec(_) => {}
TypeParam::TypeVar(type_param) => type_param.visit_preorder(visitor),
TypeParam::TypeVarTuple(type_param) => type_param.visit_preorder(visitor),
TypeParam::ParamSpec(type_param) => type_param.visit_preorder(visitor),
}
}
visitor.leave_node(node);
@ -937,11 +437,7 @@ where
{
let node = AnyNodeRef::from(match_case);
if visitor.enter_node(node).is_traverse() {
visitor.visit_pattern(&match_case.pattern);
if let Some(expr) = &match_case.guard {
visitor.visit_expr(expr);
}
visitor.visit_body(&match_case.body);
match_case.visit_preorder(visitor);
}
visitor.leave_node(node);
}
@ -953,66 +449,14 @@ where
let node = AnyNodeRef::from(pattern);
if visitor.enter_node(node).is_traverse() {
match pattern {
Pattern::MatchValue(ast::PatternMatchValue { value, range: _ }) => {
visitor.visit_expr(value);
}
Pattern::MatchSingleton(ast::PatternMatchSingleton { value, range: _ }) => {
visitor.visit_constant(value);
}
Pattern::MatchSequence(ast::PatternMatchSequence { patterns, range: _ }) => {
for pattern in patterns {
visitor.visit_pattern(pattern);
}
}
Pattern::MatchMapping(ast::PatternMatchMapping {
keys,
patterns,
range: _,
rest: _,
}) => {
for (key, pattern) in keys.iter().zip(patterns) {
visitor.visit_expr(key);
visitor.visit_pattern(pattern);
}
}
Pattern::MatchClass(ast::PatternMatchClass {
cls,
patterns,
kwd_attrs: _,
kwd_patterns,
range: _,
}) => {
visitor.visit_expr(cls);
for pattern in patterns {
visitor.visit_pattern(pattern);
}
for pattern in kwd_patterns {
visitor.visit_pattern(pattern);
}
}
Pattern::MatchStar(_) => {}
Pattern::MatchAs(ast::PatternMatchAs {
pattern,
range: _,
name: _,
}) => {
if let Some(pattern) = pattern {
visitor.visit_pattern(pattern);
}
}
Pattern::MatchOr(ast::PatternMatchOr { patterns, range: _ }) => {
for pattern in patterns {
visitor.visit_pattern(pattern);
}
}
Pattern::MatchValue(pattern) => pattern.visit_preorder(visitor),
Pattern::MatchSingleton(pattern) => pattern.visit_preorder(visitor),
Pattern::MatchSequence(pattern) => pattern.visit_preorder(visitor),
Pattern::MatchMapping(pattern) => pattern.visit_preorder(visitor),
Pattern::MatchClass(pattern) => pattern.visit_preorder(visitor),
Pattern::MatchStar(pattern) => pattern.visit_preorder(visitor),
Pattern::MatchAs(pattern) => pattern.visit_preorder(visitor),
Pattern::MatchOr(pattern) => pattern.visit_preorder(visitor),
}
}
visitor.leave_node(node);
@ -1051,6 +495,8 @@ where
V: PreorderVisitor<'a> + ?Sized,
{
let node = AnyNodeRef::from(alias);
visitor.enter_node(node);
if visitor.enter_node(node).is_traverse() {
alias.visit_preorder(visitor);
}
visitor.leave_node(node);
}