Rewrite placement logic (#6040)

## Summary
This is a rewrite of the main comment placement logic. `place_comment`
now has three parts:

- place own line comments
  - between branches
  - after a branch
- place end-of-line comments
  - after colon
  - after a branch
- place comments for specific nodes (that include module level comments)

The rewrite fixed three bugs: `class A: # trailing comment` comments now
stay end-of-line, `try: # comment` remains end-of-line and deeply
indented try-else-finally comments remain with the right nested
statement.

It will be much easier to give more alternative branches nodes since
this is abstracted away by `is_node_with_body` and the first/last child
helpers. Adding new node types can now be done by adding an entry to the
`place_comment` match. The code went from 1526 lines before #6033 to
1213 lines now.

It thinks it easier to just read the new `placement.rs` rather than
reviewing the diff.

## Test Plan

The existing fixtures staying the same or improving plus new ones for
the bug fixes.
This commit is contained in:
konsti 2023-07-26 18:21:23 +02:00 committed by GitHub
parent 2cf00fee96
commit 13f9a16e33
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 541 additions and 673 deletions

View file

@ -4648,6 +4648,8 @@ impl AnyNodeRef<'_> {
| AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtClassDef(_)
| AnyNodeRef::StmtTry(_) | AnyNodeRef::StmtTry(_)
| AnyNodeRef::StmtTryStar(_) | AnyNodeRef::StmtTryStar(_)
| AnyNodeRef::ExceptHandlerExceptHandler(_)
| AnyNodeRef::ElifElseClause(_)
) )
} }

View file

@ -35,5 +35,11 @@ class Test(aaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbb + cccccccccccccccccccccccc +
class Test(aaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbb * cccccccccccccccccccccccc + dddddddddddddddddddddd + eeeeeeeee, ffffffffffffffffff, gggggggggggggggggg): class Test(aaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbb * cccccccccccccccccccccccc + dddddddddddddddddddddd + eeeeeeeee, ffffffffffffffffff, gggggggggggggggggg):
pass pass
class Test(Aaaa): # trailing comment class TestTrailingComment1(Aaaa): # trailing comment
pass pass
class TestTrailingComment2: # trailing comment
pass

View file

@ -88,13 +88,17 @@ def f():
# comment # comment
if True: if True:
def f2(): def f():
pass pass
# 1 # 1
else: elif True:
def f2(): def f():
pass pass
# 2 # 2
else:
def f():
pass
# 3
if True: print("a") # 1 if True: print("a") # 1
elif True: print("b") # 2 elif True: print("b") # 2

View file

@ -106,16 +106,35 @@ except RuntimeError:
raise raise
try: try:
def f2(): def f():
pass pass
# a # a
except: except:
def f2(): def f():
pass pass
# b # b
else:
def f():
pass
# c
finally:
def f():
pass
# d
try: pass # a try: pass # a
except ZeroDivisionError: pass # b except ZeroDivisionError: pass # b
except: pass # b except: pass # c
else: pass # d else: pass # d
finally: pass # c finally: pass # e
try: # 1 preceding: any, following: first in body, enclosing: try
print(1) # 2 preceding: last in body, following: fist in alt body, enclosing: try
except ZeroDivisionError: # 3 preceding: test, following: fist in alt body, enclosing: try
print(2) # 4 preceding: last in body, following: fist in alt body, enclosing: exc
except: # 5 preceding: last in body, following: fist in alt body, enclosing: try
print(2) # 6 preceding: last in body, following: fist in alt body, enclosing: exc
else: # 7 preceding: last in body, following: fist in alt body, enclosing: exc
print(3) # 8 preceding: last in body, following: fist in alt body, enclosing: try
finally: # 9 preceding: last in body, following: fist in alt body, enclosing: try
print(3) # 10 preceding: last in body, following: any, enclosing: try

File diff suppressed because it is too large Load diff

View file

@ -4,18 +4,18 @@ expression: comments.debug(test_case.source_code)
--- ---
{ {
Node { Node {
kind: StmtExpr, kind: StmtPass,
range: 29..42, range: 12..16,
source: `print("test")`, source: `pass`,
}: { }: {
"leading": [ "leading": [],
"dangling": [],
"trailing": [
SourceComment { SourceComment {
text: "# Test", text: "# Test",
position: OwnLine, position: OwnLine,
formatted: false, formatted: false,
}, },
], ],
"dangling": [],
"trailing": [],
}, },
} }

View file

@ -132,15 +132,13 @@ impl FormatNodeRule<ExprSlice> for FormatExprSlice {
let step_leading_comments = comments.leading_comments(step.as_ref()); let step_leading_comments = comments.leading_comments(step.as_ref());
leading_comments_spacing(f, step_leading_comments)?; leading_comments_spacing(f, step_leading_comments)?;
step.format().fmt(f)?; step.format().fmt(f)?;
} else { } else if !dangling_step_comments.is_empty() {
if !dangling_step_comments.is_empty() {
// Put the colon and comments on their own lines // Put the colon and comments on their own lines
write!( write!(
f, f,
[hard_line_break(), dangling_comments(dangling_step_comments)] [hard_line_break(), dangling_comments(dangling_step_comments)]
)?; )?;
} }
}
} else { } else {
debug_assert!(step.is_none(), "step can't exist without a second colon"); debug_assert!(step.is_none(), "step can't exist without a second colon");
} }

