Handle right parens in join comma builder (#5711)

This commit is contained in:
Micha Reiser 2023-07-12 18:21:28 +02:00 committed by GitHub
parent f0aa6bd4d3
commit 653429bef9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 159 additions and 136 deletions

View file

@ -1,9 +1,12 @@
use ruff_text_size::{TextRange, TextSize};
use rustpython_parser::ast::Ranged;
use ruff_formatter::{format_args, write, Argument, Arguments};
use crate::context::NodeLevel;
use crate::prelude::*;
use crate::trivia::{first_non_trivia_token, lines_after, skip_trailing_trivia, Token, TokenKind};
use ruff_formatter::{format_args, write, Argument, Arguments};
use ruff_text_size::TextSize;
use rustpython_parser::ast::Ranged;
use crate::trivia::{lines_after, skip_trailing_trivia, SimpleTokenizer, Token, TokenKind};
use crate::MagicTrailingComma;
/// Adds parentheses and indents `content` if it doesn't fit on a line.
pub(crate) fn parenthesize_if_expands<'ast, T>(content: &T) -> ParenthesizeIfExpands<'_, 'ast>
@ -53,7 +56,10 @@ pub(crate) trait PyFormatterExtensions<'ast, 'buf> {
/// A builder that separates each element by a `,` and a [`soft_line_break_or_space`].
/// It emits a trailing `,` that is only shown if the enclosing group expands. It forces the enclosing
/// group to expand if the last item has a trailing `comma` and the magical comma option is enabled.
fn join_comma_separated<'fmt>(&'fmt mut self) -> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf>;
fn join_comma_separated<'fmt>(
&'fmt mut self,
sequence_end: TextSize,
) -> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf>;
}
impl<'buf, 'ast> PyFormatterExtensions<'ast, 'buf> for PyFormatter<'ast, 'buf> {
@ -61,8 +67,11 @@ impl<'buf, 'ast> PyFormatterExtensions<'ast, 'buf> for PyFormatter<'ast, 'buf> {
JoinNodesBuilder::new(self, level)
}
fn join_comma_separated<'fmt>(&'fmt mut self) -> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
JoinCommaSeparatedBuilder::new(self)
fn join_comma_separated<'fmt>(
&'fmt mut self,
sequence_end: TextSize,
) -> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
JoinCommaSeparatedBuilder::new(self, sequence_end)
}
}
@ -194,18 +203,20 @@ pub(crate) struct JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
result: FormatResult<()>,
fmt: &'fmt mut PyFormatter<'ast, 'buf>,
end_of_last_entry: Option<TextSize>,
sequence_end: TextSize,
/// We need to track whether we have more than one entry since a sole entry doesn't get a
/// magic trailing comma even when expanded
len: usize,
}
impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
fn new(f: &'fmt mut PyFormatter<'ast, 'buf>) -> Self {
fn new(f: &'fmt mut PyFormatter<'ast, 'buf>, sequence_end: TextSize) -> Self {
Self {
fmt: f,
result: Ok(()),
end_of_last_entry: None,
len: 0,
sequence_end,
}
}
@ -236,7 +247,7 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
where
T: Ranged,
F: Format<PyFormatContext<'ast>>,
I: Iterator<Item = (T, F)>,
I: IntoIterator<Item = (T, F)>,
{
for (node, content) in entries {
self.entry(&node, &content);
@ -248,7 +259,7 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
pub(crate) fn nodes<'a, T, I>(&mut self, entries: I) -> &mut Self
where
T: Ranged + AsFormat<PyFormatContext<'ast>> + 'a,
I: Iterator<Item = &'a T>,
I: IntoIterator<Item = &'a T>,
{
for node in entries {
self.entry(node, &node.format());
@ -260,14 +271,26 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
pub(crate) fn finish(&mut self) -> FormatResult<()> {
self.result.and_then(|_| {
if let Some(last_end) = self.end_of_last_entry.take() {
let magic_trailing_comma = self.fmt.options().magic_trailing_comma().is_respect()
&& matches!(
first_non_trivia_token(last_end, self.fmt.context().source()),
Some(Token {
kind: TokenKind::Comma,
..
})
);
let magic_trailing_comma = match self.fmt.options().magic_trailing_comma() {
MagicTrailingComma::Respect => {
let first_token = SimpleTokenizer::new(
self.fmt.context().source(),
TextRange::new(last_end, self.sequence_end),
)
.skip_trivia()
// Skip over any closing parentheses belonging to the expression
.find(|token| token.kind() != TokenKind::RParen);
matches!(
first_token,
Some(Token {
kind: TokenKind::Comma,
..
})
)
}
MagicTrailingComma::Ignore => false,
};
// If there is a single entry, only keep the magic trailing comma, don't add it if
// it wasn't there. If there is more than one entry, always add it.
@ -287,13 +310,15 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
#[cfg(test)]
mod tests {
use rustpython_parser::ast::ModModule;
use rustpython_parser::Parse;
use ruff_formatter::format;
use crate::comments::Comments;
use crate::context::{NodeLevel, PyFormatContext};
use crate::prelude::*;
use crate::PyFormatOptions;
use ruff_formatter::format;
use rustpython_parser::ast::ModModule;
use rustpython_parser::Parse;
fn format_ranged(level: NodeLevel) -> String {
let source = r#"

View file

@ -5,10 +5,12 @@ use crate::expression::parentheses::{
default_expression_needs_parentheses, parenthesized, NeedsParentheses, Parentheses,
Parenthesize,
};
use crate::trivia::{SimpleTokenizer, TokenKind};
use crate::{AsFormat, FormatNodeRule, PyFormatter};
use ruff_formatter::prelude::{format_with, group, text};
use ruff_formatter::{write, Buffer, FormatResult};
use rustpython_parser::ast::ExprCall;
use ruff_text_size::{TextRange, TextSize};
use rustpython_parser::ast::{Expr, ExprCall, Ranged};
#[derive(Default)]
pub struct FormatExprCall;
@ -43,15 +45,25 @@ impl FormatNodeRule<ExprCall> for FormatExprCall {
);
}
let all_args = format_with(|f| {
f.join_comma_separated()
.entries(
// We have the parentheses from the call so the arguments never need any
args.iter()
.map(|arg| (arg, arg.format().with_options(Parenthesize::Never))),
)
.nodes(keywords.iter())
.finish()
let all_args = format_with(|f: &mut PyFormatter| {
let source = f.context().source();
let mut joiner = f.join_comma_separated(item.end());
match args.as_slice() {
[argument] if keywords.is_empty() => {
let parentheses =
if is_single_argument_parenthesized(argument, item.end(), source) {
Parenthesize::Always
} else {
Parenthesize::Never
};
joiner.entry(argument, &argument.format().with_options(parentheses));
}
arguments => {
joiner.nodes(arguments).nodes(keywords.iter());
}
}
joiner.finish()
});
write!(
@ -97,3 +109,28 @@ impl NeedsParentheses for ExprCall {
}
}
}
fn is_single_argument_parenthesized(argument: &Expr, call_end: TextSize, source: &str) -> bool {
let mut has_seen_r_paren = false;
for token in
SimpleTokenizer::new(source, TextRange::new(argument.end(), call_end)).skip_trivia()
{
match token.kind() {
TokenKind::RParen => {
if has_seen_r_paren {
return true;
}
has_seen_r_paren = true;
}
// Skip over any trailing comma
TokenKind::Comma => continue,
_ => {
// Passed the arguments
break;
}
}
}
false
}

View file

@ -77,7 +77,7 @@ impl FormatNodeRule<ExprDict> for FormatExprDict {
}
let format_pairs = format_with(|f| {
let mut joiner = f.join_comma_separated();
let mut joiner = f.join_comma_separated(item.end());
for (key, value) in keys.iter().zip(values) {
let key_value_pair = KeyValuePair { key, value };

View file

@ -6,7 +6,7 @@ use crate::expression::parentheses::{
use crate::prelude::*;
use crate::FormatNodeRule;
use ruff_formatter::{format_args, write};
use rustpython_parser::ast::ExprList;
use rustpython_parser::ast::{ExprList, Ranged};
#[derive(Default)]
pub struct FormatExprList;
@ -53,7 +53,11 @@ impl FormatNodeRule<ExprList> for FormatExprList {
"A non-empty expression list has dangling comments"
);
let items = format_with(|f| f.join_comma_separated().nodes(elts.iter()).finish());
let items = format_with(|f| {
f.join_comma_separated(item.end())
.nodes(elts.iter())
.finish()
});
parenthesized("[", &items, "]").fmt(f)
}

View file

@ -108,14 +108,14 @@ impl FormatNodeRule<ExprTuple> for FormatExprTuple {
//
// Unlike other expression parentheses, tuple parentheses are part of the range of the
// tuple itself.
elts if is_parenthesized(*range, elts, f.context().source())
_ if is_parenthesized(*range, elts, f.context().source())
&& self.parentheses != TupleParentheses::StripInsideForLoop =>
{
parenthesized("(", &ExprSequence::new(elts), ")").fmt(f)
parenthesized("(", &ExprSequence::new(item), ")").fmt(f)
}
elts => match self.parentheses {
TupleParentheses::Subscript => group(&ExprSequence::new(elts)).fmt(f),
_ => parenthesize_if_expands(&ExprSequence::new(elts)).fmt(f),
_ => match self.parentheses {
TupleParentheses::Subscript => group(&ExprSequence::new(item)).fmt(f),
_ => parenthesize_if_expands(&ExprSequence::new(item)).fmt(f),
},
}
}
@ -128,18 +128,20 @@ impl FormatNodeRule<ExprTuple> for FormatExprTuple {
#[derive(Debug)]
struct ExprSequence<'a> {
elts: &'a [Expr],
tuple: &'a ExprTuple,
}
impl<'a> ExprSequence<'a> {
const fn new(elts: &'a [Expr]) -> Self {
Self { elts }
const fn new(expr: &'a ExprTuple) -> Self {
Self { tuple: expr }
}
}
impl Format<PyFormatContext<'_>> for ExprSequence<'_> {
fn fmt(&self, f: &mut Formatter<PyFormatContext<'_>>) -> FormatResult<()> {
f.join_comma_separated().nodes(self.elts.iter()).finish()
f.join_comma_separated(self.tuple.end())
.nodes(&self.tuple.elts)
.finish()
}
}

View file

@ -76,12 +76,13 @@ impl Format<PyFormatContext<'_>> for FormatInheritanceClause<'_> {
bases,
keywords,
name,
body,
..
} = self.class_definition;
let source = f.context().source();
let mut joiner = f.join_comma_separated();
let mut joiner = f.join_comma_separated(body.first().unwrap().start());
if let Some((first, rest)) = bases.split_first() {
// Manually handle parentheses for the first expression because the logic in `FormatExpr`

View file

@ -4,7 +4,7 @@ use crate::expression::parentheses::Parenthesize;
use crate::{AsFormat, FormatNodeRule, PyFormatter};
use ruff_formatter::prelude::{block_indent, format_with, space, text};
use ruff_formatter::{write, Buffer, Format, FormatResult};
use rustpython_parser::ast::StmtDelete;
use rustpython_parser::ast::{Ranged, StmtDelete};
#[derive(Default)]
pub struct FormatStmtDelete;
@ -35,7 +35,11 @@ impl FormatNodeRule<StmtDelete> for FormatStmtDelete {
write!(f, [single.format().with_options(Parenthesize::IfBreaks)])
}
targets => {
let item = format_with(|f| f.join_comma_separated().nodes(targets.iter()).finish());
let item = format_with(|f| {
f.join_comma_separated(item.end())
.nodes(targets.iter())
.finish()
});
parenthesize_if_expands(&item).fmt(f)
}
}

View file

@ -2,7 +2,7 @@ use crate::builders::{parenthesize_if_expands, PyFormatterExtensions};
use crate::{AsFormat, FormatNodeRule, PyFormatter};
use ruff_formatter::prelude::{dynamic_text, format_with, space, text};
use ruff_formatter::{write, Buffer, Format, FormatResult};
use rustpython_parser::ast::StmtImportFrom;
use rustpython_parser::ast::{Ranged, StmtImportFrom};
#[derive(Default)]
pub struct FormatStmtImportFrom;
@ -39,7 +39,7 @@ impl FormatNodeRule<StmtImportFrom> for FormatStmtImportFrom {
}
}
let names = format_with(|f| {
f.join_comma_separated()
f.join_comma_separated(item.end())
.entries(names.iter().map(|name| (name, name.format())))
.finish()
});

View file

@ -68,8 +68,11 @@ impl Format<PyFormatContext<'_>> for AnyStatementWith<'_> {
let comments = f.context().comments().clone();
let dangling_comments = comments.dangling_comments(self);
let joined_items =
format_with(|f| f.join_comma_separated().nodes(self.items().iter()).finish());
let joined_items = format_with(|f| {
f.join_comma_separated(self.body().first().unwrap().start())
.nodes(self.items().iter())
.finish()
});
if self.is_async() {
write!(f, [text("async"), space()])?;

View file

@ -315,30 +315,7 @@ long_unmergable_string_with_pragma = (
bad_split_func1(
"But what should happen when code has already "
@@ -96,15 +96,13 @@
)
bad_split_func3(
- (
- "But what should happen when code has already "
- r"been formatted but in the wrong way? Like "
- "with a space at the end instead of the "
- r"beginning. Or what about when it is split too "
- r"soon? In the case of a split that is too "
- "short, black will try to honer the custom "
- "split."
- ),
+ "But what should happen when code has already "
+ r"been formatted but in the wrong way? Like "
+ "with a space at the end instead of the "
+ r"beginning. Or what about when it is split too "
+ r"soon? In the case of a split that is too "
+ "short, black will try to honer the custom "
+ "split.",
xxx,
yyy,
zzz,
@@ -143,9 +141,9 @@
@@ -143,9 +143,9 @@
)
)
@ -350,7 +327,7 @@ long_unmergable_string_with_pragma = (
comment_string = "Long lines with inline comments should have their comments appended to the reformatted string's enclosing right parentheses." # This comment gets thrown to the top.
@@ -165,25 +163,13 @@
@@ -165,25 +165,13 @@
triple_quote_string = """This is a really really really long triple quote string assignment and it should not be touched."""
@ -380,53 +357,18 @@ long_unmergable_string_with_pragma = (
some_function_call(
"With a reallly generic name and with a really really long string that is, at some point down the line, "
@@ -212,29 +198,25 @@
)
@@ -221,8 +209,8 @@
func_with_bad_comma(
- (
- "This is a really long string argument to a function that has a trailing comma"
- " which should NOT be there."
- ),
+ "This is a really long string argument to a function that has a trailing comma"
+ " which should NOT be there."
)
func_with_bad_comma(
- (
- "This is a really long string argument to a function that has a trailing comma"
(
"This is a really long string argument to a function that has a trailing comma"
- " which should NOT be there."
- ), # comment after comma
+ "This is a really long string argument to a function that has a trailing comma"
+ " which should NOT be there." # comment after comma
+ " which should NOT be there." # comment after comma
+ ),
)
func_with_bad_parens_that_wont_fit_in_one_line(
- ("short string that should have parens stripped"), x, y, z
+ "short string that should have parens stripped", x, y, z
)
func_with_bad_parens_that_wont_fit_in_one_line(
- x, y, ("short string that should have parens stripped"), z
+ x, y, "short string that should have parens stripped", z
)
func_with_bad_parens(
- ("short string that should have parens stripped"),
+ "short string that should have parens stripped",
x,
y,
z,
@@ -243,7 +225,7 @@
func_with_bad_parens(
x,
y,
- ("short string that should have parens stripped"),
+ "short string that should have parens stripped",
z,
)
@@ -271,10 +253,10 @@
@@ -271,10 +259,10 @@
def foo():
@ -542,13 +484,15 @@ bad_split_func2(
)
bad_split_func3(
"But what should happen when code has already "
r"been formatted but in the wrong way? Like "
"with a space at the end instead of the "
r"beginning. Or what about when it is split too "
r"soon? In the case of a split that is too "
"short, black will try to honer the custom "
"split.",
(
"But what should happen when code has already "
r"been formatted but in the wrong way? Like "
"with a space at the end instead of the "
r"beginning. Or what about when it is split too "
r"soon? In the case of a split that is too "
"short, black will try to honer the custom "
"split."
),
xxx,
yyy,
zzz,
@ -644,25 +588,29 @@ func_with_bad_comma(
)
func_with_bad_comma(
"This is a really long string argument to a function that has a trailing comma"
" which should NOT be there."
(
"This is a really long string argument to a function that has a trailing comma"
" which should NOT be there."
),
)
func_with_bad_comma(
"This is a really long string argument to a function that has a trailing comma"
" which should NOT be there." # comment after comma
(
"This is a really long string argument to a function that has a trailing comma"
" which should NOT be there." # comment after comma
),
)
func_with_bad_parens_that_wont_fit_in_one_line(
"short string that should have parens stripped", x, y, z
("short string that should have parens stripped"), x, y, z
)
func_with_bad_parens_that_wont_fit_in_one_line(
x, y, "short string that should have parens stripped", z
x, y, ("short string that should have parens stripped"), z
)
func_with_bad_parens(
"short string that should have parens stripped",
("short string that should have parens stripped"),
x,
y,
z,
@ -671,7 +619,7 @@ func_with_bad_parens(
func_with_bad_parens(
x,
y,
"short string that should have parens stripped",
("short string that should have parens stripped"),
z,
)

View file

@ -83,7 +83,7 @@ while x := f(x):
x = (y := 0)
(z := (y := (x := 0)))
(info := (name, phone, *rest))
@@ -31,17 +31,17 @@
@@ -31,9 +31,9 @@
len(lines := f.readlines())
foo(x := 3, cat="vector")
foo(cat=(category := "vector"))
@ -95,9 +95,8 @@ while x := f(x):
return env_base
if self._is_special and (ans := self._check_nans(context=context)):
return ans
foo(b := 2, a=1)
-foo((b := 2), a=1)
+foo(b := 2, a=1)
@@ -41,7 +41,7 @@
foo((b := 2), a=1)
foo(c=(b := 2), a=1)
-while x := f(x):
@ -151,7 +150,7 @@ if (env_base := os.environ.get("PYTHONUSERBASE", None)):
if self._is_special and (ans := self._check_nans(context=context)):
return ans
foo(b := 2, a=1)
foo(b := 2, a=1)
foo((b := 2), a=1)
foo(c=(b := 2), a=1)
while (x := f(x)):