diff --git a/crates/ruff_python_formatter/src/statement/stmt_async_with.rs b/crates/ruff_python_formatter/src/statement/stmt_async_with.rs index 0555642a10..e8d0e0ebd9 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_async_with.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_async_with.rs @@ -1,5 +1,6 @@ -use crate::{not_yet_implemented, FormatNodeRule, PyFormatter}; -use ruff_formatter::{write, Buffer, FormatResult}; +use crate::prelude::*; +use crate::statement::stmt_with::AnyStatementWith; +use crate::FormatNodeRule; use rustpython_parser::ast::StmtAsyncWith; #[derive(Default)] @@ -7,6 +8,15 @@ pub struct FormatStmtAsyncWith; impl FormatNodeRule for FormatStmtAsyncWith { fn fmt_fields(&self, item: &StmtAsyncWith, f: &mut PyFormatter) -> FormatResult<()> { - write!(f, [not_yet_implemented(item)]) + AnyStatementWith::from(item).fmt(f) + } + + fn fmt_dangling_comments( + &self, + _node: &StmtAsyncWith, + _f: &mut PyFormatter, + ) -> FormatResult<()> { + // Handled in `fmt_fields` + Ok(()) } } diff --git a/crates/ruff_python_formatter/src/statement/stmt_with.rs b/crates/ruff_python_formatter/src/statement/stmt_with.rs index 443337d934..c4adb3de4d 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_with.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_with.rs @@ -1,29 +1,79 @@ -use rustpython_parser::ast::StmtWith; - use ruff_formatter::{write, Buffer, FormatResult}; -use ruff_python_ast::node::AstNode; +use ruff_python_ast::node::AnyNodeRef; +use ruff_text_size::TextRange; +use rustpython_parser::ast::{Ranged, StmtAsyncWith, StmtWith, Suite, WithItem}; use crate::builders::optional_parentheses; use crate::comments::trailing_comments; use crate::prelude::*; -use crate::{FormatNodeRule, PyFormatter}; +use crate::FormatNodeRule; -#[derive(Default)] -pub struct FormatStmtWith; +pub(super) enum AnyStatementWith<'a> { + With(&'a StmtWith), + AsyncWith(&'a StmtAsyncWith), +} -impl FormatNodeRule for FormatStmtWith { - fn fmt_fields(&self, item: &StmtWith, f: &mut PyFormatter) -> FormatResult<()> { - let StmtWith { - range: _, - items, - body, - type_comment: _, - } = item; +impl<'a> AnyStatementWith<'a> { + const fn is_async(&self) -> bool { + matches!(self, AnyStatementWith::AsyncWith(_)) + } + fn items(&self) -> &[WithItem] { + match self { + AnyStatementWith::With(with) => with.items.as_slice(), + AnyStatementWith::AsyncWith(with) => with.items.as_slice(), + } + } + + fn body(&self) -> &Suite { + match self { + AnyStatementWith::With(with) => &with.body, + AnyStatementWith::AsyncWith(with) => &with.body, + } + } +} + +impl Ranged for AnyStatementWith<'_> { + fn range(&self) -> TextRange { + match self { + AnyStatementWith::With(with) => with.range(), + AnyStatementWith::AsyncWith(with) => with.range(), + } + } +} + +impl<'a> From<&'a StmtWith> for AnyStatementWith<'a> { + fn from(value: &'a StmtWith) -> Self { + AnyStatementWith::With(value) + } +} + +impl<'a> From<&'a StmtAsyncWith> for AnyStatementWith<'a> { + fn from(value: &'a StmtAsyncWith) -> Self { + AnyStatementWith::AsyncWith(value) + } +} + +impl<'a> From<&AnyStatementWith<'a>> for AnyNodeRef<'a> { + fn from(value: &AnyStatementWith<'a>) -> Self { + match value { + AnyStatementWith::With(with) => AnyNodeRef::StmtWith(with), + AnyStatementWith::AsyncWith(with) => AnyNodeRef::StmtAsyncWith(with), + } + } +} + +impl Format> for AnyStatementWith<'_> { + fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { let comments = f.context().comments().clone(); - let dangling_comments = comments.dangling_comments(item.as_any_node_ref()); + let dangling_comments = comments.dangling_comments(self); - let joined_items = format_with(|f| f.join_comma_separated().nodes(items.iter()).finish()); + let joined_items = + format_with(|f| f.join_comma_separated().nodes(self.items().iter()).finish()); + + if self.is_async() { + write!(f, [text("async"), space()])?; + } write!( f, @@ -33,10 +83,19 @@ impl FormatNodeRule for FormatStmtWith { group(&optional_parentheses(&joined_items)), text(":"), trailing_comments(dangling_comments), - block_indent(&body.format()) + block_indent(&self.body().format()) ] ) } +} + +#[derive(Default)] +pub struct FormatStmtWith; + +impl FormatNodeRule for FormatStmtWith { + fn fmt_fields(&self, item: &StmtWith, f: &mut PyFormatter) -> FormatResult<()> { + AnyStatementWith::from(item).fmt(f) + } fn fmt_dangling_comments(&self, _node: &StmtWith, _f: &mut PyFormatter) -> FormatResult<()> { // Handled in `fmt_fields` diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@comments.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@comments.py.snap index ebd4a1b917..f5c40ca6bc 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@comments.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@comments.py.snap @@ -140,17 +140,7 @@ async def wat(): if inner_imports.are_evil(): # Explains why we have this if. -@@ -82,8 +82,7 @@ - async def wat(): - # This comment, for some reason \ - # contains a trailing backslash. -- async with X.open_async() as x: # Some more comments -- result = await x.method1() -+ NOT_YET_IMPLEMENTED_StmtAsyncWith # Some more comments - # Comment after ending a block. - if result: - print("A OK", file=sys.stdout) -@@ -93,4 +92,4 @@ +@@ -93,4 +93,4 @@ # Some closing comments. # Maybe Vim or Emacs directives for formatting. @@ -246,7 +236,8 @@ class Foo: async def wat(): # This comment, for some reason \ # contains a trailing backslash. - NOT_YET_IMPLEMENTED_StmtAsyncWith # Some more comments + async with X.open_async() as x: # Some more comments + result = await x.method1() # Comment after ending a block. if result: print("A OK", file=sys.stdout) diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@fmtonoff.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@fmtonoff.py.snap index e8c3771da6..460ce0e39e 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@fmtonoff.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@fmtonoff.py.snap @@ -221,7 +221,7 @@ d={'a':1, # Comment 1 # Comment 2 -@@ -18,30 +16,53 @@ +@@ -18,30 +16,54 @@ # fmt: off def func_no_args(): @@ -253,7 +253,8 @@ d={'a':1, - await conn.do_what_i_mean('SELECT bobby, tables FROM xkcd', timeout=2) - await asyncio.sleep(1) + "Single-line docstring. Multiline is harder to reformat." -+ NOT_YET_IMPLEMENTED_StmtAsyncWith ++ async with some_connection() as conn: ++ await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2) + await asyncio.sleep(1) + + @@ -296,7 +297,7 @@ d={'a':1, def spaces_types( -@@ -51,7 +72,7 @@ +@@ -51,7 +73,7 @@ d: dict = {}, e: bool = True, f: int = -1, @@ -305,7 +306,7 @@ d={'a':1, h: str = "", i: str = r"", ): -@@ -64,55 +85,55 @@ +@@ -64,55 +86,55 @@ something = { # fmt: off @@ -381,7 +382,7 @@ d={'a':1, # fmt: on -@@ -133,10 +154,10 @@ +@@ -133,10 +155,10 @@ """Another known limitation.""" # fmt: on # fmt: off @@ -396,7 +397,7 @@ d={'a':1, # fmt: on # fmt: off # ...but comments still get reformatted even though they should not be -@@ -151,12 +172,10 @@ +@@ -151,12 +173,10 @@ ast_args.kw_defaults, parameters, implicit_default=True, @@ -411,7 +412,7 @@ d={'a':1, # fmt: on _type_comment_re = re.compile( r""" -@@ -179,7 +198,8 @@ +@@ -179,7 +199,8 @@ $ """, # fmt: off @@ -421,7 +422,7 @@ d={'a':1, # fmt: on ) -@@ -217,8 +237,7 @@ +@@ -217,8 +238,7 @@ xxxxxxxxxx_xxxxxxxxxxx_xxxxxxx_xxxxxxxxx=5, ) # fmt: off @@ -472,7 +473,8 @@ def func_no_args(): async def coroutine(arg, exec=False): "Single-line docstring. Multiline is harder to reformat." - NOT_YET_IMPLEMENTED_StmtAsyncWith + async with some_connection() as conn: + await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2) await asyncio.sleep(1) diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@fmtskip8.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@fmtskip8.py.snap index 084370e171..4d65173c25 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@fmtskip8.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@fmtskip8.py.snap @@ -74,7 +74,7 @@ async def test_async_with(): ```diff --- Black +++ Ruff -@@ -1,62 +1,61 @@ +@@ -1,62 +1,62 @@ # Make sure a leading comment is not removed. -def some_func( unformatted, args ): # fmt: skip +def some_func(unformatted, args): # fmt: skip @@ -153,8 +153,8 @@ async def test_async_with(): async def test_async_with(): - async with give_me_async_context( unformatted, args ): # fmt: skip -- print("Do something") -+ NOT_YET_IMPLEMENTED_StmtAsyncWith # fmt: skip ++ async with give_me_async_context(unformatted, args): # fmt: skip + print("Do something") ``` ## Ruff Output @@ -220,7 +220,8 @@ with give_me_context(unformatted, args): # fmt: skip async def test_async_with(): - NOT_YET_IMPLEMENTED_StmtAsyncWith # fmt: skip + async with give_me_async_context(unformatted, args): # fmt: skip + print("Do something") ``` ## Black Output diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@function.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@function.py.snap index 7562820a71..98e3796d47 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@function.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@function.py.snap @@ -107,25 +107,25 @@ def __await__(): return (yield) ```diff --- Black +++ Ruff -@@ -1,20 +1,19 @@ +@@ -1,12 +1,11 @@ #!/usr/bin/env python3 -import asyncio -import sys -+NOT_YET_IMPLEMENTED_StmtImport -+NOT_YET_IMPLEMENTED_StmtImport - +- -from third_party import X, Y, Z -+NOT_YET_IMPLEMENTED_StmtImportFrom ++NOT_YET_IMPLEMENTED_StmtImport ++NOT_YET_IMPLEMENTED_StmtImport -from library import some_connection, some_decorator +NOT_YET_IMPLEMENTED_StmtImportFrom -+NOT_YET_IMPLEMENTED_ExprJoinedStr -f"trigger 3.6 mode" -- ++NOT_YET_IMPLEMENTED_StmtImportFrom ++NOT_YET_IMPLEMENTED_ExprJoinedStr + def func_no_args(): - a +@@ -14,7 +13,7 @@ b c if True: @@ -134,17 +134,7 @@ def __await__(): return (yield) if False: ... for i in range(10): -@@ -26,8 +25,7 @@ - - async def coroutine(arg, exec=False): - "Single-line docstring. Multiline is harder to reformat." -- async with some_connection() as conn: -- await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2) -+ NOT_YET_IMPLEMENTED_StmtAsyncWith - await asyncio.sleep(1) - - -@@ -41,12 +39,22 @@ +@@ -41,12 +40,22 @@ debug: bool = False, **kwargs, ) -> str: @@ -171,7 +161,7 @@ def __await__(): return (yield) def spaces_types( -@@ -56,7 +64,7 @@ +@@ -56,7 +65,7 @@ d: dict = {}, e: bool = True, f: int = -1, @@ -180,7 +170,7 @@ def __await__(): return (yield) h: str = "", i: str = r"", ): -@@ -64,19 +72,16 @@ +@@ -64,19 +73,16 @@ def spaces2(result=_core.Value(None)): @@ -207,7 +197,7 @@ def __await__(): return (yield) def long_lines(): -@@ -87,7 +92,7 @@ +@@ -87,7 +93,7 @@ ast_args.kw_defaults, parameters, implicit_default=True, @@ -216,7 +206,7 @@ def __await__(): return (yield) ) typedargslist.extend( gen_annotated_params( -@@ -96,7 +101,7 @@ +@@ -96,7 +102,7 @@ parameters, implicit_default=True, # trailing standalone comment @@ -225,7 +215,7 @@ def __await__(): return (yield) ) _type_comment_re = re.compile( r""" -@@ -118,7 +123,8 @@ +@@ -118,7 +124,8 @@ ) $ """, @@ -235,7 +225,7 @@ def __await__(): return (yield) ) -@@ -135,14 +141,8 @@ +@@ -135,14 +142,8 @@ a, **kwargs, ) -> A: @@ -284,7 +274,8 @@ def func_no_args(): async def coroutine(arg, exec=False): "Single-line docstring. Multiline is harder to reformat." - NOT_YET_IMPLEMENTED_StmtAsyncWith + async with some_connection() as conn: + await conn.do_what_i_mean("SELECT bobby, tables FROM xkcd", timeout=2) await asyncio.sleep(1)