From 1c750711360eb7b3f296f645b962ab5e9175b47a Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Sun, 26 Feb 2023 00:05:56 -0500 Subject: [PATCH] Implement basic rendering of remaining AST nodes (#3233) --- .../ruff_python_formatter/src/attachment.rs | 10 +- crates/ruff_python_formatter/src/cst.rs | 105 +++++++++++- .../ruff_python_formatter/src/format/expr.rs | 45 ++++- .../src/format/match_case.rs | 39 +++++ .../ruff_python_formatter/src/format/mod.rs | 2 + .../src/format/pattern.rs | 158 ++++++++++++++++++ .../ruff_python_formatter/src/format/stmt.rs | 142 ++++++++++++---- crates/ruff_python_formatter/src/trivia.rs | 74 +++++++- 8 files changed, 527 insertions(+), 48 deletions(-) create mode 100644 crates/ruff_python_formatter/src/format/match_case.rs create mode 100644 crates/ruff_python_formatter/src/format/pattern.rs diff --git a/crates/ruff_python_formatter/src/attachment.rs b/crates/ruff_python_formatter/src/attachment.rs index 233f1006ea..b11daa3bce 100644 --- a/crates/ruff_python_formatter/src/attachment.rs +++ b/crates/ruff_python_formatter/src/attachment.rs @@ -1,6 +1,6 @@ use crate::core::visitor; use crate::core::visitor::Visitor; -use crate::cst::{Alias, Excepthandler, Expr, SliceIndex, Stmt}; +use crate::cst::{Alias, Excepthandler, Expr, Pattern, SliceIndex, Stmt}; use crate::trivia::{decorate_trivia, TriviaIndex, TriviaToken}; struct AttachmentVisitor { @@ -47,6 +47,14 @@ impl<'a> Visitor<'a> for AttachmentVisitor { } visitor::walk_slice_index(self, slice_index); } + + fn visit_pattern(&mut self, pattern: &'a mut Pattern) { + let trivia = self.index.pattern.remove(&pattern.id()); + if let Some(comments) = trivia { + pattern.trivia.extend(comments); + } + visitor::walk_pattern(self, pattern); + } } pub fn attach(python_cst: &mut [Stmt], trivia: Vec) { diff --git a/crates/ruff_python_formatter/src/cst.rs b/crates/ruff_python_formatter/src/cst.rs index b193da3d06..c1222552ec 100644 --- a/crates/ruff_python_formatter/src/cst.rs +++ b/crates/ruff_python_formatter/src/cst.rs @@ -566,6 +566,95 @@ impl From<(rustpython_parser::ast::Excepthandler, &Locator<'_>)> for Excepthandl } } +impl From<(rustpython_parser::ast::Pattern, &Locator<'_>)> for Pattern { + fn from((pattern, locator): (rustpython_parser::ast::Pattern, &Locator)) -> Self { + Pattern { + location: pattern.location, + end_location: pattern.end_location, + node: match pattern.node { + rustpython_parser::ast::PatternKind::MatchValue { value } => { + PatternKind::MatchValue { + value: Box::new((*value, locator).into()), + } + } + rustpython_parser::ast::PatternKind::MatchSingleton { value } => { + PatternKind::MatchSingleton { value } + } + rustpython_parser::ast::PatternKind::MatchSequence { patterns } => { + PatternKind::MatchSequence { + patterns: patterns + .into_iter() + .map(|pattern| (pattern, locator).into()) + .collect(), + } + } + rustpython_parser::ast::PatternKind::MatchMapping { + keys, + patterns, + rest, + } => PatternKind::MatchMapping { + keys: keys.into_iter().map(|key| (key, locator).into()).collect(), + patterns: patterns + .into_iter() + .map(|pattern| (pattern, locator).into()) + .collect(), + rest, + }, + rustpython_parser::ast::PatternKind::MatchClass { + cls, + patterns, + kwd_attrs, + kwd_patterns, + } => PatternKind::MatchClass { + cls: Box::new((*cls, locator).into()), + patterns: patterns + .into_iter() + .map(|pattern| (pattern, locator).into()) + .collect(), + kwd_attrs, + kwd_patterns: kwd_patterns + .into_iter() + .map(|pattern| (pattern, locator).into()) + .collect(), + }, + rustpython_parser::ast::PatternKind::MatchStar { name } => { + PatternKind::MatchStar { name } + } + rustpython_parser::ast::PatternKind::MatchAs { pattern, name } => { + PatternKind::MatchAs { + pattern: pattern.map(|pattern| Box::new((*pattern, locator).into())), + name, + } + } + rustpython_parser::ast::PatternKind::MatchOr { patterns } => PatternKind::MatchOr { + patterns: patterns + .into_iter() + .map(|pattern| (pattern, locator).into()) + .collect(), + }, + }, + trivia: vec![], + parentheses: Parenthesize::Never, + } + } +} + +impl From<(rustpython_parser::ast::MatchCase, &Locator<'_>)> for MatchCase { + fn from((match_case, locator): (rustpython_parser::ast::MatchCase, &Locator)) -> Self { + MatchCase { + pattern: (match_case.pattern, locator).into(), + guard: match_case + .guard + .map(|guard| Box::new((*guard, locator).into())), + body: match_case + .body + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + } + } +} + impl From<(rustpython_parser::ast::Stmt, &Locator<'_>)> for Stmt { fn from((stmt, locator): (rustpython_parser::ast::Stmt, &Locator)) -> Self { match stmt.node { @@ -875,9 +964,19 @@ impl From<(rustpython_parser::ast::Stmt, &Locator<'_>)> for Stmt { trivia: vec![], parentheses: Parenthesize::Never, }, - rustpython_parser::ast::StmtKind::Match { .. } => { - todo!("match statement"); - } + rustpython_parser::ast::StmtKind::Match { subject, cases } => Stmt { + location: stmt.location, + end_location: stmt.end_location, + node: StmtKind::Match { + subject: Box::new((*subject, locator).into()), + cases: cases + .into_iter() + .map(|node| (node, locator).into()) + .collect(), + }, + trivia: vec![], + parentheses: Parenthesize::Never, + }, rustpython_parser::ast::StmtKind::Raise { exc, cause } => Stmt { location: stmt.location, end_location: stmt.end_location, diff --git a/crates/ruff_python_formatter/src/format/expr.rs b/crates/ruff_python_formatter/src/format/expr.rs index 63f11e28a3..222493380f 100644 --- a/crates/ruff_python_formatter/src/format/expr.rs +++ b/crates/ruff_python_formatter/src/format/expr.rs @@ -166,7 +166,7 @@ fn format_slice( upper: &SliceIndex, step: Option<&SliceIndex>, ) -> FormatResult<()> { - // // https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#slices + // https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#slices let lower_is_simple = if let SliceIndexKind::Index { value } = &lower.node { is_simple_slice(value) } else { @@ -242,6 +242,22 @@ fn format_slice( Ok(()) } +fn format_formatted_value( + f: &mut Formatter>, + expr: &Expr, + value: &Expr, + _conversion: usize, + format_spec: Option<&Expr>, +) -> FormatResult<()> { + write!(f, [text("!")])?; + write!(f, [value.format()])?; + if let Some(format_spec) = format_spec { + write!(f, [text(":")])?; + write!(f, [format_spec.format()])?; + } + Ok(()) +} + fn format_list( f: &mut Formatter>, expr: &Expr, @@ -703,6 +719,22 @@ fn format_attribute( Ok(()) } +fn format_named_expr( + f: &mut Formatter>, + expr: &Expr, + target: &Expr, + value: &Expr, +) -> FormatResult<()> { + write!(f, [target.format()])?; + write!(f, [text(":=")])?; + write!(f, [space()])?; + write!(f, [group(&format_args![value.format()])])?; + + write!(f, [end_of_line_comments(expr)])?; + + Ok(()) +} + fn format_bool_op( f: &mut Formatter>, expr: &Expr, @@ -825,7 +857,7 @@ impl Format> for FormatExpr<'_> { match &self.item.node { ExprKind::BoolOp { op, values } => format_bool_op(f, self.item, op, values), - // ExprKind::NamedExpr { .. } => {} + ExprKind::NamedExpr { target, value } => format_named_expr(f, self.item, target, value), ExprKind::BinOp { left, op, right } => format_bin_op(f, self.item, left, op, right), ExprKind::UnaryOp { op, operand } => format_unary_op(f, self.item, op, operand), ExprKind::Lambda { args, body } => format_lambda(f, self.item, args, body), @@ -859,7 +891,6 @@ impl Format> for FormatExpr<'_> { args, keywords, } => format_call(f, self.item, func, args, keywords), - // ExprKind::FormattedValue { .. } => {} ExprKind::JoinedStr { values } => format_joined_str(f, self.item, values), ExprKind::Constant { value, kind } => { format_constant(f, self.item, value, kind.as_deref()) @@ -875,9 +906,11 @@ impl Format> for FormatExpr<'_> { ExprKind::Slice { lower, upper, step } => { format_slice(f, self.item, lower, upper, step.as_ref()) } - _ => { - unimplemented!("Implement ExprKind: {:?}", self.item.node) - } + ExprKind::FormattedValue { + value, + conversion, + format_spec, + } => format_formatted_value(f, self.item, value, *conversion, format_spec.as_deref()), }?; // Any trailing comments come on the lines after. diff --git a/crates/ruff_python_formatter/src/format/match_case.rs b/crates/ruff_python_formatter/src/format/match_case.rs new file mode 100644 index 0000000000..98ccd968d8 --- /dev/null +++ b/crates/ruff_python_formatter/src/format/match_case.rs @@ -0,0 +1,39 @@ +use ruff_formatter::prelude::*; +use ruff_formatter::write; + +use crate::context::ASTFormatContext; +use crate::cst::MatchCase; +use crate::format::builders::block; +use crate::shared_traits::AsFormat; + +pub struct FormatMatchCase<'a> { + item: &'a MatchCase, +} + +impl AsFormat> for MatchCase { + type Format<'a> = FormatMatchCase<'a>; + + fn format(&self) -> Self::Format<'_> { + FormatMatchCase { item: self } + } +} + +impl Format> for FormatMatchCase<'_> { + fn fmt(&self, f: &mut Formatter) -> FormatResult<()> { + let MatchCase { + pattern, + guard, + body, + } = self.item; + + write!(f, [text("case")])?; + write!(f, [space(), pattern.format()])?; + if let Some(guard) = &guard { + write!(f, [space(), text("if"), space(), guard.format()])?; + } + write!(f, [text(":")])?; + write!(f, [block_indent(&block(body))])?; + + Ok(()) + } +} diff --git a/crates/ruff_python_formatter/src/format/mod.rs b/crates/ruff_python_formatter/src/format/mod.rs index 54401b6540..a207fbfb11 100644 --- a/crates/ruff_python_formatter/src/format/mod.rs +++ b/crates/ruff_python_formatter/src/format/mod.rs @@ -9,8 +9,10 @@ mod comprehension; mod excepthandler; mod expr; mod helpers; +mod match_case; mod numbers; mod operator; +mod pattern; mod stmt; mod strings; mod unaryop; diff --git a/crates/ruff_python_formatter/src/format/pattern.rs b/crates/ruff_python_formatter/src/format/pattern.rs new file mode 100644 index 0000000000..4fe8f5e396 --- /dev/null +++ b/crates/ruff_python_formatter/src/format/pattern.rs @@ -0,0 +1,158 @@ +use ruff_formatter::prelude::*; +use ruff_formatter::write; +use ruff_text_size::TextSize; +use rustpython_parser::ast::Constant; + +use crate::context::ASTFormatContext; +use crate::cst::{Pattern, PatternKind}; +use crate::shared_traits::AsFormat; + +pub struct FormatPattern<'a> { + item: &'a Pattern, +} + +impl AsFormat> for Pattern { + type Format<'a> = FormatPattern<'a>; + + fn format(&self) -> Self::Format<'_> { + FormatPattern { item: self } + } +} + +impl Format> for FormatPattern<'_> { + fn fmt(&self, f: &mut Formatter) -> FormatResult<()> { + let pattern = self.item; + + match &pattern.node { + PatternKind::MatchValue { value } => { + write!(f, [value.format()])?; + } + PatternKind::MatchSingleton { value } => match value { + Constant::None => write!(f, [text("None")])?, + Constant::Bool(value) => { + if *value { + write!(f, [text("True")])?; + } else { + write!(f, [text("False")])?; + } + } + _ => unreachable!("singleton pattern must be None or bool"), + }, + PatternKind::MatchSequence { patterns } => { + write!(f, [text("[")])?; + if let Some(pattern) = patterns.first() { + write!(f, [pattern.format()])?; + } + for pattern in patterns.iter().skip(1) { + write!(f, [text(","), space(), pattern.format()])?; + } + write!(f, [text("]")])?; + } + PatternKind::MatchMapping { + keys, + patterns, + rest, + } => { + write!(f, [text("{")])?; + if let Some(pattern) = patterns.first() { + write!(f, [keys[0].format(), text(":"), space(), pattern.format()])?; + } + for (key, pattern) in keys.iter().skip(1).zip(patterns.iter().skip(1)) { + write!( + f, + [ + text(","), + space(), + key.format(), + text(":"), + space(), + pattern.format() + ] + )?; + } + if let Some(rest) = &rest { + write!( + f, + [ + text(","), + space(), + text("**"), + space(), + dynamic_text(rest, TextSize::default()) + ] + )?; + } + write!(f, [text("}")])?; + } + PatternKind::MatchClass { + cls, + patterns, + kwd_attrs, + kwd_patterns, + } => { + write!(f, [cls.format()])?; + if !patterns.is_empty() { + write!(f, [text("(")])?; + if let Some(pattern) = patterns.first() { + write!(f, [pattern.format()])?; + } + for pattern in patterns.iter().skip(1) { + write!(f, [text(","), space(), pattern.format()])?; + } + write!(f, [text(")")])?; + } + if !kwd_attrs.is_empty() { + write!(f, [text("(")])?; + if let Some(attr) = kwd_attrs.first() { + write!(f, [dynamic_text(attr, TextSize::default())])?; + } + for attr in kwd_attrs.iter().skip(1) { + write!( + f, + [text(","), space(), dynamic_text(attr, TextSize::default())] + )?; + } + write!(f, [text(")")])?; + } + if !kwd_patterns.is_empty() { + write!(f, [text("(")])?; + if let Some(pattern) = kwd_patterns.first() { + write!(f, [pattern.format()])?; + } + for pattern in kwd_patterns.iter().skip(1) { + write!(f, [text(","), space(), pattern.format()])?; + } + write!(f, [text(")")])?; + } + } + PatternKind::MatchStar { name } => { + if let Some(name) = &name { + write!(f, [text("*"), dynamic_text(name, TextSize::default())])?; + } else { + write!(f, [text("*_")])?; + } + } + PatternKind::MatchAs { pattern, name } => { + if let Some(pattern) = &pattern { + write!(f, [pattern.format()])?; + write!(f, [space()])?; + write!(f, [text("as")])?; + write!(f, [space()])?; + } + if let Some(name) = &name { + write!(f, [dynamic_text(name, TextSize::default())])?; + } else { + write!(f, [text("_")])?; + } + } + PatternKind::MatchOr { patterns } => { + write!(f, [patterns[0].format()])?; + for pattern in patterns.iter().skip(1) { + write!(f, [space(), text("|"), space(), pattern.format()])?; + } + } + } + + Ok(()) + } +} diff --git a/crates/ruff_python_formatter/src/format/stmt.rs b/crates/ruff_python_formatter/src/format/stmt.rs index ffc57b864d..0b93f386b3 100644 --- a/crates/ruff_python_formatter/src/format/stmt.rs +++ b/crates/ruff_python_formatter/src/format/stmt.rs @@ -6,52 +6,67 @@ use ruff_text_size::TextSize; use crate::context::ASTFormatContext; use crate::cst::{ - Alias, Arguments, Excepthandler, Expr, ExprKind, Keyword, Stmt, StmtKind, Withitem, + Alias, Arguments, Excepthandler, Expr, ExprKind, Keyword, MatchCase, Operator, Stmt, StmtKind, + Withitem, }; use crate::format::builders::{block, join_names}; use crate::format::comments::{end_of_line_comments, leading_comments, trailing_comments}; use crate::format::helpers::is_self_closing; use crate::shared_traits::AsFormat; -fn format_break(f: &mut Formatter>) -> FormatResult<()> { - write!(f, [text("break")]) -} - -fn format_pass(f: &mut Formatter>, stmt: &Stmt) -> FormatResult<()> { - // Write the statement body. - write!(f, [text("pass")])?; - +fn format_break(f: &mut Formatter>, stmt: &Stmt) -> FormatResult<()> { + write!(f, [text("break")])?; write!(f, [end_of_line_comments(stmt)])?; - Ok(()) } -fn format_continue(f: &mut Formatter>) -> FormatResult<()> { - write!(f, [text("continue")]) +fn format_pass(f: &mut Formatter>, stmt: &Stmt) -> FormatResult<()> { + write!(f, [text("pass")])?; + write!(f, [end_of_line_comments(stmt)])?; + Ok(()) } -fn format_global(f: &mut Formatter>, names: &[String]) -> FormatResult<()> { +fn format_continue(f: &mut Formatter>, stmt: &Stmt) -> FormatResult<()> { + write!(f, [text("continue")])?; + write!(f, [end_of_line_comments(stmt)])?; + Ok(()) +} + +fn format_global( + f: &mut Formatter>, + stmt: &Stmt, + names: &[String], +) -> FormatResult<()> { write!(f, [text("global")])?; if !names.is_empty() { write!(f, [space(), join_names(names)])?; } + write!(f, [end_of_line_comments(stmt)])?; Ok(()) } -fn format_nonlocal(f: &mut Formatter>, names: &[String]) -> FormatResult<()> { +fn format_nonlocal( + f: &mut Formatter>, + stmt: &Stmt, + names: &[String], +) -> FormatResult<()> { write!(f, [text("nonlocal")])?; if !names.is_empty() { write!(f, [space(), join_names(names)])?; } + write!(f, [end_of_line_comments(stmt)])?; Ok(()) } -fn format_delete(f: &mut Formatter>, targets: &[Expr]) -> FormatResult<()> { +fn format_delete( + f: &mut Formatter>, + stmt: &Stmt, + targets: &[Expr], +) -> FormatResult<()> { write!(f, [text("del")])?; - match targets.len() { - 0 => Ok(()), - 1 => write!(f, [space(), targets[0].format()]), + 0 => {} + 1 => write!(f, [space(), targets[0].format()])?, _ => { write!( f, @@ -74,9 +89,11 @@ fn format_delete(f: &mut Formatter>, targets: &[Expr]) -> F if_group_breaks(&text(")")), ]) ] - ) + )?; } } + write!(f, [end_of_line_comments(stmt)])?; + Ok(()) } fn format_class_def( @@ -223,6 +240,34 @@ fn format_assign( Ok(()) } +fn format_aug_assign( + f: &mut Formatter>, + stmt: &Stmt, + target: &Expr, + op: &Operator, + value: &Expr, +) -> FormatResult<()> { + write!(f, [target.format()])?; + write!(f, [text(" "), op.format(), text("=")])?; + if is_self_closing(value) { + write!(f, [space(), group(&value.format())])?; + } else { + write!( + f, + [ + space(), + group(&format_args![ + if_group_breaks(&text("(")), + soft_block_indent(&value.format()), + if_group_breaks(&text(")")), + ]) + ] + )?; + } + write!(f, [end_of_line_comments(stmt)])?; + Ok(()) +} + fn format_ann_assign( f: &mut Formatter>, stmt: &Stmt, @@ -268,7 +313,11 @@ fn format_for( body: &[Stmt], orelse: &[Stmt], _type_comment: Option<&str>, + async_: bool, ) -> FormatResult<()> { + if async_ { + write!(f, [text("async"), space()])?; + } write!( f, [ @@ -351,6 +400,19 @@ fn format_if( Ok(()) } +fn format_match( + f: &mut Formatter>, + stmt: &Stmt, + subject: &Expr, + cases: &[MatchCase], +) -> FormatResult<()> { + write!(f, [text("match"), space(), subject.format(), text(":")])?; + for case in cases { + write!(f, [block_indent(&case.format())])?; + } + Ok(()) +} + fn format_raise( f: &mut Formatter>, stmt: &Stmt, @@ -585,7 +647,6 @@ fn format_with_( if async_ { write!(f, [text("async"), space()])?; } - write!( f, [ @@ -609,7 +670,8 @@ fn format_with_( text(":"), block_indent(&block(body)) ] - ) + )?; + Ok(()) } pub struct FormatStmt<'a> { @@ -622,10 +684,10 @@ impl Format> for FormatStmt<'_> { match &self.item.node { StmtKind::Pass => format_pass(f, self.item), - StmtKind::Break => format_break(f), - StmtKind::Continue => format_continue(f), - StmtKind::Global { names } => format_global(f, names), - StmtKind::Nonlocal { names } => format_nonlocal(f, names), + StmtKind::Break => format_break(f, self.item), + StmtKind::Continue => format_continue(f, self.item), + StmtKind::Global { names } => format_global(f, self.item, names), + StmtKind::Nonlocal { names } => format_nonlocal(f, self.item, names), StmtKind::FunctionDef { name, args, @@ -668,9 +730,11 @@ impl Format> for FormatStmt<'_> { decorator_list, } => format_class_def(f, name, bases, keywords, body, decorator_list), StmtKind::Return { value } => format_return(f, self.item, value.as_ref()), - StmtKind::Delete { targets } => format_delete(f, targets), + StmtKind::Delete { targets } => format_delete(f, self.item, targets), StmtKind::Assign { targets, value, .. } => format_assign(f, self.item, targets, value), - // StmtKind::AugAssign { .. } => {} + StmtKind::AugAssign { target, op, value } => { + format_aug_assign(f, self.item, target, op, value) + } StmtKind::AnnAssign { target, annotation, @@ -691,8 +755,24 @@ impl Format> for FormatStmt<'_> { body, orelse, type_comment.as_deref(), + false, + ), + StmtKind::AsyncFor { + target, + iter, + body, + orelse, + type_comment, + } => format_for( + f, + self.item, + target, + iter, + body, + orelse, + type_comment.as_deref(), + true, ), - // StmtKind::AsyncFor { .. } => {} StmtKind::While { test, body, orelse } => { format_while(f, self.item, test, body, orelse) } @@ -721,7 +801,7 @@ impl Format> for FormatStmt<'_> { type_comment.as_ref().map(String::as_str), true, ), - // StmtKind::Match { .. } => {} + StmtKind::Match { subject, cases } => format_match(f, self.item, subject, cases), StmtKind::Raise { exc, cause } => { format_raise(f, self.item, exc.as_deref(), cause.as_deref()) } @@ -752,11 +832,7 @@ impl Format> for FormatStmt<'_> { names, level.as_ref(), ), - // StmtKind::Nonlocal { .. } => {} StmtKind::Expr { value } => format_expr(f, self.item, value), - _ => { - unimplemented!("Implement StmtKind: {:?}", self.item.node) - } }?; write!(f, [hard_line_break()])?; diff --git a/crates/ruff_python_formatter/src/trivia.rs b/crates/ruff_python_formatter/src/trivia.rs index de6f2422f5..4d21f81abf 100644 --- a/crates/ruff_python_formatter/src/trivia.rs +++ b/crates/ruff_python_formatter/src/trivia.rs @@ -5,8 +5,8 @@ use rustpython_parser::Tok; use crate::core::types::Range; use crate::cst::{ - Alias, Excepthandler, ExcepthandlerKind, Expr, ExprKind, SliceIndex, SliceIndexKind, Stmt, - StmtKind, + Alias, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Pattern, PatternKind, SliceIndex, + SliceIndexKind, Stmt, StmtKind, }; #[derive(Clone, Debug)] @@ -17,6 +17,7 @@ pub enum Node<'a> { Alias(&'a Alias), Excepthandler(&'a Excepthandler), SliceIndex(&'a SliceIndex), + Pattern(&'a Pattern), } impl Node<'_> { @@ -28,6 +29,7 @@ impl Node<'_> { Node::Alias(node) => node.id(), Node::Excepthandler(node) => node.id(), Node::SliceIndex(node) => node.id(), + Node::Pattern(node) => node.id(), } } } @@ -389,8 +391,17 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { result.push(Node::Stmt(stmt)); } } - StmtKind::Match { .. } => { - todo!("Support match statements"); + StmtKind::Match { subject, cases } => { + result.push(Node::Expr(subject)); + for case in cases { + result.push(Node::Pattern(&case.pattern)); + if let Some(expr) = &case.guard { + result.push(Node::Expr(expr)); + } + for stmt in &case.body { + result.push(Node::Stmt(stmt)); + } + } } StmtKind::Raise { exc, cause } => { if let Some(exc) = exc { @@ -615,7 +626,6 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { }, Node::Alias(..) => {} Node::Excepthandler(excepthandler) => { - // TODO(charlie): Ident. let ExcepthandlerKind::ExceptHandler { type_, body, .. } = &excepthandler.node; if let Some(type_) = type_ { result.push(Node::Expr(type_)); @@ -629,6 +639,48 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { result.push(Node::Expr(value)); } } + Node::Pattern(pattern) => match &pattern.node { + PatternKind::MatchValue { value } => { + result.push(Node::Expr(value)); + } + PatternKind::MatchSingleton { .. } => {} + PatternKind::MatchSequence { patterns } => { + for pattern in patterns { + result.push(Node::Pattern(pattern)); + } + } + PatternKind::MatchMapping { keys, patterns, .. } => { + for (key, pattern) in keys.iter().zip(patterns.iter()) { + result.push(Node::Expr(key)); + result.push(Node::Pattern(pattern)); + } + } + PatternKind::MatchClass { + cls, + patterns, + kwd_patterns, + .. + } => { + result.push(Node::Expr(cls)); + for pattern in patterns { + result.push(Node::Pattern(pattern)); + } + for pattern in kwd_patterns { + result.push(Node::Pattern(pattern)); + } + } + PatternKind::MatchStar { .. } => {} + PatternKind::MatchAs { pattern, .. } => { + if let Some(pattern) = pattern { + result.push(Node::Pattern(pattern)); + } + } + PatternKind::MatchOr { patterns } => { + for pattern in patterns { + result.push(Node::Pattern(pattern)); + } + } + }, } } @@ -670,6 +722,7 @@ pub fn decorate_token<'a>( Node::Alias(node) => node.location, Node::Excepthandler(node) => node.location, Node::SliceIndex(node) => node.location, + Node::Pattern(node) => node.location, Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), }; let end = match &child { @@ -678,6 +731,7 @@ pub fn decorate_token<'a>( Node::Alias(node) => node.end_location.unwrap(), Node::Excepthandler(node) => node.end_location.unwrap(), Node::SliceIndex(node) => node.end_location.unwrap(), + Node::Pattern(node) => node.end_location.unwrap(), Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), }; @@ -690,6 +744,7 @@ pub fn decorate_token<'a>( Node::Alias(node) => node.location, Node::Excepthandler(node) => node.location, Node::SliceIndex(node) => node.location, + Node::Pattern(node) => node.location, Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), }; let existing_end = match &existing { @@ -698,6 +753,7 @@ pub fn decorate_token<'a>( Node::Alias(node) => node.end_location.unwrap(), Node::Excepthandler(node) => node.end_location.unwrap(), Node::SliceIndex(node) => node.end_location.unwrap(), + Node::Pattern(node) => node.end_location.unwrap(), Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), }; if start == existing_start && end == existing_end { @@ -758,6 +814,7 @@ pub struct TriviaIndex { pub alias: FxHashMap>, pub excepthandler: FxHashMap>, pub slice_index: FxHashMap>, + pub pattern: FxHashMap>, } fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) { @@ -798,6 +855,13 @@ fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) { .or_insert_with(Vec::new) .push(comment); } + Node::Pattern(node) => { + trivia + .pattern + .entry(node.id()) + .or_insert_with(Vec::new) + .push(comment); + } } }