Improve comprehension line break beheavior

<!--
Thank you for contributing to Ruff! To help us out with reviewing, please consider the following:

- Does this pull request include a summary of the change? (See below.)
- Does this pull request include a descriptive title?
- Does this pull request include references to any relevant issues?
-->

## Summary

This PR improves the Black compatibility when it comes to breaking comprehensions. 

We want to avoid line breaks before the target and `in` whenever possible. Furthermore, `if X is not None` should be grouped together, similar to other binary like expressions

<!-- What's the purpose of the change? What does it do, and why? -->

## Test Plan

`cargo test`

<!-- How was it tested? -->
This commit is contained in:
Micha Reiser 2023-07-11 16:51:24 +02:00 committed by GitHub
parent 62a24e1028
commit 8b9193ab1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 121 additions and 316 deletions

View file

@ -30,3 +30,16 @@
# above g # above g
g # g g # g
] ]
[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb + [dddddddddddddddddd, eeeeeeeeeeeeeeeeeee]
for
ccccccccccccccccccccccccccccccccccccccc,
ddddddddddddddddddd, [eeeeeeeeeeeeeeeeeeeeee, fffffffffffffffffffffffff]
in
eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeffffffffffffffffffffffffggggggggggggggggggggghhhhhhhhhhhhhhothermoreeand_even_moreddddddddddddddddddddd
if
fffffffffffffffffffffffffffffffffffffffffff < gggggggggggggggggggggggggggggggggggggggggggggg < hhhhhhhhhhhhhhhhhhhhhhhhhh
if
gggggggggggggggggggggggggggggggggggggggggggg
]

View file

@ -33,6 +33,7 @@ impl FormatNodeRule<ExprCompare> for FormatExprCompare {
let comments = f.context().comments().clone(); let comments = f.context().comments().clone();
let inner = format_with(|f| {
write!(f, [in_parentheses_only_group(&left.format())])?; write!(f, [in_parentheses_only_group(&left.format())])?;
assert_eq!(comparators.len(), ops.len()); assert_eq!(comparators.len(), ops.len());
@ -63,6 +64,9 @@ impl FormatNodeRule<ExprCompare> for FormatExprCompare {
} }
Ok(()) Ok(())
});
in_parentheses_only_group(&inner).fmt(f)
} }
} }

View file

@ -3,13 +3,25 @@ use crate::prelude::*;
use crate::AsFormat; use crate::AsFormat;
use crate::{FormatNodeRule, PyFormatter}; use crate::{FormatNodeRule, PyFormatter};
use ruff_formatter::{format_args, write, Buffer, FormatResult}; use ruff_formatter::{format_args, write, Buffer, FormatResult};
use rustpython_parser::ast::{Comprehension, Ranged}; use rustpython_parser::ast::{Comprehension, Expr, Ranged};
#[derive(Default)] #[derive(Default)]
pub struct FormatComprehension; pub struct FormatComprehension;
impl FormatNodeRule<Comprehension> for FormatComprehension { impl FormatNodeRule<Comprehension> for FormatComprehension {
fn fmt_fields(&self, item: &Comprehension, f: &mut PyFormatter) -> FormatResult<()> { fn fmt_fields(&self, item: &Comprehension, f: &mut PyFormatter) -> FormatResult<()> {
struct Spacer<'a>(&'a Expr);
impl Format<PyFormatContext<'_>> for Spacer<'_> {
fn fmt(&self, f: &mut PyFormatter) -> FormatResult<()> {
if f.context().comments().has_leading_comments(self.0) {
soft_line_break_or_space().fmt(f)
} else {
space().fmt(f)
}
}
}
let Comprehension { let Comprehension {
range: _, range: _,
target, target,
@ -18,33 +30,40 @@ impl FormatNodeRule<Comprehension> for FormatComprehension {
is_async, is_async,
} = item; } = item;
let comments = f.context().comments().clone();
if *is_async { if *is_async {
write!(f, [text("async"), space()])?; write!(f, [text("async"), space()])?;
} }
let comments = f.context().comments().clone();
let dangling_item_comments = comments.dangling_comments(item); let dangling_item_comments = comments.dangling_comments(item);
let (before_target_comments, before_in_comments) = dangling_item_comments.split_at( let (before_target_comments, before_in_comments) = dangling_item_comments.split_at(
dangling_item_comments dangling_item_comments
.partition_point(|comment| comment.slice().end() < target.range().start()), .partition_point(|comment| comment.slice().end() < target.range().start()),
); );
let trailing_in_comments = comments.dangling_comments(iter); let trailing_in_comments = comments.dangling_comments(iter);
let in_spacer = format_with(|f| {
if before_in_comments.is_empty() {
space().fmt(f)
} else {
soft_line_break_or_space().fmt(f)
}
});
write!( write!(
f, f,
[ [
text("for"), text("for"),
trailing_comments(before_target_comments), trailing_comments(before_target_comments),
group(&format_args!( group(&format_args!(
soft_line_break_or_space(), Spacer(target),
target.format(), target.format(),
soft_line_break_or_space(), in_spacer,
leading_comments(before_in_comments), leading_comments(before_in_comments),
text("in"), text("in"),
trailing_comments(trailing_in_comments), trailing_comments(trailing_in_comments),
soft_line_break_or_space(), Spacer(iter),
iter.format(), iter.format(),
)), )),
] ]
@ -64,7 +83,7 @@ impl FormatNodeRule<Comprehension> for FormatComprehension {
leading_comments(own_line_if_comments), leading_comments(own_line_if_comments),
text("if"), text("if"),
trailing_comments(end_of_line_if_comments), trailing_comments(end_of_line_if_comments),
soft_line_break_or_space(), Spacer(if_case),
if_case.format(), if_case.format(),
))); )));
} }

