Introduce dedicated CST tokens for other operator kinds (#3267)

This commit is contained in:
Charlie Marsh 2023-02-27 23:54:57 -05:00 committed by GitHub
parent 061495a9eb
commit f5f09b489b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 425 additions and 217 deletions

View file

@ -1,7 +1,8 @@
use crate::core::visitor;
use crate::core::visitor::Visitor;
use crate::cst::{
Alias, Arg, Body, BoolOp, Excepthandler, Expr, Keyword, Pattern, SliceIndex, Stmt,
Alias, Arg, Body, BoolOp, CmpOp, Excepthandler, Expr, Keyword, Operator, Pattern, SliceIndex,
Stmt, UnaryOp,
};
use crate::trivia::{decorate_trivia, TriviaIndex, TriviaToken};
@ -74,6 +75,30 @@ impl<'a> Visitor<'a> for AttachmentVisitor {
visitor::walk_bool_op(self, bool_op);
}
fn visit_unary_op(&mut self, unary_op: &'a mut UnaryOp) {
let trivia = self.index.unary_op.remove(&unary_op.id());
if let Some(comments) = trivia {
unary_op.trivia.extend(comments);
}
visitor::walk_unary_op(self, unary_op);
}
fn visit_cmp_op(&mut self, cmp_op: &'a mut CmpOp) {
let trivia = self.index.cmp_op.remove(&cmp_op.id());
if let Some(comments) = trivia {
cmp_op.trivia.extend(comments);
}
visitor::walk_cmp_op(self, cmp_op);
}
fn visit_operator(&mut self, operator: &'a mut Operator) {
let trivia = self.index.operator.remove(&operator.id());
if let Some(comments) = trivia {
operator.trivia.extend(comments);
}
visitor::walk_operator(self, operator);
}
fn visit_slice_index(&mut self, slice_index: &'a mut SliceIndex) {
let trivia = self.index.slice_index.remove(&slice_index.id());
if let Some(comments) = trivia {

View file

@ -57,8 +57,10 @@ pub fn find_tok(
return (start, end);
}
}
unreachable!()
unreachable!(
"Failed to find token in range {:?}..{:?}",
location, end_location
)
}
/// Expand the range of a compound statement.

View file

@ -1,9 +1,9 @@
use rustpython_parser::ast::Constant;
use crate::cst::{
Alias, Arg, Arguments, Body, BoolOp, Cmpop, Comprehension, Excepthandler, ExcepthandlerKind,
Alias, Arg, Arguments, Body, BoolOp, CmpOp, Comprehension, Excepthandler, ExcepthandlerKind,
Expr, ExprContext, ExprKind, Keyword, MatchCase, Operator, Pattern, PatternKind, SliceIndex,
SliceIndexKind, Stmt, StmtKind, Unaryop, Withitem,
SliceIndexKind, Stmt, StmtKind, UnaryOp, Withitem,
};
pub trait Visitor<'a> {
@ -28,11 +28,11 @@ pub trait Visitor<'a> {
fn visit_operator(&mut self, operator: &'a mut Operator) {
walk_operator(self, operator);
}
fn visit_unaryop(&mut self, unaryop: &'a mut Unaryop) {
walk_unaryop(self, unaryop);
fn visit_unary_op(&mut self, unary_op: &'a mut UnaryOp) {
walk_unary_op(self, unary_op);
}
fn visit_cmpop(&mut self, cmpop: &'a mut Cmpop) {
walk_cmpop(self, cmpop);
fn visit_cmp_op(&mut self, cmp_op: &'a mut CmpOp) {
walk_cmp_op(self, cmp_op);
}
fn visit_comprehension(&mut self, comprehension: &'a mut Comprehension) {
walk_comprehension(self, comprehension);
@ -312,7 +312,7 @@ pub fn walk_expr<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, expr: &'a mut Exp
visitor.visit_expr(right);
}
ExprKind::UnaryOp { op, operand } => {
visitor.visit_unaryop(op);
visitor.visit_unary_op(op);
visitor.visit_expr(operand);
}
ExprKind::Lambda { args, body } => {
@ -380,7 +380,7 @@ pub fn walk_expr<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, expr: &'a mut Exp
} => {
visitor.visit_expr(left);
for cmpop in ops {
visitor.visit_cmpop(cmpop);
visitor.visit_cmp_op(cmpop);
}
for expr in comparators {
visitor.visit_expr(expr);
@ -608,10 +608,10 @@ pub fn walk_bool_op<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, bool_op: &'a m
pub fn walk_operator<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, operator: &'a mut Operator) {}
#[allow(unused_variables)]
pub fn walk_unaryop<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, unaryop: &'a mut Unaryop) {}
pub fn walk_unary_op<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, unary_op: &'a mut UnaryOp) {}
#[allow(unused_variables)]
pub fn walk_cmpop<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, cmpop: &'a mut Cmpop) {}
pub fn walk_cmp_op<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, cmp_op: &'a mut CmpOp) {}
#[allow(unused_variables)]
pub fn walk_alias<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, alias: &'a mut Alias) {}

View file

@ -1,5 +1,7 @@
#![allow(clippy::derive_partial_eq_without_eq)]
use std::iter;
use rustpython_parser::ast::{Constant, Location};
use rustpython_parser::Mode;
@ -76,7 +78,7 @@ impl From<&rustpython_parser::ast::Boolop> for BoolOpKind {
pub type BoolOp = Located<BoolOpKind>;
#[derive(Clone, Debug, PartialEq)]
pub enum Operator {
pub enum OperatorKind {
Add,
Sub,
Mult,
@ -92,8 +94,10 @@ pub enum Operator {
FloorDiv,
}
impl From<rustpython_parser::ast::Operator> for Operator {
fn from(op: rustpython_parser::ast::Operator) -> Self {
pub type Operator = Located<OperatorKind>;
impl From<&rustpython_parser::ast::Operator> for OperatorKind {
fn from(op: &rustpython_parser::ast::Operator) -> Self {
match op {
rustpython_parser::ast::Operator::Add => Self::Add,
rustpython_parser::ast::Operator::Sub => Self::Sub,
@ -113,15 +117,17 @@ impl From<rustpython_parser::ast::Operator> for Operator {
}
#[derive(Clone, Debug, PartialEq)]
pub enum Unaryop {
pub enum UnaryOpKind {
Invert,
Not,
UAdd,
USub,
}
impl From<rustpython_parser::ast::Unaryop> for Unaryop {
fn from(op: rustpython_parser::ast::Unaryop) -> Self {
pub type UnaryOp = Located<UnaryOpKind>;
impl From<&rustpython_parser::ast::Unaryop> for UnaryOpKind {
fn from(op: &rustpython_parser::ast::Unaryop) -> Self {
match op {
rustpython_parser::ast::Unaryop::Invert => Self::Invert,
rustpython_parser::ast::Unaryop::Not => Self::Not,
@ -132,7 +138,7 @@ impl From<rustpython_parser::ast::Unaryop> for Unaryop {
}
#[derive(Clone, Debug, PartialEq)]
pub enum Cmpop {
pub enum CmpOpKind {
Eq,
NotEq,
Lt,
@ -145,8 +151,10 @@ pub enum Cmpop {
NotIn,
}
impl From<rustpython_parser::ast::Cmpop> for Cmpop {
fn from(op: rustpython_parser::ast::Cmpop) -> Self {
pub type CmpOp = Located<CmpOpKind>;
impl From<&rustpython_parser::ast::Cmpop> for CmpOpKind {
fn from(op: &rustpython_parser::ast::Cmpop) -> Self {
match op {
rustpython_parser::ast::Cmpop::Eq => Self::Eq,
rustpython_parser::ast::Cmpop::NotEq => Self::NotEq,
@ -325,7 +333,7 @@ pub enum ExprKind {
right: Box<Expr>,
},
UnaryOp {
op: Unaryop,
op: UnaryOp,
operand: Box<Expr>,
},
Lambda {
@ -372,7 +380,7 @@ pub enum ExprKind {
},
Compare {
left: Box<Expr>,
ops: Vec<Cmpop>,
ops: Vec<CmpOp>,
comparators: Vec<Expr>,
},
Call {
@ -1013,8 +1021,57 @@ impl From<(rustpython_parser::ast::Stmt, &Locator<'_>)> for Stmt {
location: stmt.location,
end_location: stmt.end_location,
node: StmtKind::AugAssign {
op: {
let target_tok = match &op {
rustpython_parser::ast::Operator::Add => {
rustpython_parser::Tok::PlusEqual
}
rustpython_parser::ast::Operator::Sub => {
rustpython_parser::Tok::MinusEqual
}
rustpython_parser::ast::Operator::Mult => {
rustpython_parser::Tok::StarEqual
}
rustpython_parser::ast::Operator::MatMult => {
rustpython_parser::Tok::AtEqual
}
rustpython_parser::ast::Operator::Div => {
rustpython_parser::Tok::SlashEqual
}
rustpython_parser::ast::Operator::Mod => {
rustpython_parser::Tok::PercentEqual
}
rustpython_parser::ast::Operator::Pow => {
rustpython_parser::Tok::DoubleStarEqual
}
rustpython_parser::ast::Operator::LShift => {
rustpython_parser::Tok::LeftShiftEqual
}
rustpython_parser::ast::Operator::RShift => {
rustpython_parser::Tok::RightShiftEqual
}
rustpython_parser::ast::Operator::BitOr => {
rustpython_parser::Tok::VbarEqual
}
rustpython_parser::ast::Operator::BitXor => {
rustpython_parser::Tok::CircumflexEqual
}
rustpython_parser::ast::Operator::BitAnd => {
rustpython_parser::Tok::AmperEqual
}
rustpython_parser::ast::Operator::FloorDiv => {
rustpython_parser::Tok::DoubleSlashEqual
}
};
let (op_location, op_end_location) = find_tok(
target.end_location.unwrap(),
value.location,
locator,
|tok| tok == target_tok,
);
Operator::new(op_location, op_end_location, (&op).into())
},
target: Box::new((*target, locator).into()),
op: op.into(),
value: Box::new((*value, locator).into()),
},
trivia: vec![],
@ -1685,7 +1742,7 @@ impl From<(rustpython_parser::ast::Expr, &Locator<'_>)> for Expr {
.iter()
.tuple_windows()
.map(|(left, right)| {
let target = match &op {
let target_tok = match &op {
rustpython_parser::ast::Boolop::And => rustpython_parser::Tok::And,
rustpython_parser::ast::Boolop::Or => rustpython_parser::Tok::Or,
};
@ -1693,7 +1750,7 @@ impl From<(rustpython_parser::ast::Expr, &Locator<'_>)> for Expr {
left.end_location.unwrap(),
right.location,
locator,
|tok| tok == target,
|tok| tok == target_tok,
);
BoolOp::new(op_location, op_end_location, (&op).into())
})
@ -1720,8 +1777,43 @@ impl From<(rustpython_parser::ast::Expr, &Locator<'_>)> for Expr {
location: expr.location,
end_location: expr.end_location,
node: ExprKind::BinOp {
op: {
let target_tok = match &op {
rustpython_parser::ast::Operator::Add => rustpython_parser::Tok::Plus,
rustpython_parser::ast::Operator::Sub => rustpython_parser::Tok::Minus,
rustpython_parser::ast::Operator::Mult => rustpython_parser::Tok::Star,
rustpython_parser::ast::Operator::MatMult => rustpython_parser::Tok::At,
rustpython_parser::ast::Operator::Div => rustpython_parser::Tok::Slash,
rustpython_parser::ast::Operator::Mod => {
rustpython_parser::Tok::Percent
}
rustpython_parser::ast::Operator::Pow => {
rustpython_parser::Tok::DoubleStar
}
rustpython_parser::ast::Operator::LShift => {
rustpython_parser::Tok::LeftShift
}
rustpython_parser::ast::Operator::RShift => {
rustpython_parser::Tok::RightShift
}
rustpython_parser::ast::Operator::BitOr => rustpython_parser::Tok::Vbar,
rustpython_parser::ast::Operator::BitXor => {
rustpython_parser::Tok::CircumFlex
}
rustpython_parser::ast::Operator::BitAnd => {
rustpython_parser::Tok::Amper
}
rustpython_parser::ast::Operator::FloorDiv => {
rustpython_parser::Tok::DoubleSlash
}
};
let (op_location, op_end_location) =
find_tok(left.end_location.unwrap(), right.location, locator, |tok| {
tok == target_tok
});
Operator::new(op_location, op_end_location, (&op).into())
},
left: Box::new((*left, locator).into()),
op: op.into(),
right: Box::new((*right, locator).into()),
},
trivia: vec![],
@ -1731,7 +1823,21 @@ impl From<(rustpython_parser::ast::Expr, &Locator<'_>)> for Expr {
location: expr.location,
end_location: expr.end_location,
node: ExprKind::UnaryOp {
op: op.into(),
op: {
let target_tok = match &op {
rustpython_parser::ast::Unaryop::Invert => {
rustpython_parser::Tok::Tilde
}
rustpython_parser::ast::Unaryop::Not => rustpython_parser::Tok::Not,
rustpython_parser::ast::Unaryop::UAdd => rustpython_parser::Tok::Plus,
rustpython_parser::ast::Unaryop::USub => rustpython_parser::Tok::Minus,
};
let (op_location, op_end_location) =
find_tok(expr.location, operand.location, locator, |tok| {
tok == target_tok
});
UnaryOp::new(op_location, op_end_location, (&op).into())
},
operand: Box::new((*operand, locator).into()),
},
trivia: vec![],
@ -1878,8 +1984,45 @@ impl From<(rustpython_parser::ast::Expr, &Locator<'_>)> for Expr {
location: expr.location,
end_location: expr.end_location,
node: ExprKind::Compare {
ops: iter::once(left.as_ref())
.chain(comparators.iter())
.tuple_windows()
.zip(ops.into_iter())
.map(|((left, right), op)| {
let target_tok = match &op {
rustpython_parser::ast::Cmpop::Eq => {
rustpython_parser::Tok::EqEqual
}
rustpython_parser::ast::Cmpop::NotEq => {
rustpython_parser::Tok::NotEqual
}
rustpython_parser::ast::Cmpop::Lt => rustpython_parser::Tok::Less,
rustpython_parser::ast::Cmpop::LtE => {
rustpython_parser::Tok::LessEqual
}
rustpython_parser::ast::Cmpop::Gt => {
rustpython_parser::Tok::Greater
}
rustpython_parser::ast::Cmpop::GtE => {
rustpython_parser::Tok::GreaterEqual
}
rustpython_parser::ast::Cmpop::Is => rustpython_parser::Tok::Is,
// TODO(charlie): Break this into two tokens.
rustpython_parser::ast::Cmpop::IsNot => rustpython_parser::Tok::Is,
rustpython_parser::ast::Cmpop::In => rustpython_parser::Tok::In,
// TODO(charlie): Break this into two tokens.
rustpython_parser::ast::Cmpop::NotIn => rustpython_parser::Tok::In,
};
let (op_location, op_end_location) = find_tok(
left.end_location.unwrap(),
right.location,
locator,
|tok| tok == target_tok,
);
CmpOp::new(op_location, op_end_location, (&op).into())
})
.collect(),
left: Box::new((*left, locator).into()),
ops: ops.into_iter().map(Into::into).collect(),
comparators: comparators
.into_iter()
.map(|node| (node, locator).into())

View file

@ -20,19 +20,17 @@ impl AsFormat<ASTFormatContext<'_>> for BoolOp {
impl Format<ASTFormatContext<'_>> for FormatBoolOp<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext>) -> FormatResult<()> {
let boolop = self.item;
write!(f, [leading_comments(boolop)])?;
let bool_op = self.item;
write!(f, [leading_comments(bool_op)])?;
write!(
f,
[text(match boolop.node {
[text(match bool_op.node {
BoolOpKind::And => "and",
BoolOpKind::Or => "or",
})]
)?;
write!(f, [end_of_line_comments(boolop)])?;
write!(f, [trailing_comments(boolop)])?;
write!(f, [end_of_line_comments(bool_op)])?;
write!(f, [trailing_comments(bool_op)])?;
Ok(())
}
}

View file

@ -0,0 +1,44 @@
use ruff_formatter::prelude::*;
use ruff_formatter::write;
use crate::context::ASTFormatContext;
use crate::cst::{CmpOp, CmpOpKind};
use crate::format::comments::{end_of_line_comments, leading_comments, trailing_comments};
use crate::shared_traits::AsFormat;
pub struct FormatCmpOp<'a> {
item: &'a CmpOp,
}
impl AsFormat<ASTFormatContext<'_>> for CmpOp {
type Format<'a> = FormatCmpOp<'a>;
fn format(&self) -> Self::Format<'_> {
FormatCmpOp { item: self }
}
}
impl Format<ASTFormatContext<'_>> for FormatCmpOp<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext>) -> FormatResult<()> {
let cmp_op = self.item;
write!(f, [leading_comments(cmp_op)])?;
write!(
f,
[text(match cmp_op.node {
CmpOpKind::Eq => "==",
CmpOpKind::NotEq => "!=",
CmpOpKind::Lt => "<",
CmpOpKind::LtE => "<=",
CmpOpKind::Gt => ">",
CmpOpKind::GtE => ">=",
CmpOpKind::Is => "is",
CmpOpKind::IsNot => "is not",
CmpOpKind::In => "in",
CmpOpKind::NotIn => "not in",
})]
)?;
write!(f, [end_of_line_comments(cmp_op)])?;
write!(f, [trailing_comments(cmp_op)])?;
Ok(())
}
}

View file

@ -1,40 +0,0 @@
use ruff_formatter::prelude::*;
use ruff_formatter::write;
use crate::context::ASTFormatContext;
use crate::cst::Cmpop;
use crate::shared_traits::AsFormat;
pub struct FormatCmpop<'a> {
item: &'a Cmpop,
}
impl AsFormat<ASTFormatContext<'_>> for Cmpop {
type Format<'a> = FormatCmpop<'a>;
fn format(&self) -> Self::Format<'_> {
FormatCmpop { item: self }
}
}
impl Format<ASTFormatContext<'_>> for FormatCmpop<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext>) -> FormatResult<()> {
let unaryop = self.item;
write!(
f,
[text(match unaryop {
Cmpop::Eq => "==",
Cmpop::NotEq => "!=",
Cmpop::Lt => "<",
Cmpop::LtE => "<=",
Cmpop::Gt => ">",
Cmpop::GtE => ">=",
Cmpop::Is => "is",
Cmpop::IsNot => "is not",
Cmpop::In => "in",
Cmpop::NotIn => "not in",
})]
)?;
Ok(())
}
}

View file

@ -9,8 +9,8 @@ use ruff_text_size::TextSize;
use crate::context::ASTFormatContext;
use crate::core::types::Range;
use crate::cst::{
Arguments, BoolOp, Cmpop, Comprehension, Expr, ExprKind, Keyword, Operator, SliceIndex,
SliceIndexKind, Unaryop,
Arguments, BoolOp, CmpOp, Comprehension, Expr, ExprKind, Keyword, Operator, OperatorKind,
SliceIndex, SliceIndexKind, UnaryOp, UnaryOpKind,
};
use crate::format::builders::literal;
use crate::format::comments::{dangling_comments, end_of_line_comments, leading_comments};
@ -556,7 +556,7 @@ fn format_compare(
f: &mut Formatter<ASTFormatContext<'_>>,
expr: &Expr,
left: &Expr,
ops: &[Cmpop],
ops: &[CmpOp],
comparators: &[Expr],
) -> FormatResult<()> {
write!(f, [group(&format_args![left.format()])])?;
@ -723,7 +723,8 @@ fn format_bin_op(
right: &Expr,
) -> FormatResult<()> {
// https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#line-breaks-binary-operators
let is_simple = matches!(op, Operator::Pow) && is_simple_power(left) && is_simple_power(right);
let is_simple =
matches!(op.node, OperatorKind::Pow) && is_simple_power(left) && is_simple_power(right);
write!(f, [left.format()])?;
if !is_simple {
write!(f, [soft_line_break_or_space()])?;
@ -740,12 +741,12 @@ fn format_bin_op(
fn format_unary_op(
f: &mut Formatter<ASTFormatContext<'_>>,
expr: &Expr,
op: &Unaryop,
op: &UnaryOp,
operand: &Expr,
) -> FormatResult<()> {
write!(f, [op.format()])?;
// TODO(charlie): Do this in the normalization pass.
if !matches!(op, Unaryop::Not)
if !matches!(op.node, UnaryOpKind::Not)
&& matches!(
operand.node,
ExprKind::BoolOp { .. } | ExprKind::Compare { .. } | ExprKind::BinOp { .. }

View file

@ -1,4 +1,4 @@
use crate::cst::{Expr, ExprKind, Unaryop};
use crate::cst::{Expr, ExprKind, UnaryOpKind};
pub fn is_self_closing(expr: &Expr) -> bool {
match &expr.node {
@ -56,7 +56,7 @@ pub fn is_self_closing(expr: &Expr) -> bool {
pub fn is_simple_slice(expr: &Expr) -> bool {
match &expr.node {
ExprKind::UnaryOp { op, operand } => {
if matches!(op, Unaryop::Not) {
if matches!(op.node, UnaryOpKind::Not) {
false
} else {
is_simple_slice(operand)
@ -73,7 +73,7 @@ pub fn is_simple_slice(expr: &Expr) -> bool {
pub fn is_simple_power(expr: &Expr) -> bool {
match &expr.node {
ExprKind::UnaryOp { op, operand } => {
if matches!(op, Unaryop::Not) {
if matches!(op.node, UnaryOpKind::Not) {
false
} else {
is_simple_slice(operand)

View file

@ -3,7 +3,7 @@ mod arg;
mod arguments;
mod bool_op;
pub mod builders;
mod cmpop;
mod cmp_op;
mod comments;
mod comprehension;
mod excepthandler;
@ -16,5 +16,5 @@ mod operator;
mod pattern;
mod stmt;
mod strings;
mod unaryop;
mod unary_op;
mod withitem;

View file

@ -2,7 +2,8 @@ use ruff_formatter::prelude::*;
use ruff_formatter::write;
use crate::context::ASTFormatContext;
use crate::cst::Operator;
use crate::cst::{Operator, OperatorKind};
use crate::format::comments::{end_of_line_comments, leading_comments, trailing_comments};
use crate::shared_traits::AsFormat;
pub struct FormatOperator<'a> {
@ -20,26 +21,27 @@ impl AsFormat<ASTFormatContext<'_>> for Operator {
impl Format<ASTFormatContext<'_>> for FormatOperator<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext<'_>>) -> FormatResult<()> {
let operator = self.item;
write!(f, [leading_comments(operator)])?;
write!(
f,
[text(match operator {
Operator::Add => "+",
Operator::Sub => "-",
Operator::Mult => "*",
Operator::MatMult => "@",
Operator::Div => "/",
Operator::Mod => "%",
Operator::Pow => "**",
Operator::LShift => "<<",
Operator::RShift => ">>",
Operator::BitOr => "|",
Operator::BitXor => "^",
Operator::BitAnd => "&",
Operator::FloorDiv => "//",
[text(match operator.node {
OperatorKind::Add => "+",
OperatorKind::Sub => "-",
OperatorKind::Mult => "*",
OperatorKind::MatMult => "@",
OperatorKind::Div => "/",
OperatorKind::Mod => "%",
OperatorKind::Pow => "**",
OperatorKind::LShift => "<<",
OperatorKind::RShift => ">>",
OperatorKind::BitOr => "|",
OperatorKind::BitXor => "^",
OperatorKind::BitAnd => "&",
OperatorKind::FloorDiv => "//",
})]
)?;
write!(f, [end_of_line_comments(operator)])?;
write!(f, [trailing_comments(operator)])?;
Ok(())
}
}

View file

@ -0,0 +1,37 @@
use ruff_formatter::prelude::*;
use ruff_formatter::write;
use crate::context::ASTFormatContext;
use crate::cst::{UnaryOp, UnaryOpKind};
use crate::shared_traits::AsFormat;
pub struct FormatUnaryOp<'a> {
item: &'a UnaryOp,
}
impl AsFormat<ASTFormatContext<'_>> for UnaryOp {
type Format<'a> = FormatUnaryOp<'a>;
fn format(&self) -> Self::Format<'_> {
FormatUnaryOp { item: self }
}
}
impl Format<ASTFormatContext<'_>> for FormatUnaryOp<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext>) -> FormatResult<()> {
let unary_op = self.item;
write!(
f,
[
text(match unary_op.node {
UnaryOpKind::Invert => "~",
UnaryOpKind::Not => "not",
UnaryOpKind::UAdd => "+",
UnaryOpKind::USub => "-",
}),
matches!(unary_op.node, UnaryOpKind::Not).then_some(space())
]
)?;
Ok(())
}
}

View file

@ -1,37 +0,0 @@
use ruff_formatter::prelude::*;
use ruff_formatter::write;
use crate::context::ASTFormatContext;
use crate::cst::Unaryop;
use crate::shared_traits::AsFormat;
pub struct FormatUnaryop<'a> {
item: &'a Unaryop,
}
impl AsFormat<ASTFormatContext<'_>> for Unaryop {
type Format<'a> = FormatUnaryop<'a>;
fn format(&self) -> Self::Format<'_> {
FormatUnaryop { item: self }
}
}
impl Format<ASTFormatContext<'_>> for FormatUnaryop<'_> {
fn fmt(&self, f: &mut Formatter<ASTFormatContext>) -> FormatResult<()> {
let unaryop = self.item;
write!(
f,
[
text(match unaryop {
Unaryop::Invert => "~",
Unaryop::Not => "not",
Unaryop::UAdd => "+",
Unaryop::USub => "-",
}),
matches!(unaryop, Unaryop::Not).then_some(space())
]
)?;
Ok(())
}
}

View file

@ -3,8 +3,8 @@ use rustpython_parser::ast::Constant;
use crate::core::visitor;
use crate::core::visitor::Visitor;
use crate::cst::{
Alias, Arg, BoolOp, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Keyword, Pattern,
SliceIndex, Stmt, StmtKind,
Alias, Arg, BoolOp, CmpOp, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Keyword, Operator,
Pattern, SliceIndex, Stmt, StmtKind, UnaryOp,
};
use crate::trivia::{Relationship, Trivia, TriviaKind};
@ -327,6 +327,21 @@ impl<'a> Visitor<'a> for ExprNormalizer {
visitor::walk_bool_op(self, bool_op);
}
fn visit_unary_op(&mut self, unary_op: &'a mut UnaryOp) {
unary_op.trivia.retain(|c| !c.kind.is_empty_line());
visitor::walk_unary_op(self, unary_op);
}
fn visit_cmp_op(&mut self, cmp_op: &'a mut CmpOp) {
cmp_op.trivia.retain(|c| !c.kind.is_empty_line());
visitor::walk_cmp_op(self, cmp_op);
}
fn visit_operator(&mut self, operator: &'a mut Operator) {
operator.trivia.retain(|c| !c.kind.is_empty_line());
visitor::walk_operator(self, operator);
}
fn visit_slice_index(&mut self, slice_index: &'a mut SliceIndex) {
slice_index.trivia.retain(|c| !c.kind.is_empty_line());
visitor::walk_slice_index(self, slice_index);

View file

@ -110,17 +110,6 @@ elif unformatted:
},
)
@@ -18,8 +16,8 @@
"ls",
"-la",
]
- # fmt: on
- + path,
+ + # fmt: on
+ path,
check=True,
)
@@ -27,9 +25,8 @@
# Regression test for https://github.com/psf/black/issues/3026.
def test_func():
@ -212,8 +201,8 @@ run(
"ls",
"-la",
]
+ # fmt: on
path,
# fmt: on
+ path,
check=True,
)

View file

@ -5,8 +5,8 @@ use rustpython_parser::Tok;
use crate::core::types::Range;
use crate::cst::{
Alias, Arg, Body, BoolOp, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Keyword, Pattern,
PatternKind, SliceIndex, SliceIndexKind, Stmt, StmtKind,
Alias, Arg, Body, BoolOp, CmpOp, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Keyword,
Operator, Pattern, PatternKind, SliceIndex, SliceIndexKind, Stmt, StmtKind, UnaryOp,
};
#[derive(Clone, Debug)]
@ -15,13 +15,16 @@ pub enum Node<'a> {
Arg(&'a Arg),
Body(&'a Body),
BoolOp(&'a BoolOp),
CmpOp(&'a CmpOp),
Excepthandler(&'a Excepthandler),
Expr(&'a Expr),
Keyword(&'a Keyword),
Mod(&'a [Stmt]),
Operator(&'a Operator),
Pattern(&'a Pattern),
SliceIndex(&'a SliceIndex),
Stmt(&'a Stmt),
UnaryOp(&'a UnaryOp),
}
impl Node<'_> {
@ -31,13 +34,54 @@ impl Node<'_> {
Node::Arg(node) => node.id(),
Node::Body(node) => node.id(),
Node::BoolOp(node) => node.id(),
Node::CmpOp(node) => node.id(),
Node::Excepthandler(node) => node.id(),
Node::Expr(node) => node.id(),
Node::Keyword(node) => node.id(),
Node::Mod(nodes) => nodes as *const _ as usize,
Node::Operator(node) => node.id(),
Node::Pattern(node) => node.id(),
Node::SliceIndex(node) => node.id(),
Node::Stmt(node) => node.id(),
Node::UnaryOp(node) => node.id(),
}
}
pub fn location(&self) -> Location {
match self {
Node::Alias(node) => node.location,
Node::Arg(node) => node.location,
Node::Body(node) => node.location,
Node::BoolOp(node) => node.location,
Node::CmpOp(node) => node.location,
Node::Excepthandler(node) => node.location,
Node::Expr(node) => node.location,
Node::Keyword(node) => node.location,
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
Node::Operator(node) => node.location,
Node::Pattern(node) => node.location,
Node::SliceIndex(node) => node.location,
Node::Stmt(node) => node.location,
Node::UnaryOp(node) => node.location,
}
}
pub fn end_location(&self) -> Location {
match self {
Node::Alias(node) => node.end_location.unwrap(),
Node::Arg(node) => node.end_location.unwrap(),
Node::Body(node) => node.end_location.unwrap(),
Node::BoolOp(node) => node.end_location.unwrap(),
Node::CmpOp(node) => node.end_location.unwrap(),
Node::Excepthandler(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(),
Node::Keyword(node) => node.end_location.unwrap(),
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
Node::Operator(node) => node.end_location.unwrap(),
Node::Pattern(node) => node.end_location.unwrap(),
Node::SliceIndex(node) => node.end_location.unwrap(),
Node::Stmt(node) => node.end_location.unwrap(),
Node::UnaryOp(node) => node.end_location.unwrap(),
}
}
}
@ -240,7 +284,6 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Stmt(stmt));
}
}
Node::BoolOp(..) => {}
Node::Stmt(stmt) => match &stmt.node {
StmtKind::Return { value } => {
if let Some(value) = value {
@ -323,8 +366,9 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Expr(target));
}
}
StmtKind::AugAssign { target, value, .. } => {
StmtKind::AugAssign { target, op, value } => {
result.push(Node::Expr(target));
result.push(Node::Operator(op));
result.push(Node::Expr(value));
}
StmtKind::AnnAssign {
@ -464,11 +508,13 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Expr(target));
result.push(Node::Expr(value));
}
ExprKind::BinOp { left, right, .. } => {
ExprKind::BinOp { left, op, right } => {
result.push(Node::Expr(left));
result.push(Node::Operator(op));
result.push(Node::Expr(right));
}
ExprKind::UnaryOp { operand, .. } => {
ExprKind::UnaryOp { op, operand } => {
result.push(Node::UnaryOp(op));
result.push(Node::Expr(operand));
}
ExprKind::Lambda { body, args, .. } => {
@ -555,10 +601,13 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Expr(value));
}
ExprKind::Compare {
left, comparators, ..
left,
ops,
comparators,
} => {
result.push(Node::Expr(left));
for comparator in comparators {
for (op, comparator) in ops.iter().zip(comparators) {
result.push(Node::CmpOp(op));
result.push(Node::Expr(comparator));
}
}
@ -677,6 +726,10 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
}
}
},
Node::BoolOp(..) => {}
Node::UnaryOp(..) => {}
Node::Operator(..) => {}
Node::CmpOp(..) => {}
}
}
@ -712,62 +765,14 @@ pub fn decorate_token<'a>(
while left < right {
let middle = (left + right) / 2;
let child = &child_nodes[middle];
let start = match &child {
Node::Alias(node) => node.location,
Node::Arg(node) => node.location,
Node::Body(node) => node.location,
Node::BoolOp(node) => node.location,
Node::Excepthandler(node) => node.location,
Node::Expr(node) => node.location,
Node::Keyword(node) => node.location,
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
Node::Pattern(node) => node.location,
Node::SliceIndex(node) => node.location,
Node::Stmt(node) => node.location,
};
let end = match &child {
Node::Alias(node) => node.end_location.unwrap(),
Node::Arg(node) => node.end_location.unwrap(),
Node::Body(node) => node.end_location.unwrap(),
Node::BoolOp(node) => node.end_location.unwrap(),
Node::Excepthandler(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(),
Node::Keyword(node) => node.end_location.unwrap(),
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
Node::Pattern(node) => node.end_location.unwrap(),
Node::SliceIndex(node) => node.end_location.unwrap(),
Node::Stmt(node) => node.end_location.unwrap(),
};
let start = child.location();
let end = child.end_location();
if let Some(existing) = &enclosed_node {
// Special-case: if we're dealing with a statement that's a single expression,
// we want to treat the expression as the enclosed node.
let existing_start = match &existing {
Node::Alias(node) => node.location,
Node::Arg(node) => node.location,
Node::Body(node) => node.location,
Node::BoolOp(node) => node.location,
Node::Excepthandler(node) => node.location,
Node::Expr(node) => node.location,
Node::Keyword(node) => node.location,
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
Node::Pattern(node) => node.location,
Node::SliceIndex(node) => node.location,
Node::Stmt(node) => node.location,
};
let existing_end = match &existing {
Node::Alias(node) => node.end_location.unwrap(),
Node::Arg(node) => node.end_location.unwrap(),
Node::Body(node) => node.end_location.unwrap(),
Node::BoolOp(node) => node.end_location.unwrap(),
Node::Excepthandler(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(),
Node::Keyword(node) => node.end_location.unwrap(),
Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"),
Node::Pattern(node) => node.end_location.unwrap(),
Node::SliceIndex(node) => node.end_location.unwrap(),
Node::Stmt(node) => node.end_location.unwrap(),
};
let existing_start = existing.location();
let existing_end = existing.end_location();
if start == existing_start && end == existing_end {
enclosed_node = Some(child.clone());
}
@ -825,12 +830,15 @@ pub struct TriviaIndex {
pub arg: FxHashMap<usize, Vec<Trivia>>,
pub body: FxHashMap<usize, Vec<Trivia>>,
pub bool_op: FxHashMap<usize, Vec<Trivia>>,
pub cmp_op: FxHashMap<usize, Vec<Trivia>>,
pub excepthandler: FxHashMap<usize, Vec<Trivia>>,
pub expr: FxHashMap<usize, Vec<Trivia>>,
pub keyword: FxHashMap<usize, Vec<Trivia>>,
pub operator: FxHashMap<usize, Vec<Trivia>>,
pub pattern: FxHashMap<usize, Vec<Trivia>>,
pub slice_index: FxHashMap<usize, Vec<Trivia>>,
pub stmt: FxHashMap<usize, Vec<Trivia>>,
pub unary_op: FxHashMap<usize, Vec<Trivia>>,
}
fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) {
@ -863,6 +871,13 @@ fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) {
.or_insert_with(Vec::new)
.push(comment);
}
Node::CmpOp(node) => {
trivia
.cmp_op
.entry(node.id())
.or_insert_with(Vec::new)
.push(comment);
}
Node::Excepthandler(node) => {
trivia
.excepthandler
@ -884,6 +899,13 @@ fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) {
.or_insert_with(Vec::new)
.push(comment);
}
Node::Operator(node) => {
trivia
.operator
.entry(node.id())
.or_insert_with(Vec::new)
.push(comment);
}
Node::Pattern(node) => {
trivia
.pattern
@ -905,6 +927,13 @@ fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) {
.or_insert_with(Vec::new)
.push(comment);
}
Node::UnaryOp(node) => {
trivia
.unary_op
.entry(node.id())
.or_insert_with(Vec::new)
.push(comment);
}
Node::Mod(_) => {}
}
}