From 1ea88ea56b12e6c5b07deb08d6ea98619e4d587c Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Wed, 1 Feb 2023 08:12:36 -0500 Subject: [PATCH] Avoid iterating over body twice (#2439) --- .../test/fixtures/flake8_simplify/SIM110.py | 10 +++++ src/checkers/ast.rs | 42 ++++++++++++------- src/rules/flake8_simplify/rules/ast_for.rs | 27 ++++++------ ...ke8_simplify__tests__SIM110_SIM110.py.snap | 19 +++++++++ 4 files changed, 68 insertions(+), 30 deletions(-) diff --git a/resources/test/fixtures/flake8_simplify/SIM110.py b/resources/test/fixtures/flake8_simplify/SIM110.py index ddf6f82d6c..f06af587bb 100644 --- a/resources/test/fixtures/flake8_simplify/SIM110.py +++ b/resources/test/fixtures/flake8_simplify/SIM110.py @@ -115,3 +115,13 @@ def f(): else: return True return False + + +def f(): + x = 1 + + # SIM110 + for x in iterable: + if check(x): + return True + return False diff --git a/src/checkers/ast.rs b/src/checkers/ast.rs index b885deba73..6c1ef16f4d 100644 --- a/src/checkers/ast.rs +++ b/src/checkers/ast.rs @@ -87,6 +87,9 @@ pub struct Checker<'a> { deferred_functions: Vec<(&'a Stmt, DeferralContext<'a>, VisibleScope)>, deferred_lambdas: Vec<(&'a Expr, DeferralContext<'a>)>, deferred_assignments: Vec>, + // Body iteration; used to peek at siblings. + body: &'a [Stmt], + body_index: usize, // Internal, derivative state. visible_scope: VisibleScope, in_annotation: bool, @@ -145,6 +148,9 @@ impl<'a> Checker<'a> { deferred_functions: vec![], deferred_lambdas: vec![], deferred_assignments: vec![], + // Body iteration. + body: &[], + body_index: 0, // Internal, derivative state. visible_scope: VisibleScope { modifier: Modifier::Module, @@ -1590,7 +1596,11 @@ where if self.settings.rules.enabled(&Rule::ConvertLoopToAny) || self.settings.rules.enabled(&Rule::ConvertLoopToAll) { - flake8_simplify::rules::convert_for_loop_to_any_all(self, stmt, None); + flake8_simplify::rules::convert_for_loop_to_any_all( + self, + stmt, + self.current_sibling_stmt(), + ); } if self.settings.rules.enabled(&Rule::KeyInDict) { flake8_simplify::rules::key_in_dict_for(self, target, iter); @@ -3694,19 +3704,18 @@ where flake8_pie::rules::no_unnecessary_pass(self, body); } - if self.settings.rules.enabled(&Rule::ConvertLoopToAny) - || self.settings.rules.enabled(&Rule::ConvertLoopToAll) - { - for (stmt, sibling) in body.iter().tuple_windows() { - if matches!(stmt.node, StmtKind::For { .. }) - && matches!(sibling.node, StmtKind::Return { .. }) - { - flake8_simplify::rules::convert_for_loop_to_any_all(self, stmt, Some(sibling)); - } - } + let prev_body = self.body; + let prev_body_index = self.body_index; + self.body = body; + self.body_index = 0; + + for stmt in body { + self.visit_stmt(stmt); + self.body_index += 1; } - visitor::walk_body(self, body); + self.body = prev_body; + self.body_index = prev_body_index; } } @@ -3795,6 +3804,11 @@ impl<'a> Checker<'a> { self.exprs.iter().rev().nth(2) } + /// Return the `Stmt` that immediately follows the current `Stmt`, if any. + pub fn current_sibling_stmt(&self) -> Option<&'a Stmt> { + self.body.get(self.body_index + 1) + } + pub fn current_scope(&self) -> &Scope { &self.scopes[*(self.scope_stack.last().expect("No current scope found"))] } @@ -5219,9 +5233,7 @@ pub fn check_ast( }; // Iterate over the AST. - for stmt in python_ast { - checker.visit_stmt(stmt); - } + checker.visit_body(python_ast); // Check any deferred statements. checker.check_deferred_functions(); diff --git a/src/rules/flake8_simplify/rules/ast_for.rs b/src/rules/flake8_simplify/rules/ast_for.rs index 85cedd572c..cce4ca1ebe 100644 --- a/src/rules/flake8_simplify/rules/ast_for.rs +++ b/src/rules/flake8_simplify/rules/ast_for.rs @@ -1,5 +1,5 @@ use rustpython_ast::{ - Comprehension, Constant, Expr, ExprContext, ExprKind, Stmt, StmtKind, Unaryop, + Comprehension, Constant, Expr, ExprContext, ExprKind, Location, Stmt, StmtKind, Unaryop, }; use crate::ast::helpers::{create_expr, create_stmt, unparse_stmt}; @@ -16,6 +16,7 @@ struct Loop<'a> { test: &'a Expr, target: &'a Expr, iter: &'a Expr, + terminal: Location, } /// Extract the returned boolean values a `StmtKind::For` with an `else` body. @@ -78,6 +79,7 @@ fn return_values_for_else(stmt: &Stmt) -> Option { test: nested_test, target, iter, + terminal: stmt.end_location.unwrap(), }) } @@ -142,6 +144,7 @@ fn return_values_for_siblings<'a>(stmt: &'a Stmt, sibling: &'a Stmt) -> Option) { - if let Some(loop_info) = match sibling { - // Ex) `for` loop with an `else: return True` or `else: return False`. - None => return_values_for_else(stmt), - // Ex) `for` loop followed by `return True` or `return False` - Some(sibling) => return_values_for_siblings(stmt, sibling), - } { + // There are two cases to consider: + // - `for` loop with an `else: return True` or `else: return False`. + // - `for` loop followed by `return True` or `return False` + if let Some(loop_info) = return_values_for_else(stmt) + .or_else(|| sibling.and_then(|sibling| return_values_for_siblings(stmt, sibling))) + { if loop_info.return_value && !loop_info.next_return_value { if checker.settings.rules.enabled(&Rule::ConvertLoopToAny) { let contents = return_stmt( @@ -203,10 +206,7 @@ pub fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt, sibling: diagnostic.amend(Fix::replacement( contents, stmt.location, - match sibling { - None => stmt.end_location.unwrap(), - Some(sibling) => sibling.end_location.unwrap(), - }, + loop_info.terminal, )); } checker.diagnostics.push(diagnostic); @@ -253,10 +253,7 @@ pub fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt, sibling: diagnostic.amend(Fix::replacement( contents, stmt.location, - match sibling { - None => stmt.end_location.unwrap(), - Some(sibling) => sibling.end_location.unwrap(), - }, + loop_info.terminal, )); } checker.diagnostics.push(diagnostic); diff --git a/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM110_SIM110.py.snap b/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM110_SIM110.py.snap index 74bbab50d9..9242e34718 100644 --- a/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM110_SIM110.py.snap +++ b/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM110_SIM110.py.snap @@ -59,4 +59,23 @@ expression: diagnostics row: 77 column: 20 parent: ~ +- kind: + ConvertLoopToAny: + any: return any(check(x) for x in iterable) + location: + row: 124 + column: 4 + end_location: + row: 126 + column: 23 + fix: + content: + - return any(check(x) for x in iterable) + location: + row: 124 + column: 4 + end_location: + row: 127 + column: 16 + parent: ~