Improve comprehension line break beheavior

<!--
Thank you for contributing to Ruff! To help us out with reviewing, please consider the following:

- Does this pull request include a summary of the change? (See below.)
- Does this pull request include a descriptive title?
- Does this pull request include references to any relevant issues?
-->

## Summary

This PR improves the Black compatibility when it comes to breaking comprehensions. 

We want to avoid line breaks before the target and `in` whenever possible. Furthermore, `if X is not None` should be grouped together, similar to other binary like expressions

<!-- What's the purpose of the change? What does it do, and why? -->

## Test Plan

`cargo test`

<!-- How was it tested? -->
This commit is contained in:
Micha Reiser 2023-07-11 16:51:24 +02:00 committed by GitHub
parent 62a24e1028
commit 8b9193ab1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 121 additions and 316 deletions

View file

@ -30,3 +30,16 @@
# above g
g # g
]
[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb + [dddddddddddddddddd, eeeeeeeeeeeeeeeeeee]
for
ccccccccccccccccccccccccccccccccccccccc,
ddddddddddddddddddd, [eeeeeeeeeeeeeeeeeeeeee, fffffffffffffffffffffffff]
in
eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeffffffffffffffffffffffffggggggggggggggggggggghhhhhhhhhhhhhhothermoreeand_even_moreddddddddddddddddddddd
if
fffffffffffffffffffffffffffffffffffffffffff < gggggggggggggggggggggggggggggggggggggggggggggg < hhhhhhhhhhhhhhhhhhhhhhhhhh
if
gggggggggggggggggggggggggggggggggggggggggggg
]

View file

@ -33,6 +33,7 @@ impl FormatNodeRule<ExprCompare> for FormatExprCompare {
let comments = f.context().comments().clone();
let inner = format_with(|f| {
write!(f, [in_parentheses_only_group(&left.format())])?;
assert_eq!(comparators.len(), ops.len());
@ -63,6 +64,9 @@ impl FormatNodeRule<ExprCompare> for FormatExprCompare {
}
Ok(())
});
in_parentheses_only_group(&inner).fmt(f)
}
}

View file

@ -3,13 +3,25 @@ use crate::prelude::*;
use crate::AsFormat;
use crate::{FormatNodeRule, PyFormatter};
use ruff_formatter::{format_args, write, Buffer, FormatResult};
use rustpython_parser::ast::{Comprehension, Ranged};
use rustpython_parser::ast::{Comprehension, Expr, Ranged};
#[derive(Default)]
pub struct FormatComprehension;
impl FormatNodeRule<Comprehension> for FormatComprehension {
fn fmt_fields(&self, item: &Comprehension, f: &mut PyFormatter) -> FormatResult<()> {
struct Spacer<'a>(&'a Expr);
impl Format<PyFormatContext<'_>> for Spacer<'_> {
fn fmt(&self, f: &mut PyFormatter) -> FormatResult<()> {
if f.context().comments().has_leading_comments(self.0) {
soft_line_break_or_space().fmt(f)
} else {
space().fmt(f)
}
}
}
let Comprehension {
range: _,
target,
@ -18,33 +30,40 @@ impl FormatNodeRule<Comprehension> for FormatComprehension {
is_async,
} = item;
let comments = f.context().comments().clone();
if *is_async {
write!(f, [text("async"), space()])?;
}
let comments = f.context().comments().clone();
let dangling_item_comments = comments.dangling_comments(item);
let (before_target_comments, before_in_comments) = dangling_item_comments.split_at(
dangling_item_comments
.partition_point(|comment| comment.slice().end() < target.range().start()),
);
let trailing_in_comments = comments.dangling_comments(iter);
let in_spacer = format_with(|f| {
if before_in_comments.is_empty() {
space().fmt(f)
} else {
soft_line_break_or_space().fmt(f)
}
});
write!(
f,
[
text("for"),
trailing_comments(before_target_comments),
group(&format_args!(
soft_line_break_or_space(),
Spacer(target),
target.format(),
soft_line_break_or_space(),
in_spacer,
leading_comments(before_in_comments),
text("in"),
trailing_comments(trailing_in_comments),
soft_line_break_or_space(),
Spacer(iter),
iter.format(),
)),
]
@ -64,7 +83,7 @@ impl FormatNodeRule<Comprehension> for FormatComprehension {
leading_comments(own_line_if_comments),
text("if"),
trailing_comments(end_of_line_if_comments),
soft_line_break_or_space(),
Spacer(if_case),
if_case.format(),
)));
}

View file

