Only omit optinal parens if the expression ends or starts with a parenthesized expression

<!--
Thank you for contributing to Ruff! To help us out with reviewing, please consider the following:

- Does this pull request include a summary of the change? (See below.)
- Does this pull request include a descriptive title?
- Does this pull request include references to any relevant issues?
-->

## Summary

This PR matches Black' behavior where it only omits the optional parentheses if the expression starts or ends with a parenthesized expression:

```python
a + [aaa, bbb, cccc] * c # Don't omit
[aaa, bbb, cccc] + a * c # Split
a + c * [aaa, bbb, ccc] # Split 
```

<!-- What's the purpose of the change? What does it do, and why? -->

## Test Plan

This improves the Jaccard index from 0.945 to 0.946
This commit is contained in:
Micha Reiser 2023-07-11 17:05:25 +02:00 committed by GitHub
parent 8b9193ab1f
commit 30bec3fcfa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 159 deletions

View file

@ -198,37 +198,30 @@ impl<'ast> IntoFormat<PyFormatContext<'ast>> for Expr {
/// ///
/// This mimics Black's [`_maybe_split_omitting_optional_parens`](https://github.com/psf/black/blob/d1248ca9beaf0ba526d265f4108836d89cf551b7/src/black/linegen.py#L746-L820) /// This mimics Black's [`_maybe_split_omitting_optional_parens`](https://github.com/psf/black/blob/d1248ca9beaf0ba526d265f4108836d89cf551b7/src/black/linegen.py#L746-L820)
fn can_omit_optional_parentheses(expr: &Expr, context: &PyFormatContext) -> bool { fn can_omit_optional_parentheses(expr: &Expr, context: &PyFormatContext) -> bool {
let mut visitor = MaxOperatorPriorityVisitor::new(context.source()); let mut visitor = CanOmitOptionalParenthesesVisitor::new(context.source());
visitor.visit_subexpression(expr); visitor.visit_subexpression(expr);
visitor.can_omit()
let (max_operator_priority, operation_count, any_parenthesized_expression) = visitor.finish();
if operation_count > 1 {
false
} else if max_operator_priority == OperatorPriority::Attribute {
true
} else {
// Only use the more complex IR when there is any expression that we can possibly split by
any_parenthesized_expression
}
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct MaxOperatorPriorityVisitor<'input> { struct CanOmitOptionalParenthesesVisitor<'input> {
max_priority: OperatorPriority, max_priority: OperatorPriority,
max_priority_count: u32, max_priority_count: u32,
any_parenthesized_expressions: bool, any_parenthesized_expressions: bool,
last: Option<&'input Expr>,
first: Option<&'input Expr>,
source: &'input str, source: &'input str,
} }
impl<'input> MaxOperatorPriorityVisitor<'input> { impl<'input> CanOmitOptionalParenthesesVisitor<'input> {
fn new(source: &'input str) -> Self { fn new(source: &'input str) -> Self {
Self { Self {
source, source,
max_priority: OperatorPriority::None, max_priority: OperatorPriority::None,
max_priority_count: 0, max_priority_count: 0,
any_parenthesized_expressions: false, any_parenthesized_expressions: false,
last: None,
first: None,
} }
} }
@ -305,6 +298,7 @@ impl<'input> MaxOperatorPriorityVisitor<'input> {
self.any_parenthesized_expressions = true; self.any_parenthesized_expressions = true;
// Only walk the function, the arguments are always parenthesized // Only walk the function, the arguments are always parenthesized
self.visit_expr(func); self.visit_expr(func);
self.last = Some(expr);
return; return;
} }
Expr::Subscript(_) => { Expr::Subscript(_) => {
@ -351,23 +345,41 @@ impl<'input> MaxOperatorPriorityVisitor<'input> {
walk_expr(self, expr); walk_expr(self, expr);
} }
fn finish(self) -> (OperatorPriority, u32, bool) { fn can_omit(self) -> bool {
( if self.max_priority_count > 1 {
self.max_priority, false
self.max_priority_count, } else if self.max_priority == OperatorPriority::Attribute {
self.any_parenthesized_expressions, true
) } else if !self.any_parenthesized_expressions {
// Only use the more complex IR when there is any expression that we can possibly split by
false
} else {
// Only use the layout if the first or last expression has parentheses of some sort.
let first_parenthesized = self
.first
.map_or(false, |first| has_parentheses(first, self.source));
let last_parenthesized = self
.last
.map_or(false, |last| has_parentheses(last, self.source));
first_parenthesized || last_parenthesized
}
} }
} }
impl<'input> PreorderVisitor<'input> for MaxOperatorPriorityVisitor<'input> { impl<'input> PreorderVisitor<'input> for CanOmitOptionalParenthesesVisitor<'input> {
fn visit_expr(&mut self, expr: &'input Expr) { fn visit_expr(&mut self, expr: &'input Expr) {
self.last = Some(expr);
// Rule only applies for non-parenthesized expressions. // Rule only applies for non-parenthesized expressions.
if is_expression_parenthesized(AnyNodeRef::from(expr), self.source) { if is_expression_parenthesized(AnyNodeRef::from(expr), self.source) {
self.any_parenthesized_expressions = true; self.any_parenthesized_expressions = true;
} else { } else {
self.visit_subexpression(expr); self.visit_subexpression(expr);
} }
if self.first.is_none() {
self.first = Some(expr);
}
} }
} }

