format StmtAsyncWith (#5376)

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
David Szotten 2023-06-28 11:21:44 +01:00 committed by GitHub
parent 1979103ec0
commit c7adb9117f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 125 additions and 71 deletions

View file

@ -1,5 +1,6 @@
use crate::{not_yet_implemented, FormatNodeRule, PyFormatter};
use ruff_formatter::{write, Buffer, FormatResult};
use crate::prelude::*;
use crate::statement::stmt_with::AnyStatementWith;
use crate::FormatNodeRule;
use rustpython_parser::ast::StmtAsyncWith;
#[derive(Default)]
@ -7,6 +8,15 @@ pub struct FormatStmtAsyncWith;
impl FormatNodeRule<StmtAsyncWith> for FormatStmtAsyncWith {
fn fmt_fields(&self, item: &StmtAsyncWith, f: &mut PyFormatter) -> FormatResult<()> {
write!(f, [not_yet_implemented(item)])
AnyStatementWith::from(item).fmt(f)
}
fn fmt_dangling_comments(
&self,
_node: &StmtAsyncWith,
_f: &mut PyFormatter,
) -> FormatResult<()> {
// Handled in `fmt_fields`
Ok(())
}
}

View file

@ -1,29 +1,79 @@
use rustpython_parser::ast::StmtWith;
use ruff_formatter::{write, Buffer, FormatResult};
use ruff_python_ast::node::AstNode;
use ruff_python_ast::node::AnyNodeRef;
use ruff_text_size::TextRange;
use rustpython_parser::ast::{Ranged, StmtAsyncWith, StmtWith, Suite, WithItem};
use crate::builders::optional_parentheses;
use crate::comments::trailing_comments;
use crate::prelude::*;
use crate::{FormatNodeRule, PyFormatter};
use crate::FormatNodeRule;
#[derive(Default)]
pub struct FormatStmtWith;
pub(super) enum AnyStatementWith<'a> {
With(&'a StmtWith),
AsyncWith(&'a StmtAsyncWith),
}
impl FormatNodeRule<StmtWith> for FormatStmtWith {
fn fmt_fields(&self, item: &StmtWith, f: &mut PyFormatter) -> FormatResult<()> {
let StmtWith {
range: _,
items,
body,
type_comment: _,
} = item;
impl<'a> AnyStatementWith<'a> {
const fn is_async(&self) -> bool {
matches!(self, AnyStatementWith::AsyncWith(_))
}
fn items(&self) -> &[WithItem] {
match self {
AnyStatementWith::With(with) => with.items.as_slice(),
AnyStatementWith::AsyncWith(with) => with.items.as_slice(),
}
}
fn body(&self) -> &Suite {
match self {
AnyStatementWith::With(with) => &with.body,
AnyStatementWith::AsyncWith(with) => &with.body,
}
}
}
impl Ranged for AnyStatementWith<'_> {
fn range(&self) -> TextRange {
match self {
AnyStatementWith::With(with) => with.range(),
AnyStatementWith::AsyncWith(with) => with.range(),
}
}
}
impl<'a> From<&'a StmtWith> for AnyStatementWith<'a> {
fn from(value: &'a StmtWith) -> Self {
AnyStatementWith::With(value)
}
}
impl<'a> From<&'a StmtAsyncWith> for AnyStatementWith<'a> {
fn from(value: &'a StmtAsyncWith) -> Self {
AnyStatementWith::AsyncWith(value)
}
}
impl<'a> From<&AnyStatementWith<'a>> for AnyNodeRef<'a> {
fn from(value: &AnyStatementWith<'a>) -> Self {
match value {
AnyStatementWith::With(with) => AnyNodeRef::StmtWith(with),
AnyStatementWith::AsyncWith(with) => AnyNodeRef::StmtAsyncWith(with),
}
}
}
impl Format<PyFormatContext<'_>> for AnyStatementWith<'_> {
fn fmt(&self, f: &mut Formatter<PyFormatContext<'_>>) -> FormatResult<()> {
let comments = f.context().comments().clone();
let dangling_comments = comments.dangling_comments(item.as_any_node_ref());
let dangling_comments = comments.dangling_comments(self);
let joined_items = format_with(|f| f.join_comma_separated().nodes(items.iter()).finish());
let joined_items =
format_with(|f| f.join_comma_separated().nodes(self.items().iter()).finish());
if self.is_async() {
write!(f, [text("async"), space()])?;
}
write!(
f,
@ -33,10 +83,19 @@ impl FormatNodeRule<StmtWith> for FormatStmtWith {
group(&optional_parentheses(&joined_items)),
text(":"),
trailing_comments(dangling_comments),
block_indent(&body.format())
block_indent(&self.body().format())
]
)
}
}
#[derive(Default)]
pub struct FormatStmtWith;
impl FormatNodeRule<StmtWith> for FormatStmtWith {
fn fmt_fields(&self, item: &StmtWith, f: &mut PyFormatter) -> FormatResult<()> {
AnyStatementWith::from(item).fmt(f)
}
fn fmt_dangling_comments(&self, _node: &StmtWith, _f: &mut PyFormatter) -> FormatResult<()> {
// Handled in `fmt_fields`

View file

@ -140,17 +140,7 @@ async def wat():
if inner_imports.are_evil():
# Explains why we have this if.
@@ -82,8 +82,7 @@
async def wat():
# This comment, for some reason \
# contains a trailing backslash.
- async with X.open_async() as x: # Some more comments
- result = await x.method1()
+ NOT_YET_IMPLEMENTED_StmtAsyncWith # Some more comments
# Comment after ending a block.
if result:
print("A OK", file=sys.stdout)
@@ -93,4 +92,4 @@
@@ -93,4 +93,4 @@
# Some closing comments.
# Maybe Vim or Emacs directives for formatting.
@ -246,7 +236,8 @@ class Foo:
async def wat():
# This comment, for some reason \
# contains a trailing backslash.
NOT_YET_IMPLEMENTED_StmtAsyncWith # Some more comments
async with X.open_async() as x: # Some more comments
result = await x.method1()
# Comment after ending a block.
if result:
print("A OK", file=sys.stdout)

View file

@ -221,7 +221,7 @@ d={'a':1,
# Comment 1
# Comment 2
@@ -18,30 +16,53 @@
@@ -18,30 +16,54 @@
# fmt: off
def func_no_args():
@ -253,7 +253,8 @@ d={'a':1,
- await conn.do_what_i_mean('SELECT bobby, tables FROM xkcd', timeout=2)
- await asyncio.sleep(1)
+ "Single-line docstring. Multiline is harder to reformat."
+ NOT_YET_IMPLEMENTED_StmtAsyncWith
+ async with some_connection() as conn:
+ await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2)
+ await asyncio.sleep(1)
+
+
@ -296,7 +297,7 @@ d={'a':1,
def spaces_types(
@@ -51,7 +72,7 @@
@@ -51,7 +73,7 @@
d: dict = {},
e: bool = True,
f: int = -1,
@ -305,7 +306,7 @@ d={'a':1,
h: str = "",
i: str = r"",
):
@@ -64,55 +85,55 @@
@@ -64,55 +86,55 @@
something = {
# fmt: off
@ -381,7 +382,7 @@ d={'a':1,
# fmt: on
@@ -133,10 +154,10 @@
@@ -133,10 +155,10 @@
"""Another known limitation."""
# fmt: on
# fmt: off
@ -396,7 +397,7 @@ d={'a':1,
# fmt: on
# fmt: off
# ...but comments still get reformatted even though they should not be
@@ -151,12 +172,10 @@
@@ -151,12 +173,10 @@
ast_args.kw_defaults,
parameters,
implicit_default=True,
@ -411,7 +412,7 @@ d={'a':1,
# fmt: on
_type_comment_re = re.compile(
r"""
@@ -179,7 +198,8 @@
@@ -179,7 +199,8 @@
$
""",
# fmt: off
@ -421,7 +422,7 @@ d={'a':1,
# fmt: on
)
@@ -217,8 +237,7 @@
@@ -217,8 +238,7 @@
xxxxxxxxxx_xxxxxxxxxxx_xxxxxxx_xxxxxxxxx=5,
)
# fmt: off
@ -472,7 +473,8 @@ def func_no_args():
async def coroutine(arg, exec=False):
"Single-line docstring. Multiline is harder to reformat."
NOT_YET_IMPLEMENTED_StmtAsyncWith
async with some_connection() as conn:
await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2)
await asyncio.sleep(1)

View file

@ -74,7 +74,7 @@ async def test_async_with():
```diff
--- Black
+++ Ruff
@@ -1,62 +1,61 @@
@@ -1,62 +1,62 @@
# Make sure a leading comment is not removed.
-def some_func( unformatted, args ): # fmt: skip
+def some_func(unformatted, args): # fmt: skip
@ -153,8 +153,8 @@ async def test_async_with():
async def test_async_with():
- async with give_me_async_context( unformatted, args ): # fmt: skip
- print("Do something")
+ NOT_YET_IMPLEMENTED_StmtAsyncWith # fmt: skip
+ async with give_me_async_context(unformatted, args): # fmt: skip
print("Do something")
```
## Ruff Output
@ -220,7 +220,8 @@ with give_me_context(unformatted, args): # fmt: skip
async def test_async_with():
NOT_YET_IMPLEMENTED_StmtAsyncWith # fmt: skip
async with give_me_async_context(unformatted, args): # fmt: skip
print("Do something")
```
## Black Output

View file

@ -107,25 +107,25 @@ def __await__(): return (yield)
```diff
--- Black
+++ Ruff
@@ -1,20 +1,19 @@
@@ -1,12 +1,11 @@
#!/usr/bin/env python3
-import asyncio
-import sys
+NOT_YET_IMPLEMENTED_StmtImport
+NOT_YET_IMPLEMENTED_StmtImport
-
-from third_party import X, Y, Z
+NOT_YET_IMPLEMENTED_StmtImportFrom
+NOT_YET_IMPLEMENTED_StmtImport
+NOT_YET_IMPLEMENTED_StmtImport
-from library import some_connection, some_decorator
+NOT_YET_IMPLEMENTED_StmtImportFrom
+NOT_YET_IMPLEMENTED_ExprJoinedStr
-f"trigger 3.6 mode"
-
+NOT_YET_IMPLEMENTED_StmtImportFrom
+NOT_YET_IMPLEMENTED_ExprJoinedStr
def func_no_args():
a
@@ -14,7 +13,7 @@
b
c
if True:
@ -134,17 +134,7 @@ def __await__(): return (yield)
if False:
...
for i in range(10):
@@ -26,8 +25,7 @@
async def coroutine(arg, exec=False):
"Single-line docstring. Multiline is harder to reformat."
- async with some_connection() as conn:
- await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2)
+ NOT_YET_IMPLEMENTED_StmtAsyncWith
await asyncio.sleep(1)
@@ -41,12 +39,22 @@
@@ -41,12 +40,22 @@
debug: bool = False,
**kwargs,
) -> str:
@ -171,7 +161,7 @@ def __await__(): return (yield)
def spaces_types(
@@ -56,7 +64,7 @@
@@ -56,7 +65,7 @@
d: dict = {},
e: bool = True,
f: int = -1,
@ -180,7 +170,7 @@ def __await__(): return (yield)
h: str = "",
i: str = r"",
):
@@ -64,19 +72,16 @@
@@ -64,19 +73,16 @@
def spaces2(result=_core.Value(None)):
@ -207,7 +197,7 @@ def __await__(): return (yield)
def long_lines():
@@ -87,7 +92,7 @@
@@ -87,7 +93,7 @@
ast_args.kw_defaults,
parameters,
implicit_default=True,
@ -216,7 +206,7 @@ def __await__(): return (yield)
)
typedargslist.extend(
gen_annotated_params(
@@ -96,7 +101,7 @@
@@ -96,7 +102,7 @@
parameters,
implicit_default=True,
# trailing standalone comment
@ -225,7 +215,7 @@ def __await__(): return (yield)
)
_type_comment_re = re.compile(
r"""
@@ -118,7 +123,8 @@
@@ -118,7 +124,8 @@
)
$
""",
@ -235,7 +225,7 @@ def __await__(): return (yield)
)
@@ -135,14 +141,8 @@
@@ -135,14 +142,8 @@
a,
**kwargs,
) -> A:
@ -284,7 +274,8 @@ def func_no_args():
async def coroutine(arg, exec=False):
"Single-line docstring. Multiline is harder to reformat."
NOT_YET_IMPLEMENTED_StmtAsyncWith
async with some_connection() as conn:
await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2)
await asyncio.sleep(1)