Remove suite body tracking from SemanticModel (#5848)

## Summary

The `SemanticModel` currently stores the "body" of a given `Suite`,
along with the current statement index. This is used to support "next
sibling" queries, but we only use this in exactly one place -- the rule
that simplifies constructs like this to `any` or `all`:

```python
for x in y:
    if x == 0:
        return True
return False
```

Instead of tracking the state, we can just do a (slightly more
expensive) traversal, by finding the node within its parent and
returning the next node in the body.

Note that we'll only have to do this extremely rarely -- namely, for
functions that contain something like:

```python
for x in y:
    if x == 0:
        return True
```
This commit is contained in:
Charlie Marsh 2023-07-18 18:58:31 -04:00 committed by GitHub
parent a93254f026
commit 2d505e2b04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 340 additions and 250 deletions

View file

@ -1404,11 +1404,7 @@ where
}
if stmt.is_for_stmt() {
if self.enabled(Rule::ReimplementedBuiltin) {
flake8_simplify::rules::convert_for_loop_to_any_all(
self,
stmt,
self.semantic.sibling_stmt(),
);
flake8_simplify::rules::convert_for_loop_to_any_all(self, stmt);
}
if self.enabled(Rule::InDictKeys) {
flake8_simplify::rules::key_in_dict_for(self, target, iter);
@ -4237,21 +4233,10 @@ where
flake8_pie::rules::no_unnecessary_pass(self, body);
}
// Step 2: Binding
let prev_body = self.semantic.body;
let prev_body_index = self.semantic.body_index;
self.semantic.body = body;
self.semantic.body_index = 0;
// Step 3: Traversal
for stmt in body {
self.visit_stmt(stmt);
self.semantic.body_index += 1;
}
// Step 4: Clean-up
self.semantic.body = prev_body;
self.semantic.body_index = prev_body_index;
}
}

View file

@ -1,4 +1,4 @@
use ruff_text_size::{TextRange, TextSize};
use ruff_text_size::TextRange;
use rustpython_parser::ast::{
self, CmpOp, Comprehension, Constant, Expr, ExprContext, Ranged, Stmt, UnaryOp,
};
@ -7,10 +7,11 @@ use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::any_over_expr;
use ruff_python_ast::source_code::Generator;
use ruff_python_ast::traversal;
use crate::checkers::ast::Checker;
use crate::line_width::LineWidth;
use crate::registry::{AsRule, Rule};
use crate::registry::AsRule;
/// ## What it does
/// Checks for `for` loops that can be replaced with a builtin function, like
@ -38,7 +39,7 @@ use crate::registry::{AsRule, Rule};
/// - [Python documentation: `all`](https://docs.python.org/3/library/functions.html#all)
#[violation]
pub struct ReimplementedBuiltin {
repl: String,
replacement: String,
}
impl Violation for ReimplementedBuiltin {
@ -46,200 +47,222 @@ impl Violation for ReimplementedBuiltin {
#[derive_message_formats]
fn message(&self) -> String {
let ReimplementedBuiltin { repl } = self;
format!("Use `{repl}` instead of `for` loop")
let ReimplementedBuiltin { replacement } = self;
format!("Use `{replacement}` instead of `for` loop")
}
fn autofix_title(&self) -> Option<String> {
let ReimplementedBuiltin { repl } = self;
Some(format!("Replace with `{repl}`"))
let ReimplementedBuiltin { replacement } = self;
Some(format!("Replace with `{replacement}`"))
}
}
/// SIM110, SIM111
pub(crate) fn convert_for_loop_to_any_all(
checker: &mut Checker,
stmt: &Stmt,
sibling: Option<&Stmt>,
) {
// There are two cases to consider:
pub(crate) fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt) {
if !checker.semantic().scope().kind.is_any_function() {
return;
}
// The `for` loop itself must consist of an `if` with a `return`.
let Some(loop_) = match_loop(stmt) else {
return;
};
// Afterwards, there are two cases to consider:
// - `for` loop with an `else: return True` or `else: return False`.
// - `for` loop followed by `return True` or `return False`
if let Some(loop_info) = return_values_for_else(stmt)
.or_else(|| sibling.and_then(|sibling| return_values_for_siblings(stmt, sibling)))
{
// Check if loop_info.target, loop_info.iter, or loop_info.test contains `await`.
if contains_await(loop_info.target)
|| contains_await(loop_info.iter)
|| contains_await(loop_info.test)
{
return;
}
if loop_info.return_value && !loop_info.next_return_value {
if checker.enabled(Rule::ReimplementedBuiltin) {
let contents = return_stmt(
"any",
loop_info.test,
loop_info.target,
loop_info.iter,
checker.generator(),
);
// - `for` loop followed by `return True` or `return False`.
let Some(terminal) = match_else_return(stmt).or_else(|| {
let parent = checker.semantic().stmt_parent()?;
let suite = traversal::suite(stmt, parent)?;
let sibling = traversal::next_sibling(stmt, suite)?;
match_sibling_return(stmt, sibling)
}) else {
return;
};
// Don't flag if the resulting expression would exceed the maximum line length.
let line_start = checker.locator.line_start(stmt.start());
if LineWidth::new(checker.settings.tab_size)
.add_str(&checker.locator.contents()[TextRange::new(line_start, stmt.start())])
.add_str(&contents)
> checker.settings.line_length
{
return;
}
// Check if any of the expressions contain an `await` expression.
if contains_await(loop_.target) || contains_await(loop_.iter) || contains_await(loop_.test) {
return;
}
let mut diagnostic = Diagnostic::new(
ReimplementedBuiltin {
repl: contents.clone(),
},
TextRange::new(stmt.start(), loop_info.terminal),
);
if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("any") {
diagnostic.set_fix(Fix::suggested(Edit::replacement(
contents,
stmt.start(),
loop_info.terminal,
)));
}
checker.diagnostics.push(diagnostic);
match (loop_.return_value, terminal.return_value) {
// Replace with `any`.
(true, false) => {
let contents = return_stmt(
"any",
loop_.test,
loop_.target,
loop_.iter,
checker.generator(),
);
// Don't flag if the resulting expression would exceed the maximum line length.
let line_start = checker.locator.line_start(stmt.start());
if LineWidth::new(checker.settings.tab_size)
.add_str(&checker.locator.contents()[TextRange::new(line_start, stmt.start())])
.add_str(&contents)
> checker.settings.line_length
{
return;
}
}
if !loop_info.return_value && loop_info.next_return_value {
if checker.enabled(Rule::ReimplementedBuiltin) {
// Invert the condition.
let test = {
if let Expr::UnaryOp(ast::ExprUnaryOp {
op: UnaryOp::Not,
operand,
range: _,
}) = &loop_info.test
{
*operand.clone()
} else if let Expr::Compare(ast::ExprCompare {
left,
ops,
comparators,
range: _,
}) = &loop_info.test
{
if let ([op], [comparator]) = (ops.as_slice(), comparators.as_slice()) {
let op = match op {
CmpOp::Eq => CmpOp::NotEq,
CmpOp::NotEq => CmpOp::Eq,
CmpOp::Lt => CmpOp::GtE,
CmpOp::LtE => CmpOp::Gt,
CmpOp::Gt => CmpOp::LtE,
CmpOp::GtE => CmpOp::Lt,
CmpOp::Is => CmpOp::IsNot,
CmpOp::IsNot => CmpOp::Is,
CmpOp::In => CmpOp::NotIn,
CmpOp::NotIn => CmpOp::In,
};
let node = ast::ExprCompare {
left: left.clone(),
ops: vec![op],
comparators: vec![comparator.clone()],
range: TextRange::default(),
};
node.into()
} else {
let node = ast::ExprUnaryOp {
op: UnaryOp::Not,
operand: Box::new(loop_info.test.clone()),
range: TextRange::default(),
};
node.into()
}
let mut diagnostic = Diagnostic::new(
ReimplementedBuiltin {
replacement: contents.to_string(),
},
TextRange::new(stmt.start(), terminal.stmt.end()),
);
if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("any") {
diagnostic.set_fix(Fix::suggested(Edit::replacement(
contents,
stmt.start(),
terminal.stmt.end(),
)));
}
checker.diagnostics.push(diagnostic);
}
// Replace with `all`.
(false, true) => {
// Invert the condition.
let test = {
if let Expr::UnaryOp(ast::ExprUnaryOp {
op: UnaryOp::Not,
operand,
range: _,
}) = &loop_.test
{
*operand.clone()
} else if let Expr::Compare(ast::ExprCompare {
left,
ops,
comparators,
range: _,
}) = &loop_.test
{
if let ([op], [comparator]) = (ops.as_slice(), comparators.as_slice()) {
let op = match op {
CmpOp::Eq => CmpOp::NotEq,
CmpOp::NotEq => CmpOp::Eq,
CmpOp::Lt => CmpOp::GtE,
CmpOp::LtE => CmpOp::Gt,
CmpOp::Gt => CmpOp::LtE,
CmpOp::GtE => CmpOp::Lt,
CmpOp::Is => CmpOp::IsNot,
CmpOp::IsNot => CmpOp::Is,
CmpOp::In => CmpOp::NotIn,
CmpOp::NotIn => CmpOp::In,
};
let node = ast::ExprCompare {
left: left.clone(),
ops: vec![op],
comparators: vec![comparator.clone()],
range: TextRange::default(),
};
node.into()
} else {
let node = ast::ExprUnaryOp {
op: UnaryOp::Not,
operand: Box::new(loop_info.test.clone()),
operand: Box::new(loop_.test.clone()),
range: TextRange::default(),
};
node.into()
}
};
let contents = return_stmt(
"all",
&test,
loop_info.target,
loop_info.iter,
checker.generator(),
);
// Don't flag if the resulting expression would exceed the maximum line length.
let line_start = checker.locator.line_start(stmt.start());
if LineWidth::new(checker.settings.tab_size)
.add_str(&checker.locator.contents()[TextRange::new(line_start, stmt.start())])
.add_str(&contents)
> checker.settings.line_length
{
return;
} else {
let node = ast::ExprUnaryOp {
op: UnaryOp::Not,
operand: Box::new(loop_.test.clone()),
range: TextRange::default(),
};
node.into()
}
};
let contents = return_stmt("all", &test, loop_.target, loop_.iter, checker.generator());
let mut diagnostic = Diagnostic::new(
ReimplementedBuiltin {
repl: contents.clone(),
},
TextRange::new(stmt.start(), loop_info.terminal),
);
if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("all") {
diagnostic.set_fix(Fix::suggested(Edit::replacement(
contents,
stmt.start(),
loop_info.terminal,
)));
}
checker.diagnostics.push(diagnostic);
// Don't flag if the resulting expression would exceed the maximum line length.
let line_start = checker.locator.line_start(stmt.start());
if LineWidth::new(checker.settings.tab_size)
.add_str(&checker.locator.contents()[TextRange::new(line_start, stmt.start())])
.add_str(&contents)
> checker.settings.line_length
{
return;
}
let mut diagnostic = Diagnostic::new(
ReimplementedBuiltin {
replacement: contents.to_string(),
},
TextRange::new(stmt.start(), terminal.stmt.end()),
);
if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("all") {
diagnostic.set_fix(Fix::suggested(Edit::replacement(
contents,
stmt.start(),
terminal.stmt.end(),
)));
}
checker.diagnostics.push(diagnostic);
}
_ => {}
}
}
/// Represents a `for` loop with a conditional `return`, like:
/// ```python
/// for x in y:
/// if x == 0:
/// return True
/// ```
#[derive(Debug)]
struct Loop<'a> {
/// The `return` value of the loop.
return_value: bool,
next_return_value: bool,
/// The test condition in the loop.
test: &'a Expr,
/// The target of the loop.
target: &'a Expr,
/// The iterator of the loop.
iter: &'a Expr,
terminal: TextSize,
}
/// Extract the returned boolean values a `Stmt::For` with an `else` body.
fn return_values_for_else(stmt: &Stmt) -> Option<Loop> {
/// Represents a `return` statement following a `for` loop, like:
/// ```python
/// for x in y:
/// if x == 0:
/// return True
/// return False
/// ```
///
/// Or:
/// ```python
/// for x in y:
/// if x == 0:
/// return True
/// else:
/// return False
/// ```
#[derive(Debug)]
struct Terminal<'a> {
return_value: bool,
stmt: &'a Stmt,
}
fn match_loop(stmt: &Stmt) -> Option<Loop> {
let Stmt::For(ast::StmtFor {
body,
target,
iter,
orelse,
..
body, target, iter, ..
}) = stmt
else {
return None;
};
// The loop itself should contain a single `if` statement, with an `else`
// containing a single `return True` or `return False`.
if body.len() != 1 {
return None;
}
if orelse.len() != 1 {
return None;
}
let Stmt::If(ast::StmtIf {
// The loop itself should contain a single `if` statement, with a single `return` statement in
// the body.
let [Stmt::If(ast::StmtIf {
body: nested_body,
test: nested_test,
elif_else_clauses: nested_elif_else_clauses,
range: _,
}) = &body[0]
})] = body.as_slice()
else {
return None;
};
@ -263,15 +286,35 @@ fn return_values_for_else(stmt: &Stmt) -> Option<Loop> {
return None;
};
// The `else` block has to contain a single `return True` or `return False`.
let Stmt::Return(ast::StmtReturn {
value: next_value,
range: _,
}) = &orelse[0]
else {
Some(Loop {
return_value: *value,
test: nested_test,
target,
iter,
})
}
/// If a `Stmt::For` contains an `else` with a single boolean `return`, return the [`Terminal`]
/// representing that `return`.
///
/// For example, matches the `return` in:
/// ```python
/// for x in y:
/// if x == 0:
/// return True
/// return False
/// ```
fn match_else_return(stmt: &Stmt) -> Option<Terminal> {
let Stmt::For(ast::StmtFor { orelse, .. }) = stmt else {
return None;
};
let Some(next_value) = next_value else {
// The `else` block has to contain a single `return True` or `return False`.
let [Stmt::Return(ast::StmtReturn {
value: Some(next_value),
range: _,
})] = orelse.as_slice()
else {
return None;
};
let Expr::Constant(ast::ExprConstant {
@ -282,78 +325,41 @@ fn return_values_for_else(stmt: &Stmt) -> Option<Loop> {
return None;
};
Some(Loop {
return_value: *value,
next_return_value: *next_value,
test: nested_test,
target,
iter,
terminal: stmt.end(),
Some(Terminal {
return_value: *next_value,
stmt,
})
}
/// Extract the returned boolean values from subsequent `Stmt::For` and
/// `Stmt::Return` statements, or `None`.
fn return_values_for_siblings<'a>(stmt: &'a Stmt, sibling: &'a Stmt) -> Option<Loop<'a>> {
let Stmt::For(ast::StmtFor {
body,
target,
iter,
orelse,
..
}) = stmt
else {
/// If a `Stmt::For` is followed by a boolean `return`, return the [`Terminal`] representing that
/// `return`.
///
/// For example, matches the `return` in:
/// ```python
/// for x in y:
/// if x == 0:
/// return True
/// else:
/// return False
/// ```
fn match_sibling_return<'a>(stmt: &'a Stmt, sibling: &'a Stmt) -> Option<Terminal<'a>> {
let Stmt::For(ast::StmtFor { orelse, .. }) = stmt else {
return None;
};
// The loop itself should contain a single `if` statement, with a single `return
// True` or `return False`.
if body.len() != 1 {
return None;
}
// The loop itself shouldn't have an `else` block.
if !orelse.is_empty() {
return None;
}
let Stmt::If(ast::StmtIf {
body: nested_body,
test: nested_test,
elif_else_clauses: nested_elif_else_clauses,
range: _,
}) = &body[0]
else {
return None;
};
if nested_body.len() != 1 {
return None;
}
if !nested_elif_else_clauses.is_empty() {
return None;
}
let Stmt::Return(ast::StmtReturn { value, range: _ }) = &nested_body[0] else {
return None;
};
let Some(value) = value else {
return None;
};
let Expr::Constant(ast::ExprConstant {
value: Constant::Bool(value),
..
}) = value.as_ref()
else {
return None;
};
// The next statement has to be a `return True` or `return False`.
let Stmt::Return(ast::StmtReturn {
value: next_value,
value: Some(next_value),
range: _,
}) = &sibling
else {
return None;
};
let Some(next_value) = next_value else {
return None;
};
let Expr::Constant(ast::ExprConstant {
value: Constant::Bool(next_value),
..
@ -362,13 +368,9 @@ fn return_values_for_siblings<'a>(stmt: &'a Stmt, sibling: &'a Stmt) -> Option<L
return None;
};
Some(Loop {
return_value: *value,
next_return_value: *next_value,
test: nested_test,
target,
iter,
terminal: sibling.end(),
Some(Terminal {
return_value: *next_value,
stmt: sibling,
})
}

View file

@ -15,6 +15,7 @@ pub mod statement_visitor;
pub mod stmt_if;
pub mod str;
pub mod token_kind;
pub mod traversal;
pub mod types;
pub mod typing;
pub mod visitor;

View file

@ -0,0 +1,113 @@
//! Utilities for manually traversing a Python AST.
use rustpython_ast::{ExceptHandler, Stmt, Suite};
use rustpython_parser::ast;
/// Given a [`Stmt`] and its parent, return the [`Suite`] that contains the [`Stmt`].
pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<&'a Suite> {
match parent {
Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => Some(body),
Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) => Some(body),
Stmt::ClassDef(ast::StmtClassDef { body, .. }) => Some(body),
Stmt::For(ast::StmtFor { body, orelse, .. }) => {
if body.contains(stmt) {
Some(body)
} else if orelse.contains(stmt) {
Some(orelse)
} else {
None
}
}
Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) => {
if body.contains(stmt) {
Some(body)
} else if orelse.contains(stmt) {
Some(orelse)
} else {
None
}
}
Stmt::While(ast::StmtWhile { body, orelse, .. }) => {
if body.contains(stmt) {
Some(body)
} else if orelse.contains(stmt) {
Some(orelse)
} else {
None
}
}
Stmt::If(ast::StmtIf {
body,
elif_else_clauses,
..
}) => {
if body.contains(stmt) {
Some(body)
} else {
elif_else_clauses
.iter()
.map(|elif_else_clause| &elif_else_clause.body)
.find(|body| body.contains(stmt))
}
}
Stmt::With(ast::StmtWith { body, .. }) => Some(body),
Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => Some(body),
Stmt::Match(ast::StmtMatch { cases, .. }) => cases
.iter()
.map(|case| &case.body)
.find(|body| body.contains(stmt)),
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => {
if body.contains(stmt) {
Some(body)
} else if orelse.contains(stmt) {
Some(orelse)
} else if finalbody.contains(stmt) {
Some(finalbody)
} else {
handlers
.iter()
.filter_map(ExceptHandler::as_except_handler)
.map(|handler| &handler.body)
.find(|body| body.contains(stmt))
}
}
Stmt::TryStar(ast::StmtTryStar {
body,
handlers,
orelse,
finalbody,
..
}) => {
if body.contains(stmt) {
Some(body)
} else if orelse.contains(stmt) {
Some(orelse)
} else if finalbody.contains(stmt) {
Some(finalbody)
} else {
handlers
.iter()
.filter_map(ExceptHandler::as_except_handler)
.map(|handler| &handler.body)
.find(|body| body.contains(stmt))
}
}
_ => None,
}
}
/// Given a [`Stmt`] and its containing [`Suite`], return the next [`Stmt`] in the [`Suite`].
pub fn next_sibling<'a>(stmt: &'a Stmt, suite: &'a Suite) -> Option<&'a Stmt> {
let mut iter = suite.iter();
while let Some(sibling) = iter.next() {
if sibling == stmt {
return iter.next();
}
}
None
}

View file

@ -108,10 +108,6 @@ pub struct SemanticModel<'a> {
/// by way of the `global x` statement.
rebinding_scopes: HashMap<BindingId, Vec<ScopeId>, BuildNoHashHasher<BindingId>>,
/// Body iteration; used to peek at siblings.
pub body: &'a [Stmt],
pub body_index: usize,
/// Flags for the semantic model.
pub flags: SemanticModelFlags,
@ -137,8 +133,6 @@ impl<'a> SemanticModel<'a> {
shadowed_bindings: IntMap::default(),
delayed_annotations: IntMap::default(),
rebinding_scopes: IntMap::default(),
body: &[],
body_index: 0,
flags: SemanticModelFlags::new(path),
handled_exceptions: Vec::default(),
}
@ -757,11 +751,6 @@ impl<'a> SemanticModel<'a> {
self.exprs.iter().rev().skip(1)
}
/// Return the `Stmt` that immediately follows the current `Stmt`, if any.
pub fn sibling_stmt(&self) -> Option<&'a Stmt> {
self.body.get(self.body_index + 1)
}
/// Returns a reference to the global scope
pub fn global_scope(&self) -> &Scope<'a> {
self.scopes.global()