[perflint] Catch a wider range of mutations in PERF101 (#9955)

## Summary

This PR ensures that if a list `x` is modified within a `for` loop, we
avoid flagging `list(x)` as unnecessary. Previously, we only detected
calls to exactly `.append`, and they couldn't be nested within other
statements.

Closes https://github.com/astral-sh/ruff/issues/9925.
This commit is contained in:
Charlie Marsh 2024-02-12 12:17:55 -05:00 committed by GitHub
parent e2785f3fb6
commit 0304623878
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 132 additions and 44 deletions

View file

@ -36,35 +36,47 @@ for i in list( # Comment
): # PERF101 ): # PERF101
pass pass
for i in list(foo_dict): # Ok for i in list(foo_dict): # OK
pass pass
for i in list(1): # Ok for i in list(1): # OK
pass pass
for i in list(foo_int): # Ok for i in list(foo_int): # OK
pass pass
import itertools import itertools
for i in itertools.product(foo_int): # Ok for i in itertools.product(foo_int): # OK
pass pass
for i in list(foo_list): # Ok for i in list(foo_list): # OK
foo_list.append(i + 1) foo_list.append(i + 1)
for i in list(foo_list): # PERF101 for i in list(foo_list): # PERF101
# Make sure we match the correct list # Make sure we match the correct list
other_list.append(i + 1) other_list.append(i + 1)
for i in list(foo_tuple): # Ok for i in list(foo_tuple): # OK
foo_tuple.append(i + 1) foo_tuple.append(i + 1)
for i in list(foo_set): # Ok for i in list(foo_set): # OK
foo_set.append(i + 1) foo_set.append(i + 1)
x, y, nested_tuple = (1, 2, (3, 4, 5)) x, y, nested_tuple = (1, 2, (3, 4, 5))
for i in list(nested_tuple): # PERF101 for i in list(nested_tuple): # PERF101
pass pass
for i in list(foo_list): # OK
if True:
foo_list.append(i + 1)
for i in list(foo_list): # OK
if True:
foo_list[i] = i + 1
for i in list(foo_list): # OK
if True:
del foo_list[i + 1]

View file

@ -1,5 +1,6 @@
use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix}; use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use ruff_python_ast::{self as ast, Arguments, Expr, Stmt}; use ruff_python_ast::{self as ast, Arguments, Expr, Stmt};
use ruff_python_semantic::analyze::typing::find_assigned_value; use ruff_python_semantic::analyze::typing::find_assigned_value;
use ruff_text_size::TextRange; use ruff_text_size::TextRange;
@ -98,22 +99,25 @@ pub(crate) fn unnecessary_list_cast(checker: &mut Checker, iter: &Expr, body: &[
range: iterable_range, range: iterable_range,
.. ..
}) => { }) => {
// If the variable is being appended to, don't suggest removing the cast:
//
// ```python
// items = ["foo", "bar"]
// for item in list(items):
// items.append("baz")
// ```
//
// Here, removing the `list()` cast would change the behavior of the code.
if body.iter().any(|stmt| match_append(stmt, id)) {
return;
}
let Some(value) = find_assigned_value(id, checker.semantic()) else { let Some(value) = find_assigned_value(id, checker.semantic()) else {
return; return;
}; };
if matches!(value, Expr::Tuple(_) | Expr::List(_) | Expr::Set(_)) { if matches!(value, Expr::Tuple(_) | Expr::List(_) | Expr::Set(_)) {
// If the variable is being modified to, don't suggest removing the cast:
//
// ```python
// items = ["foo", "bar"]
// for item in list(items):
// items.append("baz")
// ```
//
// Here, removing the `list()` cast would change the behavior of the code.
let mut visitor = MutationVisitor::new(id);
visitor.visit_body(body);
if visitor.is_mutated {
return;
}
let mut diagnostic = Diagnostic::new(UnnecessaryListCast, *list_range); let mut diagnostic = Diagnostic::new(UnnecessaryListCast, *list_range);
diagnostic.set_fix(remove_cast(*list_range, *iterable_range)); diagnostic.set_fix(remove_cast(*list_range, *iterable_range));
checker.diagnostics.push(diagnostic); checker.diagnostics.push(diagnostic);
@ -123,28 +127,6 @@ pub(crate) fn unnecessary_list_cast(checker: &mut Checker, iter: &Expr, body: &[
} }
} }
/// Check if a statement is an `append` call to a given identifier.
///
/// For example, `foo.append(bar)` would return `true` if `id` is `foo`.
fn match_append(stmt: &Stmt, id: &str) -> bool {
let Some(ast::StmtExpr { value, .. }) = stmt.as_expr_stmt() else {
return false;
};
let Some(ast::ExprCall { func, .. }) = value.as_call_expr() else {
return false;
};
let Some(ast::ExprAttribute { value, attr, .. }) = func.as_attribute_expr() else {
return false;
};
if attr != "append" {
return false;
}
let Some(ast::ExprName { id: target_id, .. }) = value.as_name_expr() else {
return false;
};
target_id == id
}
/// Generate a [`Fix`] to remove a `list` cast from an expression. /// Generate a [`Fix`] to remove a `list` cast from an expression.
fn remove_cast(list_range: TextRange, iterable_range: TextRange) -> Fix { fn remove_cast(list_range: TextRange, iterable_range: TextRange) -> Fix {
Fix::safe_edits( Fix::safe_edits(
@ -152,3 +134,95 @@ fn remove_cast(list_range: TextRange, iterable_range: TextRange) -> Fix {
[Edit::deletion(iterable_range.end(), list_range.end())], [Edit::deletion(iterable_range.end(), list_range.end())],
) )
} }
/// A [`StatementVisitor`] that (conservatively) identifies mutations to a variable.
#[derive(Default)]
pub(crate) struct MutationVisitor<'a> {
pub(crate) target: &'a str,
pub(crate) is_mutated: bool,
}
impl<'a> MutationVisitor<'a> {
pub(crate) fn new(target: &'a str) -> Self {
Self {
target,
is_mutated: false,
}
}
}
impl<'a, 'b> StatementVisitor<'b> for MutationVisitor<'a>
where
'b: 'a,
{
fn visit_stmt(&mut self, stmt: &'b Stmt) {
if match_mutation(stmt, self.target) {
self.is_mutated = true;
} else {
walk_stmt(self, stmt);
}
}
}
/// Check if a statement is (probably) a modification to the list assigned to the given identifier.
///
/// For example, `foo.append(bar)` would return `true` if `id` is `foo`.
fn match_mutation(stmt: &Stmt, id: &str) -> bool {
match stmt {
// Ex) `foo.append(bar)`
Stmt::Expr(ast::StmtExpr { value, .. }) => {
let Some(ast::ExprCall { func, .. }) = value.as_call_expr() else {
return false;
};
let Some(ast::ExprAttribute { value, attr, .. }) = func.as_attribute_expr() else {
return false;
};
if !matches!(
attr.as_str(),
"append" | "insert" | "extend" | "remove" | "pop" | "clear" | "reverse" | "sort"
) {
return false;
}
let Some(ast::ExprName { id: target_id, .. }) = value.as_name_expr() else {
return false;
};
target_id == id
}
// Ex) `foo[0] = bar`
Stmt::Assign(ast::StmtAssign { targets, .. }) => targets.iter().any(|target| {
if let Some(ast::ExprSubscript { value: target, .. }) = target.as_subscript_expr() {
if let Some(ast::ExprName { id: target_id, .. }) = target.as_name_expr() {
return target_id == id;
}
}
false
}),
// Ex) `foo += bar`
Stmt::AugAssign(ast::StmtAugAssign { target, .. }) => {
if let Some(ast::ExprName { id: target_id, .. }) = target.as_name_expr() {
target_id == id
} else {
false
}
}
// Ex) `foo[0]: int = bar`
Stmt::AnnAssign(ast::StmtAnnAssign { target, .. }) => {
if let Some(ast::ExprSubscript { value: target, .. }) = target.as_subscript_expr() {
if let Some(ast::ExprName { id: target_id, .. }) = target.as_name_expr() {
return target_id == id;
}
}
false
}
// Ex) `del foo[0]`
Stmt::Delete(ast::StmtDelete { targets, .. }) => targets.iter().any(|target| {
if let Some(ast::ExprSubscript { value: target, .. }) = target.as_subscript_expr() {
if let Some(ast::ExprName { id: target_id, .. }) = target.as_name_expr() {
return target_id == id;
}
}
false
}),
_ => false,
}
}

View file

@ -178,7 +178,7 @@ PERF101.py:34:10: PERF101 [*] Do not cast an iterable to `list` before iterating
34 |+for i in {1, 2, 3}: # PERF101 34 |+for i in {1, 2, 3}: # PERF101
37 35 | pass 37 35 | pass
38 36 | 38 36 |
39 37 | for i in list(foo_dict): # Ok 39 37 | for i in list(foo_dict): # OK
PERF101.py:57:10: PERF101 [*] Do not cast an iterable to `list` before iterating over it PERF101.py:57:10: PERF101 [*] Do not cast an iterable to `list` before iterating over it
| |
@ -192,7 +192,7 @@ PERF101.py:57:10: PERF101 [*] Do not cast an iterable to `list` before iterating
= help: Remove `list()` cast = help: Remove `list()` cast
Safe fix Safe fix
54 54 | for i in list(foo_list): # Ok 54 54 | for i in list(foo_list): # OK
55 55 | foo_list.append(i + 1) 55 55 | foo_list.append(i + 1)
56 56 | 56 56 |
57 |-for i in list(foo_list): # PERF101 57 |-for i in list(foo_list): # PERF101
@ -218,5 +218,7 @@ PERF101.py:69:10: PERF101 [*] Do not cast an iterable to `list` before iterating
69 |-for i in list(nested_tuple): # PERF101 69 |-for i in list(nested_tuple): # PERF101
69 |+for i in nested_tuple: # PERF101 69 |+for i in nested_tuple: # PERF101
70 70 | pass 70 70 | pass
71 71 |
72 72 | for i in list(foo_list): # OK

View file

@ -935,7 +935,7 @@ where
} }
} }
/// A [`StatementVisitor`] that collects all `return` statements in a function or method. /// A [`Visitor`] that collects all `return` statements in a function or method.
#[derive(Default)] #[derive(Default)]
pub struct ReturnStatementVisitor<'a> { pub struct ReturnStatementVisitor<'a> {
pub returns: Vec<&'a ast::StmtReturn>, pub returns: Vec<&'a ast::StmtReturn>,