Properly group assignment targets (#5728)

This commit is contained in:
Micha Reiser 2023-07-13 16:00:49 +02:00 committed by GitHub
parent f48ab2d621
commit 5dd5ee0c5b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 113 additions and 15 deletions

View file

@ -8,3 +8,11 @@ a2 = (
# Break the last element
a = asdf = fjhalsdljfalflaflapamsakjsdhflakjdslfjhalsdljfalflaflapamsakjsdhflakjdslfjhalsdljfal = 1
aa = [
bakjdshflkjahdslkfjlasfdahjlfds
] = dddd = ddd = fkjaödkjaföjfahlfdalfhaöfaöfhaöfha = g = [3]
aaaa = ( # trailing
# comment
bbbbb) = cccccccccccccccc = 3

View file

@ -443,6 +443,10 @@ impl<'input> PreorderVisitor<'input> for CanOmitOptionalParenthesesVisitor<'inpu
}
fn has_parentheses(expr: &Expr, source: &str) -> bool {
has_own_parentheses(expr) || is_expression_parenthesized(AnyNodeRef::from(expr), source)
}
pub(crate) const fn has_own_parentheses(expr: &Expr) -> bool {
matches!(
expr,
Expr::Dict(_)
@ -454,7 +458,7 @@ fn has_parentheses(expr: &Expr, source: &str) -> bool {
| Expr::DictComp(_)
| Expr::Call(_)
| Expr::Subscript(_)
) || is_expression_parenthesized(AnyNodeRef::from(expr), source)
)
}
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]

View file

@ -1,16 +1,13 @@
use rustpython_parser::ast::StmtAssign;
use rustpython_parser::ast::{Expr, StmtAssign};
use ruff_formatter::write;
use ruff_formatter::{format_args, write, FormatError};
use crate::expression::maybe_parenthesize_expression;
use crate::expression::parentheses::Parenthesize;
use crate::context::NodeLevel;
use crate::expression::parentheses::{Parentheses, Parenthesize};
use crate::expression::{has_own_parentheses, maybe_parenthesize_expression};
use crate::prelude::*;
use crate::FormatNodeRule;
// Note: This currently does wrap but not the black way so the types below likely need to be
// replaced entirely
//
#[derive(Default)]
pub struct FormatStmtAssign;
@ -23,9 +20,18 @@ impl FormatNodeRule<StmtAssign> for FormatStmtAssign {
type_comment: _,
} = item;
for target in targets {
write!(f, [target.format(), space(), text("="), space()])?;
}
let (first, rest) = targets.split_first().ok_or(FormatError::SyntaxError)?;
write!(
f,
[
first.format(),
space(),
text("="),
space(),
FormatTargets { targets: rest }
]
)?;
write!(
f,
@ -37,3 +43,58 @@ impl FormatNodeRule<StmtAssign> for FormatStmtAssign {
)
}
}
struct FormatTargets<'a> {
targets: &'a [Expr],
}
impl Format<PyFormatContext<'_>> for FormatTargets<'_> {
fn fmt(&self, f: &mut Formatter<PyFormatContext<'_>>) -> FormatResult<()> {
if let Some((first, rest)) = self.targets.split_first() {
let can_omit_parentheses = has_own_parentheses(first);
let group_id = if can_omit_parentheses {
Some(f.group_id("assignment_parentheses"))
} else {
None
};
let saved_level = f.context().node_level();
f.context_mut()
.set_node_level(NodeLevel::Expression(group_id));
let format_first = format_with(|f: &mut PyFormatter| {
let result = if can_omit_parentheses {
first.format().with_options(Parentheses::Never).fmt(f)
} else {
write!(
f,
[
if_group_breaks(&text("(")),
soft_block_indent(&first.format().with_options(Parentheses::Never)),
if_group_breaks(&text(")"))
]
)
};
f.context_mut().set_node_level(saved_level);
result
});
write!(
f,
[group(&format_args![
format_first,
space(),
text("="),
space(),
FormatTargets { targets: rest }
])
.with_group_id(group_id)]
)
} else {
Ok(())
}
}
}

View file

@ -14,18 +14,43 @@ a2 = (
# Break the last element
a = asdf = fjhalsdljfalflaflapamsakjsdhflakjdslfjhalsdljfalflaflapamsakjsdhflakjdslfjhalsdljfal = 1
aa = [
bakjdshflkjahdslkfjlasfdahjlfds
] = dddd = ddd = fkjaödkjaföjfahlfdalfhaöfaöfhaöfha = g = [3]
aaaa = ( # trailing
# comment
bbbbb) = cccccccccccccccc = 3
```
## Output
```py
# break left hand side
a1akjdshflkjahdslkfjlasfdahjlfds = bakjdshflkjahdslkfjlasfdahjlfds = cakjdshflkjahdslkfjlasfdahjlfds = kjaödkjaföjfahlfdalfhaöfaöfhaöfha = fkjaödkjaföjfahlfdalfhaöfaöfhaöfha = g = 3
a1akjdshflkjahdslkfjlasfdahjlfds = (
bakjdshflkjahdslkfjlasfdahjlfds
) = (
cakjdshflkjahdslkfjlasfdahjlfds
) = kjaödkjaföjfahlfdalfhaöfaöfhaöfha = fkjaödkjaföjfahlfdalfhaöfaöfhaöfha = g = 3
# join left hand side
a2 = (b2) = 2
a2 = b2 = 2
# Break the last element
a = asdf = fjhalsdljfalflaflapamsakjsdhflakjdslfjhalsdljfalflaflapamsakjsdhflakjdslfjhalsdljfal = 1
a = (
asdf
) = (
fjhalsdljfalflaflapamsakjsdhflakjdslfjhalsdljfalflaflapamsakjsdhflakjdslfjhalsdljfal
) = 1
aa = [
bakjdshflkjahdslkfjlasfdahjlfds
] = dddd = ddd = fkjaödkjaföjfahlfdalfhaöfaöfhaöfha = g = [3]
aaaa = ( # trailing
# comment
bbbbb
) = cccccccccccccccc = 3
```