View file

@ -280,8 +280,15 @@ if True:
#[test] #[test]
fn quick_test() { fn quick_test() {
let src = r#" let src = r#"
def foo() -> tuple[int, int, int,]: if a * [
return 2 bbbbbbbbbbbbbbbbbbbbbb,
cccccccccccccccccccccccccccccdddddddddddddddddddddddddd,
] + a * e * [
ffff,
gggg,
hhhhhhhhhhhhhh,
] * c:
pass
"#; "#;
// Tokenize once // Tokenize once

View file

@ -1,135 +0,0 @@
---
source: crates/ruff_python_formatter/tests/fixtures.rs
input_file: crates/ruff_python_formatter/resources/test/fixtures/black/simple_cases/trailing_comma_optional_parens1.py
---
## Input
```py
if e1234123412341234.winerror not in (_winapi.ERROR_SEM_TIMEOUT,
_winapi.ERROR_PIPE_BUSY) or _check_timeout(t):
pass
if x:
if y:
new_id = max(Vegetable.objects.order_by('-id')[0].id,
Mineral.objects.order_by('-id')[0].id) + 1
class X:
def get_help_text(self):
return ngettext(
"Your password must contain at least %(min_length)d character.",
"Your password must contain at least %(min_length)d characters.",
self.min_length,
) % {'min_length': self.min_length}
class A:
def b(self):
if self.connection.mysql_is_mariadb and (
10,
4,
3,
) < self.connection.mysql_version < (10, 5, 2):
pass
```
## Black Differences
```diff
--- Black
+++ Ruff
@@ -6,13 +6,10 @@
if x:
if y:
- new_id = (
- max(
- Vegetable.objects.order_by("-id")[0].id,
- Mineral.objects.order_by("-id")[0].id,
- )
- + 1
- )
+ new_id = max(
+ Vegetable.objects.order_by("-id")[0].id,
+ Mineral.objects.order_by("-id")[0].id,
+ ) + 1
class X:
```
## Ruff Output
```py
if e1234123412341234.winerror not in (
_winapi.ERROR_SEM_TIMEOUT,
_winapi.ERROR_PIPE_BUSY,
) or _check_timeout(t):
pass
if x:
if y:
new_id = max(
Vegetable.objects.order_by("-id")[0].id,
Mineral.objects.order_by("-id")[0].id,
) + 1
class X:
def get_help_text(self):
return ngettext(
"Your password must contain at least %(min_length)d character.",
"Your password must contain at least %(min_length)d characters.",
self.min_length,
) % {"min_length": self.min_length}
class A:
def b(self):
if self.connection.mysql_is_mariadb and (
10,
4,
3,
) < self.connection.mysql_version < (10, 5, 2):
pass
```
## Black Output
```py
if e1234123412341234.winerror not in (
_winapi.ERROR_SEM_TIMEOUT,
_winapi.ERROR_PIPE_BUSY,
) or _check_timeout(t):
pass
if x:
if y:
new_id = (
max(
Vegetable.objects.order_by("-id")[0].id,
Mineral.objects.order_by("-id")[0].id,
)
+ 1
)
class X:
def get_help_text(self):
return ngettext(
"Your password must contain at least %(min_length)d character.",
"Your password must contain at least %(min_length)d characters.",
self.min_length,
) % {"min_length": self.min_length}
class A:
def b(self):
if self.connection.mysql_is_mariadb and (
10,
4,
3,
) < self.connection.mysql_version < (10, 5, 2):
pass
```