Replace parents statement stack with a Nodes abstraction (#4233)

This commit is contained in:
Charlie Marsh 2023-05-06 12:12:41 -04:00 committed by GitHub
parent 2c91412321
commit c1f0661225
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 362 additions and 279 deletions

View file

@ -1,14 +1,16 @@
use ruff_text_size::TextRange; use ruff_text_size::TextRange;
use rustpython_parser::ast::{Expr, Stmt}; use rustpython_parser::ast::{Expr, Stmt};
use ruff_python_ast::types::RefEquality;
use ruff_python_semantic::analyze::visibility::{Visibility, VisibleScope}; use ruff_python_semantic::analyze::visibility::{Visibility, VisibleScope};
use ruff_python_semantic::node::NodeId;
use ruff_python_semantic::scope::ScopeId; use ruff_python_semantic::scope::ScopeId;
use crate::checkers::ast::AnnotationContext; use crate::checkers::ast::AnnotationContext;
use crate::docstrings::definition::Definition; use crate::docstrings::definition::Definition;
type Context<'a> = (ScopeId, Vec<RefEquality<'a, Stmt>>); /// A snapshot of the current scope and statement, which will be restored when visiting any
/// deferred definitions.
type Context<'a> = (ScopeId, Option<NodeId>);
/// A collection of AST nodes that are deferred for later analysis. /// A collection of AST nodes that are deferred for later analysis.
/// Used to, e.g., store functions, whose bodies shouldn't be analyzed until all /// Used to, e.g., store functions, whose bodies shouldn't be analyzed until all

View file