View file

@ -42,7 +42,7 @@ def make_arange(n):
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -2,29 +2,27 @@ @@ -2,14 +2,11 @@
def f(): def f():
@ -59,17 +59,7 @@ def make_arange(n):
async def func(): async def func():
if test: @@ -23,8 +20,8 @@
out_batched = [
i
- async for i in aitertools._async_map(
- self.async_inc, arange(8), batch_size=3
- )
+ async for
+ i
+ in
+ aitertools._async_map(self.async_inc, arange(8), batch_size=3)
]
def awaited_generator_value(n): def awaited_generator_value(n):
@ -100,10 +90,9 @@ async def func():
if test: if test:
out_batched = [ out_batched = [
i i
async for async for i in aitertools._async_map(
i self.async_inc, arange(8), batch_size=3
in )
aitertools._async_map(self.async_inc, arange(8), batch_size=3)
] ]

View file

@ -224,43 +224,18 @@ instruction()#comment with bad spacing
if ( if (
self._proc is not None self._proc is not None
# has the child process finished? # has the child process finished?
@@ -115,7 +123,12 @@ @@ -115,7 +123,9 @@
arg3=True, arg3=True,
) )
lcomp = [ lcomp = [
- element for element in collection if element is not None # yup # yup # right - element for element in collection if element is not None # yup # yup # right
+ element # yup + element # yup
+ for + for element in collection # yup
+ element
+ in
+ collection # yup
+ if element is not None # right + if element is not None # right
] ]
lcomp2 = [ lcomp2 = [
# hello # hello
@@ -123,7 +136,9 @@ @@ -143,7 +153,10 @@
# yup
for element in collection
# right
- if element is not None
+ if
+ element
+ is not None
]
lcomp3 = [
# This one is actually too long to fit in a single line.
@@ -131,7 +146,9 @@
# yup
for element in collection.select_elements()
# right
- if element is not None
+ if
+ element
+ is not None
]
while True:
if False:
@@ -143,7 +160,10 @@
# let's return # let's return
return Node( return Node(
syms.simple_stmt, syms.simple_stmt,
@ -272,14 +247,13 @@ instruction()#comment with bad spacing
) )
@@ -158,7 +178,11 @@ @@ -158,7 +171,10 @@
class Test: class Test:
def _init_host(self, parsed) -> None: def _init_host(self, parsed) -> None:
- if parsed.hostname is None or not parsed.hostname.strip(): # type: ignore - if parsed.hostname is None or not parsed.hostname.strip(): # type: ignore
+ if ( + if (
+ parsed.hostname + parsed.hostname is None # type: ignore
+ is None # type: ignore
+ or not parsed.hostname.strip() + or not parsed.hostname.strip()
+ ): + ):
pass pass
@ -416,10 +390,7 @@ short
) )
lcomp = [ lcomp = [
element # yup element # yup
for for element in collection # yup
element
in
collection # yup
if element is not None # right if element is not None # right
] ]
lcomp2 = [ lcomp2 = [
@ -428,9 +399,7 @@ short
# yup # yup
for element in collection for element in collection
# right # right
if if element is not None
element
is not None
] ]
lcomp3 = [ lcomp3 = [
# This one is actually too long to fit in a single line. # This one is actually too long to fit in a single line.
@ -438,9 +407,7 @@ short
# yup # yup
for element in collection.select_elements() for element in collection.select_elements()
# right # right
if if element is not None
element
is not None
] ]
while True: while True:
if False: if False:
@ -471,8 +438,7 @@ CONFIG_FILES = (
class Test: class Test:
def _init_host(self, parsed) -> None: def _init_host(self, parsed) -> None:
if ( if (
parsed.hostname parsed.hostname is None # type: ignore
is None # type: ignore
or not parsed.hostname.strip() or not parsed.hostname.strip()
): ):
pass pass

View file

@ -1,184 +0,0 @@
---
source: crates/ruff_python_formatter/tests/fixtures.rs
input_file: crates/ruff_python_formatter/resources/test/fixtures/black/simple_cases/comments3.py
---
## Input
```py
# The percent-percent comments are Spyder IDE cells.
# %%
def func():
x = """
a really long string
"""
lcomp3 = [
# This one is actually too long to fit in a single line.
element.split("\n", 1)[0]
# yup
for element in collection.select_elements()
# right
if element is not None
]
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
embedded = []
for exc in exc_value.exceptions:
if exc not in _seen:
embedded.append(
# This should be left alone (before)
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# This should be left alone (after)
)
# everything is fine if the expression isn't nested
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# %%
```
## Black Differences
```diff
--- Black
+++ Ruff
@@ -12,7 +12,9 @@
# yup
for element in collection.select_elements()
# right
- if element is not None
+ if
+ element
+ is not None
]
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
```
## Ruff Output
```py
# The percent-percent comments are Spyder IDE cells.
# %%
def func():
x = """
a really long string
"""
lcomp3 = [
# This one is actually too long to fit in a single line.
element.split("\n", 1)[0]
# yup
for element in collection.select_elements()
# right
if
element
is not None
]
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
embedded = []
for exc in exc_value.exceptions:
if exc not in _seen:
embedded.append(
# This should be left alone (before)
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# This should be left alone (after)
)
# everything is fine if the expression isn't nested
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# %%
```
## Black Output
```py
# The percent-percent comments are Spyder IDE cells.
# %%
def func():
x = """
a really long string
"""
lcomp3 = [
# This one is actually too long to fit in a single line.
element.split("\n", 1)[0]
# yup
for element in collection.select_elements()
# right
if element is not None
]
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
embedded = []
for exc in exc_value.exceptions:
if exc not in _seen:
embedded.append(
# This should be left alone (before)
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# This should be left alone (after)
)
# everything is fine if the expression isn't nested
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# %%
```

View file

@ -21,13 +21,12 @@ else:
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,7 +1,8 @@ @@ -1,7 +1,7 @@
a, b, c = 3, 4, 5 a, b, c = 3, 4, 5
if ( if (
a == 3 a == 3
- and b != 9 # fmt: skip - and b != 9 # fmt: skip
+ and b + and b != 9 # fmt: skip
+ != 9 # fmt: skip
and c is not None and c is not None
): ):
print("I'm good!") print("I'm good!")
@ -39,8 +38,7 @@ else:
a, b, c = 3, 4, 5 a, b, c = 3, 4, 5
if ( if (
a == 3 a == 3
and b and b != 9 # fmt: skip
!= 9 # fmt: skip
and c is not None and c is not None
): ):
print("I'm good!") print("I'm good!")

View file

@ -133,7 +133,7 @@ def __await__(): return (yield)
def spaces_types( def spaces_types(
@@ -64,19 +63,17 @@ @@ -64,19 +63,15 @@
def spaces2(result=_core.Value(None)): def spaces2(result=_core.Value(None)):
@ -153,15 +153,13 @@ def __await__(): return (yield)
- .all() - .all()
- ) - )
+ result = session.query(models.Customer.id).filter( + result = session.query(models.Customer.id).filter(
+ models.Customer.account_id + models.Customer.account_id == account_id,
+ == account_id, + models.Customer.email == email_address,
+ models.Customer.email
+ == email_address,
+ ).order_by(models.Customer.id.asc()).all() + ).order_by(models.Customer.id.asc()).all()
def long_lines(): def long_lines():
@@ -135,14 +132,8 @@ @@ -135,14 +130,8 @@
a, a,
**kwargs, **kwargs,
) -> A: ) -> A:
@ -254,10 +252,8 @@ def spaces2(result=_core.Value(None)):
def example(session): def example(session):
result = session.query(models.Customer.id).filter( result = session.query(models.Customer.id).filter(
models.Customer.account_id models.Customer.account_id == account_id,
== account_id, models.Customer.email == email_address,
models.Customer.email
== email_address,
).order_by(models.Customer.id.asc()).all() ).order_by(models.Customer.id.asc()).all()

