Preserve comments on non-defaulted arguments (#3264)

This commit is contained in:
Charlie Marsh 2023-02-27 18:41:40 -05:00 committed by GitHub
parent 16be691712
commit 470e1c1754
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 87 additions and 189 deletions

View file

@ -1,6 +1,6 @@
use crate::core::visitor;
use crate::core::visitor::Visitor;
use crate::cst::{Alias, Body, Excepthandler, Expr, Pattern, SliceIndex, Stmt};
use crate::cst::{Alias, Arg, Body, Excepthandler, Expr, Pattern, SliceIndex, Stmt};
use crate::trivia::{decorate_trivia, TriviaIndex, TriviaToken};
struct AttachmentVisitor {
@ -40,6 +40,14 @@ impl<'a> Visitor<'a> for AttachmentVisitor {
visitor::walk_alias(self, alias);
}
fn visit_arg(&mut self, arg: &'a mut Arg) {
let trivia = self.index.arg.remove(&arg.id());
if let Some(comments) = trivia {
arg.trivia.extend(comments);
}
visitor::walk_arg(self, arg);
}
fn visit_excepthandler(&mut self, excepthandler: &'a mut Excepthandler) {
let trivia = self.index.excepthandler.remove(&excepthandler.id());
if let Some(comments) = trivia {

View file

@ -4,6 +4,7 @@ use ruff_text_size::TextSize;
use crate::context::ASTFormatContext;
use crate::cst::Arg;
use crate::format::comments::end_of_line_comments;
use crate::shared_traits::AsFormat;
pub struct FormatArg<'a> {
@ -27,6 +28,7 @@ impl Format<ASTFormatContext<'_>> for FormatArg<'_> {
write!(f, [text(": ")])?;
write!(f, [annotation.format()])?;
}
write!(f, [end_of_line_comments(arg)])?;
Ok(())
}

View file

@ -71,9 +71,8 @@ fn format_subscript(
Ok(())
}))])]
)?;
write!(f, [text("]")])?;
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -286,6 +285,7 @@ fn format_list(
)?;
}
write!(f, [text("]")])?;
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -613,6 +613,7 @@ fn format_joined_str(
_values: &[Expr],
) -> FormatResult<()> {
write!(f, [literal(Range::from_located(expr))])?;
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -639,6 +640,7 @@ fn format_constant(
Constant::Complex { .. } => write!(f, [complex_literal(Range::from_located(expr))])?,
Constant::Tuple(_) => unreachable!("Constant::Tuple should be handled by format_tuple"),
}
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -713,9 +715,7 @@ fn format_attribute(
write!(f, [value.format()])?;
write!(f, [text(".")])?;
write!(f, [dynamic_text(attr, TextSize::default())])?;
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -729,9 +729,7 @@ fn format_named_expr(
write!(f, [text(":=")])?;
write!(f, [space()])?;
write!(f, [group(&format_args![value.format()])])?;
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -752,9 +750,7 @@ fn format_bool_op(
write!(f, [group(&format_args![value.format()])])?;
}
}
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -767,7 +763,6 @@ fn format_bin_op(
) -> FormatResult<()> {
// https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#line-breaks-binary-operators
let is_simple = matches!(op, Operator::Pow) && is_simple_power(left) && is_simple_power(right);
write!(f, [left.format()])?;
if !is_simple {
write!(f, [soft_line_break_or_space()])?;
@ -776,10 +771,8 @@ fn format_bin_op(
if !is_simple {
write!(f, [space()])?;
}
write!(f, [group(&format_args![right.format()])])?;
write!(f, [group(&right.format())])?;
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -808,6 +801,7 @@ fn format_unary_op(
} else {
write!(f, [operand.format()])?;
}
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}
@ -825,6 +819,7 @@ fn format_lambda(
write!(f, [text(":")])?;
write!(f, [space()])?;
write!(f, [body.format()])?;
write!(f, [end_of_line_comments(expr)])?;
Ok(())
}

View file

@ -178,43 +178,9 @@ instruction()#comment with bad spacing
```diff
--- Black
+++ Ruff
@@ -13,7 +13,7 @@
"Callable",
"ClassVar",
# ABCs (from collections.abc).
- "AbstractSet", # collections.abc.Set.
+ "AbstractSet",
"ByteString",
"Container",
# Concrete collection types.
@@ -24,7 +24,7 @@
"List",
"Set",
"FrozenSet",
- "NamedTuple", # Not really a type.
+ "NamedTuple",
"Generator",
]
@@ -60,26 +60,32 @@
# Comment before function.
def inline_comments_in_brackets_ruin_everything():
if typedargslist:
- parameters.children = [children[0], body, children[-1]] # (1 # )1
+ parameters.children = [children[0], body, children[-1]]
parameters.children = [
children[0],
@@ -72,14 +72,20 @@
body,
- children[-1], # type: ignore
+ children[-1],
]
else:
parameters.children = [
- parameters.children[0], # (2 what if this was actually long
+ parameters.children[0],
body,
- parameters.children[-1], # )2
+ parameters.children[-1],
parameters.children[-1], # )2
]
- parameters.children = [parameters.what_if_this_was_actually_long.children[0], body, parameters.children[-1]] # type: ignore
+ parameters.children = [
@ -319,7 +285,7 @@ __all__ = [
"Callable",
"ClassVar",
# ABCs (from collections.abc).
"AbstractSet",
"AbstractSet", # collections.abc.Set.
"ByteString",
"Container",
# Concrete collection types.
@ -330,7 +296,7 @@ __all__ = [
"List",
"Set",
"FrozenSet",
"NamedTuple",
"NamedTuple", # Not really a type.
"Generator",
]
@ -366,17 +332,17 @@ else:
# Comment before function.
def inline_comments_in_brackets_ruin_everything():
if typedargslist:
parameters.children = [children[0], body, children[-1]]
parameters.children = [children[0], body, children[-1]] # (1 # )1
parameters.children = [
children[0],
body,
children[-1],
children[-1], # type: ignore
]
else:
parameters.children = [
parameters.children[0],
parameters.children[0], # (2 what if this was actually long
body,
parameters.children[-1],
parameters.children[-1], # )2
]
parameters.children = [
parameters.what_if_this_was_actually_long.children[0],

View file

@ -131,68 +131,18 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
```diff
--- Black
+++ Ruff
@@ -2,7 +2,7 @@
@@ -31,8 +31,8 @@
def f(
- a, # type: int
+ a,
):
pass
@@ -14,44 +14,42 @@
def f(
- a, # type: int
- b, # type: int
- c, # type: int
- d, # type: int
- e, # type: int
- f, # type: int
- g, # type: int
- h, # type: int
- i, # type: int
+ a,
+ b,
+ c,
+ d,
+ e,
+ f,
+ g,
+ h,
+ i,
):
# type: (...) -> None
pass
def f(
- arg, # type: int
- *args, # type: *Any
arg, # type: int
*args, # type: *Any
- default=False, # type: bool
- **kwargs, # type: **Any
+ arg,
+ *args,
+ default=False,
+ default=False, # type: bool # type: **Any
+ **kwargs,
):
# type: (...) -> None
pass
def f(
- a, # type: int
- b, # type: int
- c, # type: int
- d, # type: int
+ a,
+ b,
+ c,
+ d,
):
# type: (...) -> None
@@ -49,9 +49,7 @@
element = 0 # type: int
another_element = 1 # type: float
another_element_with_long_name = 2 # type: int
@ -203,38 +153,7 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
an_element_with_a_long_value = calls() or more_calls() and more() # type: bool
tup = (
@@ -70,21 +68,21 @@
def f(
- x, # not a type comment
- y, # type: int
+ x,
+ y,
):
# type: (...) -> None
pass
def f(
- x, # not a type comment
+ x,
): # type: (int) -> None
pass
def func(
- a=some_list[0], # type: int
+ a=some_list[0],
): # type: () -> int
c = call(
0.0123,
@@ -96,23 +94,37 @@
0.0123,
0.0456,
0.0789,
- a[-1], # type: ignore
+ a[-1],
@@ -100,19 +98,33 @@
)
c = call(
@ -245,7 +164,7 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
+ "aaaaaaaa",
+ "aaaaaaaa",
+ "aaaaaaaa",
+ "aaaaaaaa",
+ "aaaaaaaa", # type: ignore
)
@ -270,7 +189,7 @@ aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*ite
+ AAAAAAAAAAAAAAAAAAAAAAA,
+ AAAAAAAAAAAAAAAAAAAAAAA,
+ BBBBBBBBBBBB,
+ ],
+ ], # type: ignore
)
aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*items))) # type: ignore[arg-type]
@ -283,7 +202,7 @@ from typing import Any, Tuple
def f(
a,
a, # type: int
):
pass
@ -295,24 +214,24 @@ def f(a, b, c, d, e, f, g, h, i):
def f(
a,
b,
c,
d,
e,
f,
g,
h,
i,
a, # type: int
b, # type: int
c, # type: int
d, # type: int
e, # type: int
f, # type: int
g, # type: int
h, # type: int
i, # type: int
):
# type: (...) -> None
pass
def f(
arg,
*args,
default=False,
arg, # type: int
*args, # type: *Any
default=False, # type: bool # type: **Any
**kwargs,
):
# type: (...) -> None
@ -320,10 +239,10 @@ def f(
def f(
a,
b,
c,
d,
a, # type: int
b, # type: int
c, # type: int
d, # type: int
):
# type: (...) -> None
@ -349,21 +268,21 @@ def f(
def f(
x,
y,
x, # not a type comment
y, # type: int
):
# type: (...) -> None
pass
def f(
x,
x, # not a type comment
): # type: (int) -> None
pass
def func(
a=some_list[0],
a=some_list[0], # type: int
): # type: () -> int
c = call(
0.0123,
@ -375,7 +294,7 @@ def func(
0.0123,
0.0456,
0.0789,
a[-1],
a[-1], # type: ignore
)
c = call(
@ -385,7 +304,7 @@ def func(
"aaaaaaaa",
"aaaaaaaa",
"aaaaaaaa",
"aaaaaaaa",
"aaaaaaaa", # type: ignore
)
@ -405,7 +324,7 @@ call_to_some_function_asdf(
AAAAAAAAAAAAAAAAAAAAAAA,
AAAAAAAAAAAAAAAAAAAAAAA,
BBBBBBBBBBBB,
],
], # type: ignore
)
aaaaaaaaaaaaa, bbbbbbbbb = map(list, map(itertools.chain.from_iterable, zip(*items))) # type: ignore[arg-type]

View file

@ -106,7 +106,7 @@ elif unformatted:
- # fmt: on
- ] # Includes an formatted indentation.
+ "foo-bar" "=foo.bar.:main",
+ ],
+ ], # Includes an formatted indentation.
},
)
@ -200,7 +200,7 @@ setup(
# fmt: off
"console_scripts": [
"foo-bar" "=foo.bar.:main",
],
], # Includes an formatted indentation.
},
)