@ -17,8 +17,9 @@ use ruff_python_ast::source_code::{Indexer, Locator, Stylist};
use ruff_python_ast::types::{Node, RefEquality}; use ruff_python_ast::types::{Node, RefEquality};
use ruff_python_ast::typing::parse_type_annotation; use ruff_python_ast::typing::parse_type_annotation;
use ruff_python_ast::visitor::{walk_excepthandler, walk_pattern, Visitor}; use ruff_python_ast::visitor::{walk_excepthandler, walk_pattern, Visitor};
use ruff_python_ast::{branch_detection, cast, helpers, str, visitor}; use ruff_python_ast::{cast, helpers, str, visitor};
use ruff_python_semantic::analyze; use ruff_python_semantic::analyze;
use ruff_python_semantic::analyze::branch_detection;
use ruff_python_semantic::analyze::typing::{Callable, SubscriptKind}; use ruff_python_semantic::analyze::typing::{Callable, SubscriptKind};
use ruff_python_semantic::binding::{ use ruff_python_semantic::binding::{
Binding, BindingId, BindingKind, Exceptions, ExecutionContext, Export, FromImportation, Binding, BindingId, BindingKind, Exceptions, ExecutionContext, Export, FromImportation,
@ -175,7 +176,7 @@ where
'b: 'a, 'b: 'a,
{ {
fn visit_stmt(&mut self, stmt: &'b Stmt) { fn visit_stmt(&mut self, stmt: &'b Stmt) {
self.ctx.push_parent(stmt); self.ctx.push_stmt(stmt);
// Track whether we've seen docstrings, non-imports, etc. // Track whether we've seen docstrings, non-imports, etc.
match &stmt.node { match &stmt.node {
@ -196,7 +197,7 @@ where
self.ctx.futures_allowed = false; self.ctx.futures_allowed = false;
if !self.ctx.seen_import_boundary if !self.ctx.seen_import_boundary
&& !helpers::is_assignment_to_a_dunder(stmt) && !helpers::is_assignment_to_a_dunder(stmt)
&& !helpers::in_nested_block(self.ctx.parents.iter().rev().map(Into::into)) && !helpers::in_nested_block(self.ctx.parents())
{ {
self.ctx.seen_import_boundary = true; self.ctx.seen_import_boundary = true;
} }
@ -230,7 +231,7 @@ where
synthetic_usage: usage, synthetic_usage: usage,
typing_usage: None, typing_usage: None,
range: *range, range: *range,
source: Some(RefEquality(stmt)), source: Some(stmt),
context, context,
exceptions, exceptions,
}); });
@ -260,7 +261,7 @@ where
synthetic_usage: usage, synthetic_usage: usage,
typing_usage: None, typing_usage: None,
range: *range, range: *range,
source: Some(RefEquality(stmt)), source: Some(stmt),
context, context,
exceptions, exceptions,
}); });
@ -303,10 +304,9 @@ where
} }
StmtKind::Break => { StmtKind::Break => {
if self.settings.rules.enabled(Rule::BreakOutsideLoop) { if self.settings.rules.enabled(Rule::BreakOutsideLoop) {
if let Some(diagnostic) = pyflakes::rules::break_outside_loop( if let Some(diagnostic) =
stmt, pyflakes::rules::break_outside_loop(stmt, &mut self.ctx.parents().skip(1))
&mut self.ctx.parents.iter().rev().map(Into::into).skip(1), {
) {
self.diagnostics.push(diagnostic); self.diagnostics.push(diagnostic);
} }
} }
@ -315,7 +315,7 @@ where
if self.settings.rules.enabled(Rule::ContinueOutsideLoop) { if self.settings.rules.enabled(Rule::ContinueOutsideLoop) {
if let Some(diagnostic) = pyflakes::rules::continue_outside_loop( if let Some(diagnostic) = pyflakes::rules::continue_outside_loop(
stmt, stmt,
&mut self.ctx.parents.iter().rev().map(Into::into).skip(1), &mut self.ctx.parents().skip(1),
) { ) {
self.diagnostics.push(diagnostic); self.diagnostics.push(diagnostic);
} }
@ -688,7 +688,7 @@ where
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: stmt.range(), range: stmt.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -904,7 +904,7 @@ where
synthetic_usage: Some((self.ctx.scope_id, alias.range())), synthetic_usage: Some((self.ctx.scope_id, alias.range())),
typing_usage: None, typing_usage: None,
range: alias.range(), range: alias.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -934,7 +934,7 @@ where
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: alias.range(), range: alias.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -962,7 +962,7 @@ where
}, },
typing_usage: None, typing_usage: None,
range: alias.range(), range: alias.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -1222,7 +1222,7 @@ where
synthetic_usage: Some((self.ctx.scope_id, alias.range())), synthetic_usage: Some((self.ctx.scope_id, alias.range())),
typing_usage: None, typing_usage: None,
range: alias.range(), range: alias.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -1318,7 +1318,7 @@ where
}, },
typing_usage: None, typing_usage: None,
range: alias.range(), range: alias.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -1714,7 +1714,7 @@ where
if self.settings.rules.enabled(Rule::UnusedLoopControlVariable) { if self.settings.rules.enabled(Rule::UnusedLoopControlVariable) {
self.deferred self.deferred
.for_loops .for_loops
.push((stmt, (self.ctx.scope_id, self.ctx.parents.clone()))); .push((stmt, (self.ctx.scope_id, self.ctx.stmt_id)));
} }
if self if self
.settings .settings
@ -2003,7 +2003,7 @@ where
self.deferred.definitions.push(( self.deferred.definitions.push((
definition, definition,
scope.visibility, scope.visibility,
(self.ctx.scope_id, self.ctx.parents.clone()), (self.ctx.scope_id, self.ctx.stmt_id),
)); ));
self.ctx.visible_scope = scope; self.ctx.visible_scope = scope;
@ -2022,7 +2022,7 @@ where
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: stmt.range(), range: stmt.range(),
source: Some(RefEquality(stmt)), source: Some(stmt),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}); });
@ -2041,7 +2041,7 @@ where
self.deferred.functions.push(( self.deferred.functions.push((
stmt, stmt,
(self.ctx.scope_id, self.ctx.parents.clone()), (self.ctx.scope_id, self.ctx.stmt_id),
self.ctx.visible_scope, self.ctx.visible_scope,
)); ));
} }
@ -2066,7 +2066,7 @@ where
self.deferred.definitions.push(( self.deferred.definitions.push((
definition, definition,
scope.visibility, scope.visibility,
(self.ctx.scope_id, self.ctx.parents.clone()), (self.ctx.scope_id, self.ctx.stmt_id),
)); ));
self.ctx.visible_scope = scope; self.ctx.visible_scope = scope;
@ -2085,7 +2085,7 @@ where
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: stmt.range(), range: stmt.range(),
source: Some(RefEquality(stmt)), source: Some(stmt),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}); });
@ -2246,7 +2246,7 @@ where
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: stmt.range(), range: stmt.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -2255,7 +2255,7 @@ where
_ => {} _ => {}
} }
self.ctx.pop_parent(); self.ctx.pop_stmt();
} }
fn visit_annotation(&mut self, expr: &'b Expr) { fn visit_annotation(&mut self, expr: &'b Expr) {
@ -2281,13 +2281,13 @@ where
expr.range(), expr.range(),
value, value,
(self.ctx.in_annotation, self.ctx.in_type_checking_block), (self.ctx.in_annotation, self.ctx.in_type_checking_block),
(self.ctx.scope_id, self.ctx.parents.clone()), (self.ctx.scope_id, self.ctx.stmt_id),
)); ));
} else { } else {
self.deferred.type_definitions.push(( self.deferred.type_definitions.push((
expr, expr,
(self.ctx.in_annotation, self.ctx.in_type_checking_block), (self.ctx.in_annotation, self.ctx.in_type_checking_block),
(self.ctx.scope_id, self.ctx.parents.clone()), (self.ctx.scope_id, self.ctx.stmt_id),
)); ));
} }
return; return;
@ -3514,7 +3514,7 @@ where
expr.range(), expr.range(),
value, value,
(self.ctx.in_annotation, self.ctx.in_type_checking_block), (self.ctx.in_annotation, self.ctx.in_type_checking_block),
(self.ctx.scope_id, self.ctx.parents.clone()), (self.ctx.scope_id, self.ctx.stmt_id),
)); ));
} }
if self if self
@ -3637,7 +3637,7 @@ where
ExprKind::Lambda { .. } => { ExprKind::Lambda { .. } => {
self.deferred self.deferred
.lambdas .lambdas
.push((expr, (self.ctx.scope_id, self.ctx.parents.clone()))); .push((expr, (self.ctx.scope_id, self.ctx.stmt_id)));
} }
ExprKind::IfExp { test, body, orelse } => { ExprKind::IfExp { test, body, orelse } => {
visit_boolean_test!(self, test); visit_boolean_test!(self, test);
@ -4121,7 +4121,7 @@ where
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: arg.range(), range: arg.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -4165,7 +4165,7 @@ where
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: pattern.range(), range: pattern.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -4212,12 +4212,7 @@ impl<'a> Checker<'a> {
if !existing.kind.is_builtin() if !existing.kind.is_builtin()
&& existing.source.map_or(true, |left| { && existing.source.map_or(true, |left| {
binding.source.map_or(true, |right| { binding.source.map_or(true, |right| {
!branch_detection::different_forks( !branch_detection::different_forks(left, right, &self.ctx.stmts)
left,
right,
&self.ctx.depths,
&self.ctx.child_to_parent,
)
}) })
}) })
{ {
@ -4517,7 +4512,7 @@ impl<'a> Checker<'a> {
} }
fn handle_node_store(&mut self, id: &'a str, expr: &Expr) { fn handle_node_store(&mut self, id: &'a str, expr: &Expr) {
let parent = self.ctx.current_stmt().0; let parent = self.ctx.current_stmt();
if self.settings.rules.enabled(Rule::UndefinedLocal) { if self.settings.rules.enabled(Rule::UndefinedLocal) {
pyflakes::rules::undefined_local(self, id); pyflakes::rules::undefined_local(self, id);
@ -4576,7 +4571,7 @@ impl<'a> Checker<'a> {
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: expr.range(), range: expr.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -4596,7 +4591,7 @@ impl<'a> Checker<'a> {
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: expr.range(), range: expr.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -4613,7 +4608,7 @@ impl<'a> Checker<'a> {
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: expr.range(), range: expr.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -4696,7 +4691,7 @@ impl<'a> Checker<'a> {
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: expr.range(), range: expr.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -4718,7 +4713,7 @@ impl<'a> Checker<'a> {
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: expr.range(), range: expr.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -4734,7 +4729,7 @@ impl<'a> Checker<'a> {
synthetic_usage: None, synthetic_usage: None,
typing_usage: None, typing_usage: None,
range: expr.range(), range: expr.range(),
source: Some(*self.ctx.current_stmt()), source: Some(self.ctx.current_stmt()),
context: self.ctx.execution_context(), context: self.ctx.execution_context(),
exceptions: self.ctx.exceptions(), exceptions: self.ctx.exceptions(),
}, },
@ -4745,7 +4740,7 @@ impl<'a> Checker<'a> {
let ExprKind::Name { id, .. } = &expr.node else { let ExprKind::Name { id, .. } = &expr.node else {
return; return;
}; };
if helpers::on_conditional_branch(&mut self.ctx.parents.iter().rev().map(Into::into)) { if helpers::on_conditional_branch(&mut self.ctx.parents()) {
return; return;
} }
@ -4780,7 +4775,7 @@ impl<'a> Checker<'a> {
docstring, docstring,
}, },
self.ctx.visible_scope.visibility, self.ctx.visible_scope.visibility,
(self.ctx.scope_id, self.ctx.parents.clone()), (self.ctx.scope_id, self.ctx.stmt_id),
)); ));
docstring.is_some() docstring.is_some()
} }
@ -4788,11 +4783,11 @@ impl<'a> Checker<'a> {
fn check_deferred_type_definitions(&mut self) { fn check_deferred_type_definitions(&mut self) {
while !self.deferred.type_definitions.is_empty() { while !self.deferred.type_definitions.is_empty() {
let type_definitions = std::mem::take(&mut self.deferred.type_definitions); let type_definitions = std::mem::take(&mut self.deferred.type_definitions);
for (expr, (in_annotation, in_type_checking_block), (scope_id, parents)) in for (expr, (in_annotation, in_type_checking_block), (scope_id, node_id)) in
type_definitions type_definitions
{ {
self.ctx.scope_id = scope_id; self.ctx.scope_id = scope_id;
self.ctx.parents = parents; self.ctx.stmt_id = node_id;
self.ctx.in_annotation = in_annotation; self.ctx.in_annotation = in_annotation;
self.ctx.in_type_checking_block = in_type_checking_block; self.ctx.in_type_checking_block = in_type_checking_block;
self.ctx.in_type_definition = true; self.ctx.in_type_definition = true;
@ -4807,7 +4802,7 @@ impl<'a> Checker<'a> {
fn check_deferred_string_type_definitions(&mut self, allocator: &'a typed_arena::Arena<Expr>) { fn check_deferred_string_type_definitions(&mut self, allocator: &'a typed_arena::Arena<Expr>) {
while !self.deferred.string_type_definitions.is_empty() { while !self.deferred.string_type_definitions.is_empty() {
let type_definitions = std::mem::take(&mut self.deferred.string_type_definitions); let type_definitions = std::mem::take(&mut self.deferred.string_type_definitions);
for (range, value, (in_annotation, in_type_checking_block), (scope_id, parents)) in for (range, value, (in_annotation, in_type_checking_block), (scope_id, node_id)) in
type_definitions type_definitions
{ {
if let Ok((expr, kind)) = parse_type_annotation(value, range, self.locator) { if let Ok((expr, kind)) = parse_type_annotation(value, range, self.locator) {
@ -4825,7 +4820,7 @@ impl<'a> Checker<'a> {
let expr = allocator.alloc(expr); let expr = allocator.alloc(expr);
self.ctx.scope_id = scope_id; self.ctx.scope_id = scope_id;
self.ctx.parents = parents; self.ctx.stmt_id = node_id;
self.ctx.in_annotation = in_annotation; self.ctx.in_annotation = in_annotation;
self.ctx.in_type_checking_block = in_type_checking_block; self.ctx.in_type_checking_block = in_type_checking_block;
self.ctx.in_type_definition = true; self.ctx.in_type_definition = true;
@ -4854,10 +4849,9 @@ impl<'a> Checker<'a> {
fn check_deferred_functions(&mut self) { fn check_deferred_functions(&mut self) {
while !self.deferred.functions.is_empty() { while !self.deferred.functions.is_empty() {
let deferred_functions = std::mem::take(&mut self.deferred.functions); let deferred_functions = std::mem::take(&mut self.deferred.functions);
for (stmt, (scope_id, parents), visibility) in deferred_functions { for (stmt, (scope_id, node_id), visibility) in deferred_functions {
let parents_snapshot = parents.len();
self.ctx.scope_id = scope_id; self.ctx.scope_id = scope_id;
self.ctx.parents = parents; self.ctx.stmt_id = node_id;
self.ctx.visible_scope = visibility; self.ctx.visible_scope = visibility;
match &stmt.node { match &stmt.node {
@ -4871,10 +4865,7 @@ impl<'a> Checker<'a> {
} }
} }
let mut parents = std::mem::take(&mut self.ctx.parents); self.deferred.assignments.push((scope_id, node_id));
parents.truncate(parents_snapshot);
self.deferred.assignments.push((scope_id, parents));
} }
} }
} }
@ -4882,11 +4873,9 @@ impl<'a> Checker<'a> {
fn check_deferred_lambdas(&mut self) { fn check_deferred_lambdas(&mut self) {
while !self.deferred.lambdas.is_empty() { while !self.deferred.lambdas.is_empty() {
let lambdas = std::mem::take(&mut self.deferred.lambdas); let lambdas = std::mem::take(&mut self.deferred.lambdas);
for (expr, (scope_id, parents)) in lambdas { for (expr, (scope_id, node_id)) in lambdas {
let parents_snapshot = parents.len();
self.ctx.scope_id = scope_id; self.ctx.scope_id = scope_id;
self.ctx.parents = parents; self.ctx.stmt_id = node_id;
if let ExprKind::Lambda { args, body } = &expr.node { if let ExprKind::Lambda { args, body } = &expr.node {
self.visit_arguments(args); self.visit_arguments(args);
@ -4895,9 +4884,7 @@ impl<'a> Checker<'a> {
unreachable!("Expected ExprKind::Lambda"); unreachable!("Expected ExprKind::Lambda");
} }
let mut parents = std::mem::take(&mut self.ctx.parents); self.deferred.assignments.push((scope_id, node_id));
parents.truncate(parents_snapshot);
self.deferred.assignments.push((scope_id, parents));
} }
} }
} }
@ -4942,9 +4929,9 @@ impl<'a> Checker<'a> {
while !self.deferred.for_loops.is_empty() { while !self.deferred.for_loops.is_empty() {
let for_loops = std::mem::take(&mut self.deferred.for_loops); let for_loops = std::mem::take(&mut self.deferred.for_loops);
for (stmt, (scope_id, parents)) in for_loops { for (stmt, (scope_id, node_id)) in for_loops {
self.ctx.scope_id = scope_id; self.ctx.scope_id = scope_id;
self.ctx.parents = parents; self.ctx.stmt_id = node_id;
if let StmtKind::For { target, body, .. } if let StmtKind::For { target, body, .. }
| StmtKind::AsyncFor { target, body, .. } = &stmt.node | StmtKind::AsyncFor { target, body, .. } = &stmt.node
@ -5216,9 +5203,9 @@ impl<'a> Checker<'a> {
// Collect all unused imports by location. (Multiple unused imports at the same // Collect all unused imports by location. (Multiple unused imports at the same
// location indicates an `import from`.) // location indicates an `import from`.)
type UnusedImport<'a> = (&'a str, &'a TextRange); type UnusedImport<'a> = (&'a str, &'a TextRange);
type BindingContext<'a, 'b> = ( type BindingContext<'a> = (
&'a RefEquality<'b, Stmt>, RefEquality<'a, Stmt>,
Option<&'a RefEquality<'b, Stmt>>, Option<RefEquality<'a, Stmt>>,
Exceptions, Exceptions,
); );
@ -5245,10 +5232,9 @@ impl<'a> Checker<'a> {
continue; continue;
} }
let defined_by = binding.source.as_ref().unwrap(); let child = binding.source.unwrap();
let defined_in = self.ctx.child_to_parent.get(defined_by); let parent = self.ctx.stmts.parent(child);
let exceptions = binding.exceptions; let exceptions = binding.exceptions;
let child: &Stmt = defined_by.into();
let diagnostic_offset = binding.range.start(); let diagnostic_offset = binding.range.start();
let parent_offset = if matches!(child.node, StmtKind::ImportFrom { .. }) { let parent_offset = if matches!(child.node, StmtKind::ImportFrom { .. }) {
@ -5263,12 +5249,12 @@ impl<'a> Checker<'a> {
}) })
{ {
ignored ignored
.entry((defined_by, defined_in, exceptions)) .entry((RefEquality(child), parent.map(RefEquality), exceptions))
.or_default() .or_default()
.push((full_name, &binding.range)); .push((full_name, &binding.range));
} else { } else {
unused unused
.entry((defined_by, defined_in, exceptions)) .entry((RefEquality(child), parent.map(RefEquality), exceptions))
.or_default() .or_default()
.push((full_name, &binding.range)); .push((full_name, &binding.range));
} }
@ -5299,7 +5285,7 @@ impl<'a> Checker<'a> {
) { ) {
Ok(fix) => { Ok(fix) => {
if fix.is_deletion() || fix.content() == Some("pass") { if fix.is_deletion() || fix.content() == Some("pass") {
self.deletions.insert(*defined_by); self.deletions.insert(defined_by);
} }
Some(fix) Some(fix)
} }
@ -5336,11 +5322,10 @@ impl<'a> Checker<'a> {
diagnostics.push(diagnostic); diagnostics.push(diagnostic);
} }
} }
for ((defined_by, .., exceptions), unused_imports) in ignored for ((child, .., exceptions), unused_imports) in ignored
.into_iter() .into_iter()
.sorted_by_key(|((defined_by, ..), ..)| defined_by.start()) .sorted_by_key(|((defined_by, ..), ..)| defined_by.start())
{ {
let child: &Stmt = defined_by.into();
let multiple = unused_imports.len() > 1; let multiple = unused_imports.len() > 1;
let in_except_handler = exceptions let in_except_handler = exceptions
.intersects(Exceptions::MODULE_NOT_FOUND_ERROR | Exceptions::IMPORT_ERROR); .intersects(Exceptions::MODULE_NOT_FOUND_ERROR | Exceptions::IMPORT_ERROR);
@ -5436,9 +5421,9 @@ impl<'a> Checker<'a> {
let mut overloaded_name: Option<String> = None; let mut overloaded_name: Option<String> = None;
while !self.deferred.definitions.is_empty() { while !self.deferred.definitions.is_empty() {
let definitions = std::mem::take(&mut self.deferred.definitions); let definitions = std::mem::take(&mut self.deferred.definitions);
for (definition, visibility, (scope_id, parents)) in definitions { for (definition, visibility, (scope_id, node_id)) in definitions {
self.ctx.scope_id = scope_id; self.ctx.scope_id = scope_id;
self.ctx.parents = parents; self.ctx.stmt_id = node_id;
// flake8-annotations // flake8-annotations
if enforce_annotations { if enforce_annotations {

View file

@ -1,6 +1,7 @@
use crate::registry::{Linter, Rule};
use std::fmt::Formatter; use std::fmt::Formatter;
use crate::registry::{Linter, Rule};
#[derive(PartialEq, Eq, PartialOrd, Ord)] #[derive(PartialEq, Eq, PartialOrd, Ord)]
pub struct NoqaCode(&'static str, &'static str); pub struct NoqaCode(&'static str, &'static str);

View file

@ -168,10 +168,9 @@ pub fn unused_loop_control_variable(
let scope = checker.ctx.scope(); let scope = checker.ctx.scope();
let binding = scope.bindings_for_name(name).find_map(|index| { let binding = scope.bindings_for_name(name).find_map(|index| {
let binding = &checker.ctx.bindings[*index]; let binding = &checker.ctx.bindings[*index];
binding binding.source.and_then(|source| {
.source (RefEquality(source) == RefEquality(stmt)).then_some(binding)
.as_ref() })
.and_then(|source| (source == &RefEquality(stmt)).then_some(binding))
}); });
if let Some(binding) = binding { if let Some(binding) = binding {
if binding.kind.is_loop_var() { if binding.kind.is_loop_var() {

View file

@ -33,7 +33,7 @@ fn match_async_exit_stack(checker: &Checker) -> bool {
if attr != "enter_async_context" { if attr != "enter_async_context" {
return false; return false;
} }
for parent in &checker.ctx.parents { for parent in checker.ctx.parents() {
if let StmtKind::With { items, .. } = &parent.node { if let StmtKind::With { items, .. } = &parent.node {
for item in items { for item in items {
if let ExprKind::Call { func, .. } = &item.context_expr.node { if let ExprKind::Call { func, .. } = &item.context_expr.node {
@ -68,7 +68,7 @@ fn match_exit_stack(checker: &Checker) -> bool {
if attr != "enter_context" { if attr != "enter_context" {
return false; return false;
} }
for parent in &checker.ctx.parents { for parent in checker.ctx.parents() {
if let StmtKind::With { items, .. } = &parent.node { if let StmtKind::With { items, .. } = &parent.node {
for item in items { for item in items {
if let ExprKind::Call { func, .. } = &item.context_expr.node { if let ExprKind::Call { func, .. } = &item.context_expr.node {

View file

@ -60,11 +60,7 @@ pub fn empty_type_checking_block<'a, 'b>(
// Delete the entire type-checking block. // Delete the entire type-checking block.
if checker.patch(diagnostic.kind.rule()) { if checker.patch(diagnostic.kind.rule()) {
let parent = checker let parent = checker.ctx.stmts.parent(stmt);
.ctx
.child_to_parent
.get(&RefEquality(stmt))
.map(Into::into);
let deleted: Vec<&Stmt> = checker.deletions.iter().map(Into::into).collect(); let deleted: Vec<&Stmt> = checker.deletions.iter().map(Into::into).collect();
match delete_stmt( match delete_stmt(
stmt, stmt,

View file

@ -214,11 +214,7 @@ fn remove_unused_variable(
)) ))
} else { } else {
// If (e.g.) assigning to a constant (`x = 1`), delete the entire statement. // If (e.g.) assigning to a constant (`x = 1`), delete the entire statement.
let parent = checker let parent = checker.ctx.stmts.parent(stmt);
.ctx
.child_to_parent
.get(&RefEquality(stmt))
.map(Into::into);
let deleted: Vec<&Stmt> = checker.deletions.iter().map(Into::into).collect(); let deleted: Vec<&Stmt> = checker.deletions.iter().map(Into::into).collect();
match delete_stmt( match delete_stmt(
stmt, stmt,
@ -259,11 +255,7 @@ fn remove_unused_variable(
)) ))
} else { } else {
// If assigning to a constant (`x = 1`), delete the entire statement. // If assigning to a constant (`x = 1`), delete the entire statement.
let parent = checker let parent = checker.ctx.stmts.parent(stmt);
.ctx
.child_to_parent
.get(&RefEquality(stmt))
.map(Into::into);
let deleted: Vec<&Stmt> = checker.deletions.iter().map(Into::into).collect(); let deleted: Vec<&Stmt> = checker.deletions.iter().map(Into::into).collect();
match delete_stmt( match delete_stmt(
stmt, stmt,
@ -336,7 +328,7 @@ pub fn unused_variable(checker: &mut Checker, scope: ScopeId) {
binding.range, binding.range,
); );
if checker.patch(diagnostic.kind.rule()) { if checker.patch(diagnostic.kind.rule()) {
if let Some(stmt) = binding.source.as_ref().map(Into::into) { if let Some(stmt) = binding.source {
if let Some((kind, fix)) = remove_unused_variable(stmt, binding.range, checker) if let Some((kind, fix)) = remove_unused_variable(stmt, binding.range, checker)
{ {
if matches!(kind, DeletionKind::Whole) { if matches!(kind, DeletionKind::Whole) {

View file

@ -57,9 +57,7 @@ pub fn global_statement(checker: &mut Checker, name: &str) {
if binding.kind.is_global() { if binding.kind.is_global() {
let source: &Stmt = binding let source: &Stmt = binding
.source .source
.as_ref() .expect("`global` bindings should always have a `source`");
.expect("`global` bindings should always have a `source`")
.into();
let diagnostic = Diagnostic::new( let diagnostic = Diagnostic::new(
GlobalStatement { GlobalStatement {
name: name.to_string(), name: name.to_string(),

View file

@ -164,9 +164,9 @@ fn fix_py2_block(
let defined_by = checker.ctx.current_stmt(); let defined_by = checker.ctx.current_stmt();
let defined_in = checker.ctx.current_stmt_parent(); let defined_in = checker.ctx.current_stmt_parent();
return match delete_stmt( return match delete_stmt(
defined_by.into(), defined_by,
if block.starter == Tok::If { if block.starter == Tok::If {
defined_in.map(Into::into) defined_in
} else { } else {
None None
}, },
@ -176,7 +176,7 @@ fn fix_py2_block(
checker.stylist, checker.stylist,
) { ) {
Ok(fix) => { Ok(fix) => {
checker.deletions.insert(RefEquality(defined_by.into())); checker.deletions.insert(RefEquality(defined_by));
Some(fix) Some(fix)
} }
Err(err) => { Err(err) => {

View file

@ -1,4 +1,4 @@
use rustpython_parser::ast::{ArgData, Expr, ExprKind, Stmt, StmtKind}; use rustpython_parser::ast::{ArgData, Expr, ExprKind, StmtKind};
use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
@ -39,14 +39,13 @@ pub fn super_call_with_parameters(checker: &mut Checker, expr: &Expr, func: &Exp
return; return;
} }
let scope = checker.ctx.scope(); let scope = checker.ctx.scope();
let parents: Vec<&Stmt> = checker.ctx.parents.iter().map(Into::into).collect();
// Check: are we in a Function scope? // Check: are we in a Function scope?
if !matches!(scope.kind, ScopeKind::Function { .. }) { if !matches!(scope.kind, ScopeKind::Function { .. }) {
return; return;
} }
let mut parents = parents.iter().rev(); let mut parents = checker.ctx.parents();
// For a `super` invocation to be unnecessary, the first argument needs to match // For a `super` invocation to be unnecessary, the first argument needs to match
// the enclosing class, and the second argument needs to match the first // the enclosing class, and the second argument needs to match the first

View file

@ -4,6 +4,7 @@ use rustpython_parser::ast::{Alias, AliasData, Located, Stmt};
use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::types::RefEquality;
use crate::autofix; use crate::autofix;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -113,8 +114,8 @@ pub fn unnecessary_builtin_import(
.collect(); .collect();
match autofix::actions::remove_unused_imports( match autofix::actions::remove_unused_imports(
unused_imports.iter().map(String::as_str), unused_imports.iter().map(String::as_str),
defined_by.into(), defined_by,
defined_in.map(Into::into), defined_in,
&deleted, &deleted,
checker.locator, checker.locator,
checker.indexer, checker.indexer,
@ -122,7 +123,7 @@ pub fn unnecessary_builtin_import(
) { ) {
Ok(fix) => { Ok(fix) => {
if fix.is_deletion() || fix.content() == Some("pass") { if fix.is_deletion() || fix.content() == Some("pass") {
checker.deletions.insert(*defined_by); checker.deletions.insert(RefEquality(defined_by));
} }
diagnostic.set_fix(fix); diagnostic.set_fix(fix);
} }

View file

@ -4,6 +4,7 @@ use rustpython_parser::ast::{Alias, AliasData, Located, Stmt};
use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::types::RefEquality;
use crate::autofix; use crate::autofix;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -93,8 +94,8 @@ pub fn unnecessary_future_import(checker: &mut Checker, stmt: &Stmt, names: &[Lo
.collect(); .collect();
match autofix::actions::remove_unused_imports( match autofix::actions::remove_unused_imports(
unused_imports.iter().map(String::as_str), unused_imports.iter().map(String::as_str),
defined_by.into(), defined_by,
defined_in.map(Into::into), defined_in,
&deleted, &deleted,
checker.locator, checker.locator,
checker.indexer, checker.indexer,
@ -102,7 +103,7 @@ pub fn unnecessary_future_import(checker: &mut Checker, stmt: &Stmt, names: &[Lo
) { ) {
Ok(fix) => { Ok(fix) => {
if fix.is_deletion() || fix.content() == Some("pass") { if fix.is_deletion() || fix.content() == Some("pass") {
checker.deletions.insert(*defined_by); checker.deletions.insert(RefEquality(defined_by));
} }
diagnostic.set_fix(fix); diagnostic.set_fix(fix);
} }

View file

@ -4,6 +4,7 @@ use rustpython_parser::ast::{Expr, ExprKind, Stmt};
use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::types::RefEquality;
use crate::autofix::actions; use crate::autofix::actions;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -53,8 +54,8 @@ pub fn useless_metaclass_type(checker: &mut Checker, stmt: &Stmt, value: &Expr,
let defined_by = checker.ctx.current_stmt(); let defined_by = checker.ctx.current_stmt();
let defined_in = checker.ctx.current_stmt_parent(); let defined_in = checker.ctx.current_stmt_parent();
match actions::delete_stmt( match actions::delete_stmt(
defined_by.into(), defined_by,
defined_in.map(Into::into), defined_in,
&deleted, &deleted,
checker.locator, checker.locator,
checker.indexer, checker.indexer,
@ -62,7 +63,7 @@ pub fn useless_metaclass_type(checker: &mut Checker, stmt: &Stmt, value: &Expr,
) { ) {
Ok(fix) => { Ok(fix) => {
if fix.is_deletion() || fix.content() == Some("pass") { if fix.is_deletion() || fix.content() == Some("pass") {
checker.deletions.insert(*defined_by); checker.deletions.insert(RefEquality(defined_by));
} }
diagnostic.set_fix(fix); diagnostic.set_fix(fix);
} }

View file

@ -1,113 +0,0 @@
use std::cmp::Ordering;
use rustc_hash::FxHashMap;
use rustpython_parser::ast::ExcepthandlerKind::ExceptHandler;
use rustpython_parser::ast::{Stmt, StmtKind};
use crate::types::RefEquality;
/// Return the common ancestor of `left` and `right` below `stop`, or `None`.
fn common_ancestor<'a>(
left: RefEquality<'a, Stmt>,
right: RefEquality<'a, Stmt>,
stop: Option<RefEquality<'a, Stmt>>,
depths: &'a FxHashMap<RefEquality<'a, Stmt>, usize>,
child_to_parent: &'a FxHashMap<RefEquality<'a, Stmt>, RefEquality<'a, Stmt>>,
) -> Option<RefEquality<'a, Stmt>> {
if Some(left) == stop || Some(right) == stop {
return None;
}
if left == right {
return Some(left);
}
let left_depth = depths.get(&left)?;
let right_depth = depths.get(&right)?;
match left_depth.cmp(right_depth) {
Ordering::Less => common_ancestor(
left,
*child_to_parent.get(&right)?,
stop,
depths,
child_to_parent,
),
Ordering::Equal => common_ancestor(
*child_to_parent.get(&left)?,
*child_to_parent.get(&right)?,
stop,
depths,
child_to_parent,
),
Ordering::Greater => common_ancestor(
*child_to_parent.get(&left)?,
right,
stop,
depths,
child_to_parent,
),
}
}
/// Return the alternative branches for a given node.
fn alternatives(stmt: RefEquality<Stmt>) -> Vec<Vec<RefEquality<Stmt>>> {
match &stmt.as_ref().node {
StmtKind::If { body, .. } => vec![body.iter().map(RefEquality).collect()],
StmtKind::Try {
body,
handlers,
orelse,
..
}
| StmtKind::TryStar {
body,
handlers,
orelse,
..
} => vec![body.iter().chain(orelse.iter()).map(RefEquality).collect()]
.into_iter()
.chain(handlers.iter().map(|handler| {
let ExceptHandler { body, .. } = &handler.node;
body.iter().map(RefEquality).collect()
}))
.collect(),
StmtKind::Match { cases, .. } => cases
.iter()
.map(|case| case.body.iter().map(RefEquality).collect())
.collect(),
_ => vec![],
}
}
/// Return `true` if `stmt` is a descendent of any of the nodes in `ancestors`.
fn descendant_of<'a>(
stmt: RefEquality<'a, Stmt>,
ancestors: &[RefEquality<'a, Stmt>],
stop: RefEquality<'a, Stmt>,
depths: &FxHashMap<RefEquality<'a, Stmt>, usize>,
child_to_parent: &FxHashMap<RefEquality<'a, Stmt>, RefEquality<'a, Stmt>>,
) -> bool {
ancestors.iter().any(|ancestor| {
common_ancestor(stmt, *ancestor, Some(stop), depths, child_to_parent).is_some()
})
}
/// Return `true` if `left` and `right` are on different branches of an `if` or
/// `try` statement.
pub fn different_forks<'a>(
left: RefEquality<'a, Stmt>,
right: RefEquality<'a, Stmt>,
depths: &FxHashMap<RefEquality<'a, Stmt>, usize>,
child_to_parent: &FxHashMap<RefEquality<'a, Stmt>, RefEquality<'a, Stmt>>,
) -> bool {
if let Some(ancestor) = common_ancestor(left, right, None, depths, child_to_parent) {
for items in alternatives(ancestor) {
let l = descendant_of(left, &items, ancestor, depths, child_to_parent);
let r = descendant_of(right, &items, ancestor, depths, child_to_parent);
if l ^ r {
return true;
}
}
}
false
}

View file

@ -1,5 +1,4 @@
pub mod all; pub mod all;
pub mod branch_detection;
pub mod call_path; pub mod call_path;
pub mod cast; pub mod cast;
pub mod comparable; pub mod comparable;

View file

@ -70,3 +70,15 @@ impl<'a> From<&RefEquality<'a, Expr>> for &'a Expr {
r.0 r.0
} }
} }
impl<'a> From<RefEquality<'a, Stmt>> for &'a Stmt {
fn from(r: RefEquality<'a, Stmt>) -> Self {
r.0
}
}
impl<'a> From<RefEquality<'a, Expr>> for &'a Expr {
fn from(r: RefEquality<'a, Expr>) -> Self {
r.0
}
}

View file

@ -0,0 +1,104 @@
use std::cmp::Ordering;
use ruff_python_ast::types::RefEquality;
use rustpython_parser::ast::ExcepthandlerKind::ExceptHandler;
use rustpython_parser::ast::{Stmt, StmtKind};
use crate::node::Nodes;
/// Return the common ancestor of `left` and `right` below `stop`, or `None`.
fn common_ancestor<'a>(
left: &'a Stmt,
right: &'a Stmt,
stop: Option<&'a Stmt>,
node_tree: &Nodes<'a>,
) -> Option<&'a Stmt> {
if stop.map_or(false, |stop| {
RefEquality(left) == RefEquality(stop) || RefEquality(right) == RefEquality(stop)
}) {
return None;
}
if RefEquality(left) == RefEquality(right) {
return Some(left);
}
let left_id = node_tree.node_id(left)?;
let right_id = node_tree.node_id(right)?;
let left_depth = node_tree.depth(left_id);
let right_depth = node_tree.depth(right_id);
match left_depth.cmp(&right_depth) {
Ordering::Less => {
let right_id = node_tree.parent_id(right_id)?;
common_ancestor(left, node_tree[right_id], stop, node_tree)
}
Ordering::Equal => {
let left_id = node_tree.parent_id(left_id)?;
let right_id = node_tree.parent_id(right_id)?;
common_ancestor(node_tree[left_id], node_tree[right_id], stop, node_tree)
}
Ordering::Greater => {
let left_id = node_tree.parent_id(left_id)?;
common_ancestor(node_tree[left_id], right, stop, node_tree)
}
}
}
/// Return the alternative branches for a given node.
fn alternatives(stmt: &Stmt) -> Vec<Vec<&Stmt>> {
match &stmt.node {
StmtKind::If { body, .. } => vec![body.iter().collect()],
StmtKind::Try {
body,
handlers,
orelse,
..
}
| StmtKind::TryStar {
body,
handlers,
orelse,
..
} => vec![body.iter().chain(orelse.iter()).collect()]
.into_iter()
.chain(handlers.iter().map(|handler| {
let ExceptHandler { body, .. } = &handler.node;
body.iter().collect()
}))
.collect(),
StmtKind::Match { cases, .. } => cases
.iter()
.map(|case| case.body.iter().collect())
.collect(),
_ => vec![],
}
}
/// Return `true` if `stmt` is a descendent of any of the nodes in `ancestors`.
fn descendant_of<'a>(
stmt: &'a Stmt,
ancestors: &[&'a Stmt],
stop: &'a Stmt,
node_tree: &Nodes<'a>,
) -> bool {
ancestors
.iter()
.any(|ancestor| common_ancestor(stmt, ancestor, Some(stop), node_tree).is_some())
}
/// Return `true` if `left` and `right` are on different branches of an `if` or
/// `try` statement.
pub fn different_forks<'a>(left: &'a Stmt, right: &'a Stmt, node_tree: &Nodes<'a>) -> bool {
if let Some(ancestor) = common_ancestor(left, right, None, node_tree) {
for items in alternatives(ancestor) {
let l = descendant_of(left, &items, ancestor, node_tree);
let r = descendant_of(right, &items, ancestor, node_tree);
if l ^ r {
return true;
}
}
}
false
}

View file

@ -1,3 +1,4 @@
pub mod branch_detection;
pub mod function_type; pub mod function_type;
pub mod logging; pub mod logging;
pub mod typing; pub mod typing;

View file

@ -5,8 +5,6 @@ use bitflags::bitflags;
use ruff_text_size::TextRange; use ruff_text_size::TextRange;
use rustpython_parser::ast::Stmt; use rustpython_parser::ast::Stmt;
use ruff_python_ast::types::RefEquality;
use crate::scope::ScopeId; use crate::scope::ScopeId;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -16,7 +14,7 @@ pub struct Binding<'a> {
/// The context in which the binding was created. /// The context in which the binding was created.
pub context: ExecutionContext, pub context: ExecutionContext,
/// The statement in which the [`Binding`] was defined. /// The statement in which the [`Binding`] was defined.
pub source: Option<RefEquality<'a, Stmt>>, pub source: Option<&'a Stmt>,
/// Tuple of (scope index, range) indicating the scope and range at which /// Tuple of (scope index, range) indicating the scope and range at which
/// the binding was last used in a runtime context. /// the binding was last used in a runtime context.
pub runtime_usage: Option<(ScopeId, TextRange)>, pub runtime_usage: Option<(ScopeId, TextRange)>,

View file

@ -1,7 +1,6 @@
use std::path::Path; use std::path::Path;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use rustc_hash::FxHashMap;
use rustpython_parser::ast::{Expr, Stmt}; use rustpython_parser::ast::{Expr, Stmt};
use smallvec::smallvec; use smallvec::smallvec;
@ -17,26 +16,26 @@ use crate::binding::{
Binding, BindingId, BindingKind, Bindings, Exceptions, ExecutionContext, FromImportation, Binding, BindingId, BindingKind, Bindings, Exceptions, ExecutionContext, FromImportation,
Importation, SubmoduleImportation, Importation, SubmoduleImportation,
}; };
use crate::node::{NodeId, Nodes};
use crate::scope::{Scope, ScopeId, ScopeKind, Scopes}; use crate::scope::{Scope, ScopeId, ScopeKind, Scopes};
#[allow(clippy::struct_excessive_bools)] #[allow(clippy::struct_excessive_bools)]
pub struct Context<'a> { pub struct Context<'a> {
pub typing_modules: &'a [String], pub typing_modules: &'a [String],
pub module_path: Option<Vec<String>>, pub module_path: Option<Vec<String>>,
// Retain all scopes and parent nodes, along with a stack of indices to track which are active // Stack of all visited statements, along with the identifier of the current statement.
// at various points in time. pub stmts: Nodes<'a>,
pub parents: Vec<RefEquality<'a, Stmt>>, pub stmt_id: Option<NodeId>,
pub depths: FxHashMap<RefEquality<'a, Stmt>, usize>, // Stack of all scopes, along with the identifier of the current scope.
pub child_to_parent: FxHashMap<RefEquality<'a, Stmt>, RefEquality<'a, Stmt>>, pub scopes: Scopes<'a>,
pub scope_id: ScopeId,
pub dead_scopes: Vec<ScopeId>,
// A stack of all bindings created in any scope, at any point in execution. // A stack of all bindings created in any scope, at any point in execution.
pub bindings: Bindings<'a>, pub bindings: Bindings<'a>,
// Map from binding index to indexes of bindings that shadow it in other scopes. // Map from binding index to indexes of bindings that shadow it in other scopes.
pub shadowed_bindings: pub shadowed_bindings:
std::collections::HashMap<BindingId, Vec<BindingId>, BuildNoHashHasher<BindingId>>, std::collections::HashMap<BindingId, Vec<BindingId>, BuildNoHashHasher<BindingId>>,
pub exprs: Vec<RefEquality<'a, Expr>>, pub exprs: Vec<RefEquality<'a, Expr>>,
pub scopes: Scopes<'a>,
pub scope_id: ScopeId,
pub dead_scopes: Vec<ScopeId>,
// Body iteration; used to peek at siblings. // Body iteration; used to peek at siblings.
pub body: &'a [Stmt], pub body: &'a [Stmt],
pub body_index: usize, pub body_index: usize,
@ -68,15 +67,14 @@ impl<'a> Context<'a> {
Self { Self {
typing_modules, typing_modules,
module_path, module_path,
parents: Vec::default(), stmts: Nodes::default(),
depths: FxHashMap::default(), stmt_id: None,
child_to_parent: FxHashMap::default(),
bindings: Bindings::default(),
shadowed_bindings: IntMap::default(),
exprs: Vec::default(),
scopes: Scopes::default(), scopes: Scopes::default(),
scope_id: ScopeId::global(), scope_id: ScopeId::global(),
dead_scopes: Vec::default(), dead_scopes: Vec::default(),
bindings: Bindings::default(),
shadowed_bindings: IntMap::default(),
exprs: Vec::default(),
body: &[], body: &[],
body_index: 0, body_index: 0,
visible_scope: VisibleScope { visible_scope: VisibleScope {
@ -254,10 +252,7 @@ impl<'a> Context<'a> {
.take(scope_index) .take(scope_index)
.all(|scope| scope.get(name).is_none()) .all(|scope| scope.get(name).is_none())
{ {
return Some(( return Some((binding.source.unwrap(), format!("{name}.{member}")));
binding.source.as_ref().unwrap().into(),
format!("{name}.{member}"),
));
} }
} }
} }
@ -273,10 +268,7 @@ impl<'a> Context<'a> {
.take(scope_index) .take(scope_index)
.all(|scope| scope.get(name).is_none()) .all(|scope| scope.get(name).is_none())
{ {
return Some(( return Some((binding.source.unwrap(), (*name).to_string()));
binding.source.as_ref().unwrap().into(),
(*name).to_string(),
));
} }
} }
} }
@ -291,10 +283,7 @@ impl<'a> Context<'a> {
.take(scope_index) .take(scope_index)
.all(|scope| scope.get(name).is_none()) .all(|scope| scope.get(name).is_none())
{ {
return Some(( return Some((binding.source.unwrap(), format!("{name}.{member}")));
binding.source.as_ref().unwrap().into(),
format!("{name}.{member}"),
));
} }
} }
} }
@ -306,18 +295,15 @@ impl<'a> Context<'a> {
}) })
} }
pub fn push_parent(&mut self, parent: &'a Stmt) { /// Push a [`Stmt`] onto the stack.
let num_existing = self.parents.len(); pub fn push_stmt(&mut self, stmt: &'a Stmt) {
self.parents.push(RefEquality(parent)); self.stmt_id = Some(self.stmts.insert(stmt, self.stmt_id));
self.depths.insert(self.parents[num_existing], num_existing);
if num_existing > 0 {
self.child_to_parent
.insert(self.parents[num_existing], self.parents[num_existing - 1]);
}
} }
pub fn pop_parent(&mut self) { /// Pop the current [`Stmt`] off the stack.
self.parents.pop().expect("Attempted to pop without parent"); pub fn pop_stmt(&mut self) {
let node_id = self.stmt_id.expect("Attempted to pop without statement");
self.stmt_id = self.stmts.parent_id(node_id);
} }
pub fn push_expr(&mut self, expr: &'a Expr) { pub fn push_expr(&mut self, expr: &'a Expr) {
@ -345,13 +331,16 @@ impl<'a> Context<'a> {
} }
/// Return the current `Stmt`. /// Return the current `Stmt`.
pub fn current_stmt(&self) -> &RefEquality<'a, Stmt> { pub fn current_stmt(&self) -> &'a Stmt {
self.parents.iter().rev().next().expect("No parent found") let node_id = self.stmt_id.expect("No current statement");
self.stmts[node_id]
} }
/// Return the parent `Stmt` of the current `Stmt`, if any. /// Return the parent `Stmt` of the current `Stmt`, if any.
pub fn current_stmt_parent(&self) -> Option<&RefEquality<'a, Stmt>> { pub fn current_stmt_parent(&self) -> Option<&'a Stmt> {
self.parents.iter().rev().nth(1) let node_id = self.stmt_id.expect("No current statement");
let parent_id = self.stmts.parent_id(node_id)?;
Some(self.stmts[parent_id])
} }
/// Return the parent `Expr` of the current `Expr`. /// Return the parent `Expr` of the current `Expr`.
@ -399,6 +388,11 @@ impl<'a> Context<'a> {
self.scopes.ancestors(self.scope_id) self.scopes.ancestors(self.scope_id)
} }
pub fn parents(&self) -> impl Iterator<Item = &Stmt> + '_ {
let node_id = self.stmt_id.expect("No current statement");
self.stmts.ancestor_ids(node_id).map(|id| self.stmts[id])
}
/// Returns `true` if the context is in an exception handler. /// Returns `true` if the context is in an exception handler.
pub const fn in_exception_handler(&self) -> bool { pub const fn in_exception_handler(&self) -> bool {
self.in_exception_handler self.in_exception_handler

View file

@ -1,4 +1,5 @@
pub mod analyze; pub mod analyze;
pub mod binding; pub mod binding;
pub mod context; pub mod context;
pub mod node;
pub mod scope; pub mod scope;

View file

@ -0,0 +1,112 @@
use std::num::{NonZeroU32, TryFromIntError};
use std::ops::{Index, IndexMut};
use rustc_hash::FxHashMap;
use rustpython_parser::ast::Stmt;
use ruff_python_ast::types::RefEquality;
/// Id uniquely identifying a statement in a program.
///
/// Using a `u32` is sufficient because Ruff only supports parsing documents with a size of max `u32::max`
/// and it is impossible to have more statements than characters in the file. We use a `NonZeroU32` to
/// take advantage of memory layout optimizations.
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct NodeId(NonZeroU32);
/// Convert a `usize` to a `NodeId` (by adding 1 to the value, and casting to `NonZeroU32`).
impl TryFrom<usize> for NodeId {
type Error = TryFromIntError;
fn try_from(value: usize) -> Result<Self, Self::Error> {
Ok(Self(NonZeroU32::try_from(u32::try_from(value)? + 1)?))
}
}
/// Convert a `NodeId` to a `usize` (by subtracting 1 from the value, and casting to `usize`).
impl From<NodeId> for usize {
fn from(value: NodeId) -> Self {
value.0.get() as usize - 1
}
}
#[derive(Debug)]
struct Node<'a> {
/// The statement this node represents.
stmt: &'a Stmt,
/// The ID of the parent of this node, if any.
parent: Option<NodeId>,
/// The depth of this node in the tree.
depth: u32,
}
/// The nodes of a program indexed by [`NodeId`]
#[derive(Debug, Default)]
pub struct Nodes<'a> {
nodes: Vec<Node<'a>>,
node_to_id: FxHashMap<RefEquality<'a, Stmt>, NodeId>,
}
impl<'a> Nodes<'a> {
/// Inserts a new node into the node tree and returns its unique id.
///
/// Panics if a node with the same pointer already exists.
pub fn insert(&mut self, stmt: &'a Stmt, parent: Option<NodeId>) -> NodeId {
let next_id = NodeId::try_from(self.nodes.len()).unwrap();
if let Some(existing_id) = self.node_to_id.insert(RefEquality(stmt), next_id) {
panic!("Node already exists with id {existing_id:?}");
}
self.nodes.push(Node {
stmt,
parent,
depth: parent.map_or(0, |parent| self.nodes[usize::from(parent)].depth + 1),
});
next_id
}
/// Returns the [`NodeId`] of the given node.
#[inline]
pub fn node_id(&self, node: &'a Stmt) -> Option<NodeId> {
self.node_to_id.get(&RefEquality(node)).copied()
}
/// Return the [`NodeId`] of the parent node.
#[inline]
pub fn parent_id(&self, node_id: NodeId) -> Option<NodeId> {
self.nodes[usize::from(node_id)].parent
}
/// Return the depth of the node.
#[inline]
pub fn depth(&self, node_id: NodeId) -> u32 {
self.nodes[usize::from(node_id)].depth
}
/// Returns an iterator over all [`NodeId`] ancestors, starting from the given [`NodeId`].
pub fn ancestor_ids(&self, node_id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
std::iter::successors(Some(node_id), |&node_id| {
self.nodes[usize::from(node_id)].parent
})
}
/// Return the parent of the given node.
pub fn parent(&self, node: &'a Stmt) -> Option<&'a Stmt> {
let node_id = self.node_to_id.get(&RefEquality(node))?;
let parent_id = self.nodes[usize::from(*node_id)].parent?;
Some(self[parent_id])
}
}
impl<'a> Index<NodeId> for Nodes<'a> {
type Output = &'a Stmt;
fn index(&self, index: NodeId) -> &Self::Output {
&self.nodes[usize::from(index)].stmt
}
}
impl<'a> IndexMut<NodeId> for Nodes<'a> {
fn index_mut(&mut self, index: NodeId) -> &mut Self::Output {
&mut self.nodes[usize::from(index)].stmt
}
}