Insert parentheses for multi-argument generators (#12422)

## Summary

Closes https://github.com/astral-sh/ruff/issues/12420.
This commit is contained in:
Charlie Marsh 2024-07-20 12:41:55 -04:00 committed by GitHub
parent 4bcc96ae51
commit 2c1926beeb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 71 additions and 10 deletions

View file

@ -3,8 +3,18 @@ min([x.val for x in bar])
max([x.val for x in bar]) max([x.val for x in bar])
sum([x.val for x in bar], 0) sum([x.val for x in bar], 0)
# Ok # OK
sum(x.val for x in bar) sum(x.val for x in bar)
min(x.val for x in bar) min(x.val for x in bar)
max(x.val for x in bar) max(x.val for x in bar)
sum(x.val for x in bar, 0) sum(x.val for x in bar, 0)
# Multi-line
sum(
[
delta
for delta in timedelta_list
if delta
],
dt.timedelta(),
)

View file

@ -1,10 +1,10 @@
use ruff_python_ast::{self as ast, Expr, Keyword}; use ruff_python_ast::{self as ast, Expr, Keyword};
use ruff_diagnostics::Violation;
use ruff_diagnostics::{Diagnostic, FixAvailability}; use ruff_diagnostics::{Diagnostic, FixAvailability};
use ruff_diagnostics::{Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::any_over_expr; use ruff_python_ast::helpers::any_over_expr;
use ruff_text_size::Ranged; use ruff_text_size::{Ranged, TextSize};
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
@ -112,9 +112,30 @@ pub(crate) fn unnecessary_comprehension_in_call(
} }
let mut diagnostic = Diagnostic::new(UnnecessaryComprehensionInCall, arg.range()); let mut diagnostic = Diagnostic::new(UnnecessaryComprehensionInCall, arg.range());
if args.len() == 1 {
// If there's only one argument, remove the list or set brackets.
diagnostic.try_set_fix(|| { diagnostic.try_set_fix(|| {
fixes::fix_unnecessary_comprehension_in_call(expr, checker.locator(), checker.stylist()) fixes::fix_unnecessary_comprehension_in_call(expr, checker.locator(), checker.stylist())
}); });
} else {
// If there are multiple arguments, replace the list or set brackets with parentheses.
// If a function call has multiple arguments, one of which is a generator, then the
// generator must be parenthesized.
// Replace `[` with `(`.
let collection_start = Edit::replacement(
"(".to_string(),
arg.start(),
arg.start() + TextSize::from(1),
);
// Replace `]` with `)`.
let collection_end =
Edit::replacement(")".to_string(), arg.end() - TextSize::from(1), arg.end());
diagnostic.set_fix(Fix::unsafe_edits(collection_start, [collection_end]));
}
checker.diagnostics.push(diagnostic); checker.diagnostics.push(diagnostic);
} }

View file

@ -52,7 +52,7 @@ C419_1.py:3:5: C419 [*] Unnecessary list comprehension
3 |+max(x.val for x in bar) 3 |+max(x.val for x in bar)
4 4 | sum([x.val for x in bar], 0) 4 4 | sum([x.val for x in bar], 0)
5 5 | 5 5 |
6 6 | # Ok 6 6 | # OK
C419_1.py:4:5: C419 [*] Unnecessary list comprehension C419_1.py:4:5: C419 [*] Unnecessary list comprehension
| |
@ -61,7 +61,7 @@ C419_1.py:4:5: C419 [*] Unnecessary list comprehension
4 | sum([x.val for x in bar], 0) 4 | sum([x.val for x in bar], 0)
| ^^^^^^^^^^^^^^^^^^^^ C419 | ^^^^^^^^^^^^^^^^^^^^ C419
5 | 5 |
6 | # Ok 6 | # OK
| |
= help: Remove unnecessary list comprehension = help: Remove unnecessary list comprehension
@ -70,7 +70,37 @@ C419_1.py:4:5: C419 [*] Unnecessary list comprehension
2 2 | min([x.val for x in bar]) 2 2 | min([x.val for x in bar])
3 3 | max([x.val for x in bar]) 3 3 | max([x.val for x in bar])
4 |-sum([x.val for x in bar], 0) 4 |-sum([x.val for x in bar], 0)
4 |+sum(x.val for x in bar, 0) 4 |+sum((x.val for x in bar), 0)
5 5 | 5 5 |
6 6 | # Ok 6 6 | # OK
7 7 | sum(x.val for x in bar) 7 7 | sum(x.val for x in bar)
C419_1.py:14:5: C419 [*] Unnecessary list comprehension
|
12 | # Multi-line
13 | sum(
14 | [
| _____^
15 | | delta
16 | | for delta in timedelta_list
17 | | if delta
18 | | ],
| |_____^ C419
19 | dt.timedelta(),
20 | )
|
= help: Remove unnecessary list comprehension
Unsafe fix
11 11 |
12 12 | # Multi-line
13 13 | sum(
14 |- [
14 |+ (
15 15 | delta
16 16 | for delta in timedelta_list
17 17 | if delta
18 |- ],
18 |+ ),
19 19 | dt.timedelta(),
20 20 | )