View file

@ -111,20 +111,6 @@ return np.divide(
q = [10.5**i for i in range(6)] q = [10.5**i for i in range(6)]
@@ -55,9 +55,11 @@
view.variance, # type: ignore[union-attr]
view.sum_of_weights, # type: ignore[union-attr]
out=np.full(view.sum_of_weights.shape, np.nan), # type: ignore[union-attr]
- where=view.sum_of_weights**2 > view.sum_of_weights_squared, # type: ignore[union-attr]
+ where=view.sum_of_weights**2
+ > view.sum_of_weights_squared, # type: ignore[union-attr]
)
return np.divide(
- where=view.sum_of_weights_of_weight_long**2 > view.sum_of_weights_squared, # type: ignore
+ where=view.sum_of_weights_of_weight_long**2
+ > view.sum_of_weights_squared, # type: ignore
)
``` ```
## Ruff Output ## Ruff Output
@ -187,13 +173,11 @@ if hasattr(view, "sum_of_weights"):
view.variance, # type: ignore[union-attr] view.variance, # type: ignore[union-attr]
view.sum_of_weights, # type: ignore[union-attr] view.sum_of_weights, # type: ignore[union-attr]
out=np.full(view.sum_of_weights.shape, np.nan), # type: ignore[union-attr] out=np.full(view.sum_of_weights.shape, np.nan), # type: ignore[union-attr]
where=view.sum_of_weights**2 where=view.sum_of_weights**2 > view.sum_of_weights_squared, # type: ignore[union-attr]
> view.sum_of_weights_squared, # type: ignore[union-attr]
) )
return np.divide( return np.divide(
where=view.sum_of_weights_of_weight_long**2 where=view.sum_of_weights_of_weight_long**2 > view.sum_of_weights_squared, # type: ignore
> view.sum_of_weights_squared, # type: ignore
) )
``` ```

