Create dedicated Body nodes in the formatter CST (#3223)

This commit is contained in:
Charlie Marsh 2023-02-27 17:55:05 -05:00 committed by GitHub
parent cd6413ca09
commit 2261e194a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1239 additions and 611 deletions

View file

@ -1,6 +1,6 @@
use crate::core::visitor; use crate::core::visitor;
use crate::core::visitor::Visitor; use crate::core::visitor::Visitor;
use crate::cst::{Alias, Excepthandler, Expr, Pattern, SliceIndex, Stmt}; use crate::cst::{Alias, Body, Excepthandler, Expr, Pattern, SliceIndex, Stmt};
use crate::trivia::{decorate_trivia, TriviaIndex, TriviaToken}; use crate::trivia::{decorate_trivia, TriviaIndex, TriviaToken};
struct AttachmentVisitor { struct AttachmentVisitor {
@ -8,6 +8,14 @@ struct AttachmentVisitor {
} }
impl<'a> Visitor<'a> for AttachmentVisitor { impl<'a> Visitor<'a> for AttachmentVisitor {
fn visit_body(&mut self, body: &'a mut Body) {
let trivia = self.index.body.remove(&body.id());
if let Some(comments) = trivia {
body.trivia.extend(comments);
}
visitor::walk_body(self, body);
}
fn visit_stmt(&mut self, stmt: &'a mut Stmt) { fn visit_stmt(&mut self, stmt: &'a mut Stmt) {
let trivia = self.index.stmt.remove(&stmt.id()); let trivia = self.index.stmt.remove(&stmt.id());
if let Some(comments) = trivia { if let Some(comments) = trivia {
@ -59,5 +67,8 @@ impl<'a> Visitor<'a> for AttachmentVisitor {
pub fn attach(python_cst: &mut [Stmt], trivia: Vec<TriviaToken>) { pub fn attach(python_cst: &mut [Stmt], trivia: Vec<TriviaToken>) {
let index = decorate_trivia(trivia, python_cst); let index = decorate_trivia(trivia, python_cst);
AttachmentVisitor { index }.visit_body(python_cst); let mut visitor = AttachmentVisitor { index };
for stmt in python_cst {
visitor.visit_stmt(stmt);
}
} }

View file

@ -1,3 +1,8 @@
use rustpython_parser::ast::Location;
use crate::core::locator::Locator;
use crate::core::types::Range;
/// Return the leading quote for a string or byte literal (e.g., `"""`). /// Return the leading quote for a string or byte literal (e.g., `"""`).
pub fn leading_quote(content: &str) -> Option<&str> { pub fn leading_quote(content: &str) -> Option<&str> {
if let Some(first_line) = content.lines().next() { if let Some(first_line) = content.lines().next() {
@ -32,6 +37,99 @@ pub fn is_radix_literal(content: &str) -> bool {
|| content.starts_with("0X") || content.starts_with("0X")
} }
/// Expand the range of a compound statement.
///
/// `location` is the start of the compound statement (e.g., the `if` in `if x:`).
/// `end_location` is the end of the last statement in the body.
pub fn expand_indented_block(
location: Location,
end_location: Location,
locator: &Locator,
) -> (Location, Location) {
let contents = locator.contents();
let start_index = locator.index(location);
let end_index = locator.index(end_location);
// Find the colon, which indicates the end of the header.
let mut nesting = 0;
let mut colon = None;
for (start, tok, _end) in rustpython_parser::lexer::lex_located(
&contents[start_index..end_index],
rustpython_parser::Mode::Module,
location,
)
.flatten()
{
match tok {
rustpython_parser::Tok::Colon if nesting == 0 => {
colon = Some(start);
break;
}
rustpython_parser::Tok::Lpar
| rustpython_parser::Tok::Lsqb
| rustpython_parser::Tok::Lbrace => nesting += 1,
rustpython_parser::Tok::Rpar
| rustpython_parser::Tok::Rsqb
| rustpython_parser::Tok::Rbrace => nesting -= 1,
_ => {}
}
}
let colon_location = colon.unwrap();
let colon_index = locator.index(colon_location);
// From here, we have two options: simple statement or compound statement.
let indent = rustpython_parser::lexer::lex_located(
&contents[colon_index..end_index],
rustpython_parser::Mode::Module,
colon_location,
)
.flatten()
.find_map(|(start, tok, _end)| match tok {
rustpython_parser::Tok::Indent => Some(start),
_ => None,
});
let Some(indent_location) = indent else {
// Simple statement: from the colon to the end of the line.
return (colon_location, Location::new(end_location.row() + 1, 0));
};
// Compound statement: from the colon to the end of the block.
let mut offset = 0;
for (index, line) in contents[end_index..].lines().skip(1).enumerate() {
if line.is_empty() {
continue;
}
if line
.chars()
.take(indent_location.column())
.all(char::is_whitespace)
{
offset = index + 1;
} else {
break;
}
}
let end_location = Location::new(end_location.row() + 1 + offset, 0);
(colon_location, end_location)
}
/// Return true if the `orelse` block of an `if` statement is an `elif` statement.
pub fn is_elif(orelse: &[rustpython_parser::ast::Stmt], locator: &Locator) -> bool {
if orelse.len() == 1 && matches!(orelse[0].node, rustpython_parser::ast::StmtKind::If { .. }) {
let (source, start, end) = locator.slice(Range::new(
orelse[0].location,
orelse[0].end_location.unwrap(),
));
if source[start..end].starts_with("elif") {
return true;
}
}
false
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[test] #[test]

View file

@ -108,6 +108,15 @@ impl<'a> Locator<'a> {
self.index.get_or_init(|| index(self.contents)) self.index.get_or_init(|| index(self.contents))
} }
pub fn index(&self, location: Location) -> usize {
let index = self.get_or_init_index();
truncate(location, index, self.contents)
}
pub fn contents(&self) -> &str {
self.contents
}
/// Slice the source code at a [`Range`]. /// Slice the source code at a [`Range`].
pub fn slice(&self, range: Range) -> (Rc<str>, usize, usize) { pub fn slice(&self, range: Range) -> (Rc<str>, usize, usize) {
let index = self.get_or_init_index(); let index = self.get_or_init_index();

View file

@ -1,8 +1,8 @@
use rustpython_parser::ast::Constant; use rustpython_parser::ast::Constant;
use crate::cst::{ use crate::cst::{
Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Excepthandler, ExcepthandlerKind, Expr, Alias, Arg, Arguments, Body, Boolop, Cmpop, Comprehension, Excepthandler, ExcepthandlerKind,
ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern, PatternKind, SliceIndex, Expr, ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern, PatternKind, SliceIndex,
SliceIndexKind, Stmt, StmtKind, Unaryop, Withitem, SliceIndexKind, Stmt, StmtKind, Unaryop, Withitem,
}; };
@ -67,13 +67,13 @@ pub trait Visitor<'a> {
fn visit_pattern(&mut self, pattern: &'a mut Pattern) { fn visit_pattern(&mut self, pattern: &'a mut Pattern) {
walk_pattern(self, pattern); walk_pattern(self, pattern);
} }
fn visit_body(&mut self, body: &'a mut [Stmt]) { fn visit_body(&mut self, body: &'a mut Body) {
walk_body(self, body); walk_body(self, body);
} }
} }
pub fn walk_body<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, body: &'a mut [Stmt]) { pub fn walk_body<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, body: &'a mut Body) {
for stmt in body { for stmt in &mut body.node {
visitor.visit_stmt(stmt); visitor.visit_stmt(stmt);
} }
} }
@ -173,7 +173,9 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
visitor.visit_expr(iter); visitor.visit_expr(iter);
visitor.visit_expr(target); visitor.visit_expr(target);
visitor.visit_body(body); visitor.visit_body(body);
visitor.visit_body(orelse); if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
} }
StmtKind::AsyncFor { StmtKind::AsyncFor {
target, target,
@ -185,17 +187,25 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
visitor.visit_expr(iter); visitor.visit_expr(iter);
visitor.visit_expr(target); visitor.visit_expr(target);
visitor.visit_body(body); visitor.visit_body(body);
visitor.visit_body(orelse); if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
} }
StmtKind::While { test, body, orelse } => { StmtKind::While { test, body, orelse } => {
visitor.visit_expr(test); visitor.visit_expr(test);
visitor.visit_body(body); visitor.visit_body(body);
visitor.visit_body(orelse); if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
} }
StmtKind::If { test, body, orelse } => { StmtKind::If {
test, body, orelse, ..
} => {
visitor.visit_expr(test); visitor.visit_expr(test);
visitor.visit_body(body); visitor.visit_body(body);
visitor.visit_body(orelse); if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
} }
StmtKind::With { items, body, .. } => { StmtKind::With { items, body, .. } => {
for withitem in items { for withitem in items {
@ -210,7 +220,6 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
visitor.visit_body(body); visitor.visit_body(body);
} }
StmtKind::Match { subject, cases } => { StmtKind::Match { subject, cases } => {
// TODO(charlie): Handle `cases`.
visitor.visit_expr(subject); visitor.visit_expr(subject);
for match_case in cases { for match_case in cases {
visitor.visit_match_case(match_case); visitor.visit_match_case(match_case);
@ -234,8 +243,12 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
for excepthandler in handlers { for excepthandler in handlers {
visitor.visit_excepthandler(excepthandler); visitor.visit_excepthandler(excepthandler);
} }
visitor.visit_body(orelse); if let Some(orelse) = orelse {
visitor.visit_body(finalbody); visitor.visit_body(orelse);
}
if let Some(finalbody) = finalbody {
visitor.visit_body(finalbody);
}
} }
StmtKind::TryStar { StmtKind::TryStar {
body, body,
@ -247,8 +260,12 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
for excepthandler in handlers { for excepthandler in handlers {
visitor.visit_excepthandler(excepthandler); visitor.visit_excepthandler(excepthandler);
} }
visitor.visit_body(orelse); if let Some(orelse) = orelse {
visitor.visit_body(finalbody); visitor.visit_body(orelse);
}
if let Some(finalbody) = finalbody {
visitor.visit_body(finalbody);
}
} }
StmtKind::Assert { test, msg } => { StmtKind::Assert { test, msg } => {
visitor.visit_expr(test); visitor.visit_expr(test);

File diff suppressed because it is too large Load diff

View file

@ -4,17 +4,55 @@ use ruff_text_size::{TextRange, TextSize};
use crate::context::ASTFormatContext; use crate::context::ASTFormatContext;
use crate::core::types::Range; use crate::core::types::Range;
use crate::cst::Stmt; use crate::cst::{Body, Stmt};
use crate::shared_traits::AsFormat; use crate::shared_traits::AsFormat;
use crate::trivia::{Relationship, TriviaKind};
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct Block<'a> { pub struct Block<'a> {
body: &'a [Stmt], body: &'a Body,
} }
impl Format<ASTFormatContext<'_>> for Block<'_> { impl Format<ASTFormatContext<'_>> for Block<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> { fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
for (i, stmt) in self.body.iter().enumerate() { for (i, stmt) in self.body.node.iter().enumerate() {
if i > 0 {
write!(f, [hard_line_break()])?;
}
write!(f, [stmt.format()])?;
}
for trivia in &self.body.trivia {
if matches!(trivia.relationship, Relationship::Dangling) {
match trivia.kind {
TriviaKind::EmptyLine => {
write!(f, [empty_line()])?;
}
TriviaKind::OwnLineComment(range) => {
write!(f, [literal(range), hard_line_break()])?;
}
_ => {}
}
}
}
Ok(())
}
}
#[inline]
pub fn block(body: &Body) -> Block {
Block { body }
}
#[derive(Copy, Clone)]
pub struct Statements<'a> {
suite: &'a [Stmt],
}
impl Format<ASTFormatContext<'_>> for Statements<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
for (i, stmt) in self.suite.iter().enumerate() {
if i > 0 { if i > 0 {
write!(f, [hard_line_break()])?; write!(f, [hard_line_break()])?;
} }
@ -24,9 +62,8 @@ impl Format<ASTFormatContext<'_>> for Block<'_> {
} }
} }
#[inline] pub fn statements(suite: &[Stmt]) -> Statements {
pub fn block(body: &[Stmt]) -> Block { Statements { suite }
Block { body }
} }
#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[derive(Debug, Copy, Clone, Eq, PartialEq)]

View file

@ -72,13 +72,12 @@ pub struct EndOfLineComments<'a, T> {
impl<T> Format<ASTFormatContext<'_>> for EndOfLineComments<'_, T> { impl<T> Format<ASTFormatContext<'_>> for EndOfLineComments<'_, T> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> { fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
let mut first = true; let mut first = true;
for range in self.item.trivia.iter().filter_map(|trivia| { for range in self
if trivia.relationship.is_trailing() { .item
trivia.kind.end_of_line_comment() .trivia
} else { .iter()
None .filter_map(|trivia| trivia.kind.end_of_line_comment())
} {
}) {
if std::mem::take(&mut first) { if std::mem::take(&mut first) {
write!(f, [line_suffix(&text(" "))])?; write!(f, [line_suffix(&text(" "))])?;
} }

View file

@ -4,6 +4,7 @@ use ruff_formatter::write;
use crate::context::ASTFormatContext; use crate::context::ASTFormatContext;
use crate::cst::MatchCase; use crate::cst::MatchCase;
use crate::format::builders::block; use crate::format::builders::block;
use crate::format::comments::{end_of_line_comments, leading_comments};
use crate::shared_traits::AsFormat; use crate::shared_traits::AsFormat;
pub struct FormatMatchCase<'a> { pub struct FormatMatchCase<'a> {
@ -26,12 +27,16 @@ impl Format<ASTFormatContext<'_>> for FormatMatchCase<'_> {
body, body,
} = self.item; } = self.item;
write!(f, [leading_comments(pattern)])?;
write!(f, [text("case")])?; write!(f, [text("case")])?;
write!(f, [space(), pattern.format()])?; write!(f, [space(), pattern.format()])?;
if let Some(guard) = &guard { if let Some(guard) = &guard {
write!(f, [space(), text("if"), space(), guard.format()])?; write!(f, [space(), text("if"), space(), guard.format()])?;
} }
write!(f, [text(":")])?; write!(f, [text(":")])?;
write!(f, [end_of_line_comments(body)])?;
write!(f, [block_indent(&block(body))])?; write!(f, [block_indent(&block(body))])?;
Ok(()) Ok(())

View file

@ -6,8 +6,8 @@ use ruff_text_size::TextSize;
use crate::context::ASTFormatContext; use crate::context::ASTFormatContext;
use crate::cst::{ use crate::cst::{
Alias, Arguments, Excepthandler, Expr, ExprKind, Keyword, MatchCase, Operator, Stmt, StmtKind, Alias, Arguments, Body, Excepthandler, Expr, ExprKind, Keyword, MatchCase, Operator, Stmt,
Withitem, StmtKind, Withitem,
}; };
use crate::format::builders::{block, join_names}; use crate::format::builders::{block, join_names};
use crate::format::comments::{end_of_line_comments, leading_comments, trailing_comments}; use crate::format::comments::{end_of_line_comments, leading_comments, trailing_comments};
@ -101,13 +101,15 @@ fn format_class_def(
name: &str, name: &str,
bases: &[Expr], bases: &[Expr],
keywords: &[Keyword], keywords: &[Keyword],
body: &[Stmt], body: &Body,
decorator_list: &[Expr], decorator_list: &[Expr],
) -> FormatResult<()> { ) -> FormatResult<()> {
for decorator in decorator_list { for decorator in decorator_list {
write!(f, [text("@"), decorator.format(), hard_line_break()])?; write!(f, [text("@"), decorator.format(), hard_line_break()])?;
} }
write!(f, [leading_comments(body)])?;
write!( write!(
f, f,
[ [
@ -161,6 +163,7 @@ fn format_class_def(
)?; )?;
} }
write!(f, [end_of_line_comments(body)])?;
write!(f, [text(":"), block_indent(&block(body))]) write!(f, [text(":"), block_indent(&block(body))])
} }
@ -170,13 +173,16 @@ fn format_func_def(
name: &str, name: &str,
args: &Arguments, args: &Arguments,
returns: Option<&Expr>, returns: Option<&Expr>,
body: &[Stmt], body: &Body,
decorator_list: &[Expr], decorator_list: &[Expr],
async_: bool, async_: bool,
) -> FormatResult<()> { ) -> FormatResult<()> {
for decorator in decorator_list { for decorator in decorator_list {
write!(f, [text("@"), decorator.format(), hard_line_break()])?; write!(f, [text("@"), decorator.format(), hard_line_break()])?;
} }
write!(f, [leading_comments(body)])?;
if async_ { if async_ {
write!(f, [text("async"), space()])?; write!(f, [text("async"), space()])?;
} }
@ -202,10 +208,10 @@ fn format_func_def(
} }
write!(f, [text(":")])?; write!(f, [text(":")])?;
write!(f, [end_of_line_comments(body)])?;
write!(f, [block_indent(&block(body))])?;
write!(f, [end_of_line_comments(stmt)])?; Ok(())
write!(f, [block_indent(&format_args![block(body)])])
} }
fn format_assign( fn format_assign(
@ -310,8 +316,8 @@ fn format_for(
stmt: &Stmt, stmt: &Stmt,
target: &Expr, target: &Expr,
iter: &Expr, iter: &Expr,
body: &[Stmt], body: &Body,
orelse: &[Stmt], orelse: Option<&Body>,
_type_comment: Option<&str>, _type_comment: Option<&str>,
async_: bool, async_: bool,
) -> FormatResult<()> { ) -> FormatResult<()> {
@ -329,11 +335,19 @@ fn format_for(
space(), space(),
group(&iter.format()), group(&iter.format()),
text(":"), text(":"),
end_of_line_comments(body),
block_indent(&block(body)) block_indent(&block(body))
] ]
)?; )?;
if !orelse.is_empty() { if let Some(orelse) = orelse {
write!(f, [text("else:"), block_indent(&block(orelse))])?; write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
} }
Ok(()) Ok(())
} }
@ -342,8 +356,8 @@ fn format_while(
f: &mut Formatter<ASTFormatContext<'_>>, f: &mut Formatter<ASTFormatContext<'_>>,
stmt: &Stmt, stmt: &Stmt,
test: &Expr, test: &Expr,
body: &[Stmt], body: &Body,
orelse: &[Stmt], orelse: Option<&Body>,
) -> FormatResult<()> { ) -> FormatResult<()> {
write!(f, [text("while"), space()])?; write!(f, [text("while"), space()])?;
if is_self_closing(test) { if is_self_closing(test) {
@ -358,9 +372,23 @@ fn format_while(
])] ])]
)?; )?;
} }
write!(f, [text(":"), block_indent(&block(body))])?; write!(
if !orelse.is_empty() { f,
write!(f, [text("else:"), block_indent(&block(orelse))])?; [
text(":"),
end_of_line_comments(body),
block_indent(&block(body))
]
)?;
if let Some(orelse) = orelse {
write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
} }
Ok(()) Ok(())
} }
@ -368,10 +396,15 @@ fn format_while(
fn format_if( fn format_if(
f: &mut Formatter<ASTFormatContext<'_>>, f: &mut Formatter<ASTFormatContext<'_>>,
test: &Expr, test: &Expr,
body: &[Stmt], body: &Body,
orelse: &[Stmt], orelse: Option<&Body>,
is_elif: bool,
) -> FormatResult<()> { ) -> FormatResult<()> {
write!(f, [text("if"), space()])?; if is_elif {
write!(f, [text("elif"), space()])?;
} else {
write!(f, [text("if"), space()])?;
}
if is_self_closing(test) { if is_self_closing(test) {
write!(f, [test.format()])?; write!(f, [test.format()])?;
} else { } else {
@ -384,17 +417,43 @@ fn format_if(
])] ])]
)?; )?;
} }
write!(f, [text(":"), block_indent(&block(body))])?; write!(
if !orelse.is_empty() { f,
if orelse.len() == 1 { [
if let StmtKind::If { test, body, orelse } = &orelse[0].node { text(":"),
write!(f, [text("el")])?; end_of_line_comments(body),
format_if(f, test, body, orelse)?; block_indent(&block(body))
]
)?;
if let Some(orelse) = orelse {
if orelse.node.len() == 1 {
if let StmtKind::If {
test,
body,
orelse,
is_elif: true,
} = &orelse.node[0].node
{
format_if(f, test, body, orelse.as_ref(), true)?;
} else { } else {
write!(f, [text("else:"), block_indent(&block(orelse))])?; write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
} }
} else { } else {
write!(f, [text("else:"), block_indent(&block(orelse))])?; write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
} }
} }
Ok(()) Ok(())
@ -406,7 +465,16 @@ fn format_match(
subject: &Expr, subject: &Expr,
cases: &[MatchCase], cases: &[MatchCase],
) -> FormatResult<()> { ) -> FormatResult<()> {
write!(f, [text("match"), space(), subject.format(), text(":")])?; write!(
f,
[
text("match"),
space(),
subject.format(),
text(":"),
end_of_line_comments(stmt),
]
)?;
for case in cases { for case in cases {
write!(f, [block_indent(&case.format())])?; write!(f, [block_indent(&case.format())])?;
} }
@ -447,20 +515,31 @@ fn format_return(
fn format_try( fn format_try(
f: &mut Formatter<ASTFormatContext<'_>>, f: &mut Formatter<ASTFormatContext<'_>>,
stmt: &Stmt, stmt: &Stmt,
body: &[Stmt], body: &Body,
handlers: &[Excepthandler], handlers: &[Excepthandler],
orelse: &[Stmt], orelse: Option<&Body>,
finalbody: &[Stmt], finalbody: Option<&Body>,
) -> FormatResult<()> { ) -> FormatResult<()> {
write!(f, [text("try:"), block_indent(&block(body))])?; write!(
f,
[
text("try:"),
end_of_line_comments(body),
block_indent(&block(body))
]
)?;
for handler in handlers { for handler in handlers {
write!(f, [handler.format()])?; write!(f, [handler.format()])?;
} }
if !orelse.is_empty() { if let Some(orelse) = orelse {
write!(f, [text("else:"), block_indent(&block(orelse))])?; write!(f, [text("else:")])?;
write!(f, [end_of_line_comments(orelse)])?;
write!(f, [block_indent(&block(orelse))])?;
} }
if !finalbody.is_empty() { if let Some(finalbody) = finalbody {
write!(f, [text("finally:"), block_indent(&block(finalbody))])?; write!(f, [text("finally:")])?;
write!(f, [end_of_line_comments(finalbody)])?;
write!(f, [block_indent(&block(finalbody))])?;
} }
Ok(()) Ok(())
} }
@ -468,21 +547,42 @@ fn format_try(
fn format_try_star( fn format_try_star(
f: &mut Formatter<ASTFormatContext<'_>>, f: &mut Formatter<ASTFormatContext<'_>>,
stmt: &Stmt, stmt: &Stmt,
body: &[Stmt], body: &Body,
handlers: &[Excepthandler], handlers: &[Excepthandler],
orelse: &[Stmt], orelse: Option<&Body>,
finalbody: &[Stmt], finalbody: Option<&Body>,
) -> FormatResult<()> { ) -> FormatResult<()> {
write!(f, [text("try:"), block_indent(&block(body))])?; write!(
f,
[
text("try:"),
end_of_line_comments(body),
block_indent(&block(body))
]
)?;
for handler in handlers { for handler in handlers {
// TODO(charlie): Include `except*`. // TODO(charlie): Include `except*`.
write!(f, [handler.format()])?; write!(f, [handler.format()])?;
} }
if !orelse.is_empty() { if let Some(orelse) = orelse {
write!(f, [text("else:"), block_indent(&block(orelse))])?; write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
} }
if !finalbody.is_empty() { if let Some(finalbody) = finalbody {
write!(f, [text("finally:"), block_indent(&block(finalbody))])?; write!(
f,
[
text("finally:"),
end_of_line_comments(finalbody),
block_indent(&block(finalbody))
]
)?;
} }
Ok(()) Ok(())
} }
@ -640,7 +740,7 @@ fn format_with_(
f: &mut Formatter<ASTFormatContext<'_>>, f: &mut Formatter<ASTFormatContext<'_>>,
stmt: &Stmt, stmt: &Stmt,
items: &[Withitem], items: &[Withitem],
body: &[Stmt], body: &Body,
type_comment: Option<&str>, type_comment: Option<&str>,
async_: bool, async_: bool,
) -> FormatResult<()> { ) -> FormatResult<()> {
@ -668,6 +768,7 @@ fn format_with_(
if_group_breaks(&text(")")), if_group_breaks(&text(")")),
]), ]),
text(":"), text(":"),
end_of_line_comments(body),
block_indent(&block(body)) block_indent(&block(body))
] ]
)?; )?;
@ -753,7 +854,7 @@ impl Format<ASTFormatContext<'_>> for FormatStmt<'_> {
target, target,
iter, iter,
body, body,
orelse, orelse.as_ref(),
type_comment.as_deref(), type_comment.as_deref(),
false, false,
), ),
@ -769,14 +870,19 @@ impl Format<ASTFormatContext<'_>> for FormatStmt<'_> {
target, target,
iter, iter,
body, body,
orelse, orelse.as_ref(),
type_comment.as_deref(), type_comment.as_deref(),
true, true,
), ),
StmtKind::While { test, body, orelse } => { StmtKind::While { test, body, orelse } => {
format_while(f, self.item, test, body, orelse) format_while(f, self.item, test, body, orelse.as_ref())
} }
StmtKind::If { test, body, orelse } => format_if(f, test, body, orelse), StmtKind::If {
test,
body,
orelse,
is_elif,
} => format_if(f, test, body, orelse.as_ref(), *is_elif),
StmtKind::With { StmtKind::With {
items, items,
body, body,
@ -810,13 +916,27 @@ impl Format<ASTFormatContext<'_>> for FormatStmt<'_> {
handlers, handlers,
orelse, orelse,
finalbody, finalbody,
} => format_try(f, self.item, body, handlers, orelse, finalbody), } => format_try(
f,
self.item,
body,
handlers,
orelse.as_ref(),
finalbody.as_ref(),
),
StmtKind::TryStar { StmtKind::TryStar {
body, body,
handlers, handlers,
orelse, orelse,
finalbody, finalbody,
} => format_try_star(f, self.item, body, handlers, orelse, finalbody), } => format_try_star(
f,
self.item,
body,
handlers,
orelse.as_ref(),
finalbody.as_ref(),
),
StmtKind::Assert { test, msg } => { StmtKind::Assert { test, msg } => {
format_assert(f, self.item, test, msg.as_ref().map(|expr| &**expr)) format_assert(f, self.item, test, msg.as_ref().map(|expr| &**expr))
} }

View file

@ -53,7 +53,7 @@ pub fn fmt(contents: &str) -> Result<Formatted<ASTFormatContext>> {
}, },
locator, locator,
), ),
[format::builders::block(&python_cst)] [format::builders::statements(&python_cst)]
) )
.map_err(Into::into) .map_err(Into::into)
} }

View file

@ -163,15 +163,14 @@ impl<'a> Visitor<'a> for StmtNormalizer {
self.trailer = Trailer::CompoundStatement; self.trailer = Trailer::CompoundStatement;
self.visit_body(body); self.visit_body(body);
if !orelse.is_empty() { if let Some(orelse) = orelse {
// If the previous body ended with a function or class definition, we need to // If the previous body ended with a function or class definition, we need to
// insert an empty line before the else block. Since the `else` itself isn't // insert an empty line before the else block. Since the `else` itself isn't
// a statement, we need to insert it into the last statement of the body. // a statement, we need to insert it into the last statement of the body.
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
let stmt = body.last_mut().unwrap(); body.trivia.push(Trivia {
stmt.trivia.push(Trivia {
kind: TriviaKind::EmptyLine, kind: TriviaKind::EmptyLine,
relationship: Relationship::Trailing, relationship: Relationship::Dangling,
}); });
} }
@ -185,12 +184,11 @@ impl<'a> Visitor<'a> for StmtNormalizer {
self.trailer = Trailer::CompoundStatement; self.trailer = Trailer::CompoundStatement;
self.visit_body(body); self.visit_body(body);
if !orelse.is_empty() { if let Some(orelse) = orelse {
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
let stmt = body.last_mut().unwrap(); body.trivia.push(Trivia {
stmt.trivia.push(Trivia {
kind: TriviaKind::EmptyLine, kind: TriviaKind::EmptyLine,
relationship: Relationship::Trailing, relationship: Relationship::Dangling,
}); });
} }
@ -220,49 +218,44 @@ impl<'a> Visitor<'a> for StmtNormalizer {
self.depth = Depth::Nested; self.depth = Depth::Nested;
self.trailer = Trailer::CompoundStatement; self.trailer = Trailer::CompoundStatement;
self.visit_body(body); self.visit_body(body);
let mut last = body.last_mut();
let mut prev = &mut body.trivia;
for handler in handlers { for handler in handlers {
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
if let Some(stmt) = last.as_mut() { prev.push(Trivia {
stmt.trivia.push(Trivia { kind: TriviaKind::EmptyLine,
kind: TriviaKind::EmptyLine, relationship: Relationship::Dangling,
relationship: Relationship::Trailing, });
});
}
} }
self.depth = Depth::Nested; self.depth = Depth::Nested;
self.trailer = Trailer::CompoundStatement; self.trailer = Trailer::CompoundStatement;
let ExcepthandlerKind::ExceptHandler { body, .. } = &mut handler.node; let ExcepthandlerKind::ExceptHandler { body, .. } = &mut handler.node;
self.visit_body(body); self.visit_body(body);
last = body.last_mut(); prev = &mut body.trivia;
} }
if !orelse.is_empty() { if let Some(orelse) = orelse {
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
if let Some(stmt) = last.as_mut() { prev.push(Trivia {
stmt.trivia.push(Trivia { kind: TriviaKind::EmptyLine,
kind: TriviaKind::EmptyLine, relationship: Relationship::Dangling,
relationship: Relationship::Trailing, });
});
}
} }
self.depth = Depth::Nested; self.depth = Depth::Nested;
self.trailer = Trailer::CompoundStatement; self.trailer = Trailer::CompoundStatement;
self.visit_body(orelse); self.visit_body(orelse);
last = body.last_mut(); prev = &mut body.trivia;
} }
if !finalbody.is_empty() { if let Some(finalbody) = finalbody {
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
if let Some(stmt) = last.as_mut() { prev.push(Trivia {
stmt.trivia.push(Trivia { kind: TriviaKind::EmptyLine,
kind: TriviaKind::EmptyLine, relationship: Relationship::Dangling,
relationship: Relationship::Trailing, });
});
}
} }
self.depth = Depth::Nested; self.depth = Depth::Nested;

View file

@ -194,5 +194,7 @@ impl<'a> Visitor<'a> for ParenthesesNormalizer<'_> {
/// during formatting) and `Parenthesize` (which are used during formatting). /// during formatting) and `Parenthesize` (which are used during formatting).
pub fn normalize_parentheses(python_cst: &mut [Stmt], locator: &Locator) { pub fn normalize_parentheses(python_cst: &mut [Stmt], locator: &Locator) {
let mut normalizer = ParenthesesNormalizer { locator }; let mut normalizer = ParenthesesNormalizer { locator };
normalizer.visit_body(python_cst); for stmt in python_cst {
normalizer.visit_stmt(stmt);
}
} }

View file

@ -196,15 +196,7 @@ instruction()#comment with bad spacing
"Generator", "Generator",
] ]
@@ -54,32 +54,39 @@ @@ -60,26 +60,32 @@
# for compiler in compilers.values():
# add_compiler(compiler)
add_compiler(compilers[(7.0, 32)])
- # add_compiler(compilers[(7.1, 64)])
+# add_compiler(compilers[(7.1, 64)])
+
# Comment before function. # Comment before function.
def inline_comments_in_brackets_ruin_everything(): def inline_comments_in_brackets_ruin_everything():
if typedargslist: if typedargslist:
@ -229,7 +221,7 @@ instruction()#comment with bad spacing
+ parameters.what_if_this_was_actually_long.children[0], + parameters.what_if_this_was_actually_long.children[0],
+ body, + body,
+ parameters.children[-1], + parameters.children[-1],
+ ] + ] # type: ignore
if ( if (
self._proc is not None self._proc is not None
- # has the child process finished? - # has the child process finished?
@ -246,7 +238,7 @@ instruction()#comment with bad spacing
): ):
pass pass
# no newline before or after # no newline before or after
@@ -103,42 +110,42 @@ @@ -103,35 +109,35 @@
############################################################################ ############################################################################
call2( call2(
@ -298,16 +290,7 @@ instruction()#comment with bad spacing
] ]
while True: while True:
if False: if False:
continue @@ -167,7 +173,7 @@
- # and round and round we go
- # and round and round we go
+ # and round and round we go
+ # and round and round we go
# let's return
return Node(
@@ -167,7 +174,7 @@
####################### #######################
@ -377,10 +360,9 @@ else:
# for compiler in compilers.values(): # for compiler in compilers.values():
# add_compiler(compiler) # add_compiler(compiler)
add_compiler(compilers[(7.0, 32)]) add_compiler(compilers[(7.0, 32)])
# add_compiler(compilers[(7.1, 64)])
# add_compiler(compilers[(7.1, 64)])
# Comment before function. # Comment before function.
def inline_comments_in_brackets_ruin_everything(): def inline_comments_in_brackets_ruin_everything():
if typedargslist: if typedargslist:
@ -400,7 +382,7 @@ def inline_comments_in_brackets_ruin_everything():
parameters.what_if_this_was_actually_long.children[0], parameters.what_if_this_was_actually_long.children[0],
body, body,
parameters.children[-1], parameters.children[-1],
] ] # type: ignore
if ( if (
self._proc is not None self._proc is not None
and # has the child process finished? and # has the child process finished?
@ -467,8 +449,8 @@ short
if False: if False:
continue continue
# and round and round we go # and round and round we go
# and round and round we go # and round and round we go
# let's return # let's return
return Node( return Node(

View file

@ -131,17 +131,15 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -2,8 +2,8 @@ @@ -2,7 +2,7 @@
def f( def f(
- a, # type: int - a, # type: int
-):
+ a, + a,
+): # type: int ):
pass pass
@@ -14,44 +14,42 @@ @@ -14,44 +14,42 @@
@ -155,7 +153,6 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
- g, # type: int - g, # type: int
- h, # type: int - h, # type: int
- i, # type: int - i, # type: int
-):
+ a, + a,
+ b, + b,
+ c, + c,
@ -165,7 +162,7 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
+ g, + g,
+ h, + h,
+ i, + i,
+): # type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int ):
# type: (...) -> None # type: (...) -> None
pass pass
@ -175,12 +172,11 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
- *args, # type: *Any - *args, # type: *Any
- default=False, # type: bool - default=False, # type: bool
- **kwargs, # type: **Any - **kwargs, # type: **Any
-):
+ arg, + arg,
+ *args, + *args,
+ default=False, + default=False,
+ **kwargs, + **kwargs,
+): # type: int# type: *Any ):
# type: (...) -> None # type: (...) -> None
pass pass
@ -190,12 +186,11 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
- b, # type: int - b, # type: int
- c, # type: int - c, # type: int
- d, # type: int - d, # type: int
-):
+ a, + a,
+ b, + b,
+ c, + c,
+ d, + d,
+): # type: int# type: int# type: int# type: int# type: int ):
# type: (...) -> None # type: (...) -> None
element = 0 # type: int element = 0 # type: int
@ -208,41 +203,32 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
an_element_with_a_long_value = calls() or more_calls() and more() # type: bool an_element_with_a_long_value = calls() or more_calls() and more() # type: bool
tup = ( tup = (
@@ -66,26 +64,26 @@ @@ -70,21 +68,21 @@
+ element
+ another_element
+ another_element_with_long_name
- ) # type: int
+ )
def f( def f(
- x, # not a type comment - x, # not a type comment
- y, # type: int - y, # type: int
-):
+ x, + x,
+ y, + y,
+): # not a type comment# type: int ):
# type: (...) -> None # type: (...) -> None
pass pass
def f( def f(
- x, # not a type comment - x, # not a type comment
-): # type: (int) -> None
+ x, + x,
+): # not a type comment# type: (int) -> None ): # type: (int) -> None
pass pass
def func( def func(
- a=some_list[0], # type: int - a=some_list[0], # type: int
-): # type: () -> int
+ a=some_list[0], + a=some_list[0],
+): ): # type: () -> int
c = call( c = call(
0.0123, 0.0123,
0.0456,
@@ -96,23 +94,37 @@ @@ -96,23 +94,37 @@
0.0123, 0.0123,
0.0456, 0.0456,
@ -298,7 +284,7 @@ from typing import Any, Tuple
def f( def f(
a, a,
): # type: int ):
pass pass
@ -318,7 +304,7 @@ def f(
g, g,
h, h,
i, i,
): # type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int ):
# type: (...) -> None # type: (...) -> None
pass pass
@ -328,7 +314,7 @@ def f(
*args, *args,
default=False, default=False,
**kwargs, **kwargs,
): # type: int# type: *Any ):
# type: (...) -> None # type: (...) -> None
pass pass
@ -338,7 +324,7 @@ def f(
b, b,
c, c,
d, d,
): # type: int# type: int# type: int# type: int# type: int ):
# type: (...) -> None # type: (...) -> None
element = 0 # type: int element = 0 # type: int
@ -359,26 +345,26 @@ def f(
+ element + element
+ another_element + another_element
+ another_element_with_long_name + another_element_with_long_name
) ) # type: int
def f( def f(
x, x,
y, y,
): # not a type comment# type: int ):
# type: (...) -> None # type: (...) -> None
pass pass
def f( def f(
x, x,
): # not a type comment# type: (int) -> None ): # type: (int) -> None
pass pass
def func( def func(
a=some_list[0], a=some_list[0],
): ): # type: () -> int
c = call( c = call(
0.0123, 0.0123,
0.0456, 0.0456,

View file

@ -164,7 +164,7 @@ def bar():
# This should be split from the above by two lines # This should be split from the above by two lines
class MyClassWithComplexLeadingComments: class MyClassWithComplexLeadingComments:
pass pass
@@ -57,13 +58,13 @@ @@ -57,11 +58,11 @@
# leading 1 # leading 1
@deco1 @deco1
@ -174,16 +174,13 @@ def bar():
-@deco2(with_args=True) -@deco2(with_args=True)
-# leading 3 -# leading 3
-@deco3 -@deco3
-# leading 4
+deco2(with_args=True) +deco2(with_args=True)
+@# leading 3 +@# leading 3
+deco3 +deco3
# leading 4
def decorated(): def decorated():
+ # leading 4
pass pass
@@ -72,11 +73,10 @@
@@ -72,13 +73,12 @@
# leading 1 # leading 1
@deco1 @deco1
@ -192,17 +189,14 @@ def bar():
- -
-# leading 3 that already has an empty line -# leading 3 that already has an empty line
-@deco3 -@deco3
-# leading 4
+@# leading 2 +@# leading 2
+deco2(with_args=True) +deco2(with_args=True)
+@# leading 3 that already has an empty line +@# leading 3 that already has an empty line
+deco3 +deco3
# leading 4
def decorated_with_split_leading_comments(): def decorated_with_split_leading_comments():
+ # leading 4
pass pass
@@ -87,10 +87,10 @@
@@ -87,18 +87,18 @@
# leading 1 # leading 1
@deco1 @deco1
@ -210,16 +204,14 @@ def bar():
-@deco2(with_args=True) -@deco2(with_args=True)
-# leading 3 -# leading 3
-@deco3 -@deco3
-
-# leading 4 that already has an empty line
+@# leading 2 +@# leading 2
+deco2(with_args=True) +deco2(with_args=True)
+@# leading 3 +@# leading 3
+deco3 +deco3
def decorated_with_split_leading_comments():
+ # leading 4 that already has an empty line
pass
# leading 4 that already has an empty line
def decorated_with_split_leading_comments():
@@ -99,6 +99,7 @@
def main(): def main():
if a: if a:
@ -227,7 +219,7 @@ def bar():
# Leading comment before inline function # Leading comment before inline function
def inline(): def inline():
pass pass
@@ -108,12 +108,14 @@ @@ -108,12 +109,14 @@
pass pass
else: else:
@ -242,7 +234,7 @@ def bar():
# Leading comment before "top-level inline" function # Leading comment before "top-level inline" function
def top_level_quote_inline(): def top_level_quote_inline():
pass pass
@@ -123,6 +125,7 @@ @@ -123,6 +126,7 @@
pass pass
else: else:
@ -250,37 +242,6 @@ def bar():
# More leading comments # More leading comments
def top_level_quote_inline_after_else(): def top_level_quote_inline_after_else():
pass pass
@@ -138,9 +141,11 @@
# Regression test for https://github.com/psf/black/issues/3454.
def foo():
pass
- # Trailing comment that belongs to this function
+# Trailing comment that belongs to this function
+
+
@decorator1
@decorator2 # fmt: skip
def bar():
@@ -150,12 +155,13 @@
# Regression test for https://github.com/psf/black/issues/3454.
def foo():
pass
- # Trailing comment that belongs to this function.
- # NOTE this comment only has one empty line below, and the formatter
- # should enforce two blank lines.
+# Trailing comment that belongs to this function.
+# NOTE this comment only has one empty line below, and the formatter
+# should enforce two blank lines.
+
@decorator1
-# A standalone comment
def bar():
+ # A standalone comment
pass
``` ```
## Ruff Output ## Ruff Output
@ -351,8 +312,8 @@ some = statement
deco2(with_args=True) deco2(with_args=True)
@# leading 3 @# leading 3
deco3 deco3
# leading 4
def decorated(): def decorated():
# leading 4
pass pass
@ -365,8 +326,8 @@ some = statement
deco2(with_args=True) deco2(with_args=True)
@# leading 3 that already has an empty line @# leading 3 that already has an empty line
deco3 deco3
# leading 4
def decorated_with_split_leading_comments(): def decorated_with_split_leading_comments():
# leading 4
pass pass
@ -379,8 +340,9 @@ some = statement
deco2(with_args=True) deco2(with_args=True)
@# leading 3 @# leading 3
deco3 deco3
# leading 4 that already has an empty line
def decorated_with_split_leading_comments(): def decorated_with_split_leading_comments():
# leading 4 that already has an empty line
pass pass
@ -429,9 +391,7 @@ class MyClass:
# Regression test for https://github.com/psf/black/issues/3454. # Regression test for https://github.com/psf/black/issues/3454.
def foo(): def foo():
pass pass
# Trailing comment that belongs to this function
# Trailing comment that belongs to this function
@decorator1 @decorator1
@ -443,15 +403,14 @@ def bar():
# Regression test for https://github.com/psf/black/issues/3454. # Regression test for https://github.com/psf/black/issues/3454.
def foo(): def foo():
pass pass
# Trailing comment that belongs to this function.
# NOTE this comment only has one empty line below, and the formatter
# should enforce two blank lines.
# Trailing comment that belongs to this function.
# NOTE this comment only has one empty line below, and the formatter
# should enforce two blank lines.
@decorator1 @decorator1
# A standalone comment
def bar(): def bar():
# A standalone comment
pass pass
``` ```

View file

@ -204,7 +204,7 @@ class C:
) )
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
@@ -22,133 +23,156 @@ @@ -22,133 +23,155 @@
if ( if (
# Rule 1 # Rule 1
i % 2 == 0 i % 2 == 0
@ -217,10 +217,10 @@ class C:
- while ( - while (
- # Just a comment - # Just a comment
- call() - call()
- # Another
- ):
+ while # Just a comment + while # Just a comment
+ call(): + call():
# Another
- ):
print(i) print(i)
xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy( xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy(
push_manager=context.request.resource_manager, push_manager=context.request.resource_manager,
@ -460,7 +460,7 @@ class C:
"Not what we expected and the message is too long to fit in one line" "Not what we expected and the message is too long to fit in one line"
" because it's too long" " because it's too long"
) )
@@ -161,9 +185,8 @@ @@ -161,9 +184,8 @@
8 STORE_ATTR 0 (x) 8 STORE_ATTR 0 (x)
10 LOAD_CONST 0 (None) 10 LOAD_CONST 0 (None)
12 RETURN_VALUE 12 RETURN_VALUE
@ -508,7 +508,6 @@ class C:
): ):
while # Just a comment while # Just a comment
call(): call():
# Another
print(i) print(i)
xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy( xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy(
push_manager=context.request.resource_manager, push_manager=context.request.resource_manager,

View file

@ -204,7 +204,7 @@ class C:
) )
self.assertEqual( self.assertEqual(
unstyle(str(report)), unstyle(str(report)),
@@ -22,133 +23,156 @@ @@ -22,133 +23,155 @@
if ( if (
# Rule 1 # Rule 1
i % 2 == 0 i % 2 == 0
@ -217,10 +217,10 @@ class C:
- while ( - while (
- # Just a comment - # Just a comment
- call() - call()
- # Another
- ):
+ while # Just a comment + while # Just a comment
+ call(): + call():
# Another
- ):
print(i) print(i)
xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy( xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy(
push_manager=context.request.resource_manager, push_manager=context.request.resource_manager,
@ -460,7 +460,7 @@ class C:
"Not what we expected and the message is too long to fit in one line" "Not what we expected and the message is too long to fit in one line"
" because it's too long" " because it's too long"
) )
@@ -161,9 +185,8 @@ @@ -161,9 +184,8 @@
8 STORE_ATTR 0 (x) 8 STORE_ATTR 0 (x)
10 LOAD_CONST 0 (None) 10 LOAD_CONST 0 (None)
12 RETURN_VALUE 12 RETURN_VALUE
@ -508,7 +508,6 @@ class C:
): ):
while # Just a comment while # Just a comment
call(): call():
# Another
print(i) print(i)
xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy( xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy(
push_manager=context.request.resource_manager, push_manager=context.request.resource_manager,

View file

@ -26,13 +26,12 @@ def f(): pass
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,10 +1,14 @@ @@ -1,8 +1,12 @@
# fmt: off # fmt: off
-@test([ -@test([
- 1, 2, - 1, 2,
- 3, 4, - 3, 4,
-]) -])
-# fmt: on
+@test( +@test(
+ [ + [
+ 1, + 1,
@ -41,11 +40,9 @@ def f(): pass
+ 4, + 4,
+ ] + ]
+) +)
# fmt: on
def f(): def f():
+ # fmt: on
pass pass
``` ```
## Ruff Output ## Ruff Output
@ -60,8 +57,8 @@ def f(): pass
4, 4,
] ]
) )
# fmt: on
def f(): def f():
# fmt: on
pass pass

View file

@ -165,7 +165,7 @@ elif unformatted:
print("This will be formatted") print("This will be formatted")
@@ -68,20 +62,21 @@ @@ -68,20 +62,19 @@
class Named(t.Protocol): class Named(t.Protocol):
# fmt: off # fmt: off
@property @property
@ -177,11 +177,9 @@ elif unformatted:
class Factory(t.Protocol): class Factory(t.Protocol):
def this_will_be_formatted(self, **kwargs) -> Named: def this_will_be_formatted(self, **kwargs) -> Named:
... ...
-
# fmt: on
- # fmt: on
+# fmt: on
+
# Regression test for https://github.com/psf/black/issues/3436. # Regression test for https://github.com/psf/black/issues/3436.
if x: if x:
@ -267,9 +265,7 @@ class Named(t.Protocol):
class Factory(t.Protocol): class Factory(t.Protocol):
def this_will_be_formatted(self, **kwargs) -> Named: def this_will_be_formatted(self, **kwargs) -> Named:
... ...
# fmt: on
# fmt: on
# Regression test for https://github.com/psf/black/issues/3436. # Regression test for https://github.com/psf/black/issues/3436.

View file

@ -1,49 +0,0 @@
---
source: crates/ruff_python_formatter/src/lib.rs
expression: snapshot
input_file: crates/ruff_python_formatter/resources/test/fixtures/black/simple_cases/fmtskip6.py
---
## Input
```py
class A:
def f(self):
for line in range(10):
if True:
pass # fmt: skip
```
## Black Differences
```diff
--- Black
+++ Ruff
@@ -2,4 +2,4 @@
def f(self):
for line in range(10):
if True:
- pass # fmt: skip
+ pass
```
## Ruff Output
```py
class A:
def f(self):
for line in range(10):
if True:
pass
```
## Black Output
```py
class A:
def f(self):
for line in range(10):
if True:
pass # fmt: skip
```

View file

@ -117,12 +117,13 @@ def __await__(): return (yield)
def func_no_args(): def func_no_args():
@@ -64,19 +64,14 @@ @@ -64,19 +64,15 @@
def spaces2(result=_core.Value(None)): def spaces2(result=_core.Value(None)):
- assert fut is self._read_fut, (fut, self._read_fut) - assert fut is self._read_fut, (fut, self._read_fut)
+ assert fut is self._read_fut, fut, self._read_fut + assert fut is self._read_fut, fut, self._read_fut
+
def example(session): def example(session):
@ -142,7 +143,7 @@ def __await__(): return (yield)
def long_lines(): def long_lines():
@@ -135,14 +130,13 @@ @@ -135,14 +131,13 @@
a, a,
**kwargs, **kwargs,
) -> A: ) -> A:
@ -233,6 +234,7 @@ def spaces2(result=_core.Value(None)):
assert fut is self._read_fut, fut, self._read_fut assert fut is self._read_fut, fut, self._read_fut
def example(session): def example(session):
result = session.query(models.Customer.id).filter( result = session.query(models.Customer.id).filter(
models.Customer.account_id == account_id, models.Customer.account_id == account_id,

View file

@ -123,10 +123,9 @@ async def main():
+ await (asyncio.sleep(1)) # Hello + await (asyncio.sleep(1)) # Hello
-async def main(): async def main():
- await asyncio.sleep(1) # Hello - await asyncio.sleep(1) # Hello
+async def main(): # Hello + await (asyncio.sleep(1)) # Hello
+ await (asyncio.sleep(1))
# Long lines # Long lines
@ -231,8 +230,8 @@ async def main():
await (asyncio.sleep(1)) # Hello await (asyncio.sleep(1)) # Hello
async def main(): # Hello async def main():
await (asyncio.sleep(1)) await (asyncio.sleep(1)) # Hello
# Long lines # Long lines

View file

@ -5,13 +5,14 @@ use rustpython_parser::Tok;
use crate::core::types::Range; use crate::core::types::Range;
use crate::cst::{ use crate::cst::{
Alias, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Pattern, PatternKind, SliceIndex, Alias, Body, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Pattern, PatternKind,
SliceIndexKind, Stmt, StmtKind, SliceIndex, SliceIndexKind, Stmt, StmtKind,
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum Node<'a> { pub enum Node<'a> {
Mod(&'a [Stmt]), Mod(&'a [Stmt]),
Body(&'a Body),
Stmt(&'a Stmt), Stmt(&'a Stmt),
Expr(&'a Expr), Expr(&'a Expr),
Alias(&'a Alias), Alias(&'a Alias),
@ -24,6 +25,7 @@ impl Node<'_> {
pub fn id(&self) -> usize { pub fn id(&self) -> usize {
match self { match self {
Node::Mod(nodes) => nodes as *const _ as usize, Node::Mod(nodes) => nodes as *const _ as usize,
Node::Body(node) => node.id(),
Node::Stmt(node) => node.id(), Node::Stmt(node) => node.id(),
Node::Expr(node) => node.id(), Node::Expr(node) => node.id(),
Node::Alias(node) => node.id(), Node::Alias(node) => node.id(),
@ -227,6 +229,11 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Stmt(stmt)); result.push(Node::Stmt(stmt));
} }
} }
Node::Body(body) => {
for stmt in &body.node {
result.push(Node::Stmt(stmt));
}
}
Node::Stmt(stmt) => match &stmt.node { Node::Stmt(stmt) => match &stmt.node {
StmtKind::Return { value } => { StmtKind::Return { value } => {
if let Some(value) = value { if let Some(value) = value {
@ -294,9 +301,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
if let Some(returns) = returns { if let Some(returns) = returns {
result.push(Node::Expr(returns)); result.push(Node::Expr(returns));
} }
for stmt in body { result.push(Node::Body(body));
result.push(Node::Stmt(stmt));
}
} }
StmtKind::ClassDef { StmtKind::ClassDef {
bases, bases,
@ -314,9 +319,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
for keyword in keywords { for keyword in keywords {
result.push(Node::Expr(&keyword.node.value)); result.push(Node::Expr(&keyword.node.value));
} }
for stmt in body { result.push(Node::Body(body));
result.push(Node::Stmt(stmt));
}
} }
StmtKind::Delete { targets } => { StmtKind::Delete { targets } => {
for target in targets { for target in targets {
@ -355,29 +358,25 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
} => { } => {
result.push(Node::Expr(target)); result.push(Node::Expr(target));
result.push(Node::Expr(iter)); result.push(Node::Expr(iter));
for stmt in body { result.push(Node::Body(body));
result.push(Node::Stmt(stmt)); if let Some(orelse) = orelse {
} result.push(Node::Body(orelse));
for stmt in orelse {
result.push(Node::Stmt(stmt));
} }
} }
StmtKind::While { test, body, orelse } => { StmtKind::While { test, body, orelse } => {
result.push(Node::Expr(test)); result.push(Node::Expr(test));
for stmt in body { result.push(Node::Body(body));
result.push(Node::Stmt(stmt)); if let Some(orelse) = orelse {
} result.push(Node::Body(orelse));
for stmt in orelse {
result.push(Node::Stmt(stmt));
} }
} }
StmtKind::If { test, body, orelse } => { StmtKind::If {
test, body, orelse, ..
} => {
result.push(Node::Expr(test)); result.push(Node::Expr(test));
for stmt in body { result.push(Node::Body(body));
result.push(Node::Stmt(stmt)); if let Some(orelse) = orelse {
} result.push(Node::Body(orelse));
for stmt in orelse {
result.push(Node::Stmt(stmt));
} }
} }
StmtKind::With { items, body, .. } | StmtKind::AsyncWith { items, body, .. } => { StmtKind::With { items, body, .. } | StmtKind::AsyncWith { items, body, .. } => {
@ -387,9 +386,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Expr(expr)); result.push(Node::Expr(expr));
} }
} }
for stmt in body { result.push(Node::Body(body));
result.push(Node::Stmt(stmt));
}
} }
StmtKind::Match { subject, cases } => { StmtKind::Match { subject, cases } => {
result.push(Node::Expr(subject)); result.push(Node::Expr(subject));
@ -398,9 +395,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
if let Some(expr) = &case.guard { if let Some(expr) = &case.guard {
result.push(Node::Expr(expr)); result.push(Node::Expr(expr));
} }
for stmt in &case.body { result.push(Node::Body(&case.body));
result.push(Node::Stmt(stmt));
}
} }
} }
StmtKind::Raise { exc, cause } => { StmtKind::Raise { exc, cause } => {
@ -431,17 +426,15 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
orelse, orelse,
finalbody, finalbody,
} => { } => {
for stmt in body { result.push(Node::Body(body));
result.push(Node::Stmt(stmt));
}
for handler in handlers { for handler in handlers {
result.push(Node::Excepthandler(handler)); result.push(Node::Excepthandler(handler));
} }
for stmt in orelse { if let Some(orelse) = orelse {
result.push(Node::Stmt(stmt)); result.push(Node::Body(orelse));
} }
for stmt in finalbody { if let Some(finalbody) = finalbody {
result.push(Node::Stmt(stmt)); result.push(Node::Body(finalbody));
} }
} }
StmtKind::Import { names } => { StmtKind::Import { names } => {
@ -457,7 +450,6 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
StmtKind::Global { .. } => {} StmtKind::Global { .. } => {}
StmtKind::Nonlocal { .. } => {} StmtKind::Nonlocal { .. } => {}
}, },
// TODO(charlie): Actual logic, this doesn't do anything.
Node::Expr(expr) => match &expr.node { Node::Expr(expr) => match &expr.node {
ExprKind::BoolOp { values, .. } => { ExprKind::BoolOp { values, .. } => {
for value in values { for value in values {
@ -476,7 +468,6 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Expr(operand)); result.push(Node::Expr(operand));
} }
ExprKind::Lambda { body, args, .. } => { ExprKind::Lambda { body, args, .. } => {
// TODO(charlie): Arguments.
for expr in &args.defaults { for expr in &args.defaults {
result.push(Node::Expr(expr)); result.push(Node::Expr(expr));
} }
@ -630,9 +621,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
if let Some(type_) = type_ { if let Some(type_) = type_ {
result.push(Node::Expr(type_)); result.push(Node::Expr(type_));
} }
for stmt in body { result.push(Node::Body(body));
result.push(Node::Stmt(stmt));
}
} }
Node::SliceIndex(slice_index) => { Node::SliceIndex(slice_index) => {
if let SliceIndexKind::Index { value } = &slice_index.node { if let SliceIndexKind::Index { value } = &slice_index.node {
@ -717,6 +706,7 @@ pub fn decorate_token<'a>(
let middle = (left + right) / 2; let middle = (left + right) / 2;
let child = &child_nodes[middle]; let child = &child_nodes[middle];
let start = match &child { let start = match &child {
Node::Body(node) => node.location,
Node::Stmt(node) => node.location, Node::Stmt(node) => node.location,
Node::Expr(node) => node.location, Node::Expr(node) => node.location,
Node::Alias(node) => node.location, Node::Alias(node) => node.location,
@ -726,6 +716,7 @@ pub fn decorate_token<'a>(
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
}; };
let end = match &child { let end = match &child {
Node::Body(node) => node.end_location.unwrap(),
Node::Stmt(node) => node.end_location.unwrap(), Node::Stmt(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(), Node::Expr(node) => node.end_location.unwrap(),
Node::Alias(node) => node.end_location.unwrap(), Node::Alias(node) => node.end_location.unwrap(),
@ -739,6 +730,7 @@ pub fn decorate_token<'a>(
// Special-case: if we're dealing with a statement that's a single expression, // Special-case: if we're dealing with a statement that's a single expression,
// we want to treat the expression as the enclosed node. // we want to treat the expression as the enclosed node.
let existing_start = match &existing { let existing_start = match &existing {
Node::Body(node) => node.location,
Node::Stmt(node) => node.location, Node::Stmt(node) => node.location,
Node::Expr(node) => node.location, Node::Expr(node) => node.location,
Node::Alias(node) => node.location, Node::Alias(node) => node.location,
@ -748,6 +740,7 @@ pub fn decorate_token<'a>(
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
}; };
let existing_end = match &existing { let existing_end = match &existing {
Node::Body(node) => node.end_location.unwrap(),
Node::Stmt(node) => node.end_location.unwrap(), Node::Stmt(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(), Node::Expr(node) => node.end_location.unwrap(),
Node::Alias(node) => node.end_location.unwrap(), Node::Alias(node) => node.end_location.unwrap(),
@ -809,6 +802,7 @@ pub fn decorate_token<'a>(
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct TriviaIndex { pub struct TriviaIndex {
pub body: FxHashMap<usize, Vec<Trivia>>,
pub stmt: FxHashMap<usize, Vec<Trivia>>, pub stmt: FxHashMap<usize, Vec<Trivia>>,
pub expr: FxHashMap<usize, Vec<Trivia>>, pub expr: FxHashMap<usize, Vec<Trivia>>,
pub alias: FxHashMap<usize, Vec<Trivia>>, pub alias: FxHashMap<usize, Vec<Trivia>>,
@ -820,6 +814,13 @@ pub struct TriviaIndex {
fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) { fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) {
match node { match node {
Node::Mod(_) => {} Node::Mod(_) => {}
Node::Body(node) => {
trivia
.body
.entry(node.id())
.or_insert_with(Vec::new)
.push(comment);
}
Node::Stmt(node) => { Node::Stmt(node) => {
trivia trivia
.stmt .stmt