format StmtAsyncWith (#5376)

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
David Szotten 2023-06-28 11:21:44 +01:00 committed by GitHub
parent 1979103ec0
commit c7adb9117f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 125 additions and 71 deletions

View file

@ -1,5 +1,6 @@
use crate::{not_yet_implemented, FormatNodeRule, PyFormatter}; use crate::prelude::*;
use ruff_formatter::{write, Buffer, FormatResult}; use crate::statement::stmt_with::AnyStatementWith;
use crate::FormatNodeRule;
use rustpython_parser::ast::StmtAsyncWith; use rustpython_parser::ast::StmtAsyncWith;
#[derive(Default)] #[derive(Default)]
@ -7,6 +8,15 @@ pub struct FormatStmtAsyncWith;
impl FormatNodeRule<StmtAsyncWith> for FormatStmtAsyncWith { impl FormatNodeRule<StmtAsyncWith> for FormatStmtAsyncWith {
fn fmt_fields(&self, item: &StmtAsyncWith, f: &mut PyFormatter) -> FormatResult<()> { fn fmt_fields(&self, item: &StmtAsyncWith, f: &mut PyFormatter) -> FormatResult<()> {
write!(f, [not_yet_implemented(item)]) AnyStatementWith::from(item).fmt(f)
}
fn fmt_dangling_comments(
&self,
_node: &StmtAsyncWith,
_f: &mut PyFormatter,
) -> FormatResult<()> {
// Handled in `fmt_fields`
Ok(())
} }
} }

View file

@ -1,29 +1,79 @@
use rustpython_parser::ast::StmtWith;
use ruff_formatter::{write, Buffer, FormatResult}; use ruff_formatter::{write, Buffer, FormatResult};
use ruff_python_ast::node::AstNode; use ruff_python_ast::node::AnyNodeRef;
use ruff_text_size::TextRange;
use rustpython_parser::ast::{Ranged, StmtAsyncWith, StmtWith, Suite, WithItem};
use crate::builders::optional_parentheses; use crate::builders::optional_parentheses;
use crate::comments::trailing_comments; use crate::comments::trailing_comments;
use crate::prelude::*; use crate::prelude::*;
use crate::{FormatNodeRule, PyFormatter}; use crate::FormatNodeRule;
#[derive(Default)] pub(super) enum AnyStatementWith<'a> {
pub struct FormatStmtWith; With(&'a StmtWith),
AsyncWith(&'a StmtAsyncWith),
}
impl FormatNodeRule<StmtWith> for FormatStmtWith { impl<'a> AnyStatementWith<'a> {
fn fmt_fields(&self, item: &StmtWith, f: &mut PyFormatter) -> FormatResult<()> { const fn is_async(&self) -> bool {
let StmtWith { matches!(self, AnyStatementWith::AsyncWith(_))
range: _, }
items,
body,
type_comment: _,
} = item;
fn items(&self) -> &[WithItem] {
match self {
AnyStatementWith::With(with) => with.items.as_slice(),
AnyStatementWith::AsyncWith(with) => with.items.as_slice(),
}
}
fn body(&self) -> &Suite {
match self {
AnyStatementWith::With(with) => &with.body,
AnyStatementWith::AsyncWith(with) => &with.body,
}
}
}
impl Ranged for AnyStatementWith<'_> {
fn range(&self) -> TextRange {
match self {
AnyStatementWith::With(with) => with.range(),
AnyStatementWith::AsyncWith(with) => with.range(),
}
}
}
impl<'a> From<&'a StmtWith> for AnyStatementWith<'a> {
fn from(value: &'a StmtWith) -> Self {
AnyStatementWith::With(value)
}
}
impl<'a> From<&'a StmtAsyncWith> for AnyStatementWith<'a> {
fn from(value: &'a StmtAsyncWith) -> Self {
AnyStatementWith::AsyncWith(value)
}
}
impl<'a> From<&AnyStatementWith<'a>> for AnyNodeRef<'a> {
fn from(value: &AnyStatementWith<'a>) -> Self {
match value {
AnyStatementWith::With(with) => AnyNodeRef::StmtWith(with),
AnyStatementWith::AsyncWith(with) => AnyNodeRef::StmtAsyncWith(with),
}
}
}
impl Format<PyFormatContext<'_>> for AnyStatementWith<'_> {
fn fmt(&self, f: &mut Formatter<PyFormatContext<'_>>) -> FormatResult<()> {
let comments = f.context().comments().clone(); let comments = f.context().comments().clone();
let dangling_comments = comments.dangling_comments(item.as_any_node_ref()); let dangling_comments = comments.dangling_comments(self);
let joined_items = format_with(|f| f.join_comma_separated().nodes(items.iter()).finish()); let joined_items =
format_with(|f| f.join_comma_separated().nodes(self.items().iter()).finish());
if self.is_async() {
write!(f, [text("async"), space()])?;
}
write!( write!(
f, f,
@ -33,10 +83,19 @@ impl FormatNodeRule<StmtWith> for FormatStmtWith {
group(&optional_parentheses(&joined_items)), group(&optional_parentheses(&joined_items)),
text(":"), text(":"),
trailing_comments(dangling_comments), trailing_comments(dangling_comments),
block_indent(&body.format()) block_indent(&self.body().format())
] ]
) )
} }
}
#[derive(Default)]
pub struct FormatStmtWith;
impl FormatNodeRule<StmtWith> for FormatStmtWith {
fn fmt_fields(&self, item: &StmtWith, f: &mut PyFormatter) -> FormatResult<()> {
AnyStatementWith::from(item).fmt(f)
}
fn fmt_dangling_comments(&self, _node: &StmtWith, _f: &mut PyFormatter) -> FormatResult<()> { fn fmt_dangling_comments(&self, _node: &StmtWith, _f: &mut PyFormatter) -> FormatResult<()> {
// Handled in `fmt_fields` // Handled in `fmt_fields`

View file

@ -140,17 +140,7 @@ async def wat():
if inner_imports.are_evil(): if inner_imports.are_evil():
# Explains why we have this if. # Explains why we have this if.
@@ -82,8 +82,7 @@ @@ -93,4 +93,4 @@
async def wat():
# This comment, for some reason \
# contains a trailing backslash.
- async with X.open_async() as x: # Some more comments
- result = await x.method1()
+ NOT_YET_IMPLEMENTED_StmtAsyncWith # Some more comments
# Comment after ending a block.
if result:
print("A OK", file=sys.stdout)
@@ -93,4 +92,4 @@
# Some closing comments. # Some closing comments.
# Maybe Vim or Emacs directives for formatting. # Maybe Vim or Emacs directives for formatting.
@ -246,7 +236,8 @@ class Foo:
async def wat(): async def wat():
# This comment, for some reason \ # This comment, for some reason \
# contains a trailing backslash. # contains a trailing backslash.
NOT_YET_IMPLEMENTED_StmtAsyncWith # Some more comments async with X.open_async() as x: # Some more comments
result = await x.method1()
# Comment after ending a block. # Comment after ending a block.
if result: if result:
print("A OK", file=sys.stdout) print("A OK", file=sys.stdout)

View file

@ -221,7 +221,7 @@ d={'a':1,
# Comment 1 # Comment 1
# Comment 2 # Comment 2
@@ -18,30 +16,53 @@ @@ -18,30 +16,54 @@
# fmt: off # fmt: off
def func_no_args(): def func_no_args():
@ -253,7 +253,8 @@ d={'a':1,
- await conn.do_what_i_mean('SELECT bobby, tables FROM xkcd', timeout=2) - await conn.do_what_i_mean('SELECT bobby, tables FROM xkcd', timeout=2)
- await asyncio.sleep(1) - await asyncio.sleep(1)
+ "Single-line docstring. Multiline is harder to reformat." + "Single-line docstring. Multiline is harder to reformat."
+ NOT_YET_IMPLEMENTED_StmtAsyncWith + async with some_connection() as conn:
+ await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2)
+ await asyncio.sleep(1) + await asyncio.sleep(1)
+ +
+ +
@ -296,7 +297,7 @@ d={'a':1,
def spaces_types( def spaces_types(
@@ -51,7 +72,7 @@ @@ -51,7 +73,7 @@
d: dict = {}, d: dict = {},
e: bool = True, e: bool = True,
f: int = -1, f: int = -1,
@ -305,7 +306,7 @@ d={'a':1,
h: str = "", h: str = "",
i: str = r"", i: str = r"",
): ):
@@ -64,55 +85,55 @@ @@ -64,55 +86,55 @@
something = { something = {
# fmt: off # fmt: off
@ -381,7 +382,7 @@ d={'a':1,
# fmt: on # fmt: on
@@ -133,10 +154,10 @@ @@ -133,10 +155,10 @@
"""Another known limitation.""" """Another known limitation."""
# fmt: on # fmt: on
# fmt: off # fmt: off
@ -396,7 +397,7 @@ d={'a':1,
# fmt: on # fmt: on
# fmt: off # fmt: off
# ...but comments still get reformatted even though they should not be # ...but comments still get reformatted even though they should not be
@@ -151,12 +172,10 @@ @@ -151,12 +173,10 @@
ast_args.kw_defaults, ast_args.kw_defaults,
parameters, parameters,
implicit_default=True, implicit_default=True,
@ -411,7 +412,7 @@ d={'a':1,
# fmt: on # fmt: on
_type_comment_re = re.compile( _type_comment_re = re.compile(
r""" r"""
@@ -179,7 +198,8 @@ @@ -179,7 +199,8 @@
$ $
""", """,
# fmt: off # fmt: off
@ -421,7 +422,7 @@ d={'a':1,
# fmt: on # fmt: on
) )
@@ -217,8 +237,7 @@ @@ -217,8 +238,7 @@
xxxxxxxxxx_xxxxxxxxxxx_xxxxxxx_xxxxxxxxx=5, xxxxxxxxxx_xxxxxxxxxxx_xxxxxxx_xxxxxxxxx=5,
) )
# fmt: off # fmt: off
@ -472,7 +473,8 @@ def func_no_args():
async def coroutine(arg, exec=False): async def coroutine(arg, exec=False):
"Single-line docstring. Multiline is harder to reformat." "Single-line docstring. Multiline is harder to reformat."
NOT_YET_IMPLEMENTED_StmtAsyncWith async with some_connection() as conn:
await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2)
await asyncio.sleep(1) await asyncio.sleep(1)

View file

@ -74,7 +74,7 @@ async def test_async_with():
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,62 +1,61 @@ @@ -1,62 +1,62 @@
# Make sure a leading comment is not removed. # Make sure a leading comment is not removed.
-def some_func( unformatted, args ): # fmt: skip -def some_func( unformatted, args ): # fmt: skip
+def some_func(unformatted, args): # fmt: skip +def some_func(unformatted, args): # fmt: skip
@ -153,8 +153,8 @@ async def test_async_with():
async def test_async_with(): async def test_async_with():
- async with give_me_async_context( unformatted, args ): # fmt: skip - async with give_me_async_context( unformatted, args ): # fmt: skip
- print("Do something") + async with give_me_async_context(unformatted, args): # fmt: skip
+ NOT_YET_IMPLEMENTED_StmtAsyncWith # fmt: skip print("Do something")
``` ```
## Ruff Output ## Ruff Output
@ -220,7 +220,8 @@ with give_me_context(unformatted, args): # fmt: skip
async def test_async_with(): async def test_async_with():
NOT_YET_IMPLEMENTED_StmtAsyncWith # fmt: skip async with give_me_async_context(unformatted, args): # fmt: skip
print("Do something")
``` ```
## Black Output ## Black Output

View file

@ -107,25 +107,25 @@ def __await__(): return (yield)
```diff ```diff
--- Black --- Black
+++ Ruff +++ Ruff
@@ -1,20 +1,19 @@ @@ -1,12 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
-import asyncio -import asyncio
-import sys -import sys
+NOT_YET_IMPLEMENTED_StmtImport -
+NOT_YET_IMPLEMENTED_StmtImport
-from third_party import X, Y, Z -from third_party import X, Y, Z
+NOT_YET_IMPLEMENTED_StmtImportFrom +NOT_YET_IMPLEMENTED_StmtImport
+NOT_YET_IMPLEMENTED_StmtImport
-from library import some_connection, some_decorator -from library import some_connection, some_decorator
+NOT_YET_IMPLEMENTED_StmtImportFrom +NOT_YET_IMPLEMENTED_StmtImportFrom
+NOT_YET_IMPLEMENTED_ExprJoinedStr
-f"trigger 3.6 mode" -f"trigger 3.6 mode"
- +NOT_YET_IMPLEMENTED_StmtImportFrom
+NOT_YET_IMPLEMENTED_ExprJoinedStr
def func_no_args(): def func_no_args():
a @@ -14,7 +13,7 @@
b b
c c
if True: if True:
@ -134,17 +134,7 @@ def __await__(): return (yield)
if False: if False:
... ...
for i in range(10): for i in range(10):
@@ -26,8 +25,7 @@ @@ -41,12 +40,22 @@
async def coroutine(arg, exec=False):
"Single-line docstring. Multiline is harder to reformat."
- async with some_connection() as conn:
- await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2)
+ NOT_YET_IMPLEMENTED_StmtAsyncWith
await asyncio.sleep(1)
@@ -41,12 +39,22 @@
debug: bool = False, debug: bool = False,
**kwargs, **kwargs,
) -> str: ) -> str:
@ -171,7 +161,7 @@ def __await__(): return (yield)
def spaces_types( def spaces_types(
@@ -56,7 +64,7 @@ @@ -56,7 +65,7 @@
d: dict = {}, d: dict = {},
e: bool = True, e: bool = True,
f: int = -1, f: int = -1,
@ -180,7 +170,7 @@ def __await__(): return (yield)
h: str = "", h: str = "",
i: str = r"", i: str = r"",
): ):
@@ -64,19 +72,16 @@ @@ -64,19 +73,16 @@
def spaces2(result=_core.Value(None)): def spaces2(result=_core.Value(None)):
@ -207,7 +197,7 @@ def __await__(): return (yield)
def long_lines(): def long_lines():
@@ -87,7 +92,7 @@ @@ -87,7 +93,7 @@
ast_args.kw_defaults, ast_args.kw_defaults,
parameters, parameters,
implicit_default=True, implicit_default=True,
@ -216,7 +206,7 @@ def __await__(): return (yield)
) )
typedargslist.extend( typedargslist.extend(
gen_annotated_params( gen_annotated_params(
@@ -96,7 +101,7 @@ @@ -96,7 +102,7 @@
parameters, parameters,
implicit_default=True, implicit_default=True,
# trailing standalone comment # trailing standalone comment
@ -225,7 +215,7 @@ def __await__(): return (yield)
) )
_type_comment_re = re.compile( _type_comment_re = re.compile(
r""" r"""
@@ -118,7 +123,8 @@ @@ -118,7 +124,8 @@
) )
$ $
""", """,
@ -235,7 +225,7 @@ def __await__(): return (yield)
) )
@@ -135,14 +141,8 @@ @@ -135,14 +142,8 @@
a, a,
**kwargs, **kwargs,
) -> A: ) -> A:
@ -284,7 +274,8 @@ def func_no_args():
async def coroutine(arg, exec=False): async def coroutine(arg, exec=False):
"Single-line docstring. Multiline is harder to reformat." "Single-line docstring. Multiline is harder to reformat."
NOT_YET_IMPLEMENTED_StmtAsyncWith async with some_connection() as conn:
await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2)
await asyncio.sleep(1) await asyncio.sleep(1)