View file

@ -5,7 +5,6 @@ use crate::prelude::*;
use crate::{FormatNodeRule, PyFormatter}; use crate::{FormatNodeRule, PyFormatter};
use ruff_formatter::FormatRuleWithOptions; use ruff_formatter::FormatRuleWithOptions;
use ruff_formatter::{write, Buffer, FormatResult}; use ruff_formatter::{write, Buffer, FormatResult};
use ruff_python_ast::node::AstNode;
use rustpython_ast::ExceptHandlerExceptHandler; use rustpython_ast::ExceptHandlerExceptHandler;
#[derive(Copy, Clone, Default)] #[derive(Copy, Clone, Default)]
@ -45,7 +44,7 @@ impl FormatNodeRule<ExceptHandlerExceptHandler> for FormatExceptHandlerExceptHan
} = item; } = item;
let comments_info = f.context().comments().clone(); let comments_info = f.context().comments().clone();
let dangling_comments = comments_info.dangling_comments(item.as_any_node_ref()); let dangling_comments = comments_info.dangling_comments(item);
write!( write!(
f, f,
@ -75,7 +74,7 @@ impl FormatNodeRule<ExceptHandlerExceptHandler> for FormatExceptHandlerExceptHan
[ [
text(":"), text(":"),
trailing_comments(dangling_comments), trailing_comments(dangling_comments),
block_indent(&body.format()) block_indent(&body.format()),
] ]
) )
} }

View file

@ -1,6 +1,6 @@
use crate::comments; use crate::comments;
use crate::comments::leading_alternate_branch_comments;
use crate::comments::SourceComment; use crate::comments::SourceComment;
use crate::comments::{leading_alternate_branch_comments, trailing_comments};
use crate::other::except_handler_except_handler::ExceptHandlerKind; use crate::other::except_handler_except_handler::ExceptHandlerKind;
use crate::prelude::*; use crate::prelude::*;
use crate::statement::FormatRefWithRule; use crate::statement::FormatRefWithRule;
@ -134,8 +134,7 @@ impl Format<PyFormatContext<'_>> for AnyStatementTry<'_> {
let orelse = self.orelse(); let orelse = self.orelse();
let finalbody = self.finalbody(); let finalbody = self.finalbody();
write!(f, [text("try:"), block_indent(&body.format())])?; (_, dangling_comments) = format_case("try", body, None, dangling_comments, f)?;
let mut previous_node = body.last(); let mut previous_node = body.last();
for handler in handlers { for handler in handlers {
@ -183,15 +182,18 @@ fn format_case<'a>(
let case_comments_start = let case_comments_start =
dangling_comments.partition_point(|comment| comment.slice().end() <= last.end()); dangling_comments.partition_point(|comment| comment.slice().end() <= last.end());
let (case_comments, rest) = dangling_comments.split_at(case_comments_start); let (case_comments, rest) = dangling_comments.split_at(case_comments_start);
let partition_point =
case_comments.partition_point(|comment| comment.line_position().is_own_line());
write!( write!(
f, f,
[leading_alternate_branch_comments( [
case_comments, leading_alternate_branch_comments(&case_comments[..partition_point], previous_node),
previous_node text(name),
)] text(":"),
trailing_comments(&case_comments[partition_point..]),
block_indent(&body.format())
]
)?; )?;
write!(f, [text(name), text(":"), block_indent(&body.format())])?;
(None, rest) (None, rest)
} else { } else {
(None, dangling_comments) (None, dangling_comments)

View file

@ -74,7 +74,7 @@ async def test_async_with():
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,62 +1,63 @@ @@ -1,62 +1,62 @@
# Make sure a leading comment is not removed. # Make sure a leading comment is not removed.
-def some_func( unformatted, args ): # fmt: skip -def some_func( unformatted, args ): # fmt: skip
+def some_func(unformatted, args): # fmt: skip +def some_func(unformatted, args): # fmt: skip
@ -134,15 +134,13 @@ async def test_async_with():
-try : # fmt: skip -try : # fmt: skip
+try: +try: # fmt: skip
+ # fmt: skip
some_call() some_call()
-except UnformattedError as ex: # fmt: skip -except UnformattedError as ex: # fmt: skip
- handle_exception()
-finally : # fmt: skip
+except UnformattedError as ex: # fmt: skip +except UnformattedError as ex: # fmt: skip
+ handle_exception() # fmt: skip handle_exception()
+finally: -finally : # fmt: skip
+finally: # fmt: skip
finally_call() finally_call()
@ -207,12 +205,11 @@ async def test_async_for():
print("Do something") print("Do something")
try: try: # fmt: skip
# fmt: skip
some_call() some_call()
except UnformattedError as ex: # fmt: skip except UnformattedError as ex: # fmt: skip
handle_exception() # fmt: skip handle_exception()
finally: finally: # fmt: skip
finally_call() finally_call()

View file

@ -32,14 +32,12 @@ def h():
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,18 +1,26 @@ @@ -1,18 +1,25 @@
def f(): # type: ignore def f(): # type: ignore
... ...
-class x: # some comment
+ +
+class x: class x: # some comment
+ # some comment
... ...
-class y: ... # comment -class y: ... # comment
@ -71,8 +69,7 @@ def f(): # type: ignore
... ...
class x: class x: # some comment
# some comment
... ...

View file

@ -41,8 +41,14 @@ class Test(aaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbb + cccccccccccccccccccccccc +
class Test(aaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbb * cccccccccccccccccccccccc + dddddddddddddddddddddd + eeeeeeeee, ffffffffffffffffff, gggggggggggggggggg): class Test(aaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbb * cccccccccccccccccccccccc + dddddddddddddddddddddd + eeeeeeeee, ffffffffffffffffff, gggggggggggggggggg):
pass pass
class Test(Aaaa): # trailing comment class TestTrailingComment1(Aaaa): # trailing comment
pass pass
class TestTrailingComment2: # trailing comment
pass
``` ```
## Output ## Output
@ -102,7 +108,11 @@ class Test(
pass pass
class Test(Aaaa): # trailing comment class TestTrailingComment1(Aaaa): # trailing comment
pass
class TestTrailingComment2: # trailing comment
pass pass
``` ```

View file

@ -94,13 +94,17 @@ def f():
# comment # comment
if True: if True:
def f2(): def f():
pass pass
# 1 # 1
else: elif True:
def f2(): def f():
pass pass
# 2 # 2
else:
def f():
pass
# 3
if True: print("a") # 1 if True: print("a") # 1
elif True: print("b") # 2 elif True: print("b") # 2
@ -199,13 +203,17 @@ def f():
if True: if True:
def f2(): def f():
pass pass
# 1 # 1
else: elif True:
def f2(): def f():
pass pass
# 2 # 2
else:
def f():
pass
# 3
if True: if True:
print("a") # 1 print("a") # 1

View file

@ -112,19 +112,38 @@ except RuntimeError:
raise raise
try: try:
def f2(): def f():
pass pass
# a # a
except: except:
def f2(): def f():
pass pass
# b # b
else:
def f():
pass
# c
finally:
def f():
pass
# d
try: pass # a try: pass # a
except ZeroDivisionError: pass # b except ZeroDivisionError: pass # b
except: pass # b except: pass # c
else: pass # d else: pass # d
finally: pass # c finally: pass # e
try: # 1 preceding: any, following: first in body, enclosing: try
print(1) # 2 preceding: last in body, following: fist in alt body, enclosing: try
except ZeroDivisionError: # 3 preceding: test, following: fist in alt body, enclosing: try
print(2) # 4 preceding: last in body, following: fist in alt body, enclosing: exc
except: # 5 preceding: last in body, following: fist in alt body, enclosing: try
print(2) # 6 preceding: last in body, following: fist in alt body, enclosing: exc
else: # 7 preceding: last in body, following: fist in alt body, enclosing: exc
print(3) # 8 preceding: last in body, following: fist in alt body, enclosing: try
finally: # 9 preceding: last in body, following: fist in alt body, enclosing: try
print(3) # 10 preceding: last in body, following: any, enclosing: try
``` ```
## Output ## Output
@ -140,8 +159,7 @@ except KeyError: # should remove brackets and be a single line
... ...
try: try: # try
# try
... ...
# end of body # end of body
# before except # before except
@ -160,8 +178,7 @@ finally:
# with line breaks # with line breaks
try: try: # try
# try
... ...
# end of body # end of body
@ -213,8 +230,7 @@ except:
# try/except*, mostly the same as try # try/except*, mostly the same as try
try: try: # try
# try
... ...
# end of body # end of body
# before except # before except
@ -247,24 +263,43 @@ except RuntimeError:
raise raise
try: try:
def f2(): def f():
pass pass
# a # a
except: except:
def f2(): def f():
pass pass
# b # b
else:
def f():
pass
# c
finally:
def f():
pass
# d
try: try:
pass # a pass # a
except ZeroDivisionError: except ZeroDivisionError:
pass # b pass # b
except: except:
pass # b pass # c
else: else:
pass # d pass # d
finally: finally:
pass # c pass # e
try: # 1 preceding: any, following: first in body, enclosing: try
print(1) # 2 preceding: last in body, following: fist in alt body, enclosing: try
except ZeroDivisionError: # 3 preceding: test, following: fist in alt body, enclosing: try
print(2) # 4 preceding: last in body, following: fist in alt body, enclosing: exc
except: # 5 preceding: last in body, following: fist in alt body, enclosing: try
print(2) # 6 preceding: last in body, following: fist in alt body, enclosing: exc
else: # 7 preceding: last in body, following: fist in alt body, enclosing: exc
print(3) # 8 preceding: last in body, following: fist in alt body, enclosing: try
finally: # 9 preceding: last in body, following: fist in alt body, enclosing: try
print(3) # 10 preceding: last in body, following: any, enclosing: try
``` ```