formatter: WithNodeLevel helper (#6212)

This commit is contained in:
Micha Reiser 2023-07-31 23:22:17 +02:00 committed by GitHub
parent 615337a54d
commit 38b5726948
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 208 additions and 212 deletions

View file

@ -4,7 +4,7 @@ use ruff_python_trivia::{SimpleToken, SimpleTokenKind, SimpleTokenizer};
use ruff_text_size::{TextRange, TextSize};
use crate::comments::{dangling_comments, SourceComment};
use crate::context::NodeLevel;
use crate::context::{NodeLevel, WithNodeLevel};
use crate::prelude::*;
use crate::MagicTrailingComma;
@ -24,21 +24,18 @@ pub(crate) struct ParenthesizeIfExpands<'a, 'ast> {
impl<'ast> Format<PyFormatContext<'ast>> for ParenthesizeIfExpands<'_, 'ast> {
fn fmt(&self, f: &mut Formatter<PyFormatContext<'ast>>) -> FormatResult<()> {
let saved_level = f.context().node_level();
{
let mut f = WithNodeLevel::new(NodeLevel::ParenthesizedExpression, f);
f.context_mut()
.set_node_level(NodeLevel::ParenthesizedExpression);
let result = group(&format_args![
if_group_breaks(&text("(")),
soft_block_indent(&Arguments::from(&self.inner)),
if_group_breaks(&text(")")),
])
.fmt(f);
f.context_mut().set_node_level(saved_level);
result
write!(
f,
[group(&format_args![
if_group_breaks(&text("(")),
soft_block_indent(&Arguments::from(&self.inner)),
if_group_breaks(&text(")")),
])]
)
}
}
}

View file

@ -1,6 +1,7 @@
use crate::comments::Comments;
use crate::PyFormatOptions;
use ruff_formatter::{FormatContext, GroupId, SourceCode};
use ruff_formatter::prelude::*;
use ruff_formatter::{Arguments, Buffer, FormatContext, GroupId, SourceCode};
use ruff_source_file::Locator;
use std::fmt::{Debug, Formatter};
@ -94,3 +95,51 @@ impl NodeLevel {
)
}
}
pub(crate) struct WithNodeLevel<'ast, 'buf, B>
where
B: Buffer<Context = PyFormatContext<'ast>>,
{
buffer: &'buf mut B,
saved_level: NodeLevel,
}
impl<'ast, 'buf, B> WithNodeLevel<'ast, 'buf, B>
where
B: Buffer<Context = PyFormatContext<'ast>>,
{
pub(crate) fn new(level: NodeLevel, buffer: &'buf mut B) -> Self {
let context = buffer.state_mut().context_mut();
let saved_level = context.node_level();
context.set_node_level(level);
Self {
buffer,
saved_level,
}
}
#[inline]
pub(crate) fn write_fmt(&mut self, arguments: Arguments<B::Context>) -> FormatResult<()> {
self.buffer.write_fmt(arguments)
}
#[allow(unused)]
#[inline]
pub(crate) fn write_element(&mut self, element: FormatElement) -> FormatResult<()> {
self.buffer.write_element(element)
}
}
impl<'ast, B> Drop for WithNodeLevel<'ast, '_, B>
where
B: Buffer<Context = PyFormatContext<'ast>>,
{
fn drop(&mut self) {
self.buffer
.state_mut()
.context_mut()
.set_node_level(self.saved_level);
}
}

View file

@ -4,8 +4,8 @@ use ruff_formatter::{format_args, write};
use ruff_python_ast::node::{AnyNodeRef, AstNode};
use crate::comments::trailing_comments;
use crate::context::NodeLevel;
use crate::context::PyFormatContext;
use crate::context::{NodeLevel, WithNodeLevel};
use crate::expression::expr_tuple::TupleParentheses;
use crate::expression::parentheses::{NeedsParentheses, OptionalParentheses};
use crate::prelude::*;
@ -30,34 +30,22 @@ impl FormatNodeRule<ExprSubscript> for FormatExprSubscript {
"A subscript expression can only have a single dangling comment, the one after the bracket"
);
if let NodeLevel::Expression(Some(group_id)) = f.context().node_level() {
if let NodeLevel::Expression(Some(_)) = f.context().node_level() {
// Enforce the optional parentheses for parenthesized values.
f.context_mut().set_node_level(NodeLevel::Expression(None));
let result = value.format().fmt(f);
f.context_mut()
.set_node_level(NodeLevel::Expression(Some(group_id)));
result?;
let mut f = WithNodeLevel::new(NodeLevel::Expression(None), f);
write!(f, [value.format()])?;
} else {
value.format().fmt(f)?;
}
let format_slice = format_with(|f: &mut PyFormatter| {
let saved_level = f.context().node_level();
f.context_mut()
.set_node_level(NodeLevel::ParenthesizedExpression);
let mut f = WithNodeLevel::new(NodeLevel::ParenthesizedExpression, f);
let result = if let Expr::Tuple(tuple) = slice.as_ref() {
tuple
.format()
.with_options(TupleParentheses::Preserve)
.fmt(f)
if let Expr::Tuple(tuple) = slice.as_ref() {
write!(f, [tuple.format().with_options(TupleParentheses::Preserve)])
} else {
slice.format().fmt(f)
};
f.context_mut().set_node_level(saved_level);
result
write!(f, [slice.format()])
}
});
write!(

View file

@ -3,12 +3,14 @@ use std::cmp::Ordering;
use ruff_python_ast as ast;
use ruff_python_ast::{Expr, Operator};
use ruff_formatter::{FormatOwnedWithRule, FormatRefWithRule, FormatRule, FormatRuleWithOptions};
use ruff_formatter::{
write, FormatOwnedWithRule, FormatRefWithRule, FormatRule, FormatRuleWithOptions,
};
use ruff_python_ast::node::AnyNodeRef;
use ruff_python_ast::visitor::preorder::{walk_expr, PreorderVisitor};
use crate::builders::parenthesize_if_expands;
use crate::context::NodeLevel;
use crate::context::{NodeLevel, WithNodeLevel};
use crate::expression::parentheses::{
is_expression_parenthesized, optional_parentheses, parenthesized, NeedsParentheses,
OptionalParentheses, Parentheses, Parenthesize,
@ -106,21 +108,16 @@ impl FormatRule<Expr, PyFormatContext<'_>> for FormatExpr {
if parenthesize {
parenthesized("(", &format_expr, ")").fmt(f)
} else {
let saved_level = match f.context().node_level() {
saved_level @ (NodeLevel::TopLevel | NodeLevel::CompoundStatement) => {
f.context_mut().set_node_level(NodeLevel::Expression(None));
Some(saved_level)
let level = match f.context().node_level() {
NodeLevel::TopLevel | NodeLevel::CompoundStatement => NodeLevel::Expression(None),
saved_level @ (NodeLevel::Expression(_) | NodeLevel::ParenthesizedExpression) => {
saved_level
}
NodeLevel::Expression(_) | NodeLevel::ParenthesizedExpression => None,
};
let result = Format::fmt(&format_expr, f);
let mut f = WithNodeLevel::new(level, f);
if let Some(saved_level) = saved_level {
f.context_mut().set_node_level(saved_level);
}
result
write!(f, [format_expr])
}
}
}

View file

@ -5,7 +5,7 @@ use ruff_formatter::{format_args, write, Argument, Arguments};
use ruff_python_ast::node::AnyNodeRef;
use ruff_python_trivia::{first_non_trivia_token, SimpleToken, SimpleTokenKind, SimpleTokenizer};
use crate::context::NodeLevel;
use crate::context::{NodeLevel, WithNodeLevel};
use crate::prelude::*;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
@ -134,23 +134,20 @@ impl<'ast> Format<PyFormatContext<'ast>> for FormatParenthesized<'_, 'ast> {
let current_level = f.context().node_level();
f.context_mut()
.set_node_level(NodeLevel::ParenthesizedExpression);
let mut f = WithNodeLevel::new(NodeLevel::ParenthesizedExpression, f);
let result = if let NodeLevel::Expression(Some(group_id)) = current_level {
if let NodeLevel::Expression(Some(group_id)) = current_level {
// Use fits expanded if there's an enclosing group that adds the optional parentheses.
// This ensures that expanding this parenthesized expression does not expand the optional parentheses group.
fits_expanded(&inner)
.with_condition(Some(Condition::if_group_fits_on_line(group_id)))
.fmt(f)
write!(
f,
[fits_expanded(&inner)
.with_condition(Some(Condition::if_group_fits_on_line(group_id)))]
)
} else {
// It's not necessary to wrap the content if it is not inside of an optional_parentheses group.
inner.fmt(f)
};
f.context_mut().set_node_level(current_level);
result
write!(f, [inner])
}
}
}
@ -173,35 +170,30 @@ pub(crate) struct FormatOptionalParentheses<'content, 'ast> {
impl<'ast> Format<PyFormatContext<'ast>> for FormatOptionalParentheses<'_, 'ast> {
fn fmt(&self, f: &mut Formatter<PyFormatContext<'ast>>) -> FormatResult<()> {
let saved_level = f.context().node_level();
// The group id is used as a condition in [`in_parentheses_only`] to create a conditional group
// that is only active if the optional parentheses group expands.
let parens_id = f.group_id("optional_parentheses");
f.context_mut()
.set_node_level(NodeLevel::Expression(Some(parens_id)));
let mut f = WithNodeLevel::new(NodeLevel::Expression(Some(parens_id)), f);
// We can't use `soft_block_indent` here because that would always increment the indent,
// even if the group does not break (the indent is not soft). This would result in
// too deep indentations if a `parenthesized` group expands. Using `indent_if_group_breaks`
// gives us the desired *soft* indentation that is only present if the optional parentheses
// are shown.
let result = group(&format_args![
if_group_breaks(&text("(")),
indent_if_group_breaks(
&format_args![soft_line_break(), Arguments::from(&self.content)],
parens_id
),
soft_line_break(),
if_group_breaks(&text(")"))
])
.with_group_id(Some(parens_id))
.fmt(f);
f.context_mut().set_node_level(saved_level);
result
write!(
f,
[group(&format_args![
if_group_breaks(&text("(")),
indent_if_group_breaks(
&format_args![soft_line_break(), Arguments::from(&self.content)],
parens_id
),
soft_line_break(),
if_group_breaks(&text(")"))
])
.with_group_id(Some(parens_id))]
)
}
}

View file

@ -11,7 +11,7 @@ use crate::comments::{
dangling_comments, leading_comments, leading_node_comments, trailing_comments,
CommentLinePosition, SourceComment,
};
use crate::context::NodeLevel;
use crate::context::{NodeLevel, WithNodeLevel};
use crate::expression::parentheses::parenthesized;
use crate::prelude::*;
use crate::FormatNodeRule;
@ -61,10 +61,6 @@ impl FormatNodeRule<Arguments> for FormatArguments {
kwarg,
} = item;
let saved_level = f.context().node_level();
f.context_mut()
.set_node_level(NodeLevel::ParenthesizedExpression);
let comments = f.context().comments().clone();
let dangling = comments.dangling_comments(item);
let (slash, star) = find_argument_separators(f.context().source(), item);
@ -192,6 +188,8 @@ impl FormatNodeRule<Arguments> for FormatArguments {
Ok(())
});
let mut f = WithNodeLevel::new(NodeLevel::ParenthesizedExpression, f);
let num_arguments = posonlyargs.len()
+ args.len()
+ usize::from(vararg.is_some())
@ -199,7 +197,7 @@ impl FormatNodeRule<Arguments> for FormatArguments {
+ usize::from(kwarg.is_some());
if self.parentheses == ArgumentsParentheses::Never {
group(&format_inner).fmt(f)?;
write!(f, [group(&format_inner)])
} else if num_arguments == 0 {
// No arguments, format any dangling comments between `()`
write!(
@ -209,14 +207,10 @@ impl FormatNodeRule<Arguments> for FormatArguments {
block_indent(&dangling_comments(dangling)),
text(")")
]
)?;
)
} else {
parenthesized("(", &group(&format_inner), ")").fmt(f)?;
write!(f, [parenthesized("(", &group(&format_inner), ")")])
}
f.context_mut().set_node_level(saved_level);
Ok(())
}
fn fmt_dangling_comments(&self, _node: &Arguments, _f: &mut PyFormatter) -> FormatResult<()> {

View file

@ -2,7 +2,7 @@ use ruff_python_ast::{Expr, StmtAssign};
use ruff_formatter::{format_args, write, FormatError};
use crate::context::NodeLevel;
use crate::context::{NodeLevel, WithNodeLevel};
use crate::expression::parentheses::{Parentheses, Parenthesize};
use crate::expression::{has_own_parentheses, maybe_parenthesize_expression};
use crate::prelude::*;
@ -61,13 +61,10 @@ impl Format<PyFormatContext<'_>> for FormatTargets<'_> {
None
};
let saved_level = f.context().node_level();
f.context_mut()
.set_node_level(NodeLevel::Expression(group_id));
let format_first = format_with(|f: &mut PyFormatter| {
let result = if can_omit_parentheses {
first.format().with_options(Parentheses::Never).fmt(f)
let mut f = WithNodeLevel::new(NodeLevel::Expression(group_id), f);
if can_omit_parentheses {
write!(f, [first.format().with_options(Parentheses::Never)])
} else {
write!(
f,
@ -77,11 +74,7 @@ impl Format<PyFormatContext<'_>> for FormatTargets<'_> {
if_group_breaks(&text(")"))
]
)
};
f.context_mut().set_node_level(saved_level);
result
}
});
write!(

View file

@ -3,7 +3,7 @@ use ruff_python_ast::helpers::is_compound_statement;
use ruff_python_ast::{Ranged, Stmt, Suite};
use ruff_python_trivia::{lines_after, lines_before, skip_trailing_trivia};
use crate::context::NodeLevel;
use crate::context::{NodeLevel, WithNodeLevel};
use crate::prelude::*;
/// Level at which the [`Suite`] appears in the source code.
@ -45,125 +45,111 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
let comments = f.context().comments().clone();
let source = f.context().source();
let saved_level = f.context().node_level();
f.context_mut().set_node_level(node_level);
let mut iter = statements.iter();
let Some(first) = iter.next() else {
return Ok(());
};
// Wrap the entire formatting operation in a `format_with` to ensure that we restore
// context regardless of whether an error occurs.
let formatted = format_with(|f| {
let mut iter = statements.iter();
let Some(first) = iter.next() else {
return Ok(());
};
let mut f = WithNodeLevel::new(node_level, f);
// First entry has never any separator, doesn't matter which one we take.
write!(f, [first.format()])?;
// First entry has never any separator, doesn't matter which one we take.
write!(f, [first.format()])?;
let mut last = first;
let mut last = first;
for statement in iter {
if is_class_or_function_definition(last)
|| is_class_or_function_definition(statement)
{
match self.level {
SuiteLevel::TopLevel => {
write!(f, [empty_line(), empty_line(), statement.format()])?;
}
SuiteLevel::Nested => {
write!(f, [empty_line(), statement.format()])?;
}
for statement in iter {
if is_class_or_function_definition(last) || is_class_or_function_definition(statement) {
match self.level {
SuiteLevel::TopLevel => {
write!(f, [empty_line(), empty_line(), statement.format()])?;
}
} else if is_import_definition(last) && !is_import_definition(statement) {
write!(f, [empty_line(), statement.format()])?;
} else if is_compound_statement(last) {
// Handles the case where a body has trailing comments. The issue is that RustPython does not include
// the comments in the range of the suite. This means, the body ends right after the last statement in the body.
// ```python
// def test():
// ...
// # The body of `test` ends right after `...` and before this comment
//
// # leading comment
//
//
// a = 10
// ```
// Using `lines_after` for the node doesn't work because it would count the lines after the `...`
// which is 0 instead of 1, the number of lines between the trailing comment and
// the leading comment. This is why the suite handling counts the lines before the
// start of the next statement or before the first leading comments for compound statements.
let start =
if let Some(first_leading) = comments.leading_comments(statement).first() {
first_leading.slice().start()
} else {
statement.start()
};
match lines_before(start, source) {
0 | 1 => write!(f, [hard_line_break()])?,
2 => write!(f, [empty_line()])?,
3.. => {
if self.level.is_nested() {
write!(f, [empty_line()])?;
} else {
write!(f, [empty_line(), empty_line()])?;
}
}
SuiteLevel::Nested => {
write!(f, [empty_line(), statement.format()])?;
}
write!(f, [statement.format()])?;
} else {
// Insert the appropriate number of empty lines based on the node level, e.g.:
// * [`NodeLevel::Module`]: Up to two empty lines
// * [`NodeLevel::CompoundStatement`]: Up to one empty line
// * [`NodeLevel::Expression`]: No empty lines
let count_lines = |offset| {
// It's necessary to skip any trailing line comment because RustPython doesn't include trailing comments
// in the node's range
// ```python
// a # The range of `a` ends right before this comment
//
// b
// ```
//
// Simply using `lines_after` doesn't work if a statement has a trailing comment because
// it then counts the lines between the statement and the trailing comment, which is
// always 0. This is why it skips any trailing trivia (trivia that's on the same line)
// and counts the lines after.
let after_trailing_trivia = skip_trailing_trivia(offset, source);
lines_after(after_trailing_trivia, source)
}
} else if is_import_definition(last) && !is_import_definition(statement) {
write!(f, [empty_line(), statement.format()])?;
} else if is_compound_statement(last) {
// Handles the case where a body has trailing comments. The issue is that RustPython does not include
// the comments in the range of the suite. This means, the body ends right after the last statement in the body.
// ```python
// def test():
// ...
// # The body of `test` ends right after `...` and before this comment
//
// # leading comment
//
//
// a = 10
// ```
// Using `lines_after` for the node doesn't work because it would count the lines after the `...`
// which is 0 instead of 1, the number of lines between the trailing comment and
// the leading comment. This is why the suite handling counts the lines before the
// start of the next statement or before the first leading comments for compound statements.
let start =
if let Some(first_leading) = comments.leading_comments(statement).first() {
first_leading.slice().start()
} else {
statement.start()
};
match node_level {
NodeLevel::TopLevel => match count_lines(last.end()) {
0 | 1 => write!(f, [hard_line_break()])?,
2 => write!(f, [empty_line()])?,
_ => write!(f, [empty_line(), empty_line()])?,
},
NodeLevel::CompoundStatement => match count_lines(last.end()) {
0 | 1 => write!(f, [hard_line_break()])?,
_ => write!(f, [empty_line()])?,
},
NodeLevel::Expression(_) | NodeLevel::ParenthesizedExpression => {
write!(f, [hard_line_break()])?;
match lines_before(start, source) {
0 | 1 => write!(f, [hard_line_break()])?,
2 => write!(f, [empty_line()])?,
3.. => {
if self.level.is_nested() {
write!(f, [empty_line()])?;
} else {
write!(f, [empty_line(), empty_line()])?;
}
}
write!(f, [statement.format()])?;
}
last = statement;
write!(f, [statement.format()])?;
} else {
// Insert the appropriate number of empty lines based on the node level, e.g.:
// * [`NodeLevel::Module`]: Up to two empty lines
// * [`NodeLevel::CompoundStatement`]: Up to one empty line
// * [`NodeLevel::Expression`]: No empty lines
let count_lines = |offset| {
// It's necessary to skip any trailing line comment because RustPython doesn't include trailing comments
// in the node's range
// ```python
// a # The range of `a` ends right before this comment
//
// b
// ```
//
// Simply using `lines_after` doesn't work if a statement has a trailing comment because
// it then counts the lines between the statement and the trailing comment, which is
// always 0. This is why it skips any trailing trivia (trivia that's on the same line)
// and counts the lines after.
let after_trailing_trivia = skip_trailing_trivia(offset, source);
lines_after(after_trailing_trivia, source)
};
match node_level {
NodeLevel::TopLevel => match count_lines(last.end()) {
0 | 1 => write!(f, [hard_line_break()])?,
2 => write!(f, [empty_line()])?,
_ => write!(f, [empty_line(), empty_line()])?,
},
NodeLevel::CompoundStatement => match count_lines(last.end()) {
0 | 1 => write!(f, [hard_line_break()])?,
_ => write!(f, [empty_line()])?,
},
NodeLevel::Expression(_) | NodeLevel::ParenthesizedExpression => {
write!(f, [hard_line_break()])?;
}
}
write!(f, [statement.format()])?;
}
Ok(())
});
last = statement;
}
let result = formatted.fmt(f);
f.context_mut().set_node_level(saved_level);
result
Ok(())
}
}