diff --git a/crates/ruff_python_formatter/src/builders.rs b/crates/ruff_python_formatter/src/builders.rs index 84bdcba390..9a0ebe93f0 100644 --- a/crates/ruff_python_formatter/src/builders.rs +++ b/crates/ruff_python_formatter/src/builders.rs @@ -4,7 +4,7 @@ use ruff_python_trivia::{SimpleToken, SimpleTokenKind, SimpleTokenizer}; use ruff_text_size::{TextRange, TextSize}; use crate::comments::{dangling_comments, SourceComment}; -use crate::context::NodeLevel; +use crate::context::{NodeLevel, WithNodeLevel}; use crate::prelude::*; use crate::MagicTrailingComma; @@ -24,21 +24,18 @@ pub(crate) struct ParenthesizeIfExpands<'a, 'ast> { impl<'ast> Format> for ParenthesizeIfExpands<'_, 'ast> { fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { - let saved_level = f.context().node_level(); + { + let mut f = WithNodeLevel::new(NodeLevel::ParenthesizedExpression, f); - f.context_mut() - .set_node_level(NodeLevel::ParenthesizedExpression); - - let result = group(&format_args![ - if_group_breaks(&text("(")), - soft_block_indent(&Arguments::from(&self.inner)), - if_group_breaks(&text(")")), - ]) - .fmt(f); - - f.context_mut().set_node_level(saved_level); - - result + write!( + f, + [group(&format_args![ + if_group_breaks(&text("(")), + soft_block_indent(&Arguments::from(&self.inner)), + if_group_breaks(&text(")")), + ])] + ) + } } } diff --git a/crates/ruff_python_formatter/src/context.rs b/crates/ruff_python_formatter/src/context.rs index 049fd67e54..23a3530a24 100644 --- a/crates/ruff_python_formatter/src/context.rs +++ b/crates/ruff_python_formatter/src/context.rs @@ -1,6 +1,7 @@ use crate::comments::Comments; use crate::PyFormatOptions; -use ruff_formatter::{FormatContext, GroupId, SourceCode}; +use ruff_formatter::prelude::*; +use ruff_formatter::{Arguments, Buffer, FormatContext, GroupId, SourceCode}; use ruff_source_file::Locator; use std::fmt::{Debug, Formatter}; @@ -94,3 +95,51 @@ impl NodeLevel { ) } } + +pub(crate) struct WithNodeLevel<'ast, 'buf, B> +where + B: Buffer>, +{ + buffer: &'buf mut B, + saved_level: NodeLevel, +} + +impl<'ast, 'buf, B> WithNodeLevel<'ast, 'buf, B> +where + B: Buffer>, +{ + pub(crate) fn new(level: NodeLevel, buffer: &'buf mut B) -> Self { + let context = buffer.state_mut().context_mut(); + let saved_level = context.node_level(); + + context.set_node_level(level); + + Self { + buffer, + saved_level, + } + } + + #[inline] + pub(crate) fn write_fmt(&mut self, arguments: Arguments) -> FormatResult<()> { + self.buffer.write_fmt(arguments) + } + + #[allow(unused)] + #[inline] + pub(crate) fn write_element(&mut self, element: FormatElement) -> FormatResult<()> { + self.buffer.write_element(element) + } +} + +impl<'ast, B> Drop for WithNodeLevel<'ast, '_, B> +where + B: Buffer>, +{ + fn drop(&mut self) { + self.buffer + .state_mut() + .context_mut() + .set_node_level(self.saved_level); + } +} diff --git a/crates/ruff_python_formatter/src/expression/expr_subscript.rs b/crates/ruff_python_formatter/src/expression/expr_subscript.rs index 47089d2893..464028c2e1 100644 --- a/crates/ruff_python_formatter/src/expression/expr_subscript.rs +++ b/crates/ruff_python_formatter/src/expression/expr_subscript.rs @@ -4,8 +4,8 @@ use ruff_formatter::{format_args, write}; use ruff_python_ast::node::{AnyNodeRef, AstNode}; use crate::comments::trailing_comments; -use crate::context::NodeLevel; use crate::context::PyFormatContext; +use crate::context::{NodeLevel, WithNodeLevel}; use crate::expression::expr_tuple::TupleParentheses; use crate::expression::parentheses::{NeedsParentheses, OptionalParentheses}; use crate::prelude::*; @@ -30,34 +30,22 @@ impl FormatNodeRule for FormatExprSubscript { "A subscript expression can only have a single dangling comment, the one after the bracket" ); - if let NodeLevel::Expression(Some(group_id)) = f.context().node_level() { + if let NodeLevel::Expression(Some(_)) = f.context().node_level() { // Enforce the optional parentheses for parenthesized values. - f.context_mut().set_node_level(NodeLevel::Expression(None)); - let result = value.format().fmt(f); - f.context_mut() - .set_node_level(NodeLevel::Expression(Some(group_id))); - result?; + let mut f = WithNodeLevel::new(NodeLevel::Expression(None), f); + write!(f, [value.format()])?; } else { value.format().fmt(f)?; } let format_slice = format_with(|f: &mut PyFormatter| { - let saved_level = f.context().node_level(); - f.context_mut() - .set_node_level(NodeLevel::ParenthesizedExpression); + let mut f = WithNodeLevel::new(NodeLevel::ParenthesizedExpression, f); - let result = if let Expr::Tuple(tuple) = slice.as_ref() { - tuple - .format() - .with_options(TupleParentheses::Preserve) - .fmt(f) + if let Expr::Tuple(tuple) = slice.as_ref() { + write!(f, [tuple.format().with_options(TupleParentheses::Preserve)]) } else { - slice.format().fmt(f) - }; - - f.context_mut().set_node_level(saved_level); - - result + write!(f, [slice.format()]) + } }); write!( diff --git a/crates/ruff_python_formatter/src/expression/mod.rs b/crates/ruff_python_formatter/src/expression/mod.rs index 690509c1cd..a7c395eb04 100644 --- a/crates/ruff_python_formatter/src/expression/mod.rs +++ b/crates/ruff_python_formatter/src/expression/mod.rs @@ -3,12 +3,14 @@ use std::cmp::Ordering; use ruff_python_ast as ast; use ruff_python_ast::{Expr, Operator}; -use ruff_formatter::{FormatOwnedWithRule, FormatRefWithRule, FormatRule, FormatRuleWithOptions}; +use ruff_formatter::{ + write, FormatOwnedWithRule, FormatRefWithRule, FormatRule, FormatRuleWithOptions, +}; use ruff_python_ast::node::AnyNodeRef; use ruff_python_ast::visitor::preorder::{walk_expr, PreorderVisitor}; use crate::builders::parenthesize_if_expands; -use crate::context::NodeLevel; +use crate::context::{NodeLevel, WithNodeLevel}; use crate::expression::parentheses::{ is_expression_parenthesized, optional_parentheses, parenthesized, NeedsParentheses, OptionalParentheses, Parentheses, Parenthesize, @@ -106,21 +108,16 @@ impl FormatRule> for FormatExpr { if parenthesize { parenthesized("(", &format_expr, ")").fmt(f) } else { - let saved_level = match f.context().node_level() { - saved_level @ (NodeLevel::TopLevel | NodeLevel::CompoundStatement) => { - f.context_mut().set_node_level(NodeLevel::Expression(None)); - Some(saved_level) + let level = match f.context().node_level() { + NodeLevel::TopLevel | NodeLevel::CompoundStatement => NodeLevel::Expression(None), + saved_level @ (NodeLevel::Expression(_) | NodeLevel::ParenthesizedExpression) => { + saved_level } - NodeLevel::Expression(_) | NodeLevel::ParenthesizedExpression => None, }; - let result = Format::fmt(&format_expr, f); + let mut f = WithNodeLevel::new(level, f); - if let Some(saved_level) = saved_level { - f.context_mut().set_node_level(saved_level); - } - - result + write!(f, [format_expr]) } } } diff --git a/crates/ruff_python_formatter/src/expression/parentheses.rs b/crates/ruff_python_formatter/src/expression/parentheses.rs index bdef4796bf..abe7cc9d62 100644 --- a/crates/ruff_python_formatter/src/expression/parentheses.rs +++ b/crates/ruff_python_formatter/src/expression/parentheses.rs @@ -5,7 +5,7 @@ use ruff_formatter::{format_args, write, Argument, Arguments}; use ruff_python_ast::node::AnyNodeRef; use ruff_python_trivia::{first_non_trivia_token, SimpleToken, SimpleTokenKind, SimpleTokenizer}; -use crate::context::NodeLevel; +use crate::context::{NodeLevel, WithNodeLevel}; use crate::prelude::*; #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -134,23 +134,20 @@ impl<'ast> Format> for FormatParenthesized<'_, 'ast> { let current_level = f.context().node_level(); - f.context_mut() - .set_node_level(NodeLevel::ParenthesizedExpression); + let mut f = WithNodeLevel::new(NodeLevel::ParenthesizedExpression, f); - let result = if let NodeLevel::Expression(Some(group_id)) = current_level { + if let NodeLevel::Expression(Some(group_id)) = current_level { // Use fits expanded if there's an enclosing group that adds the optional parentheses. // This ensures that expanding this parenthesized expression does not expand the optional parentheses group. - fits_expanded(&inner) - .with_condition(Some(Condition::if_group_fits_on_line(group_id))) - .fmt(f) + write!( + f, + [fits_expanded(&inner) + .with_condition(Some(Condition::if_group_fits_on_line(group_id)))] + ) } else { // It's not necessary to wrap the content if it is not inside of an optional_parentheses group. - inner.fmt(f) - }; - - f.context_mut().set_node_level(current_level); - - result + write!(f, [inner]) + } } } @@ -173,35 +170,30 @@ pub(crate) struct FormatOptionalParentheses<'content, 'ast> { impl<'ast> Format> for FormatOptionalParentheses<'_, 'ast> { fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { - let saved_level = f.context().node_level(); - // The group id is used as a condition in [`in_parentheses_only`] to create a conditional group // that is only active if the optional parentheses group expands. let parens_id = f.group_id("optional_parentheses"); - f.context_mut() - .set_node_level(NodeLevel::Expression(Some(parens_id))); + let mut f = WithNodeLevel::new(NodeLevel::Expression(Some(parens_id)), f); // We can't use `soft_block_indent` here because that would always increment the indent, // even if the group does not break (the indent is not soft). This would result in // too deep indentations if a `parenthesized` group expands. Using `indent_if_group_breaks` // gives us the desired *soft* indentation that is only present if the optional parentheses // are shown. - let result = group(&format_args![ - if_group_breaks(&text("(")), - indent_if_group_breaks( - &format_args![soft_line_break(), Arguments::from(&self.content)], - parens_id - ), - soft_line_break(), - if_group_breaks(&text(")")) - ]) - .with_group_id(Some(parens_id)) - .fmt(f); - - f.context_mut().set_node_level(saved_level); - - result + write!( + f, + [group(&format_args![ + if_group_breaks(&text("(")), + indent_if_group_breaks( + &format_args![soft_line_break(), Arguments::from(&self.content)], + parens_id + ), + soft_line_break(), + if_group_breaks(&text(")")) + ]) + .with_group_id(Some(parens_id))] + ) } } diff --git a/crates/ruff_python_formatter/src/other/arguments.rs b/crates/ruff_python_formatter/src/other/arguments.rs index 7ef2594610..9f90dfb25d 100644 --- a/crates/ruff_python_formatter/src/other/arguments.rs +++ b/crates/ruff_python_formatter/src/other/arguments.rs @@ -11,7 +11,7 @@ use crate::comments::{ dangling_comments, leading_comments, leading_node_comments, trailing_comments, CommentLinePosition, SourceComment, }; -use crate::context::NodeLevel; +use crate::context::{NodeLevel, WithNodeLevel}; use crate::expression::parentheses::parenthesized; use crate::prelude::*; use crate::FormatNodeRule; @@ -61,10 +61,6 @@ impl FormatNodeRule for FormatArguments { kwarg, } = item; - let saved_level = f.context().node_level(); - f.context_mut() - .set_node_level(NodeLevel::ParenthesizedExpression); - let comments = f.context().comments().clone(); let dangling = comments.dangling_comments(item); let (slash, star) = find_argument_separators(f.context().source(), item); @@ -192,6 +188,8 @@ impl FormatNodeRule for FormatArguments { Ok(()) }); + let mut f = WithNodeLevel::new(NodeLevel::ParenthesizedExpression, f); + let num_arguments = posonlyargs.len() + args.len() + usize::from(vararg.is_some()) @@ -199,7 +197,7 @@ impl FormatNodeRule for FormatArguments { + usize::from(kwarg.is_some()); if self.parentheses == ArgumentsParentheses::Never { - group(&format_inner).fmt(f)?; + write!(f, [group(&format_inner)]) } else if num_arguments == 0 { // No arguments, format any dangling comments between `()` write!( @@ -209,14 +207,10 @@ impl FormatNodeRule for FormatArguments { block_indent(&dangling_comments(dangling)), text(")") ] - )?; + ) } else { - parenthesized("(", &group(&format_inner), ")").fmt(f)?; + write!(f, [parenthesized("(", &group(&format_inner), ")")]) } - - f.context_mut().set_node_level(saved_level); - - Ok(()) } fn fmt_dangling_comments(&self, _node: &Arguments, _f: &mut PyFormatter) -> FormatResult<()> { diff --git a/crates/ruff_python_formatter/src/statement/stmt_assign.rs b/crates/ruff_python_formatter/src/statement/stmt_assign.rs index 7465591b66..c73634e44b 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_assign.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_assign.rs @@ -2,7 +2,7 @@ use ruff_python_ast::{Expr, StmtAssign}; use ruff_formatter::{format_args, write, FormatError}; -use crate::context::NodeLevel; +use crate::context::{NodeLevel, WithNodeLevel}; use crate::expression::parentheses::{Parentheses, Parenthesize}; use crate::expression::{has_own_parentheses, maybe_parenthesize_expression}; use crate::prelude::*; @@ -61,13 +61,10 @@ impl Format> for FormatTargets<'_> { None }; - let saved_level = f.context().node_level(); - f.context_mut() - .set_node_level(NodeLevel::Expression(group_id)); - let format_first = format_with(|f: &mut PyFormatter| { - let result = if can_omit_parentheses { - first.format().with_options(Parentheses::Never).fmt(f) + let mut f = WithNodeLevel::new(NodeLevel::Expression(group_id), f); + if can_omit_parentheses { + write!(f, [first.format().with_options(Parentheses::Never)]) } else { write!( f, @@ -77,11 +74,7 @@ impl Format> for FormatTargets<'_> { if_group_breaks(&text(")")) ] ) - }; - - f.context_mut().set_node_level(saved_level); - - result + } }); write!( diff --git a/crates/ruff_python_formatter/src/statement/suite.rs b/crates/ruff_python_formatter/src/statement/suite.rs index acce13d79c..735cf6cc63 100644 --- a/crates/ruff_python_formatter/src/statement/suite.rs +++ b/crates/ruff_python_formatter/src/statement/suite.rs @@ -3,7 +3,7 @@ use ruff_python_ast::helpers::is_compound_statement; use ruff_python_ast::{Ranged, Stmt, Suite}; use ruff_python_trivia::{lines_after, lines_before, skip_trailing_trivia}; -use crate::context::NodeLevel; +use crate::context::{NodeLevel, WithNodeLevel}; use crate::prelude::*; /// Level at which the [`Suite`] appears in the source code. @@ -45,125 +45,111 @@ impl FormatRule> for FormatSuite { let comments = f.context().comments().clone(); let source = f.context().source(); - let saved_level = f.context().node_level(); - f.context_mut().set_node_level(node_level); + let mut iter = statements.iter(); + let Some(first) = iter.next() else { + return Ok(()); + }; - // Wrap the entire formatting operation in a `format_with` to ensure that we restore - // context regardless of whether an error occurs. - let formatted = format_with(|f| { - let mut iter = statements.iter(); - let Some(first) = iter.next() else { - return Ok(()); - }; + let mut f = WithNodeLevel::new(node_level, f); + // First entry has never any separator, doesn't matter which one we take. + write!(f, [first.format()])?; - // First entry has never any separator, doesn't matter which one we take. - write!(f, [first.format()])?; + let mut last = first; - let mut last = first; - - for statement in iter { - if is_class_or_function_definition(last) - || is_class_or_function_definition(statement) - { - match self.level { - SuiteLevel::TopLevel => { - write!(f, [empty_line(), empty_line(), statement.format()])?; - } - SuiteLevel::Nested => { - write!(f, [empty_line(), statement.format()])?; - } + for statement in iter { + if is_class_or_function_definition(last) || is_class_or_function_definition(statement) { + match self.level { + SuiteLevel::TopLevel => { + write!(f, [empty_line(), empty_line(), statement.format()])?; } - } else if is_import_definition(last) && !is_import_definition(statement) { - write!(f, [empty_line(), statement.format()])?; - } else if is_compound_statement(last) { - // Handles the case where a body has trailing comments. The issue is that RustPython does not include - // the comments in the range of the suite. This means, the body ends right after the last statement in the body. - // ```python - // def test(): - // ... - // # The body of `test` ends right after `...` and before this comment - // - // # leading comment - // - // - // a = 10 - // ``` - // Using `lines_after` for the node doesn't work because it would count the lines after the `...` - // which is 0 instead of 1, the number of lines between the trailing comment and - // the leading comment. This is why the suite handling counts the lines before the - // start of the next statement or before the first leading comments for compound statements. - let start = - if let Some(first_leading) = comments.leading_comments(statement).first() { - first_leading.slice().start() - } else { - statement.start() - }; - - match lines_before(start, source) { - 0 | 1 => write!(f, [hard_line_break()])?, - 2 => write!(f, [empty_line()])?, - 3.. => { - if self.level.is_nested() { - write!(f, [empty_line()])?; - } else { - write!(f, [empty_line(), empty_line()])?; - } - } + SuiteLevel::Nested => { + write!(f, [empty_line(), statement.format()])?; } - - write!(f, [statement.format()])?; - } else { - // Insert the appropriate number of empty lines based on the node level, e.g.: - // * [`NodeLevel::Module`]: Up to two empty lines - // * [`NodeLevel::CompoundStatement`]: Up to one empty line - // * [`NodeLevel::Expression`]: No empty lines - - let count_lines = |offset| { - // It's necessary to skip any trailing line comment because RustPython doesn't include trailing comments - // in the node's range - // ```python - // a # The range of `a` ends right before this comment - // - // b - // ``` - // - // Simply using `lines_after` doesn't work if a statement has a trailing comment because - // it then counts the lines between the statement and the trailing comment, which is - // always 0. This is why it skips any trailing trivia (trivia that's on the same line) - // and counts the lines after. - let after_trailing_trivia = skip_trailing_trivia(offset, source); - lines_after(after_trailing_trivia, source) + } + } else if is_import_definition(last) && !is_import_definition(statement) { + write!(f, [empty_line(), statement.format()])?; + } else if is_compound_statement(last) { + // Handles the case where a body has trailing comments. The issue is that RustPython does not include + // the comments in the range of the suite. This means, the body ends right after the last statement in the body. + // ```python + // def test(): + // ... + // # The body of `test` ends right after `...` and before this comment + // + // # leading comment + // + // + // a = 10 + // ``` + // Using `lines_after` for the node doesn't work because it would count the lines after the `...` + // which is 0 instead of 1, the number of lines between the trailing comment and + // the leading comment. This is why the suite handling counts the lines before the + // start of the next statement or before the first leading comments for compound statements. + let start = + if let Some(first_leading) = comments.leading_comments(statement).first() { + first_leading.slice().start() + } else { + statement.start() }; - match node_level { - NodeLevel::TopLevel => match count_lines(last.end()) { - 0 | 1 => write!(f, [hard_line_break()])?, - 2 => write!(f, [empty_line()])?, - _ => write!(f, [empty_line(), empty_line()])?, - }, - NodeLevel::CompoundStatement => match count_lines(last.end()) { - 0 | 1 => write!(f, [hard_line_break()])?, - _ => write!(f, [empty_line()])?, - }, - NodeLevel::Expression(_) | NodeLevel::ParenthesizedExpression => { - write!(f, [hard_line_break()])?; + match lines_before(start, source) { + 0 | 1 => write!(f, [hard_line_break()])?, + 2 => write!(f, [empty_line()])?, + 3.. => { + if self.level.is_nested() { + write!(f, [empty_line()])?; + } else { + write!(f, [empty_line(), empty_line()])?; } } - - write!(f, [statement.format()])?; } - last = statement; + write!(f, [statement.format()])?; + } else { + // Insert the appropriate number of empty lines based on the node level, e.g.: + // * [`NodeLevel::Module`]: Up to two empty lines + // * [`NodeLevel::CompoundStatement`]: Up to one empty line + // * [`NodeLevel::Expression`]: No empty lines + + let count_lines = |offset| { + // It's necessary to skip any trailing line comment because RustPython doesn't include trailing comments + // in the node's range + // ```python + // a # The range of `a` ends right before this comment + // + // b + // ``` + // + // Simply using `lines_after` doesn't work if a statement has a trailing comment because + // it then counts the lines between the statement and the trailing comment, which is + // always 0. This is why it skips any trailing trivia (trivia that's on the same line) + // and counts the lines after. + let after_trailing_trivia = skip_trailing_trivia(offset, source); + lines_after(after_trailing_trivia, source) + }; + + match node_level { + NodeLevel::TopLevel => match count_lines(last.end()) { + 0 | 1 => write!(f, [hard_line_break()])?, + 2 => write!(f, [empty_line()])?, + _ => write!(f, [empty_line(), empty_line()])?, + }, + NodeLevel::CompoundStatement => match count_lines(last.end()) { + 0 | 1 => write!(f, [hard_line_break()])?, + _ => write!(f, [empty_line()])?, + }, + NodeLevel::Expression(_) | NodeLevel::ParenthesizedExpression => { + write!(f, [hard_line_break()])?; + } + } + + write!(f, [statement.format()])?; } - Ok(()) - }); + last = statement; + } - let result = formatted.fmt(f); - - f.context_mut().set_node_level(saved_level); - - result + Ok(()) } }