View file

@ -136,9 +136,8 @@ def foo() -> tuple[int, int, int,]:
# Don't lose the comments
-def double(a: int) -> int: # Hello
def double(a: int) -> int: # Hello
- return 2 * a
+def double(a: int) -> int:
+ return 2
+ * a
@ -276,7 +275,7 @@ def double(a: int) -> int:
# Don't lose the comments
def double(a: int) -> int:
def double(a: int) -> int: # Hello
return 2
* a

View file

@ -5,7 +5,7 @@ use rustpython_parser::Tok;
use crate::core::types::Range;
use crate::cst::{
Alias, Body, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Pattern, PatternKind,
Alias, Arg, Body, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Pattern, PatternKind,
SliceIndex, SliceIndexKind, Stmt, StmtKind,
};
@ -16,6 +16,7 @@ pub enum Node<'a> {
Stmt(&'a Stmt),
Expr(&'a Expr),
Alias(&'a Alias),
Arg(&'a Arg),
Excepthandler(&'a Excepthandler),
SliceIndex(&'a SliceIndex),
Pattern(&'a Pattern),
@ -29,6 +30,7 @@ impl Node<'_> {
Node::Stmt(node) => node.id(),
Node::Expr(node) => node.id(),
Node::Alias(node) => node.id(),
Node::Arg(node) => node.id(),
Node::Excepthandler(node) => node.id(),
Node::SliceIndex(node) => node.id(),
Node::Pattern(node) => node.id(),
@ -268,29 +270,19 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
result.push(Node::Expr(decorator));
}
for arg in &args.posonlyargs {
if let Some(expr) = &arg.node.annotation {
result.push(Node::Expr(expr));
}
result.push(Node::Arg(arg));
}
for arg in &args.args {
if let Some(expr) = &arg.node.annotation {
result.push(Node::Expr(expr));
}
result.push(Node::Arg(arg));
}
if let Some(arg) = &args.vararg {
if let Some(expr) = &arg.node.annotation {
result.push(Node::Expr(expr));
}
result.push(Node::Arg(arg));
}
for arg in &args.kwonlyargs {
if let Some(expr) = &arg.node.annotation {
result.push(Node::Expr(expr));
}
result.push(Node::Arg(arg));
}
if let Some(arg) = &args.kwarg {
if let Some(expr) = &arg.node.annotation {
result.push(Node::Expr(expr));
}
result.push(Node::Arg(arg));
}
for expr in &args.defaults {
result.push(Node::Expr(expr));
@ -450,6 +442,11 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec<Node<'a>>) {
StmtKind::Global { .. } => {}
StmtKind::Nonlocal { .. } => {}
},
Node::Arg(arg) => {
if let Some(annotation) = &arg.node.annotation {
result.push(Node::Expr(annotation));
}
}
Node::Expr(expr) => match &expr.node {
ExprKind::BoolOp { values, .. } => {
for value in values {
@ -710,6 +707,7 @@ pub fn decorate_token<'a>(
Node::Stmt(node) => node.location,
Node::Expr(node) => node.location,
Node::Alias(node) => node.location,
Node::Arg(node) => node.location,
Node::Excepthandler(node) => node.location,
Node::SliceIndex(node) => node.location,
Node::Pattern(node) => node.location,
@ -720,6 +718,7 @@ pub fn decorate_token<'a>(
Node::Stmt(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(),
Node::Alias(node) => node.end_location.unwrap(),
Node::Arg(node) => node.end_location.unwrap(),
Node::Excepthandler(node) => node.end_location.unwrap(),
Node::SliceIndex(node) => node.end_location.unwrap(),
Node::Pattern(node) => node.end_location.unwrap(),
@ -734,6 +733,7 @@ pub fn decorate_token<'a>(
Node::Stmt(node) => node.location,
Node::Expr(node) => node.location,
Node::Alias(node) => node.location,
Node::Arg(node) => node.location,
Node::Excepthandler(node) => node.location,
Node::SliceIndex(node) => node.location,
Node::Pattern(node) => node.location,
@ -744,6 +744,7 @@ pub fn decorate_token<'a>(
Node::Stmt(node) => node.end_location.unwrap(),
Node::Expr(node) => node.end_location.unwrap(),
Node::Alias(node) => node.end_location.unwrap(),
Node::Arg(node) => node.end_location.unwrap(),
Node::Excepthandler(node) => node.end_location.unwrap(),
Node::SliceIndex(node) => node.end_location.unwrap(),
Node::Pattern(node) => node.end_location.unwrap(),
@ -806,6 +807,7 @@ pub struct TriviaIndex {
pub stmt: FxHashMap<usize, Vec<Trivia>>,
pub expr: FxHashMap<usize, Vec<Trivia>>,
pub alias: FxHashMap<usize, Vec<Trivia>>,
pub arg: FxHashMap<usize, Vec<Trivia>>,
pub excepthandler: FxHashMap<usize, Vec<Trivia>>,
pub slice_index: FxHashMap<usize, Vec<Trivia>>,
pub pattern: FxHashMap<usize, Vec<Trivia>>,
@ -842,6 +844,13 @@ fn add_comment(comment: Trivia, node: &Node, trivia: &mut TriviaIndex) {
.or_insert_with(Vec::new)
.push(comment);
}
Node::Arg(node) => {
trivia
.arg
.entry(node.id())
.or_insert_with(Vec::new)
.push(comment);
}
Node::Excepthandler(node) => {
trivia
.excepthandler