@ -42,7 +42,7 @@ def make_arange(n):
```diff
--- Black
+++ Ruff
@@ -2,29 +2,27 @@
@@ -2,14 +2,11 @@
def f():
@ -59,17 +59,7 @@ def make_arange(n):
async def func():
if test:
out_batched = [
i
- async for i in aitertools._async_map(
- self.async_inc, arange(8), batch_size=3
- )
+ async for
+ i
+ in
+ aitertools._async_map(self.async_inc, arange(8), batch_size=3)
]
@@ -23,8 +20,8 @@
def awaited_generator_value(n):
@ -100,10 +90,9 @@ async def func():
if test:
out_batched = [
i
async for
i
in
aitertools._async_map(self.async_inc, arange(8), batch_size=3)
async for i in aitertools._async_map(
self.async_inc, arange(8), batch_size=3
)
]

View file

@ -224,43 +224,18 @@ instruction()#comment with bad spacing
if (
self._proc is not None
# has the child process finished?
@@ -115,7 +123,12 @@
@@ -115,7 +123,9 @@
arg3=True,
)
lcomp = [
- element for element in collection if element is not None # yup # yup # right
+ element # yup
+ for
+ element
+ in
+ collection # yup
+ for element in collection # yup
+ if element is not None # right
]
lcomp2 = [
# hello
@@ -123,7 +136,9 @@
# yup
for element in collection
# right
- if element is not None
+ if
+ element
+ is not None
]
lcomp3 = [
# This one is actually too long to fit in a single line.
@@ -131,7 +146,9 @@
# yup
for element in collection.select_elements()
# right
- if element is not None
+ if
+ element
+ is not None
]
while True:
if False:
@@ -143,7 +160,10 @@
@@ -143,7 +153,10 @@
# let's return
return Node(
syms.simple_stmt,
@ -272,14 +247,13 @@ instruction()#comment with bad spacing
)
@@ -158,7 +178,11 @@
@@ -158,7 +171,10 @@
class Test:
def _init_host(self, parsed) -> None:
- if parsed.hostname is None or not parsed.hostname.strip(): # type: ignore
+ if (
+ parsed.hostname
+ is None # type: ignore
+ parsed.hostname is None # type: ignore
+ or not parsed.hostname.strip()
+ ):
pass
@ -416,10 +390,7 @@ short
)
lcomp = [
element # yup
for
element
in
collection # yup
for element in collection # yup
if element is not None # right
]
lcomp2 = [
@ -428,9 +399,7 @@ short
# yup
for element in collection
# right
if
element
is not None
if element is not None
]
lcomp3 = [
# This one is actually too long to fit in a single line.
@ -438,9 +407,7 @@ short
# yup
for element in collection.select_elements()
# right
if
element
is not None
if element is not None
]
while True:
if False:
@ -471,8 +438,7 @@ CONFIG_FILES = (
class Test:
def _init_host(self, parsed) -> None:
if (
parsed.hostname
is None # type: ignore
parsed.hostname is None # type: ignore
or not parsed.hostname.strip()
):
pass

View file

@ -1,184 +0,0 @@
---
source: crates/ruff_python_formatter/tests/fixtures.rs
input_file: crates/ruff_python_formatter/resources/test/fixtures/black/simple_cases/comments3.py
---
## Input
```py
# The percent-percent comments are Spyder IDE cells.
# %%
def func():
x = """
a really long string
"""
lcomp3 = [
# This one is actually too long to fit in a single line.
element.split("\n", 1)[0]
# yup
for element in collection.select_elements()
# right
if element is not None
]
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
embedded = []
for exc in exc_value.exceptions:
if exc not in _seen:
embedded.append(
# This should be left alone (before)
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# This should be left alone (after)
)
# everything is fine if the expression isn't nested
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# %%
```
## Black Differences
```diff
--- Black
+++ Ruff
@@ -12,7 +12,9 @@
# yup
for element in collection.select_elements()
# right
- if element is not None
+ if
+ element
+ is not None
]
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
```
## Ruff Output
```py
# The percent-percent comments are Spyder IDE cells.
# %%
def func():
x = """
a really long string
"""
lcomp3 = [
# This one is actually too long to fit in a single line.
element.split("\n", 1)[0]
# yup
for element in collection.select_elements()
# right
if
element
is not None
]
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
embedded = []
for exc in exc_value.exceptions:
if exc not in _seen:
embedded.append(
# This should be left alone (before)
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# This should be left alone (after)
)
# everything is fine if the expression isn't nested
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# %%
```
## Black Output
```py
# The percent-percent comments are Spyder IDE cells.
# %%
def func():
x = """
a really long string
"""
lcomp3 = [
# This one is actually too long to fit in a single line.
element.split("\n", 1)[0]
# yup
for element in collection.select_elements()
# right
if element is not None
]
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
embedded = []
for exc in exc_value.exceptions:
if exc not in _seen:
embedded.append(
# This should be left alone (before)
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# This should be left alone (after)
)
# everything is fine if the expression isn't nested
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=set(_seen),
)
# %%
```

View file

@ -21,13 +21,12 @@ else:
```diff
--- Black
+++ Ruff
@@ -1,7 +1,8 @@
@@ -1,7 +1,7 @@
a, b, c = 3, 4, 5
if (
a == 3
- and b != 9 # fmt: skip
+ and b
+ != 9 # fmt: skip
+ and b != 9 # fmt: skip
and c is not None
):
print("I'm good!")
@ -39,8 +38,7 @@ else:
a, b, c = 3, 4, 5
if (
a == 3
and b
!= 9 # fmt: skip
and b != 9 # fmt: skip
and c is not None
):
print("I'm good!")

