diff --git a/crates/ruff_python_formatter/src/builders.rs b/crates/ruff_python_formatter/src/builders.rs new file mode 100644 index 0000000000..62d7317f36 --- /dev/null +++ b/crates/ruff_python_formatter/src/builders.rs @@ -0,0 +1,212 @@ +use crate::context::NodeLevel; +use crate::prelude::*; +use crate::trivia::lines_before; +use ruff_formatter::write; +use rustpython_parser::ast::Ranged; + +/// Provides Python specific extensions to [`Formatter`]. +pub(crate) trait PyFormatterExtensions<'ast, 'buf> { + /// Creates a joiner that inserts the appropriate number of empty lines between two nodes, depending on the + /// line breaks that separate the two nodes in the source document. The `level` customizes the maximum allowed + /// empty lines between any two nodes. Separates any two nodes by at least a hard line break. + /// + /// * [`NodeLevel::Module`]: Up to two empty lines + /// * [`NodeLevel::Statement`]: Up to one empty line + /// * [`NodeLevel::Parenthesized`]: No empty lines + fn join_nodes<'fmt>(&'fmt mut self, level: NodeLevel) -> JoinNodesBuilder<'fmt, 'ast, 'buf>; +} + +impl<'buf, 'ast> PyFormatterExtensions<'ast, 'buf> for PyFormatter<'ast, 'buf> { + fn join_nodes<'fmt>(&'fmt mut self, level: NodeLevel) -> JoinNodesBuilder<'fmt, 'ast, 'buf> { + JoinNodesBuilder::new(self, level) + } +} + +#[must_use = "must eventually call `finish()` on the builder."] +pub(crate) struct JoinNodesBuilder<'fmt, 'ast, 'buf> { + fmt: &'fmt mut PyFormatter<'ast, 'buf>, + result: FormatResult<()>, + has_elements: bool, + node_level: NodeLevel, +} + +impl<'fmt, 'ast, 'buf> JoinNodesBuilder<'fmt, 'ast, 'buf> { + fn new(fmt: &'fmt mut PyFormatter<'ast, 'buf>, level: NodeLevel) -> Self { + Self { + fmt, + result: Ok(()), + has_elements: false, + node_level: level, + } + } + + /// Writes a `node`, inserting the appropriate number of line breaks depending on the number of + /// line breaks that were present in the source document. Uses `content` to format the `node`. + pub(crate) fn entry(&mut self, node: &T, content: &dyn Format>) + where + T: Ranged, + { + let node_level = self.node_level; + let separator = format_with(|f: &mut PyFormatter| match node_level { + NodeLevel::TopLevel => match lines_before(f.context().contents(), node.start()) { + 0 | 1 => hard_line_break().fmt(f), + 2 => empty_line().fmt(f), + _ => write!(f, [empty_line(), empty_line()]), + }, + NodeLevel::Statement => match lines_before(f.context().contents(), node.start()) { + 0 | 1 => hard_line_break().fmt(f), + _ => empty_line().fmt(f), + }, + NodeLevel::Parenthesized => hard_line_break().fmt(f), + }); + + self.entry_with_separator(&separator, content); + } + + /// Writes a sequence of node with their content tuples, inserting the appropriate number of line breaks between any two of them + /// depending on the number of line breaks that exist in the source document. + #[allow(unused)] + pub(crate) fn entries(&mut self, entries: I) -> &mut Self + where + T: Ranged, + F: Format>, + I: IntoIterator, + { + for (node, content) in entries { + self.entry(&node, &content); + } + + self + } + + /// Writes a sequence of nodes, using their [`AsFormat`] implementation to format the content. + /// Inserts the appropriate number of line breaks between any two nodes, depending on the number of + /// line breaks in the source document. + #[allow(unused)] + pub(crate) fn nodes<'a, T, I>(&mut self, nodes: I) -> &mut Self + where + T: Ranged + AsFormat> + 'a, + I: IntoIterator, + { + for node in nodes { + self.entry(node, &node.format()); + } + + self + } + + /// Writes a single entry using the specified separator to separate the entry from a previous entry. + pub(crate) fn entry_with_separator( + &mut self, + separator: &dyn Format>, + content: &dyn Format>, + ) { + self.result = self.result.and_then(|_| { + if self.has_elements { + separator.fmt(self.fmt)?; + } + + self.has_elements = true; + + content.fmt(self.fmt) + }); + } + + /// Finishes the joiner and gets the format result. + pub(crate) fn finish(&mut self) -> FormatResult<()> { + self.result + } +} + +#[cfg(test)] +mod tests { + use crate::comments::Comments; + use crate::context::{NodeLevel, PyFormatContext}; + use crate::prelude::*; + use ruff_formatter::format; + use ruff_formatter::SimpleFormatOptions; + use rustpython_parser::ast::ModModule; + use rustpython_parser::Parse; + + fn format_ranged(level: NodeLevel) -> String { + let source = r#" +a = 10 + + + +three_leading_newlines = 80 + + +two_leading_newlines = 20 + +one_leading_newline = 10 +no_leading_newline = 30 +"#; + + let module = ModModule::parse(source, "test.py").unwrap(); + + let context = + PyFormatContext::new(SimpleFormatOptions::default(), source, Comments::default()); + + let test_formatter = + format_with(|f: &mut PyFormatter| f.join_nodes(level).nodes(&module.body).finish()); + + let formatted = format!(context, [test_formatter]).unwrap(); + let printed = formatted.print().unwrap(); + + printed.as_code().to_string() + } + + // Keeps up to two empty lines + #[test] + fn ranged_builder_top_level() { + let printed = format_ranged(NodeLevel::TopLevel); + + assert_eq!( + &printed, + r#"a = 10 + + +three_leading_newlines = 80 + + +two_leading_newlines = 20 + +one_leading_newline = 10 +no_leading_newline = 30"# + ); + } + + // Should keep at most one empty level + #[test] + fn ranged_builder_statement_level() { + let printed = format_ranged(NodeLevel::Statement); + + assert_eq!( + &printed, + r#"a = 10 + +three_leading_newlines = 80 + +two_leading_newlines = 20 + +one_leading_newline = 10 +no_leading_newline = 30"# + ); + } + + // Removes all empty lines + #[test] + fn ranged_builder_parenthesized_level() { + let printed = format_ranged(NodeLevel::Parenthesized); + + assert_eq!( + &printed, + r#"a = 10 +three_leading_newlines = 80 +two_leading_newlines = 20 +one_leading_newline = 10 +no_leading_newline = 30"# + ); + } +} diff --git a/crates/ruff_python_formatter/src/lib.rs b/crates/ruff_python_formatter/src/lib.rs index 20edde3d99..49cc9fadc9 100644 --- a/crates/ruff_python_formatter/src/lib.rs +++ b/crates/ruff_python_formatter/src/lib.rs @@ -17,6 +17,7 @@ use ruff_python_ast::source_code::{CommentRanges, CommentRangesBuilder, Locator} use crate::comments::{dangling_comments, leading_comments, trailing_comments, Comments}; use crate::context::PyFormatContext; +pub(crate) mod builders; pub mod cli; mod comments; pub(crate) mod context; diff --git a/crates/ruff_python_formatter/src/prelude.rs b/crates/ruff_python_formatter/src/prelude.rs index 309224762b..7382684dba 100644 --- a/crates/ruff_python_formatter/src/prelude.rs +++ b/crates/ruff_python_formatter/src/prelude.rs @@ -1,4 +1,7 @@ #[allow(unused_imports)] -pub(crate) use crate::{AsFormat, FormattedIterExt as _, IntoFormat, PyFormatContext, PyFormatter}; +pub(crate) use crate::{ + builders::PyFormatterExtensions, AsFormat, FormattedIterExt as _, IntoFormat, PyFormatContext, + PyFormatter, +}; #[allow(unused_imports)] pub(crate) use ruff_formatter::prelude::*; diff --git a/crates/ruff_python_formatter/src/statement/mod.rs b/crates/ruff_python_formatter/src/statement/mod.rs index a330b6b5c9..42a18d1d3e 100644 --- a/crates/ruff_python_formatter/src/statement/mod.rs +++ b/crates/ruff_python_formatter/src/statement/mod.rs @@ -1,6 +1,5 @@ -use crate::context::PyFormatContext; -use crate::{AsFormat, IntoFormat, PyFormatter}; -use ruff_formatter::{Format, FormatOwnedWithRule, FormatRefWithRule, FormatResult, FormatRule}; +use crate::prelude::*; +use ruff_formatter::{FormatOwnedWithRule, FormatRefWithRule}; use rustpython_parser::ast::Stmt; pub(crate) mod stmt_ann_assign; @@ -30,6 +29,7 @@ pub(crate) mod stmt_try; pub(crate) mod stmt_try_star; pub(crate) mod stmt_while; pub(crate) mod stmt_with; +pub(crate) mod suite; #[derive(Default)] pub struct FormatStmt; diff --git a/crates/ruff_python_formatter/src/statement/suite.rs b/crates/ruff_python_formatter/src/statement/suite.rs new file mode 100644 index 0000000000..86c7654c71 --- /dev/null +++ b/crates/ruff_python_formatter/src/statement/suite.rs @@ -0,0 +1,211 @@ +use crate::context::NodeLevel; +use crate::prelude::*; +use ruff_formatter::{format_args, FormatOwnedWithRule, FormatRefWithRule, FormatRuleWithOptions}; +use rustpython_parser::ast::{Stmt, Suite}; + +/// Level at which the [`Suite`] appears in the source code. +#[derive(Copy, Clone, Debug)] +pub enum SuiteLevel { + /// Statements at the module level / top level + TopLevel, + + /// Statements in a nested body + Nested, +} + +#[derive(Debug)] +pub struct FormatSuite { + level: SuiteLevel, +} + +impl Default for FormatSuite { + fn default() -> Self { + FormatSuite { + level: SuiteLevel::Nested, + } + } +} + +impl FormatRule> for FormatSuite { + fn fmt(&self, statements: &Suite, f: &mut PyFormatter) -> FormatResult<()> { + let mut joiner = f.join_nodes(match self.level { + SuiteLevel::TopLevel => NodeLevel::TopLevel, + SuiteLevel::Nested => NodeLevel::Statement, + }); + + let mut iter = statements.iter(); + let Some(first) = iter.next() else { + return Ok(()) + }; + + // First entry has never any separator, doesn't matter which one we take; + joiner.entry(first, &first.format()); + + let mut is_last_function_or_class_definition = is_class_or_function_definition(first); + + for statement in iter { + let is_current_function_or_class_definition = + is_class_or_function_definition(statement); + + if is_last_function_or_class_definition || is_current_function_or_class_definition { + match self.level { + SuiteLevel::TopLevel => { + joiner.entry_with_separator( + &format_args![empty_line(), empty_line()], + &statement.format(), + ); + } + SuiteLevel::Nested => { + joiner + .entry_with_separator(&format_args![empty_line()], &statement.format()); + } + } + } else { + joiner.entry(statement, &statement.format()); + } + + is_last_function_or_class_definition = is_current_function_or_class_definition; + } + + joiner.finish() + } +} + +const fn is_class_or_function_definition(stmt: &Stmt) -> bool { + matches!( + stmt, + Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) | Stmt::ClassDef(_) + ) +} + +impl FormatRuleWithOptions> for FormatSuite { + type Options = SuiteLevel; + + fn with_options(mut self, options: Self::Options) -> Self { + self.level = options; + self + } +} + +impl<'ast> AsFormat> for Suite { + type Format<'a> = FormatRefWithRule<'a, Suite, FormatSuite, PyFormatContext<'ast>>; + + fn format(&self) -> Self::Format<'_> { + FormatRefWithRule::new(self, FormatSuite::default()) + } +} + +impl<'ast> IntoFormat> for Suite { + type Format = FormatOwnedWithRule>; + fn into_format(self) -> Self::Format { + FormatOwnedWithRule::new(self, FormatSuite::default()) + } +} + +#[cfg(test)] +mod tests { + use crate::comments::Comments; + use crate::prelude::*; + use crate::statement::suite::SuiteLevel; + use ruff_formatter::{format, SimpleFormatOptions}; + use rustpython_parser::ast::Suite; + use rustpython_parser::Parse; + + fn format_suite(level: SuiteLevel) -> String { + let source = r#" +a = 10 + + + +three_leading_newlines = 80 + + +two_leading_newlines = 20 + +one_leading_newline = 10 +no_leading_newline = 30 +class InTheMiddle: + pass +trailing_statement = 1 +def func(): + pass +def trailing_func(): + pass +"#; + + let statements = Suite::parse(source, "test.py").unwrap(); + + let context = + PyFormatContext::new(SimpleFormatOptions::default(), source, Comments::default()); + + let test_formatter = + format_with(|f: &mut PyFormatter| statements.format().with_options(level).fmt(f)); + + let formatted = format!(context, [test_formatter]).unwrap(); + let printed = formatted.print().unwrap(); + + printed.as_code().to_string() + } + + #[test] + fn top_level() { + let formatted = format_suite(SuiteLevel::TopLevel); + + assert_eq!( + formatted, + r#"a = 10 + + +three_leading_newlines = 80 + + +two_leading_newlines = 20 + +one_leading_newline = 10 +no_leading_newline = 30 + + +class InTheMiddle: + pass + + +trailing_statement = 1 + + +def func(): + pass + + +def trailing_func(): + pass"# + ); + } + + #[test] + fn nested_level() { + let formatted = format_suite(SuiteLevel::Nested); + + assert_eq!( + formatted, + r#"a = 10 + +three_leading_newlines = 80 + +two_leading_newlines = 20 + +one_leading_newline = 10 +no_leading_newline = 30 + +class InTheMiddle: + pass + +trailing_statement = 1 + +def func(): + pass + +def trailing_func(): + pass"# + ); + } +}