View file

@ -154,10 +154,8 @@ f(
# TODO(konstin): Call chains/fluent interface (https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#call-chains) # TODO(konstin): Call chains/fluent interface (https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#call-chains)
result = session.query(models.Customer.id).filter( result = session.query(models.Customer.id).filter(
models.Customer.account_id models.Customer.account_id == 10000,
== 10000, models.Customer.email == "user@example.org",
models.Customer.email
== "user@example.org",
).order_by(models.Customer.id.asc()).all() ).order_by(models.Customer.id.asc()).all()
# TODO(konstin): Black has this special case for comment placement where everything stays in one line # TODO(konstin): Black has this special case for comment placement where everything stays in one line
f("aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa") f("aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa")

View file

@ -36,6 +36,19 @@ input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/expression
# above g # above g
g # g g # g
] ]
[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb + [dddddddddddddddddd, eeeeeeeeeeeeeeeeeee]
for
ccccccccccccccccccccccccccccccccccccccc,
ddddddddddddddddddd, [eeeeeeeeeeeeeeeeeeeeee, fffffffffffffffffffffffff]
in
eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeffffffffffffffffffffffffggggggggggggggggggggghhhhhhhhhhhhhhothermoreeand_even_moreddddddddddddddddddddd
if
fffffffffffffffffffffffffffffffffffffffffff < gggggggggggggggggggggggggggggggggggggggggggggg < hhhhhhhhhhhhhhhhhhhhhhhhhh
if
gggggggggggggggggggggggggggggggggggggggggggg
]
``` ```
## Output ## Output
@ -44,20 +57,14 @@ input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/expression
[ [
i i
for for i in [
i
in
[
1, 1,
] ]
] ]
[ [
a # a a # a
for # for for c in e # for # c # in # e
c # c
in # in
e # e
] ]
[ [
@ -80,6 +87,21 @@ input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/expression
# above g # above g
g # g g # g
] ]
[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+ bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb
+ [dddddddddddddddddd, eeeeeeeeeeeeeeeeeee]
for (
ccccccccccccccccccccccccccccccccccccccc,
ddddddddddddddddddd,
[eeeeeeeeeeeeeeeeeeeeee, fffffffffffffffffffffffff],
) in eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeffffffffffffffffffffffffggggggggggggggggggggghhhhhhhhhhhhhhothermoreeand_even_moreddddddddddddddddddddd
if fffffffffffffffffffffffffffffffffffffffffff
< gggggggggggggggggggggggggggggggggggggggggggggg
< hhhhhhhhhhhhhhhhhhhhhhhhhh
if gggggggggggggggggggggggggggggggggggggggggggg
]
``` ```