Format AsyncFor (#5808)

This commit is contained in:
Luc Khai Hai 2023-07-17 17:38:59 +09:00 committed by GitHub
parent f5f8eb31ed
commit fb336898a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 110 additions and 146 deletions

View file

@ -1,5 +1,7 @@
use crate::{not_yet_implemented, FormatNodeRule, PyFormatter}; use crate::prelude::*;
use ruff_formatter::{write, Buffer, FormatResult}; use crate::statement::stmt_for::AnyStatementFor;
use crate::{FormatNodeRule, PyFormatter};
use ruff_formatter::FormatResult;
use rustpython_parser::ast::StmtAsyncFor; use rustpython_parser::ast::StmtAsyncFor;
#[derive(Default)] #[derive(Default)]
@ -7,6 +9,15 @@ pub struct FormatStmtAsyncFor;
impl FormatNodeRule<StmtAsyncFor> for FormatStmtAsyncFor { impl FormatNodeRule<StmtAsyncFor> for FormatStmtAsyncFor {
fn fmt_fields(&self, item: &StmtAsyncFor, f: &mut PyFormatter) -> FormatResult<()> { fn fmt_fields(&self, item: &StmtAsyncFor, f: &mut PyFormatter) -> FormatResult<()> {
write!(f, [not_yet_implemented(item)]) AnyStatementFor::from(item).fmt(f)
}
fn fmt_dangling_comments(
&self,
_node: &StmtAsyncFor,
_f: &mut PyFormatter,
) -> FormatResult<()> {
// Handled in `fmt_fields`
Ok(())
} }
} }

View file

@ -4,9 +4,10 @@ use crate::expression::maybe_parenthesize_expression;
use crate::expression::parentheses::Parenthesize; use crate::expression::parentheses::Parenthesize;
use crate::prelude::*; use crate::prelude::*;
use crate::{FormatNodeRule, PyFormatter}; use crate::{FormatNodeRule, PyFormatter};
use ruff_formatter::{write, Buffer, FormatResult}; use ruff_formatter::{format_args, write, Buffer, FormatResult};
use ruff_python_ast::node::AstNode; use ruff_python_ast::node::AnyNodeRef;
use rustpython_parser::ast::{Expr, Ranged, Stmt, StmtFor}; use ruff_text_size::TextRange;
use rustpython_parser::ast::{Expr, Ranged, Stmt, StmtAsyncFor, StmtFor, Suite};
#[derive(Debug)] #[derive(Debug)]
struct ExprTupleWithoutParentheses<'a>(&'a Expr); struct ExprTupleWithoutParentheses<'a>(&'a Expr);
@ -26,19 +27,85 @@ impl Format<PyFormatContext<'_>> for ExprTupleWithoutParentheses<'_> {
#[derive(Default)] #[derive(Default)]
pub struct FormatStmtFor; pub struct FormatStmtFor;
impl FormatNodeRule<StmtFor> for FormatStmtFor { pub(super) enum AnyStatementFor<'a> {
fn fmt_fields(&self, item: &StmtFor, f: &mut PyFormatter) -> FormatResult<()> { For(&'a StmtFor),
let StmtFor { AsyncFor(&'a StmtAsyncFor),
range: _, }
target,
iter, impl<'a> AnyStatementFor<'a> {
body, const fn is_async(&self) -> bool {
orelse, matches!(self, AnyStatementFor::AsyncFor(_))
type_comment: _, }
} = item;
fn target(&self) -> &Expr {
match self {
AnyStatementFor::For(stmt) => &stmt.target,
AnyStatementFor::AsyncFor(stmt) => &stmt.target,
}
}
#[allow(clippy::iter_not_returning_iterator)]
fn iter(&self) -> &Expr {
match self {
AnyStatementFor::For(stmt) => &stmt.iter,
AnyStatementFor::AsyncFor(stmt) => &stmt.iter,
}
}
fn body(&self) -> &Suite {
match self {
AnyStatementFor::For(stmt) => &stmt.body,
AnyStatementFor::AsyncFor(stmt) => &stmt.body,
}
}
fn orelse(&self) -> &Suite {
match self {
AnyStatementFor::For(stmt) => &stmt.orelse,
AnyStatementFor::AsyncFor(stmt) => &stmt.orelse,
}
}
}
impl Ranged for AnyStatementFor<'_> {
fn range(&self) -> TextRange {
match self {
AnyStatementFor::For(stmt) => stmt.range(),
AnyStatementFor::AsyncFor(stmt) => stmt.range(),
}
}
}
impl<'a> From<&'a StmtFor> for AnyStatementFor<'a> {
fn from(value: &'a StmtFor) -> Self {
AnyStatementFor::For(value)
}
}
impl<'a> From<&'a StmtAsyncFor> for AnyStatementFor<'a> {
fn from(value: &'a StmtAsyncFor) -> Self {
AnyStatementFor::AsyncFor(value)
}
}
impl<'a> From<&AnyStatementFor<'a>> for AnyNodeRef<'a> {
fn from(value: &AnyStatementFor<'a>) -> Self {
match value {
AnyStatementFor::For(stmt) => AnyNodeRef::StmtFor(stmt),
AnyStatementFor::AsyncFor(stmt) => AnyNodeRef::StmtAsyncFor(stmt),
}
}
}
impl Format<PyFormatContext<'_>> for AnyStatementFor<'_> {
fn fmt(&self, f: &mut Formatter<PyFormatContext<'_>>) -> FormatResult<()> {
let target = self.target();
let iter = self.iter();
let body = self.body();
let orelse = self.orelse();
let comments = f.context().comments().clone(); let comments = f.context().comments().clone();
let dangling_comments = comments.dangling_comments(item.as_any_node_ref()); let dangling_comments = comments.dangling_comments(self);
let body_start = body.first().map_or(iter.end(), Stmt::start); let body_start = body.first().map_or(iter.end(), Stmt::start);
let or_else_comments_start = let or_else_comments_start =
dangling_comments.partition_point(|comment| comment.slice().end() < body_start); dangling_comments.partition_point(|comment| comment.slice().end() < body_start);
@ -49,13 +116,15 @@ impl FormatNodeRule<StmtFor> for FormatStmtFor {
write!( write!(
f, f,
[ [
self.is_async()
.then_some(format_args![text("async"), space()]),
text("for"), text("for"),
space(), space(),
ExprTupleWithoutParentheses(target.as_ref()), ExprTupleWithoutParentheses(target),
space(), space(),
text("in"), text("in"),
space(), space(),
maybe_parenthesize_expression(iter, item, Parenthesize::IfBreaks), maybe_parenthesize_expression(iter, self, Parenthesize::IfBreaks),
text(":"), text(":"),
trailing_comments(trailing_condition_comments), trailing_comments(trailing_condition_comments),
block_indent(&body.format()) block_indent(&body.format())
@ -84,6 +153,12 @@ impl FormatNodeRule<StmtFor> for FormatStmtFor {
Ok(()) Ok(())
} }
}
impl FormatNodeRule<StmtFor> for FormatStmtFor {
fn fmt_fields(&self, item: &StmtFor, f: &mut PyFormatter) -> FormatResult<()> {
AnyStatementFor::from(item).fmt(f)
}
fn fmt_dangling_comments(&self, _node: &StmtFor, _f: &mut PyFormatter) -> FormatResult<()> { fn fmt_dangling_comments(&self, _node: &StmtFor, _f: &mut PyFormatter) -> FormatResult<()> {
// Handled in `fmt_fields` // Handled in `fmt_fields`

View file

@ -1,123 +0,0 @@
---
source: crates/ruff_python_formatter/tests/fixtures.rs
input_file: crates/ruff_python_formatter/resources/test/fixtures/black/py_310/starred_for_target.py
---
## Input
```py
for x in *a, *b:
print(x)
for x in a, b, *c:
print(x)
for x in *a, b, c:
print(x)
for x in *a, b, *c:
print(x)
async for x in *a, *b:
print(x)
async for x in *a, b, *c:
print(x)
async for x in a, b, *c:
print(x)
async for x in (
*loooooooooooooooooooooong,
very,
*loooooooooooooooooooooooooooooooooooooooooooooooong,
):
print(x)
```
## Black Differences
```diff
--- Black
+++ Ruff
@@ -10,18 +10,10 @@
for x in *a, b, *c:
print(x)
-async for x in *a, *b:
- print(x)
+NOT_YET_IMPLEMENTED_StmtAsyncFor
-async for x in *a, b, *c:
- print(x)
+NOT_YET_IMPLEMENTED_StmtAsyncFor
-async for x in a, b, *c:
- print(x)
+NOT_YET_IMPLEMENTED_StmtAsyncFor
-async for x in (
- *loooooooooooooooooooooong,
- very,
- *loooooooooooooooooooooooooooooooooooooooooooooooong,
-):
- print(x)
+NOT_YET_IMPLEMENTED_StmtAsyncFor
```
## Ruff Output
```py
for x in *a, *b:
print(x)
for x in a, b, *c:
print(x)
for x in *a, b, c:
print(x)
for x in *a, b, *c:
print(x)
NOT_YET_IMPLEMENTED_StmtAsyncFor
NOT_YET_IMPLEMENTED_StmtAsyncFor
NOT_YET_IMPLEMENTED_StmtAsyncFor
NOT_YET_IMPLEMENTED_StmtAsyncFor
```
## Black Output
```py
for x in *a, *b:
print(x)
for x in a, b, *c:
print(x)
for x in *a, b, c:
print(x)
for x in *a, b, *c:
print(x)
async for x in *a, *b:
print(x)
async for x in *a, b, *c:
print(x)
async for x in a, b, *c:
print(x)
async for x in (
*loooooooooooooooooooooong,
very,
*loooooooooooooooooooooooooooooooooooooooooooooooong,
):
print(x)
```

View file

@ -74,7 +74,7 @@ async def test_async_with():
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,62 +1,62 @@ @@ -1,62 +1,63 @@
# Make sure a leading comment is not removed. # Make sure a leading comment is not removed.
-def some_func( unformatted, args ): # fmt: skip -def some_func( unformatted, args ): # fmt: skip
+def some_func(unformatted, args): # fmt: skip +def some_func(unformatted, args): # fmt: skip
@ -129,8 +129,8 @@ async def test_async_with():
async def test_async_for(): async def test_async_for():
- async for i in some_async_iter( unformatted, args ): # fmt: skip - async for i in some_async_iter( unformatted, args ): # fmt: skip
- print("Do something") + async for i in some_async_iter(unformatted, args): # fmt: skip
+ NOT_YET_IMPLEMENTED_StmtAsyncFor # fmt: skip print("Do something")
-try : # fmt: skip -try : # fmt: skip
@ -203,7 +203,8 @@ for i in some_iter(unformatted, args): # fmt: skip
async def test_async_for(): async def test_async_for():
NOT_YET_IMPLEMENTED_StmtAsyncFor # fmt: skip async for i in some_async_iter(unformatted, args): # fmt: skip
print("Do something")
try: try: