Insert newline after nested function or class statements (#7946)

**Summary** Insert a newline after nested function and class
definitions, unless there is a trailing own line comment.

We need to e.g. format
```python
if platform.system() == "Linux":
    if sys.version > (3, 10):
        def f():
            print("old")
    else:
        def f():
            print("new")
    f()
```
as
```python
if platform.system() == "Linux":
    if sys.version > (3, 10):

        def f():
            print("old")

    else:

        def f():
            print("new")

    f()
```
even though `f()` is directly preceded by an if statement, not a
function or class definition. See the comments and fixtures for trailing
own line comment handling.

**Test Plan** I checked that the new content of `newlines.py` matches
black's formatting.

---------

Co-authored-by: Charlie Marsh <charlie.r.marsh@gmail.com>
This commit is contained in:
konsti 2023-10-18 11:45:58 +02:00 committed by GitHub
parent dda4ceda71
commit 0c3123e07e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 430 additions and 77 deletions

View file

@ -1,6 +1,7 @@
###
# Blank lines around functions
###
import sys
x = 1
@ -159,3 +160,97 @@ def f():
# comment
x = 1
def f():
if True:
def double(s):
return s + s
print("below function")
if True:
class A:
x = 1
print("below class")
if True:
def double(s):
return s + s
#
print("below comment function")
if True:
class A:
x = 1
#
print("below comment class")
if True:
def double(s):
return s + s
#
print("below comment function 2")
if True:
def double(s):
return s + s
#
def outer():
def inner():
pass
print("below nested functions")
if True:
def double(s):
return s + s
print("below function")
if True:
class A:
x = 1
print("below class")
def outer():
def inner():
pass
print("below nested functions")
class Path:
if sys.version_info >= (3, 11):
def joinpath(self): ...
# The .open method comes from pathlib.pyi and should be kept in sync.
@overload
def open(self): ...
def fakehttp():
class FakeHTTPConnection:
if mock_close:
def close(self):
pass
FakeHTTPConnection.fakedata = fakedata
if True:
if False:
def x():
def y():
pass
#comment
print()
# NOTE: Please keep this the last block in this file. This tests that we don't insert
# empty line(s) at the end of the file due to nested function
if True:
def nested_trailing_function():
pass

View file

@ -347,9 +347,9 @@ fn handle_end_of_line_comment_around_body<'a>(
// ```
// The first earlier branch filters out ambiguities e.g. around try-except-finally.
if let Some(preceding) = comment.preceding_node() {
if let Some(last_child) = last_child_in_body(preceding) {
if let Some(last_child) = preceding.last_child_in_body() {
let innermost_child =
std::iter::successors(Some(last_child), |parent| last_child_in_body(*parent))
std::iter::successors(Some(last_child), AnyNodeRef::last_child_in_body)
.last()
.unwrap_or(last_child);
return CommentPlacement::trailing(innermost_child, comment);
@ -670,7 +670,7 @@ fn handle_own_line_comment_after_branch<'a>(
preceding: AnyNodeRef<'a>,
locator: &Locator,
) -> CommentPlacement<'a> {
let Some(last_child) = last_child_in_body(preceding) else {
let Some(last_child) = preceding.last_child_in_body() else {
return CommentPlacement::Default(comment);
};
@ -734,7 +734,7 @@ fn handle_own_line_comment_after_branch<'a>(
return CommentPlacement::trailing(last_child_in_parent, comment);
}
Ordering::Greater => {
if let Some(nested_child) = last_child_in_body(last_child_in_parent) {
if let Some(nested_child) = last_child_in_parent.last_child_in_body() {
// The comment belongs to the inner block.
parent = Some(last_child_in_parent);
last_child_in_parent = nested_child;
@ -2176,65 +2176,6 @@ where
right.is_some_and(|right| left.ptr_eq(right.into()))
}
/// The last child of the last branch, if the node has multiple branches.
fn last_child_in_body(node: AnyNodeRef) -> Option<AnyNodeRef> {
let body = match node {
AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. })
| AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. })
| AnyNodeRef::StmtWith(ast::StmtWith { body, .. })
| AnyNodeRef::MatchCase(MatchCase { body, .. })
| AnyNodeRef::ExceptHandlerExceptHandler(ast::ExceptHandlerExceptHandler {
body, ..
})
| AnyNodeRef::ElifElseClause(ast::ElifElseClause { body, .. }) => body,
AnyNodeRef::StmtIf(ast::StmtIf {
body,
elif_else_clauses,
..
}) => elif_else_clauses.last().map_or(body, |clause| &clause.body),
AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. })
| AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => {
if orelse.is_empty() {
body
} else {
orelse
}
}
AnyNodeRef::StmtMatch(ast::StmtMatch { cases, .. }) => {
return cases.last().map(AnyNodeRef::from);
}
AnyNodeRef::StmtTry(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => {
if finalbody.is_empty() {
if orelse.is_empty() {
if handlers.is_empty() {
body
} else {
return handlers.last().map(AnyNodeRef::from);
}
} else {
orelse
}
} else {
finalbody
}
}
// Not a node that contains an indented child node.
_ => return None,
};
body.last().map(AnyNodeRef::from)
}
/// Returns `true` if `statement` is the first statement in an alternate `body` (e.g. the else of an if statement)
fn is_first_statement_in_alternate_body(statement: AnyNodeRef, has_body: AnyNodeRef) -> bool {
match has_body {

View file

@ -155,13 +155,65 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
while let Some(following) = iter.next() {
let following_comments = comments.leading_dangling_trailing(following);
let needs_empty_lines = if is_class_or_function_definition(following) {
// Here we insert empty lines even if the preceding has a trailing own line comment
true
} else {
// Find nested class or function definitions that need an empty line after them.
//
// ```python
// def f():
// if True:
//
// def double(s):
// return s + s
//
// print("below function")
// ```
std::iter::successors(
Some(AnyNodeRef::from(preceding)),
AnyNodeRef::last_child_in_body,
)
.take_while(|last_child|
// If there is a comment between preceding and following the empty lines were
// inserted before the comment by preceding and there are no extra empty lines
// after the comment.
// ```python
// class Test:
// def a(self):
// pass
// # trailing comment
//
//
// # two lines before, one line after
//
// c = 30
// ````
// This also includes nested class/function definitions, so we stop recursing
// once we see a node with a trailing own line comment:
// ```python
// def f():
// if True:
//
// def double(s):
// return s + s
//
// # nested trailing own line comment
// print("below function with trailing own line comment")
// ```
!comments.has_trailing_own_line(*last_child))
.any(|last_child| {
matches!(
last_child,
AnyNodeRef::StmtFunctionDef(_) | AnyNodeRef::StmtClassDef(_)
)
})
};
// Add empty lines before and after a function or class definition. If the preceding
// node is a function or class, and contains trailing comments, then the statement
// itself will add the requisite empty lines when formatting its comments.
if (is_class_or_function_definition(preceding)
&& !preceding_comments.has_trailing_own_line())
|| is_class_or_function_definition(following)
{
if needs_empty_lines {
if source_type.is_stub() {
stub_file_empty_lines(
self.kind,

View file

@ -73,7 +73,7 @@ with hmm_but_this_should_get_two_preceding_newlines():
elif os.name == "nt":
try:
import msvcrt
@@ -54,12 +53,10 @@
@@ -54,7 +53,6 @@
class IHopeYouAreHavingALovelyDay:
def __call__(self):
print("i_should_be_followed_by_only_one_newline")
@ -81,11 +81,6 @@ with hmm_but_this_should_get_two_preceding_newlines():
else:
def foo():
pass
-
with hmm_but_this_should_get_two_preceding_newlines():
pass
```
## Ruff Output
@ -151,6 +146,7 @@ else:
def foo():
pass
with hmm_but_this_should_get_two_preceding_newlines():
pass
```

View file

@ -7,6 +7,7 @@ input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/newlines.p
###
# Blank lines around functions
###
import sys
x = 1
@ -165,13 +166,107 @@ def f():
# comment
x = 1
```
def f():
if True:
def double(s):
return s + s
print("below function")
if True:
class A:
x = 1
print("below class")
if True:
def double(s):
return s + s
#
print("below comment function")
if True:
class A:
x = 1
#
print("below comment class")
if True:
def double(s):
return s + s
#
print("below comment function 2")
if True:
def double(s):
return s + s
#
def outer():
def inner():
pass
print("below nested functions")
if True:
def double(s):
return s + s
print("below function")
if True:
class A:
x = 1
print("below class")
def outer():
def inner():
pass
print("below nested functions")
class Path:
if sys.version_info >= (3, 11):
def joinpath(self): ...
# The .open method comes from pathlib.pyi and should be kept in sync.
@overload
def open(self): ...
def fakehttp():
class FakeHTTPConnection:
if mock_close:
def close(self):
pass
FakeHTTPConnection.fakedata = fakedata
if True:
if False:
def x():
def y():
pass
#comment
print()
# NOTE: Please keep this the last block in this file. This tests that we don't insert
# empty line(s) at the end of the file due to nested function
if True:
def nested_trailing_function():
pass```
## Output
```py
###
# Blank lines around functions
###
import sys
x = 1
@ -339,6 +434,118 @@ def f():
# comment
x = 1
def f():
if True:
def double(s):
return s + s
print("below function")
if True:
class A:
x = 1
print("below class")
if True:
def double(s):
return s + s
#
print("below comment function")
if True:
class A:
x = 1
#
print("below comment class")
if True:
def double(s):
return s + s
#
print("below comment function 2")
if True:
def double(s):
return s + s
#
def outer():
def inner():
pass
print("below nested functions")
if True:
def double(s):
return s + s
print("below function")
if True:
class A:
x = 1
print("below class")
def outer():
def inner():
pass
print("below nested functions")
class Path:
if sys.version_info >= (3, 11):
def joinpath(self):
...
# The .open method comes from pathlib.pyi and should be kept in sync.
@overload
def open(self):
...
def fakehttp():
class FakeHTTPConnection:
if mock_close:
def close(self):
pass
FakeHTTPConnection.fakedata = fakedata
if True:
if False:
def x():
def y():
pass
# comment
print()
# NOTE: Please keep this the last block in this file. This tests that we don't insert
# empty line(s) at the end of the file due to nested function
if True:
def nested_trailing_function():
pass
```

View file

@ -410,6 +410,7 @@ else:
pass
# 3
if True:
print("a") # 1
elif True:

View file

@ -311,6 +311,7 @@ finally:
pass
# d
try:
pass # a
except ZeroDivisionError: