diff --git a/crates/ruff_python_formatter/src/builders.rs b/crates/ruff_python_formatter/src/builders.rs index 95a2a76ace..943b06efd8 100644 --- a/crates/ruff_python_formatter/src/builders.rs +++ b/crates/ruff_python_formatter/src/builders.rs @@ -92,11 +92,23 @@ impl Entries { } } +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +pub(crate) enum TrailingComma { + /// Add a trailing comma if the group breaks and there's more than one element (or if the last + /// element has a trailing comma and the magical trailing comma option is enabled). + #[default] + MoreThanOne, + /// Add a trailing comma if the group breaks (or if the last element has a trailing comma and + /// the magical trailing comma option is enabled). + OneOrMore, +} + pub(crate) struct JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> { result: FormatResult<()>, fmt: &'fmt mut PyFormatter<'ast, 'buf>, entries: Entries, sequence_end: TextSize, + trailing_comma: TrailingComma, } impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> { @@ -106,9 +118,19 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> { result: Ok(()), entries: Entries::None, sequence_end, + trailing_comma: TrailingComma::default(), } } + /// Set the trailing comma behavior for the builder. Trailing commas will only be inserted if + /// the group breaks, and will _always_ be inserted if the last element has a trailing comma + /// (and the magical trailing comma option is enabled). However, this setting dictates whether + /// trailing commas are inserted for single element groups. + pub(crate) fn with_trailing_comma(mut self, trailing_comma: TrailingComma) -> Self { + self.trailing_comma = trailing_comma; + self + } + pub(crate) fn entry( &mut self, node: &T, @@ -194,8 +216,11 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> { }; // If there is a single entry, only keep the magic trailing comma, don't add it if - // it wasn't there. If there is more than one entry, always add it. - if magic_trailing_comma || self.entries.is_more_than_one() { + // it wasn't there -- unless the trailing comma behavior is set to one-or-more. + if magic_trailing_comma + || self.trailing_comma == TrailingComma::OneOrMore + || self.entries.is_more_than_one() + { if_group_breaks(&text(",")).fmt(self.fmt)?; } diff --git a/crates/ruff_python_formatter/src/statement/stmt_import_from.rs b/crates/ruff_python_formatter/src/statement/stmt_import_from.rs index 3249dd7a79..433293cfa9 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_import_from.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_import_from.rs @@ -3,7 +3,7 @@ use ruff_formatter::{write, Buffer, Format, FormatResult}; use ruff_python_ast::node::AstNode; use ruff_python_ast::{Ranged, StmtImportFrom}; -use crate::builders::{parenthesize_if_expands, PyFormatterExtensions}; +use crate::builders::{parenthesize_if_expands, PyFormatterExtensions, TrailingComma}; use crate::expression::parentheses::parenthesized; use crate::{AsFormat, FormatNodeRule, PyFormatter}; @@ -45,6 +45,7 @@ impl FormatNodeRule for FormatStmtImportFrom { let names = format_with(|f| { f.join_comma_separated(item.end()) + .with_trailing_comma(TrailingComma::OneOrMore) .entries(names.iter().map(|name| (name, name.format()))) .finish() }); diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__comments2.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__comments2.py.snap index 9628978a2c..0e98a11642 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__comments2.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__comments2.py.snap @@ -180,17 +180,6 @@ instruction()#comment with bad spacing ```diff --- Black +++ Ruff -@@ -1,8 +1,8 @@ - from com.my_lovely_company.my_lovely_team.my_lovely_project.my_lovely_component import ( -- MyLovelyCompanyTeamProjectComponent, # NOT DRY -+ MyLovelyCompanyTeamProjectComponent # NOT DRY - ) - from com.my_lovely_company.my_lovely_team.my_lovely_project.my_lovely_component import ( -- MyLovelyCompanyTeamProjectComponent as component, # DRY -+ MyLovelyCompanyTeamProjectComponent as component # DRY - ) - - # Please keep __all__ alphabetized within each category. @@ -60,8 +60,12 @@ # Comment before function. def inline_comments_in_brackets_ruin_everything(): @@ -259,10 +248,10 @@ instruction()#comment with bad spacing ```py from com.my_lovely_company.my_lovely_team.my_lovely_project.my_lovely_component import ( - MyLovelyCompanyTeamProjectComponent # NOT DRY + MyLovelyCompanyTeamProjectComponent, # NOT DRY ) from com.my_lovely_company.my_lovely_team.my_lovely_project.my_lovely_component import ( - MyLovelyCompanyTeamProjectComponent as component # DRY + MyLovelyCompanyTeamProjectComponent as component, # DRY ) # Please keep __all__ alphabetized within each category. diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__import_spacing.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__import_spacing.py.snap deleted file mode 100644 index f08a097d01..0000000000 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__import_spacing.py.snap +++ /dev/null @@ -1,213 +0,0 @@ ---- -source: crates/ruff_python_formatter/tests/fixtures.rs -input_file: crates/ruff_python_formatter/resources/test/fixtures/black/simple_cases/import_spacing.py ---- -## Input - -```py -"""The asyncio package, tracking PEP 3156.""" - -# flake8: noqa - -from logging import ( - WARNING -) -from logging import ( - ERROR, -) -import sys - -# This relies on each of the submodules having an __all__ variable. -from .base_events import * -from .coroutines import * -from .events import * # comment here - -from .futures import * -from .locks import * # comment here -from .protocols import * - -from ..runners import * # comment here -from ..queues import * -from ..streams import * - -from some_library import ( - Just, Enough, Libraries, To, Fit, In, This, Nice, Split, Which, We, No, Longer, Use -) -from name_of_a_company.extremely_long_project_name.component.ttypes import CuteLittleServiceHandlerFactoryyy -from name_of_a_company.extremely_long_project_name.extremely_long_component_name.ttypes import * - -from .a.b.c.subprocess import * -from . import (tasks) -from . import (A, B, C) -from . import SomeVeryLongNameAndAllOfItsAdditionalLetters1, \ - SomeVeryLongNameAndAllOfItsAdditionalLetters2 - -__all__ = ( - base_events.__all__ - + coroutines.__all__ - + events.__all__ - + futures.__all__ - + locks.__all__ - + protocols.__all__ - + runners.__all__ - + queues.__all__ - + streams.__all__ - + tasks.__all__ -) -``` - -## Black Differences - -```diff ---- Black -+++ Ruff -@@ -38,7 +38,7 @@ - Use, - ) - from name_of_a_company.extremely_long_project_name.component.ttypes import ( -- CuteLittleServiceHandlerFactoryyy, -+ CuteLittleServiceHandlerFactoryyy - ) - from name_of_a_company.extremely_long_project_name.extremely_long_component_name.ttypes import * - -``` - -## Ruff Output - -```py -"""The asyncio package, tracking PEP 3156.""" - -# flake8: noqa - -from logging import WARNING -from logging import ( - ERROR, -) -import sys - -# This relies on each of the submodules having an __all__ variable. -from .base_events import * -from .coroutines import * -from .events import * # comment here - -from .futures import * -from .locks import * # comment here -from .protocols import * - -from ..runners import * # comment here -from ..queues import * -from ..streams import * - -from some_library import ( - Just, - Enough, - Libraries, - To, - Fit, - In, - This, - Nice, - Split, - Which, - We, - No, - Longer, - Use, -) -from name_of_a_company.extremely_long_project_name.component.ttypes import ( - CuteLittleServiceHandlerFactoryyy -) -from name_of_a_company.extremely_long_project_name.extremely_long_component_name.ttypes import * - -from .a.b.c.subprocess import * -from . import tasks -from . import A, B, C -from . import ( - SomeVeryLongNameAndAllOfItsAdditionalLetters1, - SomeVeryLongNameAndAllOfItsAdditionalLetters2, -) - -__all__ = ( - base_events.__all__ - + coroutines.__all__ - + events.__all__ - + futures.__all__ - + locks.__all__ - + protocols.__all__ - + runners.__all__ - + queues.__all__ - + streams.__all__ - + tasks.__all__ -) -``` - -## Black Output - -```py -"""The asyncio package, tracking PEP 3156.""" - -# flake8: noqa - -from logging import WARNING -from logging import ( - ERROR, -) -import sys - -# This relies on each of the submodules having an __all__ variable. -from .base_events import * -from .coroutines import * -from .events import * # comment here - -from .futures import * -from .locks import * # comment here -from .protocols import * - -from ..runners import * # comment here -from ..queues import * -from ..streams import * - -from some_library import ( - Just, - Enough, - Libraries, - To, - Fit, - In, - This, - Nice, - Split, - Which, - We, - No, - Longer, - Use, -) -from name_of_a_company.extremely_long_project_name.component.ttypes import ( - CuteLittleServiceHandlerFactoryyy, -) -from name_of_a_company.extremely_long_project_name.extremely_long_component_name.ttypes import * - -from .a.b.c.subprocess import * -from . import tasks -from . import A, B, C -from . import ( - SomeVeryLongNameAndAllOfItsAdditionalLetters1, - SomeVeryLongNameAndAllOfItsAdditionalLetters2, -) - -__all__ = ( - base_events.__all__ - + coroutines.__all__ - + events.__all__ - + futures.__all__ - + locks.__all__ - + protocols.__all__ - + runners.__all__ - + queues.__all__ - + streams.__all__ - + tasks.__all__ -) -``` - - diff --git a/crates/ruff_python_formatter/tests/snapshots/format@statement__import.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@statement__import.py.snap index 0208606383..4c76352286 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@statement__import.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@statement__import.py.snap @@ -12,7 +12,7 @@ from a import aksjdhflsakhdflkjsadlfajkslhfdkjsaldajlahflashdfljahlfksajlhfajfjf ## Output ```py from a import ( - aksjdhflsakhdflkjsadlfajkslhfdkjsaldajlahflashdfljahlfksajlhfajfjfsaahflakjslhdfkjalhdskjfa + aksjdhflsakhdflkjsadlfajkslhfdkjsaldajlahflashdfljahlfksajlhfajfjfsaahflakjslhdfkjalhdskjfa, ) from a import ( aksjdhflsakhdflkjsadlfajkslhfdkjsaldajlahflashdfljahlfksajlhfajfjfsaahflakjslhdfkjalhdskjfa, diff --git a/crates/ruff_python_formatter/tests/snapshots/format@statement__import_from.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@statement__import_from.py.snap index 433e893297..6b3942ea0c 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@statement__import_from.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@statement__import_from.py.snap @@ -89,7 +89,7 @@ from a import ( # comment ) from a import ( # comment - bar + bar, ) from a import (