diff --git a/crates/ruff_python_formatter/src/attachment.rs b/crates/ruff_python_formatter/src/attachment.rs index b11daa3bce..090d8f0939 100644 --- a/crates/ruff_python_formatter/src/attachment.rs +++ b/crates/ruff_python_formatter/src/attachment.rs @@ -1,6 +1,6 @@ use crate::core::visitor; use crate::core::visitor::Visitor; -use crate::cst::{Alias, Excepthandler, Expr, Pattern, SliceIndex, Stmt}; +use crate::cst::{Alias, Body, Excepthandler, Expr, Pattern, SliceIndex, Stmt}; use crate::trivia::{decorate_trivia, TriviaIndex, TriviaToken}; struct AttachmentVisitor { @@ -8,6 +8,14 @@ struct AttachmentVisitor { } impl<'a> Visitor<'a> for AttachmentVisitor { + fn visit_body(&mut self, body: &'a mut Body) { + let trivia = self.index.body.remove(&body.id()); + if let Some(comments) = trivia { + body.trivia.extend(comments); + } + visitor::walk_body(self, body); + } + fn visit_stmt(&mut self, stmt: &'a mut Stmt) { let trivia = self.index.stmt.remove(&stmt.id()); if let Some(comments) = trivia { @@ -59,5 +67,8 @@ impl<'a> Visitor<'a> for AttachmentVisitor { pub fn attach(python_cst: &mut [Stmt], trivia: Vec) { let index = decorate_trivia(trivia, python_cst); - AttachmentVisitor { index }.visit_body(python_cst); + let mut visitor = AttachmentVisitor { index }; + for stmt in python_cst { + visitor.visit_stmt(stmt); + } } diff --git a/crates/ruff_python_formatter/src/core/helpers.rs b/crates/ruff_python_formatter/src/core/helpers.rs index e6130fd59d..103e024d86 100644 --- a/crates/ruff_python_formatter/src/core/helpers.rs +++ b/crates/ruff_python_formatter/src/core/helpers.rs @@ -1,3 +1,8 @@ +use rustpython_parser::ast::Location; + +use crate::core::locator::Locator; +use crate::core::types::Range; + /// Return the leading quote for a string or byte literal (e.g., `"""`). pub fn leading_quote(content: &str) -> Option<&str> { if let Some(first_line) = content.lines().next() { @@ -32,6 +37,99 @@ pub fn is_radix_literal(content: &str) -> bool { || content.starts_with("0X") } +/// Expand the range of a compound statement. +/// +/// `location` is the start of the compound statement (e.g., the `if` in `if x:`). +/// `end_location` is the end of the last statement in the body. +pub fn expand_indented_block( + location: Location, + end_location: Location, + locator: &Locator, +) -> (Location, Location) { + let contents = locator.contents(); + let start_index = locator.index(location); + let end_index = locator.index(end_location); + + // Find the colon, which indicates the end of the header. + let mut nesting = 0; + let mut colon = None; + for (start, tok, _end) in rustpython_parser::lexer::lex_located( + &contents[start_index..end_index], + rustpython_parser::Mode::Module, + location, + ) + .flatten() + { + match tok { + rustpython_parser::Tok::Colon if nesting == 0 => { + colon = Some(start); + break; + } + rustpython_parser::Tok::Lpar + | rustpython_parser::Tok::Lsqb + | rustpython_parser::Tok::Lbrace => nesting += 1, + rustpython_parser::Tok::Rpar + | rustpython_parser::Tok::Rsqb + | rustpython_parser::Tok::Rbrace => nesting -= 1, + _ => {} + } + } + let colon_location = colon.unwrap(); + let colon_index = locator.index(colon_location); + + // From here, we have two options: simple statement or compound statement. + let indent = rustpython_parser::lexer::lex_located( + &contents[colon_index..end_index], + rustpython_parser::Mode::Module, + colon_location, + ) + .flatten() + .find_map(|(start, tok, _end)| match tok { + rustpython_parser::Tok::Indent => Some(start), + _ => None, + }); + + let Some(indent_location) = indent else { + // Simple statement: from the colon to the end of the line. + return (colon_location, Location::new(end_location.row() + 1, 0)); + }; + + // Compound statement: from the colon to the end of the block. + let mut offset = 0; + for (index, line) in contents[end_index..].lines().skip(1).enumerate() { + if line.is_empty() { + continue; + } + + if line + .chars() + .take(indent_location.column()) + .all(char::is_whitespace) + { + offset = index + 1; + } else { + break; + } + } + + let end_location = Location::new(end_location.row() + 1 + offset, 0); + (colon_location, end_location) +} + +/// Return true if the `orelse` block of an `if` statement is an `elif` statement. +pub fn is_elif(orelse: &[rustpython_parser::ast::Stmt], locator: &Locator) -> bool { + if orelse.len() == 1 && matches!(orelse[0].node, rustpython_parser::ast::StmtKind::If { .. }) { + let (source, start, end) = locator.slice(Range::new( + orelse[0].location, + orelse[0].end_location.unwrap(), + )); + if source[start..end].starts_with("elif") { + return true; + } + } + false +} + #[cfg(test)] mod tests { #[test] diff --git a/crates/ruff_python_formatter/src/core/locator.rs b/crates/ruff_python_formatter/src/core/locator.rs index e803e61e1c..060b9774ed 100644 --- a/crates/ruff_python_formatter/src/core/locator.rs +++ b/crates/ruff_python_formatter/src/core/locator.rs @@ -108,6 +108,15 @@ impl<'a> Locator<'a> { self.index.get_or_init(|| index(self.contents)) } + pub fn index(&self, location: Location) -> usize { + let index = self.get_or_init_index(); + truncate(location, index, self.contents) + } + + pub fn contents(&self) -> &str { + self.contents + } + /// Slice the source code at a [`Range`]. pub fn slice(&self, range: Range) -> (Rc, usize, usize) { let index = self.get_or_init_index(); diff --git a/crates/ruff_python_formatter/src/core/visitor.rs b/crates/ruff_python_formatter/src/core/visitor.rs index 51fc640a97..d7c7d519f3 100644 --- a/crates/ruff_python_formatter/src/core/visitor.rs +++ b/crates/ruff_python_formatter/src/core/visitor.rs @@ -1,8 +1,8 @@ use rustpython_parser::ast::Constant; use crate::cst::{ - Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Excepthandler, ExcepthandlerKind, Expr, - ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern, PatternKind, SliceIndex, + Alias, Arg, Arguments, Body, Boolop, Cmpop, Comprehension, Excepthandler, ExcepthandlerKind, + Expr, ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern, PatternKind, SliceIndex, SliceIndexKind, Stmt, StmtKind, Unaryop, Withitem, }; @@ -67,13 +67,13 @@ pub trait Visitor<'a> { fn visit_pattern(&mut self, pattern: &'a mut Pattern) { walk_pattern(self, pattern); } - fn visit_body(&mut self, body: &'a mut [Stmt]) { + fn visit_body(&mut self, body: &'a mut Body) { walk_body(self, body); } } -pub fn walk_body<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, body: &'a mut [Stmt]) { - for stmt in body { +pub fn walk_body<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, body: &'a mut Body) { + for stmt in &mut body.node { visitor.visit_stmt(stmt); } } @@ -173,7 +173,9 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm visitor.visit_expr(iter); visitor.visit_expr(target); visitor.visit_body(body); - visitor.visit_body(orelse); + if let Some(orelse) = orelse { + visitor.visit_body(orelse); + } } StmtKind::AsyncFor { target, @@ -185,17 +187,25 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm visitor.visit_expr(iter); visitor.visit_expr(target); visitor.visit_body(body); - visitor.visit_body(orelse); + if let Some(orelse) = orelse { + visitor.visit_body(orelse); + } } StmtKind::While { test, body, orelse } => { visitor.visit_expr(test); visitor.visit_body(body); - visitor.visit_body(orelse); + if let Some(orelse) = orelse { + visitor.visit_body(orelse); + } } - StmtKind::If { test, body, orelse } => { + StmtKind::If { + test, body, orelse, .. + } => { visitor.visit_expr(test); visitor.visit_body(body); - visitor.visit_body(orelse); + if let Some(orelse) = orelse { + visitor.visit_body(orelse); + } } StmtKind::With { items, body, .. } => { for withitem in items { @@ -210,7 +220,6 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm visitor.visit_body(body); } StmtKind::Match { subject, cases } => { - // TODO(charlie): Handle `cases`. visitor.visit_expr(subject); for match_case in cases { visitor.visit_match_case(match_case); @@ -234,8 +243,12 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm for excepthandler in handlers { visitor.visit_excepthandler(excepthandler); } - visitor.visit_body(orelse); - visitor.visit_body(finalbody); + if let Some(orelse) = orelse { + visitor.visit_body(orelse); + } + if let Some(finalbody) = finalbody { + visitor.visit_body(finalbody); + } } StmtKind::TryStar { body, @@ -247,8 +260,12 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a mut Stm for excepthandler in handlers { visitor.visit_excepthandler(excepthandler); } - visitor.visit_body(orelse); - visitor.visit_body(finalbody); + if let Some(orelse) = orelse { + visitor.visit_body(orelse); + } + if let Some(finalbody) = finalbody { + visitor.visit_body(finalbody); + } } StmtKind::Assert { test, msg } => { visitor.visit_expr(test); diff --git a/crates/ruff_python_formatter/src/cst.rs b/crates/ruff_python_formatter/src/cst.rs index c1222552ec..204178a3b3 100644 --- a/crates/ruff_python_formatter/src/cst.rs +++ b/crates/ruff_python_formatter/src/cst.rs @@ -1,5 +1,6 @@ #![allow(clippy::derive_partial_eq_without_eq)] +use crate::core::helpers::{expand_indented_block, is_elif}; use rustpython_parser::ast::{Constant, Location}; use rustpython_parser::Mode; @@ -157,12 +158,29 @@ impl From for Cmpop { } } +pub type Body = Located>; + +impl From<(Vec, &Locator<'_>)> for Body { + fn from((body, locator): (Vec, &Locator)) -> Self { + Body { + location: body.first().unwrap().location, + end_location: body.last().unwrap().end_location, + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + } +} + #[derive(Clone, Debug, PartialEq)] pub enum StmtKind { FunctionDef { name: Ident, args: Box, - body: Vec, + body: Body, decorator_list: Vec, returns: Option>, type_comment: Option, @@ -170,7 +188,7 @@ pub enum StmtKind { AsyncFunctionDef { name: Ident, args: Box, - body: Vec, + body: Body, decorator_list: Vec, returns: Option>, type_comment: Option, @@ -179,7 +197,7 @@ pub enum StmtKind { name: Ident, bases: Vec, keywords: Vec, - body: Vec, + body: Body, decorator_list: Vec, }, Return { @@ -207,35 +225,36 @@ pub enum StmtKind { For { target: Box, iter: Box, - body: Vec, - orelse: Vec, + body: Body, + orelse: Option, type_comment: Option, }, AsyncFor { target: Box, iter: Box, - body: Vec, - orelse: Vec, + body: Body, + orelse: Option, type_comment: Option, }, While { test: Box, - body: Vec, - orelse: Vec, + body: Body, + orelse: Option, }, If { test: Box, - body: Vec, - orelse: Vec, + body: Body, + orelse: Option, + is_elif: bool, }, With { items: Vec, - body: Vec, + body: Body, type_comment: Option, }, AsyncWith { items: Vec, - body: Vec, + body: Body, type_comment: Option, }, Match { @@ -247,16 +266,16 @@ pub enum StmtKind { cause: Option>, }, Try { - body: Vec, + body: Body, handlers: Vec, - orelse: Vec, - finalbody: Vec, + orelse: Option, + finalbody: Option, }, TryStar { - body: Vec, + body: Body, handlers: Vec, - orelse: Vec, - finalbody: Vec, + orelse: Option, + finalbody: Option, }, Assert { test: Box, @@ -417,7 +436,7 @@ pub enum ExcepthandlerKind { ExceptHandler { type_: Option>, name: Option, - body: Vec, + body: Body, }, } @@ -479,7 +498,7 @@ pub struct Withitem { pub struct MatchCase { pub pattern: Pattern, pub guard: Option>, - pub body: Vec, + pub body: Body, } #[allow(clippy::enum_variant_names)] @@ -549,16 +568,33 @@ impl From<(rustpython_parser::ast::Excepthandler, &Locator<'_>)> for Excepthandl fn from((excepthandler, locator): (rustpython_parser::ast::Excepthandler, &Locator)) -> Self { let rustpython_parser::ast::ExcepthandlerKind::ExceptHandler { type_, name, body } = excepthandler.node; - Excepthandler { - location: excepthandler.location, - end_location: excepthandler.end_location, - node: ExcepthandlerKind::ExceptHandler { - type_: type_.map(|type_| Box::new((*type_, locator).into())), - name, - body: body + + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + excepthandler.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body .into_iter() .map(|node| (node, locator).into()) .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + Excepthandler { + location: excepthandler.location, + end_location: body.end_location, + node: ExcepthandlerKind::ExceptHandler { + type_: type_.map(|type_| Box::new((*type_, locator).into())), + name, + body, }, trivia: vec![], parentheses: Parenthesize::Never, @@ -641,16 +677,32 @@ impl From<(rustpython_parser::ast::Pattern, &Locator<'_>)> for Pattern { impl From<(rustpython_parser::ast::MatchCase, &Locator<'_>)> for MatchCase { fn from((match_case, locator): (rustpython_parser::ast::MatchCase, &Locator)) -> Self { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + match_case.pattern.location, + match_case.body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: match_case + .body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + MatchCase { pattern: (match_case.pattern, locator).into(), guard: match_case .guard .map(|guard| Box::new((*guard, locator).into())), - body: match_case - .body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), + body, } } } @@ -707,27 +759,144 @@ impl From<(rustpython_parser::ast::Stmt, &Locator<'_>)> for Stmt { keywords, body, decorator_list, - } => Stmt { + } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + Stmt { + location: stmt.location, + end_location: body.end_location, + node: StmtKind::ClassDef { + name, + bases: bases + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + keywords: keywords + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + body, + decorator_list: decorator_list + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } + rustpython_parser::ast::StmtKind::If { test, body, orelse } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + if orelse.is_empty() { + // No `else` block. + Stmt { + location: stmt.location, + end_location: body.end_location, + node: StmtKind::If { + test: Box::new((*test, locator).into()), + body, + orelse: None, + is_elif: false, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } else { + if is_elif(&orelse, locator) { + // Find the start and end of the `elif`. + let mut elif: Body = (orelse, locator).into(); + if let StmtKind::If { is_elif, .. } = + &mut elif.node.first_mut().unwrap().node + { + *is_elif = true; + }; + + Stmt { + location: stmt.location, + end_location: elif.end_location, + node: StmtKind::If { + test: Box::new((*test, locator).into()), + body, + orelse: Some(elif), + is_elif: false, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } else { + // Find the start and end of the `else`. + let (orelse_location, orelse_end_location) = expand_indented_block( + body.end_location.unwrap(), + orelse.last().unwrap().end_location.unwrap(), + locator, + ); + let orelse = Body { + location: orelse_location, + end_location: Some(orelse_end_location), + node: orelse + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + }; + + Stmt { + location: stmt.location, + end_location: orelse.end_location, + node: StmtKind::If { + test: Box::new((*test, locator).into()), + body, + orelse: Some(orelse), + is_elif: false, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } + } + } + rustpython_parser::ast::StmtKind::Assert { test, msg } => Stmt { location: stmt.location, end_location: stmt.end_location, - node: StmtKind::ClassDef { - name, - bases: bases - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - keywords: keywords - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - decorator_list: decorator_list - .into_iter() - .map(|node| (node, locator).into()) - .collect(), + node: StmtKind::Assert { + test: Box::new((*test, locator).into()), + msg: msg.map(|node| Box::new((*node, locator).into())), }, trivia: vec![], parentheses: Parenthesize::Never, @@ -739,55 +908,44 @@ impl From<(rustpython_parser::ast::Stmt, &Locator<'_>)> for Stmt { decorator_list, returns, type_comment, - } => Stmt { - location: decorator_list - .first() - .map_or(stmt.location, |expr| expr.location), - end_location: stmt.end_location, - node: StmtKind::FunctionDef { - name, - args: Box::new((*args, locator).into()), - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - decorator_list: decorator_list - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - returns: returns.map(|r| Box::new((*r, locator).into())), - type_comment, - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, - rustpython_parser::ast::StmtKind::If { test, body, orelse } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::If { - test: Box::new((*test, locator).into()), - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - orelse: orelse - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, - rustpython_parser::ast::StmtKind::Assert { test, msg } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::Assert { - test: Box::new((*test, locator).into()), - msg: msg.map(|node| Box::new((*node, locator).into())), - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, + } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + Stmt { + location: decorator_list.first().map_or(stmt.location, |d| d.location), + end_location: body.end_location, + node: StmtKind::FunctionDef { + name, + args: Box::new((*args, locator).into()), + body, + decorator_list: decorator_list + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + returns: returns.map(|r| Box::new((*r, locator).into())), + type_comment, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } rustpython_parser::ast::StmtKind::AsyncFunctionDef { name, args, @@ -795,26 +953,46 @@ impl From<(rustpython_parser::ast::Stmt, &Locator<'_>)> for Stmt { decorator_list, returns, type_comment, - } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::AsyncFunctionDef { - name, - args: Box::new((*args, locator).into()), - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - decorator_list: decorator_list - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - returns: returns.map(|r| Box::new((*r, locator).into())), - type_comment, - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, + } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + Stmt { + location: decorator_list + .first() + .map_or(stmt.location, |expr| expr.location), + end_location: body.end_location, + node: StmtKind::AsyncFunctionDef { + name, + args: Box::new((*args, locator).into()), + body, + decorator_list: decorator_list + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + returns: returns.map(|r| Box::new((*r, locator).into())), + type_comment, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } rustpython_parser::ast::StmtKind::Delete { targets } => Stmt { location: stmt.location, end_location: stmt.end_location, @@ -861,109 +1039,247 @@ impl From<(rustpython_parser::ast::Stmt, &Locator<'_>)> for Stmt { body, orelse, type_comment, - } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::For { - target: Box::new((*target, locator).into()), - iter: Box::new((*iter, locator).into()), - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - orelse: orelse - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - type_comment, - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, + } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + // Find the start and end of the `orelse`. + let orelse = (!orelse.is_empty()).then(|| { + let (orelse_location, orelse_end_location) = expand_indented_block( + body.end_location.unwrap(), + orelse.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: orelse_location, + end_location: Some(orelse_end_location), + node: orelse + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }); + + Stmt { + location: stmt.location, + end_location: orelse.as_ref().unwrap_or(&body).end_location, + node: StmtKind::For { + target: Box::new((*target, locator).into()), + iter: Box::new((*iter, locator).into()), + body, + orelse, + type_comment, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } rustpython_parser::ast::StmtKind::AsyncFor { target, iter, body, orelse, type_comment, - } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::AsyncFor { - target: Box::new((*target, locator).into()), - iter: Box::new((*iter, locator).into()), - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - orelse: orelse - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - type_comment, - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, - rustpython_parser::ast::StmtKind::While { test, body, orelse } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::While { - test: Box::new((*test, locator).into()), - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - orelse: orelse - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, + } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + // Find the start and end of the `orelse`. + let orelse = (!orelse.is_empty()).then(|| { + let (orelse_location, orelse_end_location) = expand_indented_block( + body.end_location.unwrap(), + orelse.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: orelse_location, + end_location: Some(orelse_end_location), + node: orelse + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }); + + Stmt { + location: stmt.location, + end_location: orelse.as_ref().unwrap_or(&body).end_location, + node: StmtKind::AsyncFor { + target: Box::new((*target, locator).into()), + iter: Box::new((*iter, locator).into()), + body, + orelse, + type_comment, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } + rustpython_parser::ast::StmtKind::While { test, body, orelse } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + // Find the start and end of the `orelse`. + let orelse = (!orelse.is_empty()).then(|| { + let (orelse_location, orelse_end_location) = expand_indented_block( + body.end_location.unwrap(), + orelse.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: orelse_location, + end_location: Some(orelse_end_location), + node: orelse + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }); + + Stmt { + location: stmt.location, + end_location: orelse.as_ref().unwrap_or(&body).end_location, + node: StmtKind::While { + test: Box::new((*test, locator).into()), + body, + orelse, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } rustpython_parser::ast::StmtKind::With { items, body, type_comment, - } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::With { - items: items - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - type_comment, - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, + } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + Stmt { + location: stmt.location, + end_location: body.end_location, + node: StmtKind::With { + items: items + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + body, + type_comment, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } rustpython_parser::ast::StmtKind::AsyncWith { items, body, type_comment, - } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::AsyncWith { - items: items - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - type_comment, - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, + } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + Stmt { + location: stmt.location, + end_location: body.end_location, + node: StmtKind::AsyncWith { + items: items + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + body, + type_comment, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } rustpython_parser::ast::StmtKind::Match { subject, cases } => Stmt { location: stmt.location, end_location: stmt.end_location, @@ -992,59 +1308,209 @@ impl From<(rustpython_parser::ast::Stmt, &Locator<'_>)> for Stmt { handlers, orelse, finalbody, - } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::Try { - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - handlers: handlers - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - orelse: orelse - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - finalbody: finalbody - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, + } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + let handlers: Vec = handlers + .into_iter() + .map(|node| (node, locator).into()) + .collect(); + + // Find the start and end of the `orelse`. + let orelse = (!orelse.is_empty()).then(|| { + let (orelse_location, orelse_end_location) = expand_indented_block( + handlers + .last() + .map_or(body.end_location.unwrap(), |handler| { + handler.end_location.unwrap() + }), + orelse.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: orelse_location, + end_location: Some(orelse_end_location), + node: orelse + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }); + + // Find the start and end of the `finalbody`. + let finalbody = (!finalbody.is_empty()).then(|| { + let (finalbody_location, finalbody_end_location) = expand_indented_block( + orelse.as_ref().map_or( + handlers + .last() + .map_or(body.end_location.unwrap(), |handler| { + handler.end_location.unwrap() + }), + |orelse| orelse.end_location.unwrap(), + ), + finalbody.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: finalbody_location, + end_location: Some(finalbody_end_location), + node: finalbody + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }); + + let end_location = finalbody.as_ref().map_or( + orelse.as_ref().map_or( + handlers + .last() + .map_or(body.end_location.unwrap(), |handler| { + handler.end_location.unwrap() + }), + |orelse| orelse.end_location.unwrap(), + ), + |finalbody| finalbody.end_location.unwrap(), + ); + + Stmt { + location: stmt.location, + end_location: Some(end_location), + node: StmtKind::Try { + body, + handlers, + orelse, + finalbody, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } rustpython_parser::ast::StmtKind::TryStar { body, handlers, orelse, finalbody, - } => Stmt { - location: stmt.location, - end_location: stmt.end_location, - node: StmtKind::TryStar { - body: body - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - handlers: handlers - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - orelse: orelse - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - finalbody: finalbody - .into_iter() - .map(|node| (node, locator).into()) - .collect(), - }, - trivia: vec![], - parentheses: Parenthesize::Never, - }, + } => { + // Find the start and end of the `body`. + let body = { + let (body_location, body_end_location) = expand_indented_block( + stmt.location, + body.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: body_location, + end_location: Some(body_end_location), + node: body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }; + + let handlers: Vec = handlers + .into_iter() + .map(|node| (node, locator).into()) + .collect(); + + // Find the start and end of the `orelse`. + let orelse = (!orelse.is_empty()).then(|| { + let (orelse_location, orelse_end_location) = expand_indented_block( + handlers + .last() + .map_or(body.end_location.unwrap(), |handler| { + handler.end_location.unwrap() + }), + orelse.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: orelse_location, + end_location: Some(orelse_end_location), + node: orelse + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }); + + // Find the start and end of the `finalbody`. + let finalbody = (!finalbody.is_empty()).then(|| { + let (finalbody_location, finalbody_end_location) = expand_indented_block( + orelse.as_ref().map_or( + handlers + .last() + .map_or(body.end_location.unwrap(), |handler| { + handler.end_location.unwrap() + }), + |orelse| orelse.end_location.unwrap(), + ), + finalbody.last().unwrap().end_location.unwrap(), + locator, + ); + Body { + location: finalbody_location, + end_location: Some(finalbody_end_location), + node: finalbody + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + trivia: vec![], + parentheses: Parenthesize::Never, + } + }); + + let end_location = finalbody.as_ref().map_or( + orelse.as_ref().map_or( + handlers + .last() + .map_or(body.end_location.unwrap(), |handler| { + handler.end_location.unwrap() + }), + |orelse| orelse.end_location.unwrap(), + ), + |finalbody| finalbody.end_location.unwrap(), + ); + + Stmt { + location: stmt.location, + end_location: Some(end_location), + node: StmtKind::TryStar { + body, + handlers, + orelse, + finalbody, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } rustpython_parser::ast::StmtKind::Import { names } => Stmt { location: stmt.location, end_location: stmt.end_location, diff --git a/crates/ruff_python_formatter/src/format/builders.rs b/crates/ruff_python_formatter/src/format/builders.rs index 17d3502c54..166459c12f 100644 --- a/crates/ruff_python_formatter/src/format/builders.rs +++ b/crates/ruff_python_formatter/src/format/builders.rs @@ -4,17 +4,55 @@ use ruff_text_size::{TextRange, TextSize}; use crate::context::ASTFormatContext; use crate::core::types::Range; -use crate::cst::Stmt; +use crate::cst::{Body, Stmt}; use crate::shared_traits::AsFormat; +use crate::trivia::{Relationship, TriviaKind}; #[derive(Copy, Clone)] pub struct Block<'a> { - body: &'a [Stmt], + body: &'a Body, } impl Format> for Block<'_> { fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { - for (i, stmt) in self.body.iter().enumerate() { + for (i, stmt) in self.body.node.iter().enumerate() { + if i > 0 { + write!(f, [hard_line_break()])?; + } + write!(f, [stmt.format()])?; + } + + for trivia in &self.body.trivia { + if matches!(trivia.relationship, Relationship::Dangling) { + match trivia.kind { + TriviaKind::EmptyLine => { + write!(f, [empty_line()])?; + } + TriviaKind::OwnLineComment(range) => { + write!(f, [literal(range), hard_line_break()])?; + } + _ => {} + } + } + } + + Ok(()) + } +} + +#[inline] +pub fn block(body: &Body) -> Block { + Block { body } +} + +#[derive(Copy, Clone)] +pub struct Statements<'a> { + suite: &'a [Stmt], +} + +impl Format> for Statements<'_> { + fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { + for (i, stmt) in self.suite.iter().enumerate() { if i > 0 { write!(f, [hard_line_break()])?; } @@ -24,9 +62,8 @@ impl Format> for Block<'_> { } } -#[inline] -pub fn block(body: &[Stmt]) -> Block { - Block { body } +pub fn statements(suite: &[Stmt]) -> Statements { + Statements { suite } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] diff --git a/crates/ruff_python_formatter/src/format/comments.rs b/crates/ruff_python_formatter/src/format/comments.rs index e43ad651c4..9d7b3cc42b 100644 --- a/crates/ruff_python_formatter/src/format/comments.rs +++ b/crates/ruff_python_formatter/src/format/comments.rs @@ -72,13 +72,12 @@ pub struct EndOfLineComments<'a, T> { impl Format> for EndOfLineComments<'_, T> { fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { let mut first = true; - for range in self.item.trivia.iter().filter_map(|trivia| { - if trivia.relationship.is_trailing() { - trivia.kind.end_of_line_comment() - } else { - None - } - }) { + for range in self + .item + .trivia + .iter() + .filter_map(|trivia| trivia.kind.end_of_line_comment()) + { if std::mem::take(&mut first) { write!(f, [line_suffix(&text(" "))])?; } diff --git a/crates/ruff_python_formatter/src/format/match_case.rs b/crates/ruff_python_formatter/src/format/match_case.rs index 98ccd968d8..6161e0f764 100644 --- a/crates/ruff_python_formatter/src/format/match_case.rs +++ b/crates/ruff_python_formatter/src/format/match_case.rs @@ -4,6 +4,7 @@ use ruff_formatter::write; use crate::context::ASTFormatContext; use crate::cst::MatchCase; use crate::format::builders::block; +use crate::format::comments::{end_of_line_comments, leading_comments}; use crate::shared_traits::AsFormat; pub struct FormatMatchCase<'a> { @@ -26,12 +27,16 @@ impl Format> for FormatMatchCase<'_> { body, } = self.item; + write!(f, [leading_comments(pattern)])?; + write!(f, [text("case")])?; write!(f, [space(), pattern.format()])?; if let Some(guard) = &guard { write!(f, [space(), text("if"), space(), guard.format()])?; } write!(f, [text(":")])?; + + write!(f, [end_of_line_comments(body)])?; write!(f, [block_indent(&block(body))])?; Ok(()) diff --git a/crates/ruff_python_formatter/src/format/stmt.rs b/crates/ruff_python_formatter/src/format/stmt.rs index 0b93f386b3..85babf5846 100644 --- a/crates/ruff_python_formatter/src/format/stmt.rs +++ b/crates/ruff_python_formatter/src/format/stmt.rs @@ -6,8 +6,8 @@ use ruff_text_size::TextSize; use crate::context::ASTFormatContext; use crate::cst::{ - Alias, Arguments, Excepthandler, Expr, ExprKind, Keyword, MatchCase, Operator, Stmt, StmtKind, - Withitem, + Alias, Arguments, Body, Excepthandler, Expr, ExprKind, Keyword, MatchCase, Operator, Stmt, + StmtKind, Withitem, }; use crate::format::builders::{block, join_names}; use crate::format::comments::{end_of_line_comments, leading_comments, trailing_comments}; @@ -101,13 +101,15 @@ fn format_class_def( name: &str, bases: &[Expr], keywords: &[Keyword], - body: &[Stmt], + body: &Body, decorator_list: &[Expr], ) -> FormatResult<()> { for decorator in decorator_list { write!(f, [text("@"), decorator.format(), hard_line_break()])?; } + write!(f, [leading_comments(body)])?; + write!( f, [ @@ -161,6 +163,7 @@ fn format_class_def( )?; } + write!(f, [end_of_line_comments(body)])?; write!(f, [text(":"), block_indent(&block(body))]) } @@ -170,13 +173,16 @@ fn format_func_def( name: &str, args: &Arguments, returns: Option<&Expr>, - body: &[Stmt], + body: &Body, decorator_list: &[Expr], async_: bool, ) -> FormatResult<()> { for decorator in decorator_list { write!(f, [text("@"), decorator.format(), hard_line_break()])?; } + + write!(f, [leading_comments(body)])?; + if async_ { write!(f, [text("async"), space()])?; } @@ -202,10 +208,10 @@ fn format_func_def( } write!(f, [text(":")])?; + write!(f, [end_of_line_comments(body)])?; + write!(f, [block_indent(&block(body))])?; - write!(f, [end_of_line_comments(stmt)])?; - - write!(f, [block_indent(&format_args![block(body)])]) + Ok(()) } fn format_assign( @@ -310,8 +316,8 @@ fn format_for( stmt: &Stmt, target: &Expr, iter: &Expr, - body: &[Stmt], - orelse: &[Stmt], + body: &Body, + orelse: Option<&Body>, _type_comment: Option<&str>, async_: bool, ) -> FormatResult<()> { @@ -329,11 +335,19 @@ fn format_for( space(), group(&iter.format()), text(":"), + end_of_line_comments(body), block_indent(&block(body)) ] )?; - if !orelse.is_empty() { - write!(f, [text("else:"), block_indent(&block(orelse))])?; + if let Some(orelse) = orelse { + write!( + f, + [ + text("else:"), + end_of_line_comments(orelse), + block_indent(&block(orelse)) + ] + )?; } Ok(()) } @@ -342,8 +356,8 @@ fn format_while( f: &mut Formatter>, stmt: &Stmt, test: &Expr, - body: &[Stmt], - orelse: &[Stmt], + body: &Body, + orelse: Option<&Body>, ) -> FormatResult<()> { write!(f, [text("while"), space()])?; if is_self_closing(test) { @@ -358,9 +372,23 @@ fn format_while( ])] )?; } - write!(f, [text(":"), block_indent(&block(body))])?; - if !orelse.is_empty() { - write!(f, [text("else:"), block_indent(&block(orelse))])?; + write!( + f, + [ + text(":"), + end_of_line_comments(body), + block_indent(&block(body)) + ] + )?; + if let Some(orelse) = orelse { + write!( + f, + [ + text("else:"), + end_of_line_comments(orelse), + block_indent(&block(orelse)) + ] + )?; } Ok(()) } @@ -368,10 +396,15 @@ fn format_while( fn format_if( f: &mut Formatter>, test: &Expr, - body: &[Stmt], - orelse: &[Stmt], + body: &Body, + orelse: Option<&Body>, + is_elif: bool, ) -> FormatResult<()> { - write!(f, [text("if"), space()])?; + if is_elif { + write!(f, [text("elif"), space()])?; + } else { + write!(f, [text("if"), space()])?; + } if is_self_closing(test) { write!(f, [test.format()])?; } else { @@ -384,17 +417,43 @@ fn format_if( ])] )?; } - write!(f, [text(":"), block_indent(&block(body))])?; - if !orelse.is_empty() { - if orelse.len() == 1 { - if let StmtKind::If { test, body, orelse } = &orelse[0].node { - write!(f, [text("el")])?; - format_if(f, test, body, orelse)?; + write!( + f, + [ + text(":"), + end_of_line_comments(body), + block_indent(&block(body)) + ] + )?; + if let Some(orelse) = orelse { + if orelse.node.len() == 1 { + if let StmtKind::If { + test, + body, + orelse, + is_elif: true, + } = &orelse.node[0].node + { + format_if(f, test, body, orelse.as_ref(), true)?; } else { - write!(f, [text("else:"), block_indent(&block(orelse))])?; + write!( + f, + [ + text("else:"), + end_of_line_comments(orelse), + block_indent(&block(orelse)) + ] + )?; } } else { - write!(f, [text("else:"), block_indent(&block(orelse))])?; + write!( + f, + [ + text("else:"), + end_of_line_comments(orelse), + block_indent(&block(orelse)) + ] + )?; } } Ok(()) @@ -406,7 +465,16 @@ fn format_match( subject: &Expr, cases: &[MatchCase], ) -> FormatResult<()> { - write!(f, [text("match"), space(), subject.format(), text(":")])?; + write!( + f, + [ + text("match"), + space(), + subject.format(), + text(":"), + end_of_line_comments(stmt), + ] + )?; for case in cases { write!(f, [block_indent(&case.format())])?; } @@ -447,20 +515,31 @@ fn format_return( fn format_try( f: &mut Formatter>, stmt: &Stmt, - body: &[Stmt], + body: &Body, handlers: &[Excepthandler], - orelse: &[Stmt], - finalbody: &[Stmt], + orelse: Option<&Body>, + finalbody: Option<&Body>, ) -> FormatResult<()> { - write!(f, [text("try:"), block_indent(&block(body))])?; + write!( + f, + [ + text("try:"), + end_of_line_comments(body), + block_indent(&block(body)) + ] + )?; for handler in handlers { write!(f, [handler.format()])?; } - if !orelse.is_empty() { - write!(f, [text("else:"), block_indent(&block(orelse))])?; + if let Some(orelse) = orelse { + write!(f, [text("else:")])?; + write!(f, [end_of_line_comments(orelse)])?; + write!(f, [block_indent(&block(orelse))])?; } - if !finalbody.is_empty() { - write!(f, [text("finally:"), block_indent(&block(finalbody))])?; + if let Some(finalbody) = finalbody { + write!(f, [text("finally:")])?; + write!(f, [end_of_line_comments(finalbody)])?; + write!(f, [block_indent(&block(finalbody))])?; } Ok(()) } @@ -468,21 +547,42 @@ fn format_try( fn format_try_star( f: &mut Formatter>, stmt: &Stmt, - body: &[Stmt], + body: &Body, handlers: &[Excepthandler], - orelse: &[Stmt], - finalbody: &[Stmt], + orelse: Option<&Body>, + finalbody: Option<&Body>, ) -> FormatResult<()> { - write!(f, [text("try:"), block_indent(&block(body))])?; + write!( + f, + [ + text("try:"), + end_of_line_comments(body), + block_indent(&block(body)) + ] + )?; for handler in handlers { // TODO(charlie): Include `except*`. write!(f, [handler.format()])?; } - if !orelse.is_empty() { - write!(f, [text("else:"), block_indent(&block(orelse))])?; + if let Some(orelse) = orelse { + write!( + f, + [ + text("else:"), + end_of_line_comments(orelse), + block_indent(&block(orelse)) + ] + )?; } - if !finalbody.is_empty() { - write!(f, [text("finally:"), block_indent(&block(finalbody))])?; + if let Some(finalbody) = finalbody { + write!( + f, + [ + text("finally:"), + end_of_line_comments(finalbody), + block_indent(&block(finalbody)) + ] + )?; } Ok(()) } @@ -640,7 +740,7 @@ fn format_with_( f: &mut Formatter>, stmt: &Stmt, items: &[Withitem], - body: &[Stmt], + body: &Body, type_comment: Option<&str>, async_: bool, ) -> FormatResult<()> { @@ -668,6 +768,7 @@ fn format_with_( if_group_breaks(&text(")")), ]), text(":"), + end_of_line_comments(body), block_indent(&block(body)) ] )?; @@ -753,7 +854,7 @@ impl Format> for FormatStmt<'_> { target, iter, body, - orelse, + orelse.as_ref(), type_comment.as_deref(), false, ), @@ -769,14 +870,19 @@ impl Format> for FormatStmt<'_> { target, iter, body, - orelse, + orelse.as_ref(), type_comment.as_deref(), true, ), StmtKind::While { test, body, orelse } => { - format_while(f, self.item, test, body, orelse) + format_while(f, self.item, test, body, orelse.as_ref()) } - StmtKind::If { test, body, orelse } => format_if(f, test, body, orelse), + StmtKind::If { + test, + body, + orelse, + is_elif, + } => format_if(f, test, body, orelse.as_ref(), *is_elif), StmtKind::With { items, body, @@ -810,13 +916,27 @@ impl Format> for FormatStmt<'_> { handlers, orelse, finalbody, - } => format_try(f, self.item, body, handlers, orelse, finalbody), + } => format_try( + f, + self.item, + body, + handlers, + orelse.as_ref(), + finalbody.as_ref(), + ), StmtKind::TryStar { body, handlers, orelse, finalbody, - } => format_try_star(f, self.item, body, handlers, orelse, finalbody), + } => format_try_star( + f, + self.item, + body, + handlers, + orelse.as_ref(), + finalbody.as_ref(), + ), StmtKind::Assert { test, msg } => { format_assert(f, self.item, test, msg.as_ref().map(|expr| &**expr)) } diff --git a/crates/ruff_python_formatter/src/lib.rs b/crates/ruff_python_formatter/src/lib.rs index 04a64a9b02..6d8f6839b8 100644 --- a/crates/ruff_python_formatter/src/lib.rs +++ b/crates/ruff_python_formatter/src/lib.rs @@ -53,7 +53,7 @@ pub fn fmt(contents: &str) -> Result> { }, locator, ), - [format::builders::block(&python_cst)] + [format::builders::statements(&python_cst)] ) .map_err(Into::into) } diff --git a/crates/ruff_python_formatter/src/newlines.rs b/crates/ruff_python_formatter/src/newlines.rs index ae711b6601..6baf69f214 100644 --- a/crates/ruff_python_formatter/src/newlines.rs +++ b/crates/ruff_python_formatter/src/newlines.rs @@ -163,15 +163,14 @@ impl<'a> Visitor<'a> for StmtNormalizer { self.trailer = Trailer::CompoundStatement; self.visit_body(body); - if !orelse.is_empty() { + if let Some(orelse) = orelse { // If the previous body ended with a function or class definition, we need to // insert an empty line before the else block. Since the `else` itself isn't // a statement, we need to insert it into the last statement of the body. if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { - let stmt = body.last_mut().unwrap(); - stmt.trivia.push(Trivia { + body.trivia.push(Trivia { kind: TriviaKind::EmptyLine, - relationship: Relationship::Trailing, + relationship: Relationship::Dangling, }); } @@ -185,12 +184,11 @@ impl<'a> Visitor<'a> for StmtNormalizer { self.trailer = Trailer::CompoundStatement; self.visit_body(body); - if !orelse.is_empty() { + if let Some(orelse) = orelse { if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { - let stmt = body.last_mut().unwrap(); - stmt.trivia.push(Trivia { + body.trivia.push(Trivia { kind: TriviaKind::EmptyLine, - relationship: Relationship::Trailing, + relationship: Relationship::Dangling, }); } @@ -220,49 +218,44 @@ impl<'a> Visitor<'a> for StmtNormalizer { self.depth = Depth::Nested; self.trailer = Trailer::CompoundStatement; self.visit_body(body); - let mut last = body.last_mut(); + + let mut prev = &mut body.trivia; for handler in handlers { if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { - if let Some(stmt) = last.as_mut() { - stmt.trivia.push(Trivia { - kind: TriviaKind::EmptyLine, - relationship: Relationship::Trailing, - }); - } + prev.push(Trivia { + kind: TriviaKind::EmptyLine, + relationship: Relationship::Dangling, + }); } self.depth = Depth::Nested; self.trailer = Trailer::CompoundStatement; let ExcepthandlerKind::ExceptHandler { body, .. } = &mut handler.node; self.visit_body(body); - last = body.last_mut(); + prev = &mut body.trivia; } - if !orelse.is_empty() { + if let Some(orelse) = orelse { if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { - if let Some(stmt) = last.as_mut() { - stmt.trivia.push(Trivia { - kind: TriviaKind::EmptyLine, - relationship: Relationship::Trailing, - }); - } + prev.push(Trivia { + kind: TriviaKind::EmptyLine, + relationship: Relationship::Dangling, + }); } self.depth = Depth::Nested; self.trailer = Trailer::CompoundStatement; self.visit_body(orelse); - last = body.last_mut(); + prev = &mut body.trivia; } - if !finalbody.is_empty() { + if let Some(finalbody) = finalbody { if matches!(self.trailer, Trailer::ClassDef | Trailer::FunctionDef) { - if let Some(stmt) = last.as_mut() { - stmt.trivia.push(Trivia { - kind: TriviaKind::EmptyLine, - relationship: Relationship::Trailing, - }); - } + prev.push(Trivia { + kind: TriviaKind::EmptyLine, + relationship: Relationship::Dangling, + }); } self.depth = Depth::Nested; diff --git a/crates/ruff_python_formatter/src/parentheses.rs b/crates/ruff_python_formatter/src/parentheses.rs index 52d04f013f..751d83dea4 100644 --- a/crates/ruff_python_formatter/src/parentheses.rs +++ b/crates/ruff_python_formatter/src/parentheses.rs @@ -194,5 +194,7 @@ impl<'a> Visitor<'a> for ParenthesesNormalizer<'_> { /// during formatting) and `Parenthesize` (which are used during formatting). pub fn normalize_parentheses(python_cst: &mut [Stmt], locator: &Locator) { let mut normalizer = ParenthesesNormalizer { locator }; - normalizer.visit_body(python_cst); + for stmt in python_cst { + normalizer.visit_stmt(stmt); + } } diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments2_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments2_py.snap index a0f5f2619b..2a33c22477 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments2_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments2_py.snap @@ -196,15 +196,7 @@ instruction()#comment with bad spacing "Generator", ] -@@ -54,32 +54,39 @@ - # for compiler in compilers.values(): - # add_compiler(compiler) - add_compiler(compilers[(7.0, 32)]) -- # add_compiler(compilers[(7.1, 64)]) - - -+# add_compiler(compilers[(7.1, 64)]) -+ +@@ -60,26 +60,32 @@ # Comment before function. def inline_comments_in_brackets_ruin_everything(): if typedargslist: @@ -229,7 +221,7 @@ instruction()#comment with bad spacing + parameters.what_if_this_was_actually_long.children[0], + body, + parameters.children[-1], -+ ] ++ ] # type: ignore if ( self._proc is not None - # has the child process finished? @@ -246,7 +238,7 @@ instruction()#comment with bad spacing ): pass # no newline before or after -@@ -103,42 +110,42 @@ +@@ -103,35 +109,35 @@ ############################################################################ call2( @@ -298,16 +290,7 @@ instruction()#comment with bad spacing ] while True: if False: - continue - -- # and round and round we go -- # and round and round we go -+ # and round and round we go -+ # and round and round we go - - # let's return - return Node( -@@ -167,7 +174,7 @@ +@@ -167,7 +173,7 @@ ####################### @@ -377,10 +360,9 @@ else: # for compiler in compilers.values(): # add_compiler(compiler) add_compiler(compilers[(7.0, 32)]) + # add_compiler(compilers[(7.1, 64)]) -# add_compiler(compilers[(7.1, 64)]) - # Comment before function. def inline_comments_in_brackets_ruin_everything(): if typedargslist: @@ -400,7 +382,7 @@ def inline_comments_in_brackets_ruin_everything(): parameters.what_if_this_was_actually_long.children[0], body, parameters.children[-1], - ] + ] # type: ignore if ( self._proc is not None and # has the child process finished? @@ -467,8 +449,8 @@ short if False: continue - # and round and round we go - # and round and round we go + # and round and round we go + # and round and round we go # let's return return Node( diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments6_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments6_py.snap index e25a53fb03..f735782266 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments6_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments6_py.snap @@ -131,17 +131,15 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite ```diff --- Black +++ Ruff -@@ -2,8 +2,8 @@ +@@ -2,7 +2,7 @@ def f( - a, # type: int --): + a, -+): # type: int + ): pass - @@ -14,44 +14,42 @@ @@ -155,7 +153,6 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite - g, # type: int - h, # type: int - i, # type: int --): + a, + b, + c, @@ -165,7 +162,7 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite + g, + h, + i, -+): # type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int + ): # type: (...) -> None pass @@ -175,12 +172,11 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite - *args, # type: *Any - default=False, # type: bool - **kwargs, # type: **Any --): + arg, + *args, + default=False, + **kwargs, -+): # type: int# type: *Any + ): # type: (...) -> None pass @@ -190,12 +186,11 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite - b, # type: int - c, # type: int - d, # type: int --): + a, + b, + c, + d, -+): # type: int# type: int# type: int# type: int# type: int + ): # type: (...) -> None element = 0 # type: int @@ -208,41 +203,32 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite an_element_with_a_long_value = calls() or more_calls() and more() # type: bool tup = ( -@@ -66,26 +64,26 @@ - + element - + another_element - + another_element_with_long_name -- ) # type: int -+ ) +@@ -70,21 +68,21 @@ def f( - x, # not a type comment - y, # type: int --): + x, + y, -+): # not a type comment# type: int + ): # type: (...) -> None pass def f( - x, # not a type comment --): # type: (int) -> None + x, -+): # not a type comment# type: (int) -> None + ): # type: (int) -> None pass def func( - a=some_list[0], # type: int --): # type: () -> int + a=some_list[0], -+): + ): # type: () -> int c = call( 0.0123, - 0.0456, @@ -96,23 +94,37 @@ 0.0123, 0.0456, @@ -298,7 +284,7 @@ from typing import Any, Tuple def f( a, -): # type: int +): pass @@ -318,7 +304,7 @@ def f( g, h, i, -): # type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int# type: int +): # type: (...) -> None pass @@ -328,7 +314,7 @@ def f( *args, default=False, **kwargs, -): # type: int# type: *Any +): # type: (...) -> None pass @@ -338,7 +324,7 @@ def f( b, c, d, -): # type: int# type: int# type: int# type: int# type: int +): # type: (...) -> None element = 0 # type: int @@ -359,26 +345,26 @@ def f( + element + another_element + another_element_with_long_name - ) + ) # type: int def f( x, y, -): # not a type comment# type: int +): # type: (...) -> None pass def f( x, -): # not a type comment# type: (int) -> None +): # type: (int) -> None pass def func( a=some_list[0], -): +): # type: () -> int c = call( 0.0123, 0.0456, diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments9_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments9_py.snap index 5572327616..53ba320c73 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments9_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments9_py.snap @@ -164,7 +164,7 @@ def bar(): # This should be split from the above by two lines class MyClassWithComplexLeadingComments: pass -@@ -57,13 +58,13 @@ +@@ -57,11 +58,11 @@ # leading 1 @deco1 @@ -174,16 +174,13 @@ def bar(): -@deco2(with_args=True) -# leading 3 -@deco3 --# leading 4 +deco2(with_args=True) +@# leading 3 +deco3 + # leading 4 def decorated(): -+ # leading 4 pass - - -@@ -72,13 +73,12 @@ +@@ -72,11 +73,10 @@ # leading 1 @deco1 @@ -192,17 +189,14 @@ def bar(): - -# leading 3 that already has an empty line -@deco3 --# leading 4 +@# leading 2 +deco2(with_args=True) +@# leading 3 that already has an empty line +deco3 + # leading 4 def decorated_with_split_leading_comments(): -+ # leading 4 pass - - -@@ -87,18 +87,18 @@ +@@ -87,10 +87,10 @@ # leading 1 @deco1 @@ -210,16 +204,14 @@ def bar(): -@deco2(with_args=True) -# leading 3 -@deco3 -- --# leading 4 that already has an empty line +@# leading 2 +deco2(with_args=True) +@# leading 3 +deco3 - def decorated_with_split_leading_comments(): -+ # leading 4 that already has an empty line - pass + # leading 4 that already has an empty line + def decorated_with_split_leading_comments(): +@@ -99,6 +99,7 @@ def main(): if a: @@ -227,7 +219,7 @@ def bar(): # Leading comment before inline function def inline(): pass -@@ -108,12 +108,14 @@ +@@ -108,12 +109,14 @@ pass else: @@ -242,7 +234,7 @@ def bar(): # Leading comment before "top-level inline" function def top_level_quote_inline(): pass -@@ -123,6 +125,7 @@ +@@ -123,6 +126,7 @@ pass else: @@ -250,37 +242,6 @@ def bar(): # More leading comments def top_level_quote_inline_after_else(): pass -@@ -138,9 +141,11 @@ - # Regression test for https://github.com/psf/black/issues/3454. - def foo(): - pass -- # Trailing comment that belongs to this function - - -+# Trailing comment that belongs to this function -+ -+ - @decorator1 - @decorator2 # fmt: skip - def bar(): -@@ -150,12 +155,13 @@ - # Regression test for https://github.com/psf/black/issues/3454. - def foo(): - pass -- # Trailing comment that belongs to this function. -- # NOTE this comment only has one empty line below, and the formatter -- # should enforce two blank lines. - - -+# Trailing comment that belongs to this function. -+# NOTE this comment only has one empty line below, and the formatter -+# should enforce two blank lines. -+ - @decorator1 --# A standalone comment - def bar(): -+ # A standalone comment - pass ``` ## Ruff Output @@ -351,8 +312,8 @@ some = statement deco2(with_args=True) @# leading 3 deco3 +# leading 4 def decorated(): - # leading 4 pass @@ -365,8 +326,8 @@ some = statement deco2(with_args=True) @# leading 3 that already has an empty line deco3 +# leading 4 def decorated_with_split_leading_comments(): - # leading 4 pass @@ -379,8 +340,9 @@ some = statement deco2(with_args=True) @# leading 3 deco3 + +# leading 4 that already has an empty line def decorated_with_split_leading_comments(): - # leading 4 that already has an empty line pass @@ -429,9 +391,7 @@ class MyClass: # Regression test for https://github.com/psf/black/issues/3454. def foo(): pass - - -# Trailing comment that belongs to this function + # Trailing comment that belongs to this function @decorator1 @@ -443,15 +403,14 @@ def bar(): # Regression test for https://github.com/psf/black/issues/3454. def foo(): pass + # Trailing comment that belongs to this function. + # NOTE this comment only has one empty line below, and the formatter + # should enforce two blank lines. -# Trailing comment that belongs to this function. -# NOTE this comment only has one empty line below, and the formatter -# should enforce two blank lines. - @decorator1 +# A standalone comment def bar(): - # A standalone comment pass ``` diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__composition_no_trailing_comma_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__composition_no_trailing_comma_py.snap index ef60ad17b4..d7a8674a2a 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__composition_no_trailing_comma_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__composition_no_trailing_comma_py.snap @@ -204,7 +204,7 @@ class C: ) self.assertEqual( unstyle(str(report)), -@@ -22,133 +23,156 @@ +@@ -22,133 +23,155 @@ if ( # Rule 1 i % 2 == 0 @@ -217,10 +217,10 @@ class C: - while ( - # Just a comment - call() +- # Another +- ): + while # Just a comment + call(): - # Another -- ): print(i) xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy( push_manager=context.request.resource_manager, @@ -460,7 +460,7 @@ class C: "Not what we expected and the message is too long to fit in one line" " because it's too long" ) -@@ -161,9 +185,8 @@ +@@ -161,9 +184,8 @@ 8 STORE_ATTR 0 (x) 10 LOAD_CONST 0 (None) 12 RETURN_VALUE @@ -508,7 +508,6 @@ class C: ): while # Just a comment call(): - # Another print(i) xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy( push_manager=context.request.resource_manager, diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__composition_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__composition_py.snap index 8841f50a16..774829451c 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__composition_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__composition_py.snap @@ -204,7 +204,7 @@ class C: ) self.assertEqual( unstyle(str(report)), -@@ -22,133 +23,156 @@ +@@ -22,133 +23,155 @@ if ( # Rule 1 i % 2 == 0 @@ -217,10 +217,10 @@ class C: - while ( - # Just a comment - call() +- # Another +- ): + while # Just a comment + call(): - # Another -- ): print(i) xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy( push_manager=context.request.resource_manager, @@ -460,7 +460,7 @@ class C: "Not what we expected and the message is too long to fit in one line" " because it's too long" ) -@@ -161,9 +185,8 @@ +@@ -161,9 +184,8 @@ 8 STORE_ATTR 0 (x) 10 LOAD_CONST 0 (None) 12 RETURN_VALUE @@ -508,7 +508,6 @@ class C: ): while # Just a comment call(): - # Another print(i) xxxxxxxxxxxxxxxx = Yyyy2YyyyyYyyyyy( push_manager=context.request.resource_manager, diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtonoff4_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtonoff4_py.snap index 50daa356bc..68ffcf8c4e 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtonoff4_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtonoff4_py.snap @@ -26,13 +26,12 @@ def f(): pass ```diff --- Black +++ Ruff -@@ -1,10 +1,14 @@ +@@ -1,8 +1,12 @@ # fmt: off -@test([ - 1, 2, - 3, 4, -]) --# fmt: on +@test( + [ + 1, @@ -41,11 +40,9 @@ def f(): pass + 4, + ] +) + # fmt: on def f(): -+ # fmt: on pass - - ``` ## Ruff Output @@ -60,8 +57,8 @@ def f(): pass 4, ] ) +# fmt: on def f(): - # fmt: on pass diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtonoff5_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtonoff5_py.snap index 171a290ea4..ac06179de6 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtonoff5_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtonoff5_py.snap @@ -165,7 +165,7 @@ elif unformatted: print("This will be formatted") -@@ -68,20 +62,21 @@ +@@ -68,20 +62,19 @@ class Named(t.Protocol): # fmt: off @property @@ -177,11 +177,9 @@ elif unformatted: class Factory(t.Protocol): def this_will_be_formatted(self, **kwargs) -> Named: ... +- + # fmt: on -- # fmt: on - -+# fmt: on -+ # Regression test for https://github.com/psf/black/issues/3436. if x: @@ -267,9 +265,7 @@ class Named(t.Protocol): class Factory(t.Protocol): def this_will_be_formatted(self, **kwargs) -> Named: ... - - -# fmt: on + # fmt: on # Regression test for https://github.com/psf/black/issues/3436. diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtskip6_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtskip6_py.snap deleted file mode 100644 index 46b8920531..0000000000 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__fmtskip6_py.snap +++ /dev/null @@ -1,49 +0,0 @@ ---- -source: crates/ruff_python_formatter/src/lib.rs -expression: snapshot -input_file: crates/ruff_python_formatter/resources/test/fixtures/black/simple_cases/fmtskip6.py ---- -## Input - -```py -class A: - def f(self): - for line in range(10): - if True: - pass # fmt: skip -``` - -## Black Differences - -```diff ---- Black -+++ Ruff -@@ -2,4 +2,4 @@ - def f(self): - for line in range(10): - if True: -- pass # fmt: skip -+ pass -``` - -## Ruff Output - -```py -class A: - def f(self): - for line in range(10): - if True: - pass -``` - -## Black Output - -```py -class A: - def f(self): - for line in range(10): - if True: - pass # fmt: skip -``` - - diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__function_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__function_py.snap index 376701cf3f..aacde3d87b 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__function_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__function_py.snap @@ -117,12 +117,13 @@ def __await__(): return (yield) def func_no_args(): -@@ -64,19 +64,14 @@ +@@ -64,19 +64,15 @@ def spaces2(result=_core.Value(None)): - assert fut is self._read_fut, (fut, self._read_fut) + assert fut is self._read_fut, fut, self._read_fut ++ def example(session): @@ -142,7 +143,7 @@ def __await__(): return (yield) def long_lines(): -@@ -135,14 +130,13 @@ +@@ -135,14 +131,13 @@ a, **kwargs, ) -> A: @@ -233,6 +234,7 @@ def spaces2(result=_core.Value(None)): assert fut is self._read_fut, fut, self._read_fut + def example(session): result = session.query(models.Customer.id).filter( models.Customer.account_id == account_id, diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__remove_await_parens_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__remove_await_parens_py.snap index e074631133..644787aaab 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__remove_await_parens_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__remove_await_parens_py.snap @@ -123,10 +123,9 @@ async def main(): + await (asyncio.sleep(1)) # Hello --async def main(): + async def main(): - await asyncio.sleep(1) # Hello -+async def main(): # Hello -+ await (asyncio.sleep(1)) ++ await (asyncio.sleep(1)) # Hello # Long lines @@ -231,8 +230,8 @@ async def main(): await (asyncio.sleep(1)) # Hello -async def main(): # Hello - await (asyncio.sleep(1)) +async def main(): + await (asyncio.sleep(1)) # Hello # Long lines diff --git a/crates/ruff_python_formatter/src/trivia.rs b/crates/ruff_python_formatter/src/trivia.rs index 4d21f81abf..a6ebe292f1 100644 --- a/crates/ruff_python_formatter/src/trivia.rs +++ b/crates/ruff_python_formatter/src/trivia.rs @@ -5,13 +5,14 @@ use rustpython_parser::Tok; use crate::core::types::Range; use crate::cst::{ - Alias, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Pattern, PatternKind, SliceIndex, - SliceIndexKind, Stmt, StmtKind, + Alias, Body, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Pattern, PatternKind, + SliceIndex, SliceIndexKind, Stmt, StmtKind, }; #[derive(Clone, Debug)] pub enum Node<'a> { Mod(&'a [Stmt]), + Body(&'a Body), Stmt(&'a Stmt), Expr(&'a Expr), Alias(&'a Alias), @@ -24,6 +25,7 @@ impl Node<'_> { pub fn id(&self) -> usize { match self { Node::Mod(nodes) => nodes as *const _ as usize, + Node::Body(node) => node.id(), Node::Stmt(node) => node.id(), Node::Expr(node) => node.id(), Node::Alias(node) => node.id(), @@ -227,6 +229,11 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { result.push(Node::Stmt(stmt)); } } + Node::Body(body) => { + for stmt in &body.node { + result.push(Node::Stmt(stmt)); + } + } Node::Stmt(stmt) => match &stmt.node { StmtKind::Return { value } => { if let Some(value) = value { @@ -294,9 +301,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { if let Some(returns) = returns { result.push(Node::Expr(returns)); } - for stmt in body { - result.push(Node::Stmt(stmt)); - } + result.push(Node::Body(body)); } StmtKind::ClassDef { bases, @@ -314,9 +319,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { for keyword in keywords { result.push(Node::Expr(&keyword.node.value)); } - for stmt in body { - result.push(Node::Stmt(stmt)); - } + result.push(Node::Body(body)); } StmtKind::Delete { targets } => { for target in targets { @@ -355,29 +358,25 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { } => { result.push(Node::Expr(target)); result.push(Node::Expr(iter)); - for stmt in body { - result.push(Node::Stmt(stmt)); - } - for stmt in orelse { - result.push(Node::Stmt(stmt)); + result.push(Node::Body(body)); + if let Some(orelse) = orelse { + result.push(Node::Body(orelse)); } } StmtKind::While { test, body, orelse } => { result.push(Node::Expr(test)); - for stmt in body { - result.push(Node::Stmt(stmt)); - } - for stmt in orelse { - result.push(Node::Stmt(stmt)); + result.push(Node::Body(body)); + if let Some(orelse) = orelse { + result.push(Node::Body(orelse)); } } - StmtKind::If { test, body, orelse } => { + StmtKind::If { + test, body, orelse, .. + } => { result.push(Node::Expr(test)); - for stmt in body { - result.push(Node::Stmt(stmt)); - } - for stmt in orelse { - result.push(Node::Stmt(stmt)); + result.push(Node::Body(body)); + if let Some(orelse) = orelse { + result.push(Node::Body(orelse)); } } StmtKind::With { items, body, .. } | StmtKind::AsyncWith { items, body, .. } => { @@ -387,9 +386,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { result.push(Node::Expr(expr)); } } - for stmt in body { - result.push(Node::Stmt(stmt)); - } + result.push(Node::Body(body)); } StmtKind::Match { subject, cases } => { result.push(Node::Expr(subject)); @@ -398,9 +395,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { if let Some(expr) = &case.guard { result.push(Node::Expr(expr)); } - for stmt in &case.body { - result.push(Node::Stmt(stmt)); - } + result.push(Node::Body(&case.body)); } } StmtKind::Raise { exc, cause } => { @@ -431,17 +426,15 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { orelse, finalbody, } => { - for stmt in body { - result.push(Node::Stmt(stmt)); - } + result.push(Node::Body(body)); for handler in handlers { result.push(Node::Excepthandler(handler)); } - for stmt in orelse { - result.push(Node::Stmt(stmt)); + if let Some(orelse) = orelse { + result.push(Node::Body(orelse)); } - for stmt in finalbody { - result.push(Node::Stmt(stmt)); + if let Some(finalbody) = finalbody { + result.push(Node::Body(finalbody)); } } StmtKind::Import { names } => { @@ -457,7 +450,6 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { StmtKind::Global { .. } => {} StmtKind::Nonlocal { .. } => {} }, - // TODO(charlie): Actual logic, this doesn't do anything. Node::Expr(expr) => match &expr.node { ExprKind::BoolOp { values, .. } => { for value in values { @@ -476,7 +468,6 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { result.push(Node::Expr(operand)); } ExprKind::Lambda { body, args, .. } => { - // TODO(charlie): Arguments. for expr in &args.defaults { result.push(Node::Expr(expr)); } @@ -630,9 +621,7 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { if let Some(type_) = type_ { result.push(Node::Expr(type_)); } - for stmt in body { - result.push(Node::Stmt(stmt)); - } + result.push(Node::Body(body)); } Node::SliceIndex(slice_index) => { if let SliceIndexKind::Index { value } = &slice_index.node { @@ -717,6 +706,7 @@ pub fn decorate_token<'a>( let middle = (left + right) / 2; let child = &child_nodes[middle]; let start = match &child { + Node::Body(node) => node.location, Node::Stmt(node) => node.location, Node::Expr(node) => node.location, Node::Alias(node) => node.location, @@ -726,6 +716,7 @@ pub fn decorate_token<'a>( Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), }; let end = match &child { + Node::Body(node) => node.end_location.unwrap(), Node::Stmt(node) => node.end_location.unwrap(), Node::Expr(node) => node.end_location.unwrap(), Node::Alias(node) => node.end_location.unwrap(), @@ -739,6 +730,7 @@ pub fn decorate_token<'a>( // Special-case: if we're dealing with a statement that's a single expression, // we want to treat the expression as the enclosed node. let existing_start = match &existing { + Node::Body(node) => node.location, Node::Stmt(node) => node.location, Node::Expr(node) => node.location, Node::Alias(node) => node.location, @@ -748,6 +740,7 @@ pub fn decorate_token<'a>( Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), }; let existing_end = match &existing { + Node::Body(node) => node.end_location.unwrap(), Node::Stmt(node) => node.end_location.unwrap(), Node::Expr(node) => node.end_location.unwrap(), Node::Alias(node) => node.end_location.unwrap(), @@ -809,6 +802,7 @@ pub fn decorate_token<'a>( #[derive(Debug, Default)] pub struct TriviaIndex { + pub body: FxHashMap>, pub stmt: FxHashMap>, pub expr: FxHashMap>, pub alias: FxHashMap>, @@ -820,6 +814,13 @@ pub struct TriviaIndex { fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) { match node { Node::Mod(_) => {} + Node::Body(node) => { + trivia + .body + .entry(node.id()) + .or_insert_with(Vec::new) + .push(comment); + } Node::Stmt(node) => { trivia .stmt