View file

@ -133,7 +133,7 @@ def __await__(): return (yield)
def spaces_types(
@@ -64,19 +63,17 @@
@@ -64,19 +63,15 @@
def spaces2(result=_core.Value(None)):
@ -153,15 +153,13 @@ def __await__(): return (yield)
- .all()
- )
+ result = session.query(models.Customer.id).filter(
+ models.Customer.account_id
+ == account_id,
+ models.Customer.email
+ == email_address,
+ models.Customer.account_id == account_id,
+ models.Customer.email == email_address,
+ ).order_by(models.Customer.id.asc()).all()
def long_lines():
@@ -135,14 +132,8 @@
@@ -135,14 +130,8 @@
a,
**kwargs,
) -> A:
@ -254,10 +252,8 @@ def spaces2(result=_core.Value(None)):
def example(session):
result = session.query(models.Customer.id).filter(
models.Customer.account_id
== account_id,
models.Customer.email
== email_address,
models.Customer.account_id == account_id,
models.Customer.email == email_address,
).order_by(models.Customer.id.asc()).all()

View file

@ -111,20 +111,6 @@ return np.divide(
q = [10.5**i for i in range(6)]
@@ -55,9 +55,11 @@
view.variance, # type: ignore[union-attr]
view.sum_of_weights, # type: ignore[union-attr]
out=np.full(view.sum_of_weights.shape, np.nan), # type: ignore[union-attr]
- where=view.sum_of_weights**2 > view.sum_of_weights_squared, # type: ignore[union-attr]
+ where=view.sum_of_weights**2
+ > view.sum_of_weights_squared, # type: ignore[union-attr]
)
return np.divide(
- where=view.sum_of_weights_of_weight_long**2 > view.sum_of_weights_squared, # type: ignore
+ where=view.sum_of_weights_of_weight_long**2
+ > view.sum_of_weights_squared, # type: ignore
)
```
## Ruff Output
@ -187,13 +173,11 @@ if hasattr(view, "sum_of_weights"):
view.variance, # type: ignore[union-attr]
view.sum_of_weights, # type: ignore[union-attr]
out=np.full(view.sum_of_weights.shape, np.nan), # type: ignore[union-attr]
where=view.sum_of_weights**2
> view.sum_of_weights_squared, # type: ignore[union-attr]
where=view.sum_of_weights**2 > view.sum_of_weights_squared, # type: ignore[union-attr]
)
return np.divide(
where=view.sum_of_weights_of_weight_long**2
> view.sum_of_weights_squared, # type: ignore
where=view.sum_of_weights_of_weight_long**2 > view.sum_of_weights_squared, # type: ignore
)
```

View file

@ -154,10 +154,8 @@ f(
# TODO(konstin): Call chains/fluent interface (https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#call-chains)
result = session.query(models.Customer.id).filter(
models.Customer.account_id
== 10000,
models.Customer.email
== "user@example.org",
models.Customer.account_id == 10000,
models.Customer.email == "user@example.org",
).order_by(models.Customer.id.asc()).all()
# TODO(konstin): Black has this special case for comment placement where everything stays in one line
f("aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa", "aaaaaaaa")

View file

@ -36,6 +36,19 @@ input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/expression
# above g
g # g
]
[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb + [dddddddddddddddddd, eeeeeeeeeeeeeeeeeee]
for
ccccccccccccccccccccccccccccccccccccccc,
ddddddddddddddddddd, [eeeeeeeeeeeeeeeeeeeeee, fffffffffffffffffffffffff]
in
eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeffffffffffffffffffffffffggggggggggggggggggggghhhhhhhhhhhhhhothermoreeand_even_moreddddddddddddddddddddd
if
fffffffffffffffffffffffffffffffffffffffffff < gggggggggggggggggggggggggggggggggggggggggggggg < hhhhhhhhhhhhhhhhhhhhhhhhhh
if
gggggggggggggggggggggggggggggggggggggggggggg
]
```
## Output
@ -44,20 +57,14 @@ input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/expression
[
i
for
i
in
[
for i in [
1,
]
]
[
a # a
for # for
c # c
in # in
e # e
for c in e # for # c # in # e
]
[
@ -80,6 +87,21 @@ input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/expression
# above g
g # g
]
[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+ bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb
+ [dddddddddddddddddd, eeeeeeeeeeeeeeeeeee]
for (
ccccccccccccccccccccccccccccccccccccccc,
ddddddddddddddddddd,
[eeeeeeeeeeeeeeeeeeeeee, fffffffffffffffffffffffff],
) in eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeffffffffffffffffffffffffggggggggggggggggggggghhhhhhhhhhhhhhothermoreeand_even_moreddddddddddddddddddddd
if fffffffffffffffffffffffffffffffffffffffffff
< gggggggggggggggggggggggggggggggggggggggggggggg
< hhhhhhhhhhhhhhhhhhhhhhhhhh
if gggggggggggggggggggggggggggggggggggggggggggg
]
```