Create dedicated Body nodes in the formatter CST (#3223)

This commit is contained in:
Charlie Marsh 2023-02-27 17:55:05 -05:00 committed by GitHub
parent cd6413ca09
commit 2261e194a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1239 additions and 611 deletions

View file

@ -1,6 +1,6 @@
use crate::core::visitor;
use crate::core::visitor::Visitor;
use crate::cst::{Alias, Excepthandler, Expr, Pattern, SliceIndex, Stmt};
use crate::cst::{Alias, Body, Excepthandler, Expr, Pattern, SliceIndex, Stmt};
use crate::trivia::{decorate_trivia, TriviaIndex, TriviaToken};
struct AttachmentVisitor {
@ -8,6 +8,14 @@ struct AttachmentVisitor {
}
impl<'a> Visitor<'a> for AttachmentVisitor {
fn visit_body(&mut self, body: &'a mut Body) {
let trivia = self.index.body.remove(&body.id());
if let Some(comments) = trivia {
body.trivia.extend(comments);
}
visitor::walk_body(self, body);
}
fn visit_stmt(&mut self, stmt: &'a mut Stmt) {
let trivia = self.index.stmt.remove(&stmt.id());
if let Some(comments) = trivia {
@ -59,5 +67,8 @@ impl<'a> Visitor<'a> for AttachmentVisitor {
pub fn attach(python_cst: &mut [Stmt], trivia: Vec<TriviaToken>) {
let index = decorate_trivia(trivia, python_cst);
AttachmentVisitor { index }.visit_body(python_cst);
let mut visitor = AttachmentVisitor { index };
for stmt in python_cst {
visitor.visit_stmt(stmt);
}
}

View file

@ -1,3 +1,8 @@
use rustpython_parser::ast::Location;
use crate::core::locator::Locator;
use crate::core::types::Range;
/// Return the leading quote for a string or byte literal (e.g., `"""`).
pub fn leading_quote(content: &str) -> Option<&str> {
if let Some(first_line) = content.lines().next() {
@ -32,6 +37,99 @@ pub fn is_radix_literal(content: &str) -> bool {
|| content.starts_with("0X")
}
/// Expand the range of a compound statement.
///
/// `location` is the start of the compound statement (e.g., the `if` in `if x:`).
/// `end_location` is the end of the last statement in the body.
pub fn expand_indented_block(
location: Location,
end_location: Location,
locator: &Locator,
) -> (Location, Location) {
let contents = locator.contents();
let start_index = locator.index(location);
let end_index = locator.index(end_location);
// Find the colon, which indicates the end of the header.
let mut nesting = 0;
let mut colon = None;
for (start, tok, _end) in rustpython_parser::lexer::lex_located(
&contents[start_index..end_index],
rustpython_parser::Mode::Module,
location,
)
.flatten()
{
match tok {
rustpython_parser::Tok::Colon if nesting == 0 => {
colon = Some(start);
break;
}
rustpython_parser::Tok::Lpar
| rustpython_parser::Tok::Lsqb
| rustpython_parser::Tok::Lbrace => nesting += 1,
rustpython_parser::Tok::Rpar
| rustpython_parser::Tok::Rsqb
| rustpython_parser::Tok::Rbrace => nesting -= 1,
_ => {}
}
}
let colon_location = colon.unwrap();
let colon_index = locator.index(colon_location);
// From here, we have two options: simple statement or compound statement.
let indent = rustpython_parser::lexer::lex_located(
&contents[colon_index..end_index],
rustpython_parser::Mode::Module,
colon_location,
)
.flatten()
.find_map(|(start, tok, _end)| match tok {
rustpython_parser::Tok::Indent => Some(start),
_ => None,
});
let Some(indent_location) = indent else {
// Simple statement: from the colon to the end of the line.
return (colon_location, Location::new(end_location.row() + 1, 0));
};
// Compound statement: from the colon to the end of the block.
let mut offset = 0;
for (index, line) in contents[end_index..].lines().skip(1).enumerate() {
if line.is_empty() {
continue;
}
if line
.chars()
.take(indent_location.column())
.all(char::is_whitespace)
{
offset = index + 1;
} else {
break;
}
}
let end_location = Location::new(end_location.row() + 1 + offset, 0);
(colon_location, end_location)
}
/// Return true if the `orelse` block of an `if` statement is an `elif` statement.
pub fn is_elif(orelse: &[rustpython_parser::ast::Stmt], locator: &Locator) -> bool {
if orelse.len() == 1 && matches!(orelse[0].node, rustpython_parser::ast::StmtKind::If { .. }) {
let (source, start, end) = locator.slice(Range::new(
orelse[0].location,
orelse[0].end_location.unwrap(),
));
if source[start..end].starts_with("elif") {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
#[test]

View file

@ -108,6 +108,15 @@ impl<'a> Locator<'a> {
self.index.get_or_init(|| index(self.contents))
}
pub fn index(&self, location: Location) -> usize {
let index = self.get_or_init_index();
truncate(location, index, self.contents)
}
pub fn contents(&self) -> &str {
self.contents
}
/// Slice the source code at a [`Range`].
pub fn slice(&self, range: Range) -> (Rc<str>, usize, usize) {
let index = self.get_or_init_index();

View file

@ -1,8 +1,8 @@
use rustpython_parser::ast::Constant;
use crate::cst::{
Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Excepthandler, ExcepthandlerKind, Expr,
ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern, PatternKind, SliceIndex,
Alias, Arg, Arguments, Body, Boolop, Cmpop, Comprehension, Excepthandler, ExcepthandlerKind,
Expr, ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern, PatternKind, SliceIndex,
SliceIndexKind, Stmt, StmtKind, Unaryop, Withitem,
};
@ -67,13 +67,13 @@ pub trait Visitor<'a> {
fn visit_pattern(&mut self, pattern: &'a mut Pattern) {
walk_pattern(self, pattern);
}
fn visit_body(&mut self, body: &'a mut [Stmt]) {
fn visit_body(&mut self, body: &'a mut Body) {
walk_body(self, body);
}
}
pub fn walk_body<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, body: &'a mut [Stmt]) {
for stmt in body {
pub fn walk_body<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, body: &'a mut Body) {
for stmt in &mut body.node {
visitor.visit_stmt(stmt);
}
}
@ -173,7 +173,9 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
visitor.visit_expr(iter);
visitor.visit_expr(target);
visitor.visit_body(body);
visitor.visit_body(orelse);
if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
}
StmtKind::AsyncFor {
target,
@ -185,17 +187,25 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
visitor.visit_expr(iter);
visitor.visit_expr(target);
visitor.visit_body(body);
visitor.visit_body(orelse);
if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
}
StmtKind::While { test, body, orelse } => {
visitor.visit_expr(test);
visitor.visit_body(body);
visitor.visit_body(orelse);
if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
}
StmtKind::If { test, body, orelse } => {
StmtKind::If {
test, body, orelse, ..
} => {
visitor.visit_expr(test);
visitor.visit_body(body);
visitor.visit_body(orelse);
if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
}
StmtKind::With { items, body, .. } => {
for withitem in items {
@ -210,7 +220,6 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
visitor.visit_body(body);
}
StmtKind::Match { subject, cases } => {
// TODO(charlie): Handle `cases`.
visitor.visit_expr(subject);
for match_case in cases {
visitor.visit_match_case(match_case);
@ -234,8 +243,12 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
for excepthandler in handlers {
visitor.visit_excepthandler(excepthandler);
}
visitor.visit_body(orelse);
visitor.visit_body(finalbody);
if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
if let Some(finalbody) = finalbody {
visitor.visit_body(finalbody);
}
}
StmtKind::TryStar {
body,
@ -247,8 +260,12 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm
for excepthandler in handlers {
visitor.visit_excepthandler(excepthandler);
}
visitor.visit_body(orelse);
visitor.visit_body(finalbody);
if let Some(orelse) = orelse {
visitor.visit_body(orelse);
}
if let Some(finalbody) = finalbody {
visitor.visit_body(finalbody);
}
}
StmtKind::Assert { test, msg } => {
visitor.visit_expr(test);

File diff suppressed because it is too large Load diff

View file

@ -4,17 +4,55 @@ use ruff_text_size::{TextRange, TextSize};
use crate::context::ASTFormatContext;
use crate::core::types::Range;
use crate::cst::Stmt;
use crate::cst::{Body, Stmt};
use crate::shared_traits::AsFormat;
use crate::trivia::{Relationship, TriviaKind};
#[derive(Copy, Clone)]
pub struct Block<'a> {
body: &'a [Stmt],
body: &'a Body,
}
impl Format<ASTFormatContext<'_>> for Block<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
for (i, stmt) in self.body.iter().enumerate() {
for (i, stmt) in self.body.node.iter().enumerate() {
if i > 0 {
write!(f, [hard_line_break()])?;
}
write!(f, [stmt.format()])?;
}
for trivia in &self.body.trivia {
if matches!(trivia.relationship, Relationship::Dangling) {
match trivia.kind {
TriviaKind::EmptyLine => {
write!(f, [empty_line()])?;
}
TriviaKind::OwnLineComment(range) => {
write!(f, [literal(range), hard_line_break()])?;
}
_ => {}
}
}
}
Ok(())
}
}
#[inline]
pub fn block(body: &Body) -> Block {
Block { body }
}
#[derive(Copy, Clone)]
pub struct Statements<'a> {
suite: &'a [Stmt],
}
impl Format<ASTFormatContext<'_>> for Statements<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
for (i, stmt) in self.suite.iter().enumerate() {
if i > 0 {
write!(f, [hard_line_break()])?;
}
@ -24,9 +62,8 @@ impl Format<ASTFormatContext<'_>> for Block<'_> {
}
}
#[inline]
pub fn block(body: &[Stmt]) -> Block {
Block { body }
pub fn statements(suite: &[Stmt]) -> Statements {
Statements { suite }
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]

View file

@ -72,13 +72,12 @@ pub struct EndOfLineComments<'a, T> {
impl<T> Format<ASTFormatContext<'_>> for EndOfLineComments<'_, T> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
let mut first = true;
for range in self.item.trivia.iter().filter_map(|trivia| {
if trivia.relationship.is_trailing() {
trivia.kind.end_of_line_comment()
} else {
None
}
}) {
for range in self
.item
.trivia
.iter()
.filter_map(|trivia| trivia.kind.end_of_line_comment())
{
if std::mem::take(&mut first) {
write!(f, [line_suffix(&text(" "))])?;
}

View file

@ -4,6 +4,7 @@ use ruff_formatter::write;
use crate::context::ASTFormatContext;
use crate::cst::MatchCase;
use crate::format::builders::block;
use crate::format::comments::{end_of_line_comments, leading_comments};
use crate::shared_traits::AsFormat;
pub struct FormatMatchCase<'a> {
@ -26,12 +27,16 @@ impl Format<ASTFormatContext<'_>> for FormatMatchCase<'_> {
body,
} = self.item;
write!(f, [leading_comments(pattern)])?;
write!(f, [text("case")])?;
write!(f, [space(), pattern.format()])?;
if let Some(guard) = &guard {
write!(f, [space(), text("if"), space(), guard.format()])?;
}
write!(f, [text(":")])?;
write!(f, [end_of_line_comments(body)])?;
write!(f, [block_indent(&block(body))])?;
Ok(())

View file

@ -6,8 +6,8 @@ use ruff_text_size::TextSize;
use crate::context::ASTFormatContext;
use crate::cst::{
Alias, Arguments, Excepthandler, Expr, ExprKind, Keyword, MatchCase, Operator, Stmt, StmtKind,
Withitem,
Alias, Arguments, Body, Excepthandler, Expr, ExprKind, Keyword, MatchCase, Operator, Stmt,
StmtKind, Withitem,
};
use crate::format::builders::{block, join_names};
use crate::format::comments::{end_of_line_comments, leading_comments, trailing_comments};
@ -101,13 +101,15 @@ fn format_class_def(
name: &str,
bases: &[Expr],
keywords: &[Keyword],
body: &[Stmt],
body: &Body,
decorator_list: &[Expr],
) -> FormatResult<()> {
for decorator in decorator_list {
write!(f, [text("@"), decorator.format(), hard_line_break()])?;
}
write!(f, [leading_comments(body)])?;
write!(
f,
[
@ -161,6 +163,7 @@ fn format_class_def(
)?;
}
write!(f, [end_of_line_comments(body)])?;
write!(f, [text(":"), block_indent(&block(body))])
}
@ -170,13 +173,16 @@ fn format_func_def(
name: &str,
args: &Arguments,
returns: Option<&Expr>,
body: &[Stmt],
body: &Body,
decorator_list: &[Expr],
async_: bool,
) -> FormatResult<()> {
for decorator in decorator_list {
write!(f, [text("@"), decorator.format(), hard_line_break()])?;
}
write!(f, [leading_comments(body)])?;
if async_ {
write!(f, [text("async"), space()])?;
}
@ -202,10 +208,10 @@ fn format_func_def(
}
write!(f, [text(":")])?;
write!(f, [end_of_line_comments(body)])?;
write!(f, [block_indent(&block(body))])?;
write!(f, [end_of_line_comments(stmt)])?;
write!(f, [block_indent(&format_args![block(body)])])
Ok(())
}
fn format_assign(
@ -310,8 +316,8 @@ fn format_for(
stmt: &Stmt,
target: &Expr,
iter: &Expr,
body: &[Stmt],
orelse: &[Stmt],
body: &Body,
orelse: Option<&Body>,
_type_comment: Option<&str>,
async_: bool,
) -> FormatResult<()> {
@ -329,11 +335,19 @@ fn format_for(
space(),
group(&iter.format()),
text(":"),
end_of_line_comments(body),
block_indent(&block(body))
]
)?;
if !orelse.is_empty() {
write!(f, [text("else:"), block_indent(&block(orelse))])?;
if let Some(orelse) = orelse {
write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
}
Ok(())
}
@ -342,8 +356,8 @@ fn format_while(
f: &mut Formatter<ASTFormatContext<'_>>,
stmt: &Stmt,
test: &Expr,
body: &[Stmt],
orelse: &[Stmt],
body: &Body,
orelse: Option<&Body>,
) -> FormatResult<()> {
write!(f, [text("while"), space()])?;
if is_self_closing(test) {
@ -358,9 +372,23 @@ fn format_while(
])]
)?;
}
write!(f, [text(":"), block_indent(&block(body))])?;
if !orelse.is_empty() {
write!(f, [text("else:"), block_indent(&block(orelse))])?;
write!(
f,
[
text(":"),
end_of_line_comments(body),
block_indent(&block(body))
]
)?;
if let Some(orelse) = orelse {
write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
}
Ok(())
}
@ -368,10 +396,15 @@ fn format_while(
fn format_if(
f: &mut Formatter<ASTFormatContext<'_>>,
test: &Expr,
body: &[Stmt],
orelse: &[Stmt],
body: &Body,
orelse: Option<&Body>,
is_elif: bool,
) -> FormatResult<()> {
write!(f, [text("if"), space()])?;
if is_elif {
write!(f, [text("elif"), space()])?;
} else {
write!(f, [text("if"), space()])?;
}
if is_self_closing(test) {
write!(f, [test.format()])?;
} else {
@ -384,17 +417,43 @@ fn format_if(
])]
)?;
}
write!(f, [text(":"), block_indent(&block(body))])?;
if !orelse.is_empty() {
if orelse.len() == 1 {
if let StmtKind::If { test, body, orelse } = &orelse[0].node {
write!(f, [text("el")])?;
format_if(f, test, body, orelse)?;
write!(
f,
[
text(":"),
end_of_line_comments(body),
block_indent(&block(body))
]
)?;
if let Some(orelse) = orelse {
if orelse.node.len() == 1 {
if let StmtKind::If {
test,
body,
orelse,
is_elif: true,
} = &orelse.node[0].node
{
format_if(f, test, body, orelse.as_ref(), true)?;
} else {
write!(f, [text("else:"), block_indent(&block(orelse))])?;
write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
}
} else {
write!(f, [text("else:"), block_indent(&block(orelse))])?;
write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
}
}
Ok(())
@ -406,7 +465,16 @@ fn format_match(
subject: &Expr,
cases: &[MatchCase],
) -> FormatResult<()> {
write!(f, [text("match"), space(), subject.format(), text(":")])?;
write!(
f,
[
text("match"),
space(),
subject.format(),
text(":"),
end_of_line_comments(stmt),
]
)?;
for case in cases {
write!(f, [block_indent(&case.format())])?;
}
@ -447,20 +515,31 @@ fn format_return(
fn format_try(
f: &mut Formatter<ASTFormatContext<'_>>,
stmt: &Stmt,
body: &[Stmt],
body: &Body,
handlers: &[Excepthandler],
orelse: &[Stmt],
finalbody: &[Stmt],
orelse: Option<&Body>,
finalbody: Option<&Body>,
) -> FormatResult<()> {
write!(f, [text("try:"), block_indent(&block(body))])?;
write!(
f,
[
text("try:"),
end_of_line_comments(body),
block_indent(&block(body))
]
)?;
for handler in handlers {
write!(f, [handler.format()])?;
}
if !orelse.is_empty() {
write!(f, [text("else:"), block_indent(&block(orelse))])?;
if let Some(orelse) = orelse {
write!(f, [text("else:")])?;
write!(f, [end_of_line_comments(orelse)])?;
write!(f, [block_indent(&block(orelse))])?;
}
if !finalbody.is_empty() {
write!(f, [text("finally:"), block_indent(&block(finalbody))])?;
if let Some(finalbody) = finalbody {
write!(f, [text("finally:")])?;
write!(f, [end_of_line_comments(finalbody)])?;
write!(f, [block_indent(&block(finalbody))])?;
}
Ok(())
}
@ -468,21 +547,42 @@ fn format_try(
fn format_try_star(
f: &mut Formatter<ASTFormatContext<'_>>,
stmt: &Stmt,
body: &[Stmt],
body: &Body,
handlers: &[Excepthandler],
orelse: &[Stmt],
finalbody: &[Stmt],
orelse: Option<&Body>,
finalbody: Option<&Body>,
) -> FormatResult<()> {
write!(f, [text("try:"), block_indent(&block(body))])?;
write!(
f,
[
text("try:"),
end_of_line_comments(body),
block_indent(&block(body))
]
)?;
for handler in handlers {
// TODO(charlie): Include `except*`.
write!(f, [handler.format()])?;
}
if !orelse.is_empty() {
write!(f, [text("else:"), block_indent(&block(orelse))])?;
if let Some(orelse) = orelse {
write!(
f,
[
text("else:"),
end_of_line_comments(orelse),
block_indent(&block(orelse))
]
)?;
}
if !finalbody.is_empty() {
write!(f, [text("finally:"), block_indent(&block(finalbody))])?;
if let Some(finalbody) = finalbody {
write!(
f,
[
text("finally:"),
end_of_line_comments(finalbody),
block_indent(&block(finalbody))
]
)?;
}
Ok(())
}
@ -640,7 +740,7 @@ fn format_with_(
f: &mut Formatter<ASTFormatContext<'_>>,
stmt: &Stmt,
items: &[Withitem],
body: &[Stmt],
body: &Body,
type_comment: Option<&str>,
async_: bool,
) -> FormatResult<()> {
@ -668,6 +768,7 @@ fn format_with_(
if_group_breaks(&text(")")),
]),
text(":"),
end_of_line_comments(body),
block_indent(&block(body))
]
)?;
@ -753,7 +854,7 @@ impl Format<ASTFormatContext<'_>> for FormatStmt<'_> {
target,
iter,
body,
orelse,
orelse.as_ref(),
type_comment.as_deref(),
false,
),
@ -769,14 +870,19 @@ impl Format<ASTFormatContext<'_>> for FormatStmt<'_> {
target,
iter,
body,
orelse,
orelse.as_ref(),
type_comment.as_deref(),
true,
),
StmtKind::While { test, body, orelse } => {
format_while(f, self.item, test, body, orelse)
format_while(f, self.item, test, body, orelse.as_ref())
}
StmtKind::If { test, body, orelse } => format_if(f, test, body, orelse),
StmtKind::If {
test,
body,
orelse,
is_elif,
} => format_if(f, test, body, orelse.as_ref(), *is_elif),
StmtKind::With {
items,
body,
@ -810,13 +916,27 @@ impl Format<ASTFormatContext<'_>> for FormatStmt<'_> {
handlers,
orelse,
finalbody,
} => format_try(f, self.item, body, handlers, orelse, finalbody),
} => format_try(
f,
self.item,
body,
handlers,
orelse.as_ref(),
finalbody.as_ref(),
),
StmtKind::TryStar {
body,
handlers,
orelse,
finalbody,
} => format_try_star(f, self.item, body, handlers, orelse, finalbody),
} => format_try_star(
f,
self.item,
body,
handlers,
orelse.as_ref(),
finalbody.as_ref(),
),
StmtKind::Assert { test, msg } => {
format_assert(f, self.item, test, msg.as_ref().map(|expr| &**expr))
}

View file

@ -53,7 +53,7 @@ pub fn fmt(contents: &str) -> Result<Formatted<ASTFormatContext>> {
},
locator,
),
[format::builders::block(&python_cst)]
[format::builders::statements(&python_cst)]
)
.map_err(Into::into)
}

View file

@ -163,15 +163,14 @@ impl<'a> Visitor<'a> for StmtNormalizer {
self.trailer = Trailer::CompoundStatement;
self.visit_body(body);
if !orelse.is_empty() {
if let Some(orelse) = orelse {
// If the previous body ended with a function or class definition, we need to
// insert an empty line before the else block. Since the `else` itself isn't
// a statement, we need to insert it into the last statement of the body.
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
let stmt = body.last_mut().unwrap();
stmt.trivia.push(Trivia {
body.trivia.push(Trivia {
kind: TriviaKind::EmptyLine,
relationship: Relationship::Trailing,
relationship: Relationship::Dangling,
});
}
@ -185,12 +184,11 @@ impl<'a> Visitor<'a> for StmtNormalizer {
self.trailer = Trailer::CompoundStatement;
self.visit_body(body);
if !orelse.is_empty() {
if let Some(orelse) = orelse {
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
let stmt = body.last_mut().unwrap();
stmt.trivia.push(Trivia {
body.trivia.push(Trivia {
kind: TriviaKind::EmptyLine,
relationship: Relationship::Trailing,
relationship: Relationship::Dangling,
});
}
@ -220,49 +218,44 @@ impl<'a> Visitor<'a> for StmtNormalizer {
self.depth = Depth::Nested;
self.trailer = Trailer::CompoundStatement;
self.visit_body(body);
let mut last = body.last_mut();
let mut prev = &mut body.trivia;
for handler in handlers {
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
if let Some(stmt) = last.as_mut() {
stmt.trivia.push(Trivia {
kind: TriviaKind::EmptyLine,
relationship: Relationship::Trailing,
});
}
prev.push(Trivia {
kind: TriviaKind::EmptyLine,
relationship: Relationship::Dangling,
});
}
self.depth = Depth::Nested;
self.trailer = Trailer::CompoundStatement;
let ExcepthandlerKind::ExceptHandler { body, .. } = &mut handler.node;
self.visit_body(body);
last = body.last_mut();
prev = &mut body.trivia;
}
if !orelse.is_empty() {
if let Some(orelse) = orelse {
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
if let Some(stmt) = last.as_mut() {
stmt.trivia.push(Trivia {
kind: TriviaKind::EmptyLine,
relationship: Relationship::Trailing,
});
}
prev.push(Trivia {
kind: TriviaKind::EmptyLine,
relationship: Relationship::Dangling,
});
}
self.depth = Depth::Nested;
self.trailer = Trailer::CompoundStatement;
self.visit_body(orelse);
last = body.last_mut();
prev = &mut body.trivia;
}
if !finalbody.is_empty() {
if let Some(finalbody) = finalbody {
if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) {
if let Some(stmt) = last.as_mut() {
stmt.trivia.push(Trivia {
kind: TriviaKind::EmptyLine,
relationship: Relationship::Trailing,
});
}
prev.push(Trivia {
kind: TriviaKind::EmptyLine,
relationship: Relationship::Dangling,
});
}
self.depth = Depth::Nested;

View file

@ -194,5 +194,7 @@ impl<'a> Visitor<'a> for ParenthesesNormalizer<'_> {
/// during formatting) and `Parenthesize` (which are used during formatting).
pub fn normalize_parentheses(python_cst: &mut [Stmt], locator: &Locator) {
let mut normalizer = ParenthesesNormalizer { locator };
normalizer.visit_body(python_cst);
for stmt in python_cst {
normalizer.visit_stmt(stmt);
}
}

View file

@ -196,15 +196,7 @@ instruction()#comment with bad spacing
"Generator",
]
@@ -54,32 +54,39 @@
# for compiler in compilers.values():
# add_compiler(compiler)
add_compiler(compilers[(7.0, 32)])
- # add_compiler(compilers[(7.1, 64)])
+# add_compiler(compilers[(7.1, 64)])
+
@@ -60,26 +60,32 @@
# Comment before function.
def inline_comments_in_brackets_ruin_everything():
if typedargslist:
@ -229,7 +221,7 @@ instruction()#comment with bad spacing
+ parameters.what_if_this_was_actually_long.children[0],
+ body,
+ parameters.children[-1],
+ ]
+ ] # type: ignore
if (
self._proc is not None
- # has the child process finished?
@ -246,7 +238,7 @@ instruction()#comment with bad spacing
):
pass
# no newline before or after
@@ -103,42 +110,42 @@
@@ -103,35 +109,35 @@
############################################################################
call2(
@ -298,16 +290,7 @@ instruction()#comment with bad spacing
]
while True:
if False:
continue
- # and round and round we go
- # and round and round we go
+ # and round and round we go
+ # and round and round we go
# let's return
return Node(
@@ -167,7 +174,7 @@
@@ -167,7 +173,7 @@
#######################
@ -377,10 +360,9 @@ else:
# for compiler in compilers.values():
# add_compiler(compiler)
add_compiler(compilers[(7.0, 32)])
# add_compiler(compilers[(7.1, 64)])
# add_compiler(compilers[(7.1, 64)])
# Comment before function.
def inline_comments_in_brackets_ruin_everything():
if typedargslist:
@ -400,7 +382,7 @@ def inline_comments_in_brackets_ruin_everything():
parameters.what_if_this_was_actually_long.children[0],
body,
parameters.children[-1],
]
] # type: ignore
if (
self._proc is not None
and # has the child process finished?
@ -467,8 +449,8 @@ short
if False:
continue
# and round and round we go
# and round and round we go
# and round and round we go
# and round and round we go
# let's return
return Node(

View file

@ -131,17 +131,15 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
```diff
--- Black
+++ Ruff
@@ -2,8 +2,8 @@
@@ -2,7 +2,7 @@
def f(
- a, # type: int
-):
+ a,
+): # type: int
):
pass
@@ -14,44 +14,42 @@
@ -155,7 +153,6 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
- g, # type: int
- h, # type: int
- i, # type: int
-):
+ a,
+ b,
+ c,
@ -165,7 +162,7 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
+ g,
+ h,
+ i,
+): # type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int
):
# type: (...) -> None
pass
@ -175,12 +172,11 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
- *args, # type: *Any
- default=False, # type: bool
- **kwargs, # type: **Any
-):
+ arg,
+ *args,
+ default=False,
+ **kwargs,
+): # type: int# type: *Any
):
# type: (...) -> None
pass
@ -190,12 +186,11 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
- b, # type: int
- c, # type: int
- d, # type: int
-):
+ a,
+ b,
+ c,
+ d,
+): # type: int# type: int# type: int# type: int# type: int
):
# type: (...) -> None
element = 0 # type: int
@ -208,41 +203,32 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
an_element_with_a_long_value = calls() or more_calls() and more() # type: bool
tup = (
@@ -66,26 +64,26 @@
+ element
+ another_element
+ another_element_with_long_name
- ) # type: int
+ )
@@ -70,21 +68,21 @@
def f(
- x, # not a type comment
- y, # type: int
-):
+ x,
+ y,
+): # not a type comment# type: int
):
# type: (...) -> None
pass
def f(
- x, # not a type comment
-): # type: (int) -> None
+ x,
+): # not a type comment# type: (int) -> None
): # type: (int) -> None
pass
def func(
- a=some_list[0], # type: int
-): # type: () -> int
+ a=some_list[0],
+):
): # type: () -> int
c = call(
0.0123,
0.0456,
@@ -96,23 +94,37 @@
0.0123,
0.0456,
@ -298,7 +284,7 @@ from typing import Any, Tuple
def f(
a,
): # type: int
):
pass
@ -318,7 +304,7 @@ def f(
g,
h,
i,
): # type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int
):
# type: (...) -> None
pass
@ -328,7 +314,7 @@ def f(
*args,
default=False,
**kwargs,
): # type: int# type: *Any
):
# type: (...) -> None
pass
@ -338,7 +324,7 @@ def f(
b,
c,
d,
): # type: int# type: int# type: int# type: int# type: int
):
# type: (...) -> None
element = 0 # type: int
@ -359,26 +345,26 @@ def f(
+ element
+ another_element
+ another_element_with_long_name
)
) # type: int
def f(
x,
y,
): # not a type comment# type: int
):
# type: (...) -> None
pass
def f(
x,
): # not a type comment# type: (int) -> None
): # type: (int) -> None
pass
def func(
a=some_list[0],
):
): # type: () -> int
c = call(
0.0123,
0.0456,

View file

@ -164,7 +164,7 @@ def bar():
# This should be split from the above by two lines
class MyClassWithComplexLeadingComments:
pass
@@ -57,13 +58,13 @@
@@ -57,11 +58,11 @@
# leading 1
@deco1
@ -174,16 +174,13 @@ def bar():
-@deco2(with_args=True)
-# leading 3
-@deco3
-# leading 4
+deco2(with_args=True)
+@# leading 3
+deco3
# leading 4
def decorated():
+ # leading 4
pass
@@ -72,13 +73,12 @@
@@ -72,11 +73,10 @@
# leading 1
@deco1
@ -192,17 +189,14 @@ def bar():
-
-# leading 3 that already has an empty line
-@deco3
-# leading 4
+@# leading 2
+deco2(with_args=True)
+@# leading 3 that already has an empty line
+deco3
# leading 4
def decorated_with_split_leading_comments():
+ # leading 4
pass
@@ -87,18 +87,18 @@
@@ -87,10 +87,10 @@
# leading 1
@deco1
@ -210,16 +204,14 @@ def bar():
-@deco2(with_args=True)
-# leading 3
-@deco3
-
-# leading 4 that already has an empty line
+@# leading 2
+deco2(with_args=True)
+@# leading 3
+deco3
def decorated_with_split_leading_comments():
+ # leading 4 that already has an empty line
pass
# leading 4 that already has an empty line
def decorated_with_split_leading_comments():
@@ -99,6 +99,7 @@
def main():
if a:
@ -227,7 +219,7 @@ def bar():
# Leading comment before inline function
def inline():
pass
@@ -108,12 +108,14 @@
@@ -108,12 +109,14 @@
pass
else:
@ -242,7 +234,7 @@ def bar():
# Leading comment before "top-level inline" function
def top_level_quote_inline():
pass
@@ -123,6 +125,7 @@
@@ -123,6 +126,7 @@
pass
else:
@ -250,37 +242,6 @@ def bar():
# More leading comments
def top_level_quote_inline_after_else():
pass
@@ -138,9 +141,11 @@
# Regression test for https://github.com/psf/black/issues/3454.
def foo():
pass
- # Trailing comment that belongs to this function
+# Trailing comment that belongs to this function
+
+
@decorator1
@decorator2 # fmt: skip
def bar():
@@ -150,12 +155,13 @@
# Regression test for https://github.com/psf/black/issues/3454.
def foo():
pass
- # Trailing comment that belongs to this function.
- # NOTE this comment only has one empty line below, and the formatter
- # should enforce two blank lines.
+# Trailing comment that belongs to this function.
+# NOTE this comment only has one empty line below, and the formatter
+# should enforce two blank lines.
+
@decorator1
-# A standalone comment
def bar():
+ # A standalone comment
pass
```
## Ruff Output
@ -351,8 +312,8 @@ some = statement
deco2(with_args=True)
@# leading 3
deco3
# leading 4
def decorated():
# leading 4
pass
@ -365,8 +326,8 @@ some = statement
deco2(with_args=True)
@# leading 3 that already has an empty line
deco3
# leading 4
def decorated_with_split_leading_comments():
# leading 4
pass
@ -379,8 +340,9 @@ some = statement
deco2(with_args=True)
@# leading 3
deco3
# leading 4 that already has an empty line
def decorated_with_split_leading_comments():
# leading 4 that already has an empty line
pass
@ -429,9 +391,7 @@ class MyClass:
# Regression test for https://github.com/psf/black/issues/3454.
def foo():
pass
# Trailing comment that belongs to this function
# Trailing comment that belongs to this function
@decorator1
@ -443,15 +403,14 @@ def bar():
# Regression test for https://github.com/psf/black/issues/3454.
def foo():
pass
# Trailing comment that belongs to this function.
# NOTE this comment only has one empty line below, and the formatter
# should enforce two blank lines.
# Trailing comment that belongs to this function.
# NOTE this comment only has one empty line below, and the formatter
# should enforce two blank lines.
@decorator1
# A standalone comment
def bar():
# A standalone comment
pass
```

View file

@ -204,7 +204,7 @@ class C:
)
self.assertEqual(
unstyle(str(report)),
@@ -22,133 +23,156 @@
@@ -22,133 +23,155 @@
if (
# Rule 1
i % 2 == 0
@ -217,10 +217,10 @@ class C:
- while (
- # Just a comment
- call()
- # Another
- ):
+ while # Just a comment
+ call():
# Another
- ):
print(i)
xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy(
push_manager=context.request.resource_manager,
@ -460,7 +460,7 @@ class C:
"Not what we expected and the message is too long to fit in one line"
" because it's too long"
)
@@ -161,9 +185,8 @@
@@ -161,9 +184,8 @@
8 STORE_ATTR 0 (x)
10 LOAD_CONST 0 (None)
12 RETURN_VALUE
@ -508,7 +508,6 @@ class C:
):
while # Just a comment
call():
# Another
print(i)
xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy(
push_manager=context.request.resource_manager,

View file

@ -204,7 +204,7 @@ class C:
)
self.assertEqual(
unstyle(str(report)),
@@ -22,133 +23,156 @@
@@ -22,133 +23,155 @@
if (
# Rule 1
i % 2 == 0
@ -217,10 +217,10 @@ class C:
- while (
- # Just a comment
- call()
- # Another
- ):
+ while # Just a comment
+ call():
# Another
- ):
print(i)
xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy(
push_manager=context.request.resource_manager,
@ -460,7 +460,7 @@ class C:
"Not what we expected and the message is too long to fit in one line"
" because it's too long"
)
@@ -161,9 +185,8 @@
@@ -161,9 +184,8 @@
8 STORE_ATTR 0 (x)
10 LOAD_CONST 0 (None)
12 RETURN_VALUE
@ -508,7 +508,6 @@ class C:
):
while # Just a comment
call():
# Another
print(i)
xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy(
push_manager=context.request.resource_manager,

View file

@ -26,13 +26,12 @@ def f(): pass
```diff
--- Black
+++ Ruff
@@ -1,10 +1,14 @@
@@ -1,8 +1,12 @@
# fmt: off
-@test([
- 1, 2,
- 3, 4,
-])
-# fmt: on
+@test(
+ [
+ 1,
@ -41,11 +40,9 @@ def f(): pass
+ 4,
+ ]
+)
# fmt: on
def f():
+ # fmt: on
pass
```
## Ruff Output
@ -60,8 +57,8 @@ def f(): pass
4,
]
)
# fmt: on
def f():
# fmt: on
pass

View file

@ -165,7 +165,7 @@ elif unformatted:
print("This will be formatted")
@@ -68,20 +62,21 @@
@@ -68,20 +62,19 @@
class Named(t.Protocol):
# fmt: off
@property
@ -177,11 +177,9 @@ elif unformatted:
class Factory(t.Protocol):
def this_will_be_formatted(self, **kwargs) -> Named:
...
-
# fmt: on
- # fmt: on
+# fmt: on
+
# Regression test for https://github.com/psf/black/issues/3436.
if x:
@ -267,9 +265,7 @@ class Named(t.Protocol):
class Factory(t.Protocol):
def this_will_be_formatted(self, **kwargs) -> Named:
...
# fmt: on
# fmt: on
# Regression test for https://github.com/psf/black/issues/3436.

View file

@ -1,49 +0,0 @@
---
source: crates/ruff_python_formatter/src/lib.rs
expression: snapshot
input_file: crates/ruff_python_formatter/resources/test/fixtures/black/simple_cases/fmtskip6.py
---
## Input
```py
class A:
def f(self):
for line in range(10):
if True:
pass # fmt: skip
```
## Black Differences
```diff
--- Black
+++ Ruff
@@ -2,4 +2,4 @@
def f(self):
for line in range(10):
if True:
- pass # fmt: skip
+ pass
```
## Ruff Output
```py
class A:
def f(self):
for line in range(10):
if True:
pass
```
## Black Output
```py
class A:
def f(self):
for line in range(10):
if True:
pass # fmt: skip
```

View file

@ -117,12 +117,13 @@ def __await__(): return (yield)
def func_no_args():
@@ -64,19 +64,14 @@
@@ -64,19 +64,15 @@
def spaces2(result=_core.Value(None)):
- assert fut is self._read_fut, (fut, self._read_fut)
+ assert fut is self._read_fut, fut, self._read_fut
+
def example(session):
@ -142,7 +143,7 @@ def __await__(): return (yield)
def long_lines():
@@ -135,14 +130,13 @@
@@ -135,14 +131,13 @@
a,
**kwargs,
) -> A:
@ -233,6 +234,7 @@ def spaces2(result=_core.Value(None)):
assert fut is self._read_fut, fut, self._read_fut
def example(session):
result = session.query(models.Customer.id).filter(
models.Customer.account_id == account_id,

View file

@ -123,10 +123,9 @@ async def main():
+ await (asyncio.sleep(1)) # Hello
-async def main():
async def main():
- await asyncio.sleep(1) # Hello
+async def main(): # Hello
+ await (asyncio.sleep(1))
+ await (asyncio.sleep(1)) # Hello
# Long lines
@ -231,8 +230,8 @@ async def main():
await (asyncio.sleep(1)) # Hello
async def main(): # Hello
await (asyncio.sleep(1))
async def main():
await (asyncio.sleep(1)) # Hello
# Long lines

View file

@ -5,13 +5,14 @@ use rustpython_parser::Tok;
use crate::core::types::Range;
use crate::cst::{
Alias, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Pattern, PatternKind, SliceIndex,
SliceIndexKind, Stmt, StmtKind,
Alias, Body, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Pattern, PatternKind,
SliceIndex, SliceIndexKind, Stmt, StmtKind,
};
#[derive(Clone, Debug)]
pub enum Node<'a> {
Mod(&'a [Stmt]),
Body(&'a Body),
Stmt(&'a Stmt),
Expr(&'a Expr),
Alias(&'a Alias),
@ -24,6 +25,7 @@ impl Node<'_> {
pub fn id(&self) -> usize {
match self {
Node::Mod(nodes) => nodes as *const _ as usize,
Node::Body(node) => node.id(),
Node::Stmt(node) => node.id(),
Node::Expr(node) => node.id(),
Node::Alias(node) => node.id(),
@ -227,6 +229,11 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Stmt(stmt));
}
}
Node::Body(body) => {
for stmt in &body.node {
result.push(Node::Stmt(stmt));
}
}
Node::Stmt(stmt) => match &stmt.node {
StmtKind::Return { value } => {
if let Some(value) = value {
@ -294,9 +301,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
if let Some(returns) = returns {
result.push(Node::Expr(returns));
}
for stmt in body {
result.push(Node::Stmt(stmt));
}
result.push(Node::Body(body));
}
StmtKind::ClassDef {
bases,
@ -314,9 +319,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
for keyword in keywords {
result.push(Node::Expr(&keyword.node.value));
}
for stmt in body {
result.push(Node::Stmt(stmt));
}
result.push(Node::Body(body));
}
StmtKind::Delete { targets } => {
for target in targets {
@ -355,29 +358,25 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
} => {
result.push(Node::Expr(target));
result.push(Node::Expr(iter));
for stmt in body {
result.push(Node::Stmt(stmt));
}
for stmt in orelse {
result.push(Node::Stmt(stmt));
result.push(Node::Body(body));
if let Some(orelse) = orelse {
result.push(Node::Body(orelse));
}
}
StmtKind::While { test, body, orelse } => {
result.push(Node::Expr(test));
for stmt in body {
result.push(Node::Stmt(stmt));
}
for stmt in orelse {
result.push(Node::Stmt(stmt));
result.push(Node::Body(body));
if let Some(orelse) = orelse {
result.push(Node::Body(orelse));
}
}
StmtKind::If { test, body, orelse } => {
StmtKind::If {
test, body, orelse, ..
} => {
result.push(Node::Expr(test));
for stmt in body {
result.push(Node::Stmt(stmt));
}
for stmt in orelse {
result.push(Node::Stmt(stmt));
result.push(Node::Body(body));
if let Some(orelse) = orelse {
result.push(Node::Body(orelse));
}
}
StmtKind::With { items, body, .. } | StmtKind::AsyncWith { items, body, .. } => {
@ -387,9 +386,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Expr(expr));
}
}
for stmt in body {
result.push(Node::Stmt(stmt));
}
result.push(Node::Body(body));
}
StmtKind::Match { subject, cases } => {
result.push(Node::Expr(subject));
@ -398,9 +395,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
if let Some(expr) = &case.guard {
result.push(Node::Expr(expr));
}
for stmt in &case.body {
result.push(Node::Stmt(stmt));
}
result.push(Node::Body(&case.body));
}
}
StmtKind::Raise { exc, cause } => {
@ -431,17 +426,15 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
orelse,
finalbody,
} => {
for stmt in body {
result.push(Node::Stmt(stmt));
}
result.push(Node::Body(body));
for handler in handlers {
result.push(Node::Excepthandler(handler));
}
for stmt in orelse {
result.push(Node::Stmt(stmt));
if let Some(orelse) = orelse {
result.push(Node::Body(orelse));
}
for stmt in finalbody {
result.push(Node::Stmt(stmt));
if let Some(finalbody) = finalbody {
result.push(Node::Body(finalbody));
}
}
StmtKind::Import { names } => {
@ -457,7 +450,6 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
StmtKind::Global { .. } => {}
StmtKind::Nonlocal { .. } => {}
},
// TODO(charlie): Actual logic, this doesn't do anything.
Node::Expr(expr) => match &expr.node {
ExprKind::BoolOp { values, .. } => {
for value in values {
@ -476,7 +468,6 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Expr(operand));
}
ExprKind::Lambda { body, args, .. } => {
// TODO(charlie): Arguments.
for expr in &args.defaults {
result.push(Node::Expr(expr));
}
@ -630,9 +621,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
if let Some(type_) = type_ {
result.push(Node::Expr(type_));
}
for stmt in body {
result.push(Node::Stmt(stmt));
}
result.push(Node::Body(body));
}
Node::SliceIndex(slice_index) => {
if let SliceIndexKind::Index { value } = &slice_index.node {
@ -717,6 +706,7 @@ pub fn decorate_token<'a>(
let middle = (left + right) / 2;
let child = &child_nodes[middle];
let start = match &child {
Node::Body(node) => node.location,
Node::Stmt(node) => node.location,
Node::Expr(node) => node.location,
Node::Alias(node) => node.location,
@ -726,6 +716,7 @@ pub fn decorate_token<'a>(
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
};
let end = match &child {
Node::Body(node) => node.end_location.unwrap(),
Node::Stmt(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(),
Node::Alias(node) => node.end_location.unwrap(),
@ -739,6 +730,7 @@ pub fn decorate_token<'a>(
// Special-case: if we're dealing with a statement that's a single expression,
// we want to treat the expression as the enclosed node.
let existing_start = match &existing {
Node::Body(node) => node.location,
Node::Stmt(node) => node.location,
Node::Expr(node) => node.location,
Node::Alias(node) => node.location,
@ -748,6 +740,7 @@ pub fn decorate_token<'a>(
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
};
let existing_end = match &existing {
Node::Body(node) => node.end_location.unwrap(),
Node::Stmt(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(),
Node::Alias(node) => node.end_location.unwrap(),
@ -809,6 +802,7 @@ pub fn decorate_token<'a>(
#[derive(Debug, Default)]
pub struct TriviaIndex {
pub body: FxHashMap<usize, Vec<Trivia>>,
pub stmt: FxHashMap<usize, Vec<Trivia>>,
pub expr: FxHashMap<usize, Vec<Trivia>>,
pub alias: FxHashMap<usize, Vec<Trivia>>,
@ -820,6 +814,13 @@ pub struct TriviaIndex {
fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) {
match node {
Node::Mod(_) => {}
Node::Body(node) => {
trivia
.body
.entry(node.id())
.or_insert_with(Vec::new)
.push(comment);
}
Node::Stmt(node) => {
trivia
.stmt