Correctly handle newlines after/before comments (#4895)

<!--
Thank you for contributing to Ruff! To help us out with reviewing, please consider the following:

- Does this pull request include a summary of the change? (See below.)
- Does this pull request include a descriptive title?
- Does this pull request include references to any relevant issues?
-->

## Summary

This issue fixes the removal of empty lines between a leading comment and the previous statement:

```python
a  = 20

# leading comment
b = 10
```

Ruff removed the empty line between `a` and `b` because:
* The leading comments formatting does not preserve leading newlines (to avoid adding new lines at the top of a body)
* The `JoinNodesBuilder` counted the lines before `b`, which is 1 -> Doesn't insert a new line

This is fixed by changing the `JoinNodesBuilder` to count the lines instead *after* the last node. This correctly gives 1, and the `# leading comment` will insert the empty lines between any other leading comment or the node.



## Test Plan

I added a new test for empty lines.
This commit is contained in:
Micha Reiser 2023-06-07 14:49:43 +02:00 committed by GitHub
parent 222ca98a41
commit 6ab3fc60f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 332 additions and 124 deletions

View file

@ -0,0 +1,32 @@
# Removes the line above
a = 10 # Keeps the line above
# Separated by one line from `a` and `b`
b = 20
# Adds two lines after `b`
class Test:
def a(self):
pass
# trailing comment
# two lines before, one line after
c = 30
while a == 10:
...
# trailing comment with one line before
# one line before this leading comment
d = 40
while b == 20:
...
# no empty line before
e = 50 # one empty line before

View file

@ -1,7 +1,8 @@
use crate::context::NodeLevel; use crate::context::NodeLevel;
use crate::prelude::*; use crate::prelude::*;
use crate::trivia::lines_before; use crate::trivia::{lines_after, skip_trailing_trivia};
use ruff_formatter::write; use ruff_formatter::write;
use ruff_text_size::TextSize;
use rustpython_parser::ast::Ranged; use rustpython_parser::ast::Ranged;
/// Provides Python specific extensions to [`Formatter`]. /// Provides Python specific extensions to [`Formatter`].
@ -26,7 +27,7 @@ impl<'buf, 'ast> PyFormatterExtensions<'ast, 'buf> for PyFormatter<'ast, 'buf> {
pub(crate) struct JoinNodesBuilder<'fmt, 'ast, 'buf> { pub(crate) struct JoinNodesBuilder<'fmt, 'ast, 'buf> {
fmt: &'fmt mut PyFormatter<'ast, 'buf>, fmt: &'fmt mut PyFormatter<'ast, 'buf>,
result: FormatResult<()>, result: FormatResult<()>,
has_elements: bool, last_end: Option<TextSize>,
node_level: NodeLevel, node_level: NodeLevel,
} }
@ -35,7 +36,7 @@ impl<'fmt, 'ast, 'buf> JoinNodesBuilder<'fmt, 'ast, 'buf> {
Self { Self {
fmt, fmt,
result: Ok(()), result: Ok(()),
has_elements: false, last_end: None,
node_level: level, node_level: level,
} }
} }
@ -47,22 +48,43 @@ impl<'fmt, 'ast, 'buf> JoinNodesBuilder<'fmt, 'ast, 'buf> {
T: Ranged, T: Ranged,
{ {
let node_level = self.node_level; let node_level = self.node_level;
let separator = format_with(|f: &mut PyFormatter| match node_level {
NodeLevel::TopLevel => match lines_before(node.start(), f.context().contents()) {
0 | 1 => hard_line_break().fmt(f),
2 => empty_line().fmt(f),
_ => write!(f, [empty_line(), empty_line()]),
},
NodeLevel::CompoundStatement => {
match lines_before(node.start(), f.context().contents()) {
0 | 1 => hard_line_break().fmt(f),
_ => empty_line().fmt(f),
}
}
NodeLevel::Expression => hard_line_break().fmt(f),
});
self.entry_with_separator(&separator, content); self.result = self.result.and_then(|_| {
if let Some(last_end) = self.last_end.replace(node.end()) {
let source = self.fmt.context().contents();
let count_lines = |offset| {
// It's necessary to skip any trailing line comment because RustPython doesn't include trailing comments
// in the node's range
// ```python
// a # The range of `a` ends right before this comment
//
// b
// ```
//
// Simply using `lines_after` doesn't work if a statement has a trailing comment because
// it then counts the lines between the statement and the trailing comment, which is
// always 0. This is why it skips any trailing trivia (trivia that's on the same line)
// and counts the lines after.
let after_trailing_trivia = skip_trailing_trivia(offset, source);
lines_after(after_trailing_trivia, source)
};
match node_level {
NodeLevel::TopLevel => match count_lines(last_end) {
0 | 1 => hard_line_break().fmt(self.fmt),
2 => empty_line().fmt(self.fmt),
_ => write!(self.fmt, [empty_line(), empty_line()]),
},
NodeLevel::CompoundStatement => match count_lines(last_end) {
0 | 1 => hard_line_break().fmt(self.fmt),
_ => empty_line().fmt(self.fmt),
},
NodeLevel::Expression => hard_line_break().fmt(self.fmt),
}?;
}
content.fmt(self.fmt)
});
} }
/// Writes a sequence of node with their content tuples, inserting the appropriate number of line breaks between any two of them /// Writes a sequence of node with their content tuples, inserting the appropriate number of line breaks between any two of them
@ -98,17 +120,20 @@ impl<'fmt, 'ast, 'buf> JoinNodesBuilder<'fmt, 'ast, 'buf> {
} }
/// Writes a single entry using the specified separator to separate the entry from a previous entry. /// Writes a single entry using the specified separator to separate the entry from a previous entry.
pub(crate) fn entry_with_separator( pub(crate) fn entry_with_separator<T>(
&mut self, &mut self,
separator: &dyn Format<PyFormatContext<'ast>>, separator: &dyn Format<PyFormatContext<'ast>>,
content: &dyn Format<PyFormatContext<'ast>>, content: &dyn Format<PyFormatContext<'ast>>,
) { node: &T,
) where
T: Ranged,
{
self.result = self.result.and_then(|_| { self.result = self.result.and_then(|_| {
if self.has_elements { if self.last_end.is_some() {
separator.fmt(self.fmt)?; separator.fmt(self.fmt)?;
} }
self.has_elements = true; self.last_end = Some(node.end());
content.fmt(self.fmt) content.fmt(self.fmt)
}); });

View file

@ -355,11 +355,12 @@ Formatted twice:
#[ignore] #[ignore]
#[test] #[test]
fn quick_test() { fn quick_test() {
let src = r#" let src = r#"AAAAAAAAAAAAA = AAAAAAAAAAAAA # type: ignore
while True:
if something.changed: call_to_some_function_asdf(
do.stuff() # trailing comment foo,
other [AAAAAAAAAAAAAAAAAAAAAAA, AAAAAAAAAAAAAAAAAAAAAAA, AAAAAAAAAAAAAAAAAAAAAAA, BBBBBBBBBBBB], # type: ignore
)
"#; "#;
// Tokenize once // Tokenize once
let mut tokens = Vec::new(); let mut tokens = Vec::new();

View file

@ -84,24 +84,7 @@ if True:
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,75 +1,49 @@ @@ -18,44 +18,26 @@
import core, time, a
from . import A, B, C
-
# keeps existing trailing comma
from foo import (
bar,
)
-
# also keeps existing structure
from foo import (
baz,
qux,
)
-
# `as` works as well
from foo import (
xyzzy as magic, xyzzy as magic,
) )
@ -154,12 +137,12 @@ if True:
- "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa wraps %s" - "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa wraps %s"
- % bar - % bar
-) -)
-
+y = {"oneple": (1,),} +y = {"oneple": (1,),}
+assert False, ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa wraps %s" % bar) +assert False, ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa wraps %s" % bar)
# looping over a 1-tuple should also not get wrapped # looping over a 1-tuple should also not get wrapped
for x in (1,): for x in (1,):
pass @@ -63,13 +45,9 @@
for (x,) in (1,), (2,), (3,): for (x,) in (1,), (2,), (3,):
pass pass
@ -175,7 +158,7 @@ if True:
print("foo %r", (foo.bar,)) print("foo %r", (foo.bar,))
if True: if True:
@@ -79,21 +53,15 @@ @@ -79,21 +57,15 @@
) )
if True: if True:
@ -210,15 +193,18 @@ if True:
import core, time, a import core, time, a
from . import A, B, C from . import A, B, C
# keeps existing trailing comma # keeps existing trailing comma
from foo import ( from foo import (
bar, bar,
) )
# also keeps existing structure # also keeps existing structure
from foo import ( from foo import (
baz, baz,
qux, qux,
) )
# `as` works as well # `as` works as well
from foo import ( from foo import (
xyzzy as magic, xyzzy as magic,
@ -244,6 +230,7 @@ nested_long_lines = ["aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "bbbbbbbbbbbbbb
x = {"oneple": (1,)} x = {"oneple": (1,)}
y = {"oneple": (1,),} y = {"oneple": (1,),}
assert False, ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa wraps %s" % bar) assert False, ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa wraps %s" % bar)
# looping over a 1-tuple should also not get wrapped # looping over a 1-tuple should also not get wrapped
for x in (1,): for x in (1,):
pass pass

View file

@ -30,11 +30,8 @@ x = [
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -10,6 +10,9 @@ @@ -12,4 +12,6 @@
1, 2,
3, 4,
] ]
+
# fmt: on # fmt: on
-x = [1, 2, 3, 4] -x = [1, 2, 3, 4]
@ -58,7 +55,6 @@ x = [
1, 2, 1, 2,
3, 4, 3, 4,
] ]
# fmt: on # fmt: on
x = [ x = [

View file

@ -97,16 +97,7 @@ elif unformatted:
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -9,8 +9,6 @@ @@ -44,7 +44,7 @@
] # Includes an formatted indentation.
},
)
-
-
# Regression test for https://github.com/psf/black/issues/2015.
run(
# fmt: off
@@ -44,7 +42,7 @@
print ( "This won't be formatted" ) print ( "This won't be formatted" )
print ( "This won't be formatted either" ) print ( "This won't be formatted either" )
else: else:
@ -115,7 +106,7 @@ elif unformatted:
# Regression test for https://github.com/psf/black/issues/3184. # Regression test for https://github.com/psf/black/issues/3184.
@@ -61,7 +59,7 @@ @@ -61,7 +61,7 @@
elif param[0:4] in ("ZZZZ",): elif param[0:4] in ("ZZZZ",):
print ( "This won't be formatted either" ) print ( "This won't be formatted either" )
@ -124,7 +115,7 @@ elif unformatted:
# Regression test for https://github.com/psf/black/issues/2985. # Regression test for https://github.com/psf/black/issues/2985.
@@ -72,10 +70,7 @@ @@ -72,10 +72,7 @@
class Factory(t.Protocol): class Factory(t.Protocol):
@ -136,7 +127,7 @@ elif unformatted:
# Regression test for https://github.com/psf/black/issues/3436. # Regression test for https://github.com/psf/black/issues/3436.
@@ -83,5 +78,5 @@ @@ -83,5 +80,5 @@
return x return x
# fmt: off # fmt: off
elif unformatted: elif unformatted:
@ -160,6 +151,8 @@ setup(
] # Includes an formatted indentation. ] # Includes an formatted indentation.
}, },
) )
# Regression test for https://github.com/psf/black/issues/2015. # Regression test for https://github.com/psf/black/issues/2015.
run( run(
# fmt: off # fmt: off

View file

@ -188,11 +188,8 @@ some_module.some_function(
): ):
pass pass
@@ -100,15 +56,7 @@ @@ -103,12 +59,5 @@
some_module.some_function(
argument1, (one_element_tuple,), argument4, argument5, argument6
)
-
# Inner trailing comma causes outer to explode # Inner trailing comma causes outer to explode
some_module.some_function( some_module.some_function(
- argument1, - argument1,
@ -268,6 +265,7 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr
some_module.some_function( some_module.some_function(
argument1, (one_element_tuple,), argument4, argument5, argument6 argument1, (one_element_tuple,), argument4, argument5, argument6
) )
# Inner trailing comma causes outer to explode # Inner trailing comma causes outer to explode
some_module.some_function( some_module.some_function(
argument1, (one, two,), argument4, argument5, argument6 argument1, (one, two,), argument4, argument5, argument6

View file

@ -62,7 +62,7 @@ __all__ = (
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -2,12 +2,13 @@ @@ -2,8 +2,10 @@
# flake8: noqa # flake8: noqa
@ -74,11 +74,7 @@ __all__ = (
ERROR, ERROR,
) )
import sys import sys
- @@ -22,33 +24,16 @@
# This relies on each of the submodules having an __all__ variable.
from .base_events import *
from .coroutines import *
@@ -22,33 +23,16 @@
from ..streams import * from ..streams import *
from some_library import ( from some_library import (
@ -134,6 +130,7 @@ from logging import (
ERROR, ERROR,
) )
import sys import sys
# This relies on each of the submodules having an __all__ variable. # This relies on each of the submodules having an __all__ variable.
from .base_events import * from .base_events import *
from .coroutines import * from .coroutines import *

View file

@ -25,11 +25,9 @@ list_of_types = [tuple[int,],]
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -2,21 +2,9 @@ @@ -4,19 +4,9 @@
# in a single-element subscript.
a: tuple[int,]
b = tuple[int,] b = tuple[int,]
-
# The magic comma still applies to multi-element subscripts. # The magic comma still applies to multi-element subscripts.
-c: tuple[ -c: tuple[
- int, - int,
@ -39,9 +37,9 @@ list_of_types = [tuple[int,],]
- int, - int,
- int, - int,
-] -]
-
+c: tuple[int, int,] +c: tuple[int, int,]
+d = tuple[int, int,] +d = tuple[int, int,]
# Magic commas still work as expected for non-subscripts. # Magic commas still work as expected for non-subscripts.
-small_list = [ -small_list = [
- 1, - 1,
@ -60,9 +58,11 @@ list_of_types = [tuple[int,],]
# in a single-element subscript. # in a single-element subscript.
a: tuple[int,] a: tuple[int,]
b = tuple[int,] b = tuple[int,]
# The magic comma still applies to multi-element subscripts. # The magic comma still applies to multi-element subscripts.
c: tuple[int, int,] c: tuple[int, int,]
d = tuple[int, int,] d = tuple[int, int,]
# Magic commas still work as expected for non-subscripts. # Magic commas still work as expected for non-subscripts.
small_list = [1,] small_list = [1,]
list_of_types = [tuple[int,],] list_of_types = [tuple[int,],]

View file

@ -89,15 +89,6 @@ return np.divide(
def function_dont_replace_spaces(): def function_dont_replace_spaces():
@@ -47,8 +47,6 @@
o = settings(max_examples=10**6.0)
p = {(k, k**2): v**2.0 for k, v in pairs}
q = [10.5**i for i in range(6)]
-
-
# WE SHOULD DEFINITELY NOT EAT THESE COMMENTS (https://github.com/psf/black/issues/2873)
if hasattr(view, "sum_of_weights"):
return np.divide( # type: ignore[no-any-return]
``` ```
## Ruff Output ## Ruff Output
@ -152,6 +143,8 @@ n = count <= 10**5.0
o = settings(max_examples=10**6.0) o = settings(max_examples=10**6.0)
p = {(k, k**2): v**2.0 for k, v in pairs} p = {(k, k**2): v**2.0 for k, v in pairs}
q = [10.5**i for i in range(6)] q = [10.5**i for i in range(6)]
# WE SHOULD DEFINITELY NOT EAT THESE COMMENTS (https://github.com/psf/black/issues/2873) # WE SHOULD DEFINITELY NOT EAT THESE COMMENTS (https://github.com/psf/black/issues/2873)
if hasattr(view, "sum_of_weights"): if hasattr(view, "sum_of_weights"):
return np.divide( # type: ignore[no-any-return] return np.divide( # type: ignore[no-any-return]

View file

@ -25,7 +25,7 @@ xxxxxxxxx_yyy_zzzzzzzz[xx.xxxxxx(x_yyy_zzzzzz.xxxxx[0]), x_yyy_zzzzzz.xxxxxx(xxx
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -2,20 +2,10 @@ @@ -2,20 +2,11 @@
# Left hand side fits in a single line but will still be exploded by the # Left hand side fits in a single line but will still be exploded by the
# magic trailing comma. # magic trailing comma.
@ -41,7 +41,7 @@ xxxxxxxxx_yyy_zzzzzzzz[xx.xxxxxx(x_yyy_zzzzzz.xxxxx[0]), x_yyy_zzzzzz.xxxxxx(xxx
arg1, arg1,
arg2, arg2,
) )
-
# Make when when the left side of assignment plus the opening paren "... = (" is # Make when when the left side of assignment plus the opening paren "... = (" is
# exactly line length limit + 1, it won't be split like that. # exactly line length limit + 1, it won't be split like that.
-xxxxxxxxx_yyy_zzzzzzzz[ -xxxxxxxxx_yyy_zzzzzzzz[
@ -61,6 +61,7 @@ first_value, (m1, m2,), third_value = xxxxxx_yyyyyy_zzzzzz_wwwwww_uuuuuuu_vvvvvv
arg1, arg1,
arg2, arg2,
) )
# Make when when the left side of assignment plus the opening paren "... = (" is # Make when when the left side of assignment plus the opening paren "... = (" is
# exactly line length limit + 1, it won't be split like that. # exactly line length limit + 1, it won't be split like that.
xxxxxxxxx_yyy_zzzzzzzz[xx.xxxxxx(x_yyy_zzzzzz.xxxxx[0]), x_yyy_zzzzzz.xxxxxx(xxxx=1)] = 1 xxxxxxxxx_yyy_zzzzzzzz[xx.xxxxxx(x_yyy_zzzzzz.xxxxx[0]), x_yyy_zzzzzz.xxxxxx(xxxx=1)] = 1

View file

@ -32,17 +32,16 @@ for (((((k, v))))) in d.items():
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,27 +1,16 @@ @@ -1,5 +1,5 @@
# Only remove tuple brackets after `for` # Only remove tuple brackets after `for`
-for k, v in d.items(): -for k, v in d.items():
+for (k, v) in d.items(): +for (k, v) in d.items():
print(k, v) print(k, v)
-
# Don't touch tuple brackets after `in` # Don't touch tuple brackets after `in`
for module in (core, _unicodefun): @@ -8,20 +8,12 @@
if hasattr(module, "_verify_python3_env"):
module._verify_python3_env = lambda: None module._verify_python3_env = lambda: None
-
# Brackets remain for long for loop lines # Brackets remain for long for loop lines
-for ( -for (
- why_would_anyone_choose_to_name_a_loop_variable_with_a_name_this_long, - why_would_anyone_choose_to_name_a_loop_variable_with_a_name_this_long,
@ -59,7 +58,7 @@ for (((((k, v))))) in d.items():
-): -):
+for (k, v) in dfkasdjfldsjflkdsjflkdsjfdslkfjldsjfgkjdshgkljjdsfldgkhsdofudsfudsofajdslkfjdslkfjldisfjdffjsdlkfjdlkjjkdflskadjldkfjsalkfjdasj.items(): +for (k, v) in dfkasdjfldsjflkdsjflkdsjfdslkfjldsjfgkjdshgkljjdsfldgkhsdofudsfudsofajdslkfjdslkfjldisfjdffjsdlkfjdlkjjkdflskadjldkfjsalkfjdasj.items():
print(k, v) print(k, v)
-
# Test deeply nested brackets # Test deeply nested brackets
-for k, v in d.items(): -for k, v in d.items():
+for (((((k, v))))) in d.items(): +for (((((k, v))))) in d.items():
@ -72,16 +71,19 @@ for (((((k, v))))) in d.items():
# Only remove tuple brackets after `for` # Only remove tuple brackets after `for`
for (k, v) in d.items(): for (k, v) in d.items():
print(k, v) print(k, v)
# Don't touch tuple brackets after `in` # Don't touch tuple brackets after `in`
for module in (core, _unicodefun): for module in (core, _unicodefun):
if hasattr(module, "_verify_python3_env"): if hasattr(module, "_verify_python3_env"):
module._verify_python3_env = lambda: None module._verify_python3_env = lambda: None
# Brackets remain for long for loop lines # Brackets remain for long for loop lines
for (why_would_anyone_choose_to_name_a_loop_variable_with_a_name_this_long, i_dont_know_but_we_should_still_check_the_behaviour_if_they_do) in d.items(): for (why_would_anyone_choose_to_name_a_loop_variable_with_a_name_this_long, i_dont_know_but_we_should_still_check_the_behaviour_if_they_do) in d.items():
print(k, v) print(k, v)
for (k, v) in dfkasdjfldsjflkdsjflkdsjfdslkfjldsjfgkjdshgkljjdsfldgkhsdofudsfudsofajdslkfjdslkfjldisfjdffjsdlkfjdlkjjkdflskadjldkfjsalkfjdasj.items(): for (k, v) in dfkasdjfldsjflkdsjflkdsjfdslkfjldsjfgkjdshgkljjdsfldgkhsdofudsfudsofajdslkfjdslkfjldisfjdffjsdlkfjdlkjjkdflskadjldkfjsalkfjdasj.items():
print(k, v) print(k, v)
# Test deeply nested brackets # Test deeply nested brackets
for (((((k, v))))) in d.items(): for (((((k, v))))) in d.items():
print(k, v) print(k, v)

View file

@ -60,30 +60,28 @@ func(
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,25 +1,43 @@ @@ -3,23 +3,45 @@
# We should not remove the trailing comma in a single-element subscript.
a: tuple[int,]
b = tuple[int,] b = tuple[int,]
-
# But commas in multiple element subscripts should be removed. # But commas in multiple element subscripts should be removed.
-c: tuple[int, int] -c: tuple[int, int]
-d = tuple[int, int] -d = tuple[int, int]
-
+c: tuple[int, int,] +c: tuple[int, int,]
+d = tuple[int, int,] +d = tuple[int, int,]
# Remove commas for non-subscripts. # Remove commas for non-subscripts.
-small_list = [1] -small_list = [1]
-list_of_types = [tuple[int,]] -list_of_types = [tuple[int,]]
-small_set = {1} -small_set = {1}
-set_of_types = {tuple[int,]} -set_of_types = {tuple[int,]}
-
+small_list = [1,] +small_list = [1,]
+list_of_types = [tuple[int,],] +list_of_types = [tuple[int,],]
+small_set = {1,} +small_set = {1,}
+set_of_types = {tuple[int,],} +set_of_types = {tuple[int,],}
# Except single element tuples # Except single element tuples
small_tuple = (1,) small_tuple = (1,)
-
# Trailing commas in multiple chained non-nested parens. # Trailing commas in multiple chained non-nested parens.
-zero(one).two(three).four(five) -zero(one).two(three).four(five)
+zero( +zero(
@ -126,16 +124,20 @@ func(
# We should not remove the trailing comma in a single-element subscript. # We should not remove the trailing comma in a single-element subscript.
a: tuple[int,] a: tuple[int,]
b = tuple[int,] b = tuple[int,]
# But commas in multiple element subscripts should be removed. # But commas in multiple element subscripts should be removed.
c: tuple[int, int,] c: tuple[int, int,]
d = tuple[int, int,] d = tuple[int, int,]
# Remove commas for non-subscripts. # Remove commas for non-subscripts.
small_list = [1,] small_list = [1,]
list_of_types = [tuple[int,],] list_of_types = [tuple[int,],]
small_set = {1,} small_set = {1,}
set_of_types = {tuple[int,],} set_of_types = {tuple[int,],}
# Except single element tuples # Except single element tuples
small_tuple = (1,) small_tuple = (1,)
# Trailing commas in multiple chained non-nested parens. # Trailing commas in multiple chained non-nested parens.
zero( zero(
one, one,

View file

@ -46,7 +46,7 @@ assert xxxxxxxxx.xxxxxxxxx.xxxxxxxxx(
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,28 +1,10 @@ @@ -1,28 +1,11 @@
-zero( -zero(
- one, - one,
-).two( -).two(
@ -54,15 +54,15 @@ assert xxxxxxxxx.xxxxxxxxx.xxxxxxxxx(
-).four( -).four(
- five, - five,
-) -)
- +zero(one,).two(three,).four(five,)
-func1(arg1).func2( -func1(arg1).func2(
- arg2, - arg2,
-).func3(arg3).func4( -).func3(arg3).func4(
- arg4, - arg4,
-).func5(arg5) -).func5(arg5)
+zero(one,).two(three,).four(five,)
+func1(arg1).func2(arg2,).func3(arg3).func4(arg4,).func5(arg5) +func1(arg1).func2(arg2,).func3(arg3).func4(arg4,).func5(arg5)
# Inner one-element tuple shouldn't explode # Inner one-element tuple shouldn't explode
func1(arg1).func2(arg1, (one_tuple,)).func3(arg3) func1(arg1).func2(arg1, (one_tuple,)).func3(arg3)
@ -78,14 +78,6 @@ assert xxxxxxxxx.xxxxxxxxx.xxxxxxxxx(
# Example from https://github.com/psf/black/issues/3229 # Example from https://github.com/psf/black/issues/3229
@@ -41,7 +23,6 @@
long_module.long_class.long_func().another_func()
== long_module.long_class.long_func()["some_key"].another_func(arg1)
)
-
# Regression test for https://github.com/psf/black/issues/3414.
assert xxxxxxxxx.xxxxxxxxx.xxxxxxxxx(
xxxxxxxxx
``` ```
## Ruff Output ## Ruff Output
@ -94,6 +86,7 @@ assert xxxxxxxxx.xxxxxxxxx.xxxxxxxxx(
zero(one,).two(three,).four(five,) zero(one,).two(three,).four(five,)
func1(arg1).func2(arg2,).func3(arg3).func4(arg4,).func5(arg5) func1(arg1).func2(arg2,).func3(arg3).func4(arg4,).func5(arg5)
# Inner one-element tuple shouldn't explode # Inner one-element tuple shouldn't explode
func1(arg1).func2(arg1, (one_tuple,)).func3(arg3) func1(arg1).func2(arg1, (one_tuple,)).func3(arg3)
@ -116,6 +109,7 @@ assert (
long_module.long_class.long_func().another_func() long_module.long_class.long_func().another_func()
== long_module.long_class.long_func()["some_key"].another_func(arg1) == long_module.long_class.long_func()["some_key"].another_func(arg1)
) )
# Regression test for https://github.com/psf/black/issues/3414. # Regression test for https://github.com/psf/black/issues/3414.
assert xxxxxxxxx.xxxxxxxxx.xxxxxxxxx( assert xxxxxxxxx.xxxxxxxxx.xxxxxxxxx(
xxxxxxxxx xxxxxxxxx

View file

@ -20,7 +20,7 @@ this_will_be_wrapped_in_parens, = struct.unpack(b"12345678901234567890")
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,12 +1,6 @@ @@ -1,12 +1,7 @@
# This is a standalone comment. # This is a standalone comment.
-( -(
- sdfjklsdfsjldkflkjsf, - sdfjklsdfsjldkflkjsf,
@ -28,8 +28,8 @@ this_will_be_wrapped_in_parens, = struct.unpack(b"12345678901234567890")
- sdfsdjfklsdfjlksdljkf, - sdfsdjfklsdfjlksdljkf,
- sdsfsdfjskdflsfsdf, - sdsfsdfjskdflsfsdf,
-) = (1, 2, 3) -) = (1, 2, 3)
-
+sdfjklsdfsjldkflkjsf, sdfjsdfjlksdljkfsdlkf, sdfsdjfklsdfjlksdljkf, sdsfsdfjskdflsfsdf = 1, 2, 3 +sdfjklsdfsjldkflkjsf, sdfjsdfjlksdljkfsdlkf, sdfsdjfklsdfjlksdljkf, sdsfsdfjskdflsfsdf = 1, 2, 3
# This is as well. # This is as well.
-(this_will_be_wrapped_in_parens,) = struct.unpack(b"12345678901234567890") -(this_will_be_wrapped_in_parens,) = struct.unpack(b"12345678901234567890")
+this_will_be_wrapped_in_parens, = struct.unpack(b"12345678901234567890") +this_will_be_wrapped_in_parens, = struct.unpack(b"12345678901234567890")
@ -42,6 +42,7 @@ this_will_be_wrapped_in_parens, = struct.unpack(b"12345678901234567890")
```py ```py
# This is a standalone comment. # This is a standalone comment.
sdfjklsdfsjldkflkjsf, sdfjsdfjlksdljkfsdlkf, sdfsdjfklsdfjlksdljkf, sdsfsdfjskdflsfsdf = 1, 2, 3 sdfjklsdfsjldkflkjsf, sdfjsdfjlksdljkfsdlkf, sdfsdjfklsdfjlksdljkf, sdsfsdfjskdflsfsdf = 1, 2, 3
# This is as well. # This is as well.
this_will_be_wrapped_in_parens, = struct.unpack(b"12345678901234567890") this_will_be_wrapped_in_parens, = struct.unpack(b"12345678901234567890")

View file

@ -65,6 +65,8 @@ not (aaaaaaaaaaaaaa + {a for x in bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb
# leading right comment # leading right comment
b b
) )
# Black breaks the right side first for the following expressions: # Black breaks the right side first for the following expressions:
( (
aaaaaaaaaaaaaa aaaaaaaaaaaaaa
@ -100,11 +102,14 @@ aaaaaaaaaaaaaa + [
aaaaaaaaaaaaaa aaaaaaaaaaaaaa
+ {a for x in bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb} + {a for x in bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb}
) )
# Wraps it in parentheses if it needs to break both left and right # Wraps it in parentheses if it needs to break both left and right
( (
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+ [bbbbbbbbbbbbbbbbbbbbbb, ccccccccccccccccccccc, dddddddddddddddd, eee] + [bbbbbbbbbbbbbbbbbbbbbb, ccccccccccccccccccccc, dddddddddddddddd, eee]
) # comment ) # comment
# But only for expressions that have a statement parent. # But only for expressions that have a statement parent.
( (
not (aaaaaaaaaaaaaa + {a for x in bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb}) not (aaaaaaaaaaaaaa + {a for x in bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb})
@ -112,6 +117,8 @@ aaaaaaaaaaaaaa + [
[ [
a + [bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb] in c, a + [bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb] in c,
] ]
# leading comment # leading comment
( (
# comment # comment

View file

@ -0,0 +1,80 @@
---
source: crates/ruff_python_formatter/src/lib.rs
expression: snapshot
---
## Input
```py
# Removes the line above
a = 10 # Keeps the line above
# Separated by one line from `a` and `b`
b = 20
# Adds two lines after `b`
class Test:
def a(self):
pass
# trailing comment
# two lines before, one line after
c = 30
while a == 10:
...
# trailing comment with one line before
# one line before this leading comment
d = 40
while b == 20:
...
# no empty line before
e = 50 # one empty line before
```
## Output
```py
# Removes the line above
a = 10 # Keeps the line above
# Separated by one line from `a` and `b`
b = 20
# Adds two lines after `b`
class Test:
def a(self):
pass
# two lines before, one line after
c = 30
while a == 10:
...
# trailing comment with one line before
# one line before this leading comment
d = 40
while b == 20:
...
# no empty line before
e = 50 # one empty line before
```

View file

@ -1,7 +1,10 @@
use crate::context::NodeLevel; use crate::context::NodeLevel;
use crate::prelude::*; use crate::prelude::*;
use ruff_formatter::{format_args, FormatOwnedWithRule, FormatRefWithRule, FormatRuleWithOptions}; use crate::trivia::lines_before;
use rustpython_parser::ast::{Stmt, Suite}; use ruff_formatter::{
format_args, write, FormatOwnedWithRule, FormatRefWithRule, FormatRuleWithOptions,
};
use rustpython_parser::ast::{Ranged, Stmt, Suite};
/// Level at which the [`Suite`] appears in the source code. /// Level at which the [`Suite`] appears in the source code.
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
@ -13,6 +16,12 @@ pub enum SuiteLevel {
Nested, Nested,
} }
impl SuiteLevel {
const fn is_nested(self) -> bool {
matches!(self, SuiteLevel::Nested)
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct FormatSuite { pub struct FormatSuite {
level: SuiteLevel, level: SuiteLevel,
@ -33,6 +42,9 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
SuiteLevel::Nested => NodeLevel::CompoundStatement, SuiteLevel::Nested => NodeLevel::CompoundStatement,
}; };
let comments = f.context().comments().clone();
let source = f.context().contents();
let saved_level = f.context().node_level(); let saved_level = f.context().node_level();
f.context_mut().set_node_level(node_level); f.context_mut().set_node_level(node_level);
@ -46,6 +58,7 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
// First entry has never any separator, doesn't matter which one we take; // First entry has never any separator, doesn't matter which one we take;
joiner.entry(first, &first.format()); joiner.entry(first, &first.format());
let mut last = first;
let mut is_last_function_or_class_definition = is_class_or_function_definition(first); let mut is_last_function_or_class_definition = is_class_or_function_definition(first);
for statement in iter { for statement in iter {
@ -58,18 +71,59 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
joiner.entry_with_separator( joiner.entry_with_separator(
&format_args![empty_line(), empty_line()], &format_args![empty_line(), empty_line()],
&statement.format(), &statement.format(),
statement,
); );
} }
SuiteLevel::Nested => { SuiteLevel::Nested => {
joiner joiner.entry_with_separator(&empty_line(), &statement.format(), statement);
.entry_with_separator(&format_args![empty_line()], &statement.format());
} }
} }
} else if is_compound_statement(last) {
// Handles the case where a body has trailing comments. The issue is that RustPython does not include
// the comments in the range of the suite. This means, the body ends right after the last statement in the body.
// ```python
// def test():
// ...
// # The body of `test` ends right after `...` and before this comment
//
// # leading comment
//
//
// a = 10
// ```
// Using `lines_after` for the node doesn't work because it would count the lines after the `...`
// which is 0 instead of 1, the number of lines between the trailing comment and
// the leading comment. This is why the suite handling counts the lines before the
// start of the next statement or before the first leading comments for compound statements.
let separator = format_with(|f| {
let start = if let Some(first_leading) =
comments.leading_comments(statement.into()).first()
{
first_leading.slice().start()
} else {
statement.start()
};
match lines_before(start, source) {
0 | 1 => hard_line_break().fmt(f),
2 => empty_line().fmt(f),
3.. => {
if self.level.is_nested() {
empty_line().fmt(f)
} else {
write!(f, [empty_line(), empty_line()])
}
}
}
});
joiner.entry_with_separator(&separator, &statement.format(), statement);
} else { } else {
joiner.entry(statement, &statement.format()); joiner.entry(statement, &statement.format());
} }
is_last_function_or_class_definition = is_current_function_or_class_definition; is_last_function_or_class_definition = is_current_function_or_class_definition;
last = statement;
} }
let result = joiner.finish(); let result = joiner.finish();
@ -87,6 +141,24 @@ const fn is_class_or_function_definition(stmt: &Stmt) -> bool {
) )
} }
const fn is_compound_statement(stmt: &Stmt) -> bool {
matches!(
stmt,
Stmt::FunctionDef(_)
| Stmt::AsyncFunctionDef(_)
| Stmt::ClassDef(_)
| Stmt::While(_)
| Stmt::For(_)
| Stmt::AsyncFor(_)
| Stmt::Match(_)
| Stmt::With(_)
| Stmt::AsyncWith(_)
| Stmt::If(_)
| Stmt::Try(_)
| Stmt::TryStar(_)
)
}
impl FormatRuleWithOptions<Suite, PyFormatContext<'_>> for FormatSuite { impl FormatRuleWithOptions<Suite, PyFormatContext<'_>> for FormatSuite {
type Options = SuiteLevel; type Options = SuiteLevel;

View file

@ -146,6 +146,33 @@ pub(crate) fn lines_after(offset: TextSize, code: &str) -> u32 {
newlines newlines
} }
/// Returns the position after skipping any trailing trivia up to, but not including the newline character.
pub(crate) fn skip_trailing_trivia(offset: TextSize, code: &str) -> TextSize {
let rest = &code[usize::from(offset)..];
let mut iter = rest.char_indices();
while let Some((relative_offset, c)) = iter.next() {
match c {
'\n' | '\r' => return offset + TextSize::try_from(relative_offset).unwrap(),
'#' => {
// Skip the comment
let newline_offset = iter
.as_str()
.find(['\n', '\r'])
.unwrap_or(iter.as_str().len());
return offset
+ TextSize::try_from(relative_offset + '#'.len_utf8() + newline_offset)
.unwrap();
}
c if is_python_whitespace(c) => continue,
_ => return offset + TextSize::try_from(relative_offset).unwrap(),
}
}
offset + rest.text_len()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::trivia::{lines_after, lines_before}; use crate::trivia::{lines_after, lines_before};