diff --git a/src/ast/operations.rs b/src/ast/operations.rs index 60bb1f3a66..b15284d4bd 100644 --- a/src/ast/operations.rs +++ b/src/ast/operations.rs @@ -79,12 +79,18 @@ struct GlobalVisitor<'a> { impl<'a> Visitor<'a> for GlobalVisitor<'a> { fn visit_stmt(&mut self, stmt: &'a Stmt) { - if let StmtKind::Global { names } = &stmt.node { - for name in names { - self.globals.insert(name, stmt); + match &stmt.node { + StmtKind::Global { names } => { + for name in names { + self.globals.insert(name, stmt); + } } - } else { - visitor::walk_stmt(self, stmt); + StmtKind::FunctionDef { .. } + | StmtKind::AsyncFunctionDef { .. } + | StmtKind::ClassDef { .. } => { + // Don't recurse. + } + _ => visitor::walk_stmt(self, stmt), } } } diff --git a/src/check_ast.rs b/src/check_ast.rs index 5ea0a7edd9..930016c618 100644 --- a/src/check_ast.rs +++ b/src/check_ast.rs @@ -77,7 +77,7 @@ pub struct Checker<'a> { deferred_type_definitions: Vec<(&'a Expr, bool, DeferralContext<'a>)>, deferred_functions: Vec<(&'a Stmt, DeferralContext<'a>, VisibleScope)>, deferred_lambdas: Vec<(&'a Expr, DeferralContext<'a>)>, - deferred_assignments: Vec<(usize, DeferralContext<'a>)>, + deferred_assignments: Vec>, // Internal, derivative state. visible_scope: VisibleScope, in_f_string: Option, @@ -241,23 +241,6 @@ where StmtKind::Global { names } => { let scope_index = *self.scope_stack.last().expect("No current scope found"); if scope_index != GLOBAL_SCOPE_INDEX { - // If the binding doesn't already exist in the global scope, add it. - for name in names { - if !self.scopes[GLOBAL_SCOPE_INDEX] - .values - .contains_key(&name.as_str()) - { - let index = self.bindings.len(); - self.bindings.push(Binding { - kind: BindingKind::Assignment, - used: None, - range: Range::from_located(stmt), - source: Some(RefEquality(stmt)), - }); - self.scopes[GLOBAL_SCOPE_INDEX].values.insert(name, index); - } - } - // Add the binding to the current scope. let scope = &mut self.scopes[scope_index]; let usage = Some((scope.id, Range::from_located(stmt))); @@ -601,15 +584,6 @@ where for expr in decorator_list { self.visit_expr(expr); } - - let globals = operations::extract_globals(body); - self.push_scope(Scope::new(ScopeKind::Class(ClassDef { - name, - bases, - keywords, - decorator_list, - globals, - }))); } StmtKind::Import { names } => { if self.settings.enabled.contains(&CheckCode::E402) { @@ -1150,7 +1124,20 @@ where // Recurse. let prev_visible_scope = self.visible_scope.clone(); match &stmt.node { - StmtKind::FunctionDef { body, .. } | StmtKind::AsyncFunctionDef { body, .. } => { + StmtKind::FunctionDef { + body, + name, + args, + decorator_list, + .. + } + | StmtKind::AsyncFunctionDef { + body, + name, + args, + decorator_list, + .. + } => { if self.settings.enabled.contains(&CheckCode::B021) { flake8_bugbear::plugins::f_string_docstring(self, body); } @@ -1165,13 +1152,44 @@ where .push((definition, scope.visibility.clone())); self.visible_scope = scope; + // If any global bindings don't already exist in the global scope, add it. + let globals = operations::extract_globals(body); + for (name, stmt) in operations::extract_globals(body) { + if !self.scopes[GLOBAL_SCOPE_INDEX].values.contains_key(name) { + let index = self.bindings.len(); + self.bindings.push(Binding { + kind: BindingKind::Assignment, + used: None, + range: Range::from_located(stmt), + source: Some(RefEquality(stmt)), + }); + self.scopes[GLOBAL_SCOPE_INDEX].values.insert(name, index); + } + } + + self.push_scope(Scope::new(ScopeKind::Function(FunctionDef { + name, + body, + args, + decorator_list, + async_: matches!(stmt.node, StmtKind::AsyncFunctionDef { .. }), + globals, + }))); + self.deferred_functions.push(( stmt, (self.scope_stack.clone(), self.parents.clone()), self.visible_scope.clone(), )); } - StmtKind::ClassDef { body, .. } => { + StmtKind::ClassDef { + body, + name, + bases, + keywords, + decorator_list, + .. + } => { if self.settings.enabled.contains(&CheckCode::B021) { flake8_bugbear::plugins::f_string_docstring(self, body); } @@ -1186,6 +1204,29 @@ where .push((definition, scope.visibility.clone())); self.visible_scope = scope; + // If any global bindings don't already exist in the global scope, add it. + let globals = operations::extract_globals(body); + for (name, stmt) in &globals { + if !self.scopes[GLOBAL_SCOPE_INDEX].values.contains_key(name) { + let index = self.bindings.len(); + self.bindings.push(Binding { + kind: BindingKind::Assignment, + used: None, + range: Range::from_located(stmt), + source: Some(RefEquality(stmt)), + }); + self.scopes[GLOBAL_SCOPE_INDEX].values.insert(name, index); + } + } + + self.push_scope(Scope::new(ScopeKind::Class(ClassDef { + name, + bases, + keywords, + decorator_list, + globals, + }))); + for stmt in body { self.visit_stmt(stmt); } @@ -1237,18 +1278,24 @@ where self.visible_scope = prev_visible_scope; // Post-visit. - if let StmtKind::ClassDef { name, .. } = &stmt.node { - self.pop_scope(); - self.add_binding( - name, - Binding { - kind: BindingKind::ClassDefinition, - used: None, - range: Range::from_located(stmt), - source: Some(self.current_parent().clone()), - }, - ); - }; + match &stmt.node { + StmtKind::FunctionDef { .. } | StmtKind::AsyncFunctionDef { .. } => { + self.pop_scope(); + } + StmtKind::ClassDef { name, .. } => { + self.pop_scope(); + self.add_binding( + name, + Binding { + kind: BindingKind::ClassDefinition, + used: None, + range: Range::from_located(stmt), + source: Some(self.current_parent().clone()), + }, + ); + } + _ => {} + } self.pop_parent(); } @@ -3073,30 +3120,8 @@ impl<'a> Checker<'a> { self.visible_scope = visibility; match &stmt.node { - StmtKind::FunctionDef { - name, - body, - args, - decorator_list, - .. - } - | StmtKind::AsyncFunctionDef { - name, - body, - args, - decorator_list, - .. - } => { - let globals = operations::extract_globals(body); - self.push_scope(Scope::new(ScopeKind::Function(FunctionDef { - name, - body, - args, - decorator_list, - async_: matches!(stmt.node, StmtKind::AsyncFunctionDef { .. }), - globals, - }))); - + StmtKind::FunctionDef { body, args, .. } + | StmtKind::AsyncFunctionDef { body, args, .. } => { self.visit_arguments(args); for stmt in body { self.visit_stmt(stmt); @@ -3105,12 +3130,7 @@ impl<'a> Checker<'a> { _ => unreachable!("Expected StmtKind::FunctionDef | StmtKind::AsyncFunctionDef"), } - self.deferred_assignments.push(( - *self.scope_stack.last().expect("No current scope found"), - (scopes, parents), - )); - - self.pop_scope(); + self.deferred_assignments.push((scopes, parents)); } } @@ -3121,29 +3141,25 @@ impl<'a> Checker<'a> { self.parents = parents.clone(); if let ExprKind::Lambda { args, body } = &expr.node { - self.push_scope(Scope::new(ScopeKind::Lambda(Lambda { args, body }))); self.visit_arguments(args); self.visit_expr(body); } else { unreachable!("Expected ExprKind::Lambda"); } - self.deferred_assignments.push(( - *self.scope_stack.last().expect("No current scope found"), - (scopes, parents), - )); - - self.pop_scope(); + self.deferred_assignments.push((scopes, parents)); } } fn check_deferred_assignments(&mut self) { self.deferred_assignments.reverse(); - while let Some((index, (scopes, _parents))) = self.deferred_assignments.pop() { + while let Some((scopes, _parents)) = self.deferred_assignments.pop() { + let scope_index = scopes[scopes.len() - 1]; + let parent_scope_index = scopes[scopes.len() - 2]; if self.settings.enabled.contains(&CheckCode::F841) { self.add_checks( pyflakes::checks::unused_variable( - &self.scopes[index], + &self.scopes[scope_index], &self.bindings, &self.settings.dummy_variable_rgx, ) @@ -3153,7 +3169,7 @@ impl<'a> Checker<'a> { if self.settings.enabled.contains(&CheckCode::F842) { self.add_checks( pyflakes::checks::unused_annotation( - &self.scopes[index], + &self.scopes[scope_index], &self.bindings, &self.settings.dummy_variable_rgx, ) @@ -3169,10 +3185,8 @@ impl<'a> Checker<'a> { self.add_checks( flake8_unused_arguments::plugins::unused_arguments( self, - &self.scopes[*scopes - .last() - .expect("Expected parent scope above function scope")], - &self.scopes[index], + &self.scopes[parent_scope_index], + &self.scopes[scope_index], &self.bindings, ) .into_iter(), diff --git a/src/pyflakes/mod.rs b/src/pyflakes/mod.rs index 262cc1b84a..b1929fee7c 100644 --- a/src/pyflakes/mod.rs +++ b/src/pyflakes/mod.rs @@ -404,14 +404,13 @@ mod tests { "#, &[], )?; - // Pyflakes allows this, but it causes other issues. - // flakes( - // r#" - // def c(): bar - // def b(): global bar; bar = 1 - // "#, - // &[], - // )?; + flakes( + r#" + def c(): bar + def b(): global bar; bar = 1 + "#, + &[], + )?; Ok(()) }