From ea3d3a655d58474467f302e7f17a17e01bb3329a Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Wed, 10 May 2023 12:50:47 -0400 Subject: [PATCH] Add a `Snapshot` abstraction for deferring and restoring visitor context (#4353) --- crates/ruff/src/checkers/ast/deferred.rs | 22 ++-- crates/ruff/src/checkers/ast/mod.rs | 115 ++++++++------------- crates/ruff_python_semantic/src/context.rs | 27 +++++ crates/ruff_python_semantic/src/node.rs | 1 + 4 files changed, 81 insertions(+), 84 deletions(-) diff --git a/crates/ruff/src/checkers/ast/deferred.rs b/crates/ruff/src/checkers/ast/deferred.rs index 346c054826..ae79c37542 100644 --- a/crates/ruff/src/checkers/ast/deferred.rs +++ b/crates/ruff/src/checkers/ast/deferred.rs @@ -2,26 +2,20 @@ use ruff_text_size::TextRange; use rustpython_parser::ast::Expr; use ruff_python_semantic::analyze::visibility::{Visibility, VisibleScope}; -use ruff_python_semantic::node::NodeId; -use ruff_python_semantic::scope::ScopeId; +use ruff_python_semantic::context::Snapshot; -use crate::checkers::ast::AnnotationContext; use crate::docstrings::definition::Definition; -/// A snapshot of the current scope and statement, which will be restored when visiting any -/// deferred definitions. -type Context<'a> = (ScopeId, Option); - /// A collection of AST nodes that are deferred for later analysis. /// Used to, e.g., store functions, whose bodies shouldn't be analyzed until all /// module-level definitions have been analyzed. #[derive(Default)] pub struct Deferred<'a> { - pub definitions: Vec<(Definition<'a>, Visibility, Context<'a>)>, - pub string_type_definitions: Vec<(TextRange, &'a str, AnnotationContext, Context<'a>)>, - pub type_definitions: Vec<(&'a Expr, AnnotationContext, Context<'a>)>, - pub functions: Vec<(Context<'a>, VisibleScope)>, - pub lambdas: Vec<(&'a Expr, Context<'a>)>, - pub for_loops: Vec>, - pub assignments: Vec>, + pub definitions: Vec<(Definition<'a>, Visibility, Snapshot)>, + pub string_type_definitions: Vec<(TextRange, &'a str, Snapshot)>, + pub type_definitions: Vec<(&'a Expr, Snapshot)>, + pub functions: Vec<(Snapshot, VisibleScope)>, + pub lambdas: Vec<(&'a Expr, Snapshot)>, + pub for_loops: Vec, + pub assignments: Vec, } diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index c20e865494..f573f1a944 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -55,8 +55,6 @@ use crate::{autofix, docstrings, noqa, warn_user}; mod deferred; -type AnnotationContext = (bool, bool); - pub struct Checker<'a> { // Settings, static metadata, etc. path: &'a Path, @@ -1706,9 +1704,7 @@ where .. } => { if self.settings.rules.enabled(Rule::UnusedLoopControlVariable) { - self.deferred - .for_loops - .push((self.ctx.scope_id, self.ctx.stmt_id)); + self.deferred.for_loops.push(self.ctx.snapshot()); } if self .settings @@ -1994,11 +1990,9 @@ where pyupgrade::rules::yield_in_for_loop(self, stmt); } let scope = transition_scope(self.ctx.visible_scope, stmt, Documentable::Function); - self.deferred.definitions.push(( - definition, - scope.visibility, - (self.ctx.scope_id, self.ctx.stmt_id), - )); + self.deferred + .definitions + .push((definition, scope.visibility, self.ctx.snapshot())); self.ctx.visible_scope = scope; // If any global bindings don't already exist in the global scope, add it. @@ -2033,10 +2027,9 @@ where globals, })); - self.deferred.functions.push(( - (self.ctx.scope_id, self.ctx.stmt_id), - self.ctx.visible_scope, - )); + self.deferred + .functions + .push((self.ctx.snapshot(), self.ctx.visible_scope)); } StmtKind::ClassDef { body, @@ -2056,11 +2049,9 @@ where Documentable::Class, ); let scope = transition_scope(self.ctx.visible_scope, stmt, Documentable::Class); - self.deferred.definitions.push(( - definition, - scope.visibility, - (self.ctx.scope_id, self.ctx.stmt_id), - )); + self.deferred + .definitions + .push((definition, scope.visibility, self.ctx.snapshot())); self.ctx.visible_scope = scope; // If any global bindings don't already exist in the global scope, add it. @@ -2273,15 +2264,12 @@ where self.deferred.string_type_definitions.push(( expr.range(), value, - (self.ctx.in_annotation, self.ctx.in_type_checking_block), - (self.ctx.scope_id, self.ctx.stmt_id), + self.ctx.snapshot(), )); } else { - self.deferred.type_definitions.push(( - expr, - (self.ctx.in_annotation, self.ctx.in_type_checking_block), - (self.ctx.scope_id, self.ctx.stmt_id), - )); + self.deferred + .type_definitions + .push((expr, self.ctx.snapshot())); } return; } @@ -3526,8 +3514,7 @@ where self.deferred.string_type_definitions.push(( expr.range(), value, - (self.ctx.in_annotation, self.ctx.in_type_checking_block), - (self.ctx.scope_id, self.ctx.stmt_id), + self.ctx.snapshot(), )); } if self @@ -3648,9 +3635,7 @@ where // Recurse. match &expr.node { ExprKind::Lambda { .. } => { - self.deferred - .lambdas - .push((expr, (self.ctx.scope_id, self.ctx.stmt_id))); + self.deferred.lambdas.push((expr, self.ctx.snapshot())); } ExprKind::IfExp { test, body, orelse } => { visit_boolean_test!(self, test); @@ -4793,7 +4778,7 @@ impl<'a> Checker<'a> { docstring, }, self.ctx.visible_scope.visibility, - (self.ctx.scope_id, self.ctx.stmt_id), + self.ctx.snapshot(), )); docstring.is_some() } @@ -4801,13 +4786,9 @@ impl<'a> Checker<'a> { fn check_deferred_type_definitions(&mut self) { while !self.deferred.type_definitions.is_empty() { let type_definitions = std::mem::take(&mut self.deferred.type_definitions); - for (expr, (in_annotation, in_type_checking_block), (scope_id, stmt_id)) in - type_definitions - { - self.ctx.scope_id = scope_id; - self.ctx.stmt_id = stmt_id; - self.ctx.in_annotation = in_annotation; - self.ctx.in_type_checking_block = in_type_checking_block; + for (expr, snapshot) in type_definitions { + self.ctx.restore(snapshot); + self.ctx.in_type_definition = true; self.ctx.in_deferred_type_definition = true; self.visit_expr(expr); @@ -4820,11 +4801,15 @@ impl<'a> Checker<'a> { fn check_deferred_string_type_definitions(&mut self, allocator: &'a typed_arena::Arena) { while !self.deferred.string_type_definitions.is_empty() { let type_definitions = std::mem::take(&mut self.deferred.string_type_definitions); - for (range, value, (in_annotation, in_type_checking_block), (scope_id, stmt_id)) in - type_definitions - { + for (range, value, snapshot) in type_definitions { if let Ok((expr, kind)) = parse_type_annotation(value, range, self.locator) { - if in_annotation && self.ctx.annotations_future_enabled { + let expr = allocator.alloc(expr); + + self.ctx.restore(snapshot); + self.ctx.in_type_definition = true; + self.ctx.in_deferred_string_type_definition = Some(kind); + + if self.ctx.in_annotation && self.ctx.annotations_future_enabled { if self.settings.rules.enabled(Rule::QuotedAnnotation) { pyupgrade::rules::quoted_annotation(self, value, range); } @@ -4834,16 +4819,8 @@ impl<'a> Checker<'a> { flake8_pyi::rules::quoted_annotation_in_stub(self, value, range); } } - - let expr = allocator.alloc(expr); - - self.ctx.scope_id = scope_id; - self.ctx.stmt_id = stmt_id; - self.ctx.in_annotation = in_annotation; - self.ctx.in_type_checking_block = in_type_checking_block; - self.ctx.in_type_definition = true; - self.ctx.in_deferred_string_type_definition = Some(kind); self.visit_expr(expr); + self.ctx.in_deferred_string_type_definition = None; self.ctx.in_type_definition = false; } else { @@ -4867,9 +4844,8 @@ impl<'a> Checker<'a> { fn check_deferred_functions(&mut self) { while !self.deferred.functions.is_empty() { let deferred_functions = std::mem::take(&mut self.deferred.functions); - for ((scope_id, stmt_id), visibility) in deferred_functions { - self.ctx.scope_id = scope_id; - self.ctx.stmt_id = stmt_id; + for (snapshot, visibility) in deferred_functions { + self.ctx.restore(snapshot); self.ctx.visible_scope = visibility; match &self.ctx.stmt().node { @@ -4883,7 +4859,7 @@ impl<'a> Checker<'a> { } } - self.deferred.assignments.push((scope_id, stmt_id)); + self.deferred.assignments.push(snapshot); } } } @@ -4891,9 +4867,8 @@ impl<'a> Checker<'a> { fn check_deferred_lambdas(&mut self) { while !self.deferred.lambdas.is_empty() { let lambdas = std::mem::take(&mut self.deferred.lambdas); - for (expr, (scope_id, stmt_id)) in lambdas { - self.ctx.scope_id = scope_id; - self.ctx.stmt_id = stmt_id; + for (expr, snapshot) in lambdas { + self.ctx.restore(snapshot); if let ExprKind::Lambda { args, body } = &expr.node { self.visit_arguments(args); @@ -4902,7 +4877,7 @@ impl<'a> Checker<'a> { unreachable!("Expected ExprKind::Lambda"); } - self.deferred.assignments.push((scope_id, stmt_id)); + self.deferred.assignments.push(snapshot); } } } @@ -4910,13 +4885,15 @@ impl<'a> Checker<'a> { fn check_deferred_assignments(&mut self) { while !self.deferred.assignments.is_empty() { let assignments = std::mem::take(&mut self.deferred.assignments); - for (scope_id, ..) in assignments { + for snapshot in assignments { + self.ctx.restore(snapshot); + // pyflakes if self.settings.rules.enabled(Rule::UnusedVariable) { - pyflakes::rules::unused_variable(self, scope_id); + pyflakes::rules::unused_variable(self, self.ctx.scope_id); } if self.settings.rules.enabled(Rule::UnusedAnnotation) { - pyflakes::rules::unused_annotation(self, scope_id); + pyflakes::rules::unused_annotation(self, self.ctx.scope_id); } if !self.is_stub { @@ -4928,7 +4905,7 @@ impl<'a> Checker<'a> { Rule::UnusedStaticMethodArgument, Rule::UnusedLambdaArgument, ]) { - let scope = &self.ctx.scopes[scope_id]; + let scope = &self.ctx.scopes[self.ctx.scope_id]; let parent = &self.ctx.scopes[scope.parent.unwrap()]; self.diagnostics .extend(flake8_unused_arguments::rules::unused_arguments( @@ -4947,9 +4924,8 @@ impl<'a> Checker<'a> { while !self.deferred.for_loops.is_empty() { let for_loops = std::mem::take(&mut self.deferred.for_loops); - for (scope_id, stmt_id) in for_loops { - self.ctx.scope_id = scope_id; - self.ctx.stmt_id = stmt_id; + for snapshot in for_loops { + self.ctx.restore(snapshot); if let StmtKind::For { target, body, .. } | StmtKind::AsyncFor { target, body, .. } = &self.ctx.stmt().node @@ -5442,9 +5418,8 @@ impl<'a> Checker<'a> { let mut overloaded_name: Option = None; while !self.deferred.definitions.is_empty() { let definitions = std::mem::take(&mut self.deferred.definitions); - for (definition, visibility, (scope_id, stmt_id)) in definitions { - self.ctx.scope_id = scope_id; - self.ctx.stmt_id = stmt_id; + for (definition, visibility, snapshot) in definitions { + self.ctx.restore(snapshot); // flake8-annotations if enforce_annotations { diff --git a/crates/ruff_python_semantic/src/context.rs b/crates/ruff_python_semantic/src/context.rs index 5718d64dc2..92d8b7fb1d 100644 --- a/crates/ruff_python_semantic/src/context.rs +++ b/crates/ruff_python_semantic/src/context.rs @@ -19,6 +19,15 @@ use crate::binding::{ use crate::node::{NodeId, Nodes}; use crate::scope::{Scope, ScopeId, ScopeKind, Scopes}; +/// A snapshot of the [`Context`] at a given point in the AST traversal. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Snapshot { + scope_id: ScopeId, + stmt_id: Option, + in_annotation: bool, + in_type_checking_block: bool, +} + #[allow(clippy::struct_excessive_bools)] pub struct Context<'a> { pub typing_modules: &'a [String], @@ -435,4 +444,22 @@ impl<'a> Context<'a> { } exceptions } + + /// Generate a [`Snapshot`] of the current context. + pub fn snapshot(&self) -> Snapshot { + Snapshot { + scope_id: self.scope_id, + stmt_id: self.stmt_id, + in_annotation: self.in_annotation, + in_type_checking_block: self.in_type_checking_block, + } + } + + /// Restore the context to the given [`Snapshot`]. + pub fn restore(&mut self, snapshot: Snapshot) { + self.scope_id = snapshot.scope_id; + self.stmt_id = snapshot.stmt_id; + self.in_annotation = snapshot.in_annotation; + self.in_type_checking_block = snapshot.in_type_checking_block; + } } diff --git a/crates/ruff_python_semantic/src/node.rs b/crates/ruff_python_semantic/src/node.rs index 7d84e03461..bc3dbba7d5 100644 --- a/crates/ruff_python_semantic/src/node.rs +++ b/crates/ruff_python_semantic/src/node.rs @@ -30,6 +30,7 @@ impl From for usize { } } +/// A [`Node`] represents a statement in a program, along with a pointer to its parent (if any). #[derive(Debug)] struct Node<'a> { /// The statement this node represents.