From 928ab63a64b357da96ea1a7ff2e0ce642c35e248 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 1 Aug 2023 11:30:59 -0400 Subject: [PATCH] Add empty lines before nested functions and classes (#6206) ## Summary This PR ensures that if a function or class is the first statement in a nested suite that _isn't_ a function or class body, we insert a leading newline. For example, given: ```python def f(): if True: def register_type(): pass ``` We _want_ to preserve the newline, whereas today, we remove it. Note that this only applies when the function or class doesn't have any leading comments. Closes https://github.com/astral-sh/ruff/issues/6066. --- .../src/module/mod_module.rs | 4 +- .../src/statement/stmt_class_def.rs | 3 +- .../src/statement/stmt_function_def.rs | 3 +- .../src/statement/suite.rs | 72 +++++++++++-------- ...patibility@simple_cases__function2.py.snap | 16 ++--- .../format@statement__function.py.snap | 2 + .../snapshots/format@statement__if.py.snap | 3 + .../snapshots/format@statement__try.py.snap | 4 ++ 8 files changed, 64 insertions(+), 43 deletions(-) diff --git a/crates/ruff_python_formatter/src/module/mod_module.rs b/crates/ruff_python_formatter/src/module/mod_module.rs index 61556fe1ec..654c92aa9b 100644 --- a/crates/ruff_python_formatter/src/module/mod_module.rs +++ b/crates/ruff_python_formatter/src/module/mod_module.rs @@ -1,4 +1,4 @@ -use crate::statement::suite::SuiteLevel; +use crate::statement::suite::SuiteKind; use crate::{AsFormat, FormatNodeRule, PyFormatter}; use ruff_formatter::prelude::hard_line_break; use ruff_formatter::{write, Buffer, FormatResult}; @@ -13,7 +13,7 @@ impl FormatNodeRule for FormatModModule { write!( f, [ - body.format().with_options(SuiteLevel::TopLevel), + body.format().with_options(SuiteKind::TopLevel), // Trailing newline at the end of the file hard_line_break() ] diff --git a/crates/ruff_python_formatter/src/statement/stmt_class_def.rs b/crates/ruff_python_formatter/src/statement/stmt_class_def.rs index 202afd2f01..ad6647670f 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_class_def.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_class_def.rs @@ -7,6 +7,7 @@ use ruff_python_trivia::{SimpleTokenKind, SimpleTokenizer}; use crate::comments::trailing_comments; use crate::expression::parentheses::{parenthesized, Parentheses}; use crate::prelude::*; +use crate::statement::suite::SuiteKind; #[derive(Default)] pub struct FormatStmtClassDef; @@ -52,7 +53,7 @@ impl FormatNodeRule for FormatStmtClassDef { [ text(":"), trailing_comments(trailing_head_comments), - block_indent(&body.format()) + block_indent(&body.format().with_options(SuiteKind::Class)) ] ) } diff --git a/crates/ruff_python_formatter/src/statement/stmt_function_def.rs b/crates/ruff_python_formatter/src/statement/stmt_function_def.rs index 27d569b5b3..dc89011450 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_function_def.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_function_def.rs @@ -8,6 +8,7 @@ use crate::comments::{leading_comments, trailing_comments}; use crate::expression::parentheses::{optional_parentheses, Parentheses}; use crate::prelude::*; +use crate::statement::suite::SuiteKind; use crate::FormatNodeRule; #[derive(Default)] @@ -111,7 +112,7 @@ impl FormatRule, PyFormatContext<'_>> for FormatAnyFun [ text(":"), trailing_comments(trailing_definition_comments), - block_indent(&item.body().format()) + block_indent(&item.body().format().with_options(SuiteKind::Function)) ] ) } diff --git a/crates/ruff_python_formatter/src/statement/suite.rs b/crates/ruff_python_formatter/src/statement/suite.rs index 735cf6cc63..0b68b8df29 100644 --- a/crates/ruff_python_formatter/src/statement/suite.rs +++ b/crates/ruff_python_formatter/src/statement/suite.rs @@ -8,38 +8,40 @@ use crate::prelude::*; /// Level at which the [`Suite`] appears in the source code. #[derive(Copy, Clone, Debug)] -pub enum SuiteLevel { +pub enum SuiteKind { /// Statements at the module level / top level TopLevel, - /// Statements in a nested body - Nested, -} + /// Statements in a function body. + Function, -impl SuiteLevel { - const fn is_nested(self) -> bool { - matches!(self, SuiteLevel::Nested) - } + /// Statements in a class body. + Class, + + /// Statements in any other body (e.g., `if` or `while`). + Other, } #[derive(Debug)] pub struct FormatSuite { - level: SuiteLevel, + kind: SuiteKind, } impl Default for FormatSuite { fn default() -> Self { FormatSuite { - level: SuiteLevel::Nested, + kind: SuiteKind::Other, } } } impl FormatRule> for FormatSuite { fn fmt(&self, statements: &Suite, f: &mut PyFormatter) -> FormatResult<()> { - let node_level = match self.level { - SuiteLevel::TopLevel => NodeLevel::TopLevel, - SuiteLevel::Nested => NodeLevel::CompoundStatement, + let node_level = match self.kind { + SuiteKind::TopLevel => NodeLevel::TopLevel, + SuiteKind::Function | SuiteKind::Class | SuiteKind::Other => { + NodeLevel::CompoundStatement + } }; let comments = f.context().comments().clone(); @@ -51,18 +53,33 @@ impl FormatRule> for FormatSuite { }; let mut f = WithNodeLevel::new(node_level, f); - // First entry has never any separator, doesn't matter which one we take. + + if matches!(self.kind, SuiteKind::Other) + && is_class_or_function_definition(first) + && !comments.has_leading_comments(first) + { + // Add an empty line for any nested functions or classes defined within non-function + // or class compound statements, e.g., this is stable formatting: + // ```python + // if True: + // + // def test(): + // ... + // ``` + write!(f, [empty_line()])?; + } + write!(f, [first.format()])?; let mut last = first; for statement in iter { if is_class_or_function_definition(last) || is_class_or_function_definition(statement) { - match self.level { - SuiteLevel::TopLevel => { + match self.kind { + SuiteKind::TopLevel => { write!(f, [empty_line(), empty_line(), statement.format()])?; } - SuiteLevel::Nested => { + SuiteKind::Function | SuiteKind::Class | SuiteKind::Other => { write!(f, [empty_line(), statement.format()])?; } } @@ -95,13 +112,12 @@ impl FormatRule> for FormatSuite { match lines_before(start, source) { 0 | 1 => write!(f, [hard_line_break()])?, 2 => write!(f, [empty_line()])?, - 3.. => { - if self.level.is_nested() { + 3.. => match self.kind { + SuiteKind::TopLevel => write!(f, [empty_line(), empty_line()])?, + SuiteKind::Function | SuiteKind::Class | SuiteKind::Other => { write!(f, [empty_line()])?; - } else { - write!(f, [empty_line(), empty_line()])?; } - } + }, } write!(f, [statement.format()])?; @@ -167,10 +183,10 @@ const fn is_import_definition(stmt: &Stmt) -> bool { } impl FormatRuleWithOptions> for FormatSuite { - type Options = SuiteLevel; + type Options = SuiteKind; fn with_options(mut self, options: Self::Options) -> Self { - self.level = options; + self.kind = options; self } } @@ -199,10 +215,10 @@ mod tests { use crate::comments::Comments; use crate::prelude::*; - use crate::statement::suite::SuiteLevel; + use crate::statement::suite::SuiteKind; use crate::PyFormatOptions; - fn format_suite(level: SuiteLevel) -> String { + fn format_suite(level: SuiteKind) -> String { let source = r#" a = 10 @@ -239,7 +255,7 @@ def trailing_func(): #[test] fn top_level() { - let formatted = format_suite(SuiteLevel::TopLevel); + let formatted = format_suite(SuiteKind::TopLevel); assert_eq!( formatted, @@ -274,7 +290,7 @@ def trailing_func(): #[test] fn nested_level() { - let formatted = format_suite(SuiteLevel::Nested); + let formatted = format_suite(SuiteKind::Other); assert_eq!( formatted, diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__function2.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__function2.py.snap index 23da8b6351..2570ff4de0 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__function2.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__function2.py.snap @@ -73,22 +73,13 @@ with hmm_but_this_should_get_two_preceding_newlines(): elif os.name == "nt": try: import msvcrt -@@ -45,21 +44,16 @@ - pass - - except ImportError: -- - def i_should_be_followed_by_only_one_newline(): - pass - - elif False: -- +@@ -54,12 +53,10 @@ class IHopeYouAreHavingALovelyDay: def __call__(self): print("i_should_be_followed_by_only_one_newline") - else: -- + def foo(): pass - @@ -146,14 +137,17 @@ elif os.name == "nt": pass except ImportError: + def i_should_be_followed_by_only_one_newline(): pass elif False: + class IHopeYouAreHavingALovelyDay: def __call__(self): print("i_should_be_followed_by_only_one_newline") else: + def foo(): pass diff --git a/crates/ruff_python_formatter/tests/snapshots/format@statement__function.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@statement__function.py.snap index 31226f41f7..3880bc3f20 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@statement__function.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@statement__function.py.snap @@ -344,6 +344,7 @@ def with_leading_comment(): # looking from the position of the if # Regression test for https://github.com/python/cpython/blob/ad56340b665c5d8ac1f318964f71697bba41acb7/Lib/logging/__init__.py#L253-L260 if True: + def f1(): pass # a else: @@ -351,6 +352,7 @@ else: # Here it's actually a trailing comment if True: + def f2(): pass # a diff --git a/crates/ruff_python_formatter/tests/snapshots/format@statement__if.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@statement__if.py.snap index 3452bfe9b8..10ab45b872 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@statement__if.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@statement__if.py.snap @@ -203,14 +203,17 @@ def f(): if True: + def f(): pass # 1 elif True: + def f(): pass # 2 else: + def f(): pass # 3 diff --git a/crates/ruff_python_formatter/tests/snapshots/format@statement__try.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@statement__try.py.snap index 9f9afbc4a2..97db725e9b 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@statement__try.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@statement__try.py.snap @@ -263,18 +263,22 @@ except RuntimeError: raise try: + def f(): pass # a except: + def f(): pass # b else: + def f(): pass # c finally: + def f(): pass # d