Fix formatting of lambda star arguments (#6257)

## Summary
Previously, the ruff formatter was removing the star argument of
`lambda` expressions when formatting.

Given the following code snippet
```python
lambda *a: ()
lambda **b: ()
```
it would be formatted to
```python
lambda: ()
lambda: ()
```

We fix this by checking for the presence of `args`, `vararg` or `kwarg`
in the `lambda` expression, before we were only checking for the
presence of `args`.

Fixes #5894

## Test Plan

Add new tests cases.

---------

Co-authored-by: Charlie Marsh <charlie.r.marsh@gmail.com>
This commit is contained in:
Victor Hugo Gomes 2023-08-02 16:31:20 -03:00 committed by GitHub
parent c362ea7fd4
commit 7c5791fb77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 212 additions and 242 deletions

View file

@ -75,3 +75,19 @@ def f(
# ...but we do preserve a trailing comma after the arguments
a = lambda b,: 0
lambda a,: 0
lambda *args,: 0
lambda **kwds,: 0
lambda a, *args,: 0
lambda a, **kwds,: 0
lambda *args, b,: 0
lambda *, b,: 0
lambda *args, **kwds,: 0
lambda a, *args, b,: 0
lambda a, *, b,: 0
lambda a, *args, **kwds,: 0
lambda *args, b, **kwds,: 0
lambda *, b, **kwds,: 0
lambda a, *args, b, **kwds,: 0
lambda a, *, b, **kwds,: 0

View file

@ -231,3 +231,21 @@ def f42(
c,
):
pass
# Check trailing commas are permitted in funcdef argument list.
def f(a, ): pass
def f(*args, ): pass
def f(**kwds, ): pass
def f(a, *args, ): pass
def f(a, **kwds, ): pass
def f(*args, b, ): pass
def f(*, b, ): pass
def f(*args, **kwds, ): pass
def f(a, *args, b, ): pass
def f(a, *, b, ): pass
def f(a, *args, **kwds, ): pass
def f(*args, b, **kwds, ): pass
def f(*, b, **kwds, ): pass
def f(a, *args, b, **kwds, ): pass
def f(a, *, b, **kwds, ): pass

View file

@ -22,7 +22,8 @@ impl FormatNodeRule<ExprLambda> for FormatExprLambda {
write!(f, [text("lambda")])?;
if !parameters.args.is_empty() {
if !parameters.args.is_empty() || parameters.vararg.is_some() || parameters.kwarg.is_some()
{
write!(
f,
[

View file

@ -278,9 +278,9 @@ pub(crate) struct ArgumentSeparator {
pub(crate) following_start: TextSize,
}
/// Finds slash and star in `f(a, /, b, *, c)`
/// Finds slash and star in `f(a, /, b, *, c)` or `lambda a, /, b, *, c: 1`.
///
/// Returns slash and star
/// Returns the location of the slash and star separators, if any.
pub(crate) fn find_argument_separators(
contents: &str,
parameters: &Parameters,
@ -347,14 +347,21 @@ pub(crate) fn find_argument_separators(
} else {
let mut tokens = SimpleTokenizer::new(contents, parameters.range).skip_trivia();
let lparen = tokens
.next()
.expect("The function definition can't end here");
debug_assert!(lparen.kind() == SimpleTokenKind::LParen, "{lparen:?}");
let star = tokens
let lparen_or_star = tokens
.next()
.expect("The function definition can't end here");
// In a function definition, the first token should always be a `(`; in a lambda
// definition, it _can't_ be a `(`.
let star = if lparen_or_star.kind == SimpleTokenKind::LParen {
tokens
.next()
.expect("The function definition can't end here")
} else {
lparen_or_star
};
debug_assert!(star.kind() == SimpleTokenKind::Star, "{star:?}");
Some(ArgumentSeparator {
preceding_end: parameters.range.start(),
separator: star.range,

View file

@ -1,234 +0,0 @@
---
source: crates/ruff_python_formatter/tests/fixtures.rs
input_file: crates/ruff_python_formatter/resources/test/fixtures/black/simple_cases/power_op_spacing.py
---
## Input
```py
def function(**kwargs):
t = a**2 + b**3
return t ** 2
def function_replace_spaces(**kwargs):
t = a **2 + b** 3 + c ** 4
def function_dont_replace_spaces():
{**a, **b, **c}
a = 5**~4
b = 5 ** f()
c = -(5**2)
d = 5 ** f["hi"]
e = lazy(lambda **kwargs: 5)
f = f() ** 5
g = a.b**c.d
h = 5 ** funcs.f()
i = funcs.f() ** 5
j = super().name ** 5
k = [(2**idx, value) for idx, value in pairs]
l = mod.weights_[0] == pytest.approx(0.95**100, abs=0.001)
m = [([2**63], [1, 2**63])]
n = count <= 10**5
o = settings(max_examples=10**6)
p = {(k, k**2): v**2 for k, v in pairs}
q = [10**i for i in range(6)]
r = x**y
a = 5.0**~4.0
b = 5.0 ** f()
c = -(5.0**2.0)
d = 5.0 ** f["hi"]
e = lazy(lambda **kwargs: 5)
f = f() ** 5.0
g = a.b**c.d
h = 5.0 ** funcs.f()
i = funcs.f() ** 5.0
j = super().name ** 5.0
k = [(2.0**idx, value) for idx, value in pairs]
l = mod.weights_[0] == pytest.approx(0.95**100, abs=0.001)
m = [([2.0**63.0], [1.0, 2**63.0])]
n = count <= 10**5.0
o = settings(max_examples=10**6.0)
p = {(k, k**2): v**2.0 for k, v in pairs}
q = [10.5**i for i in range(6)]
# WE SHOULD DEFINITELY NOT EAT THESE COMMENTS (https://github.com/psf/black/issues/2873)
if hasattr(view, "sum_of_weights"):
return np.divide( # type: ignore[no-any-return]
view.variance, # type: ignore[union-attr]
view.sum_of_weights, # type: ignore[union-attr]
out=np.full(view.sum_of_weights.shape, np.nan), # type: ignore[union-attr]
where=view.sum_of_weights**2 > view.sum_of_weights_squared, # type: ignore[union-attr]
)
return np.divide(
where=view.sum_of_weights_of_weight_long**2 > view.sum_of_weights_squared, # type: ignore
)
```
## Black Differences
```diff
--- Black
+++ Ruff
@@ -15,7 +15,7 @@
b = 5 ** f()
c = -(5**2)
d = 5 ** f["hi"]
-e = lazy(lambda **kwargs: 5)
+e = lazy(lambda: 5)
f = f() ** 5
g = a.b**c.d
h = 5 ** funcs.f()
@@ -34,7 +34,7 @@
b = 5.0 ** f()
c = -(5.0**2.0)
d = 5.0 ** f["hi"]
-e = lazy(lambda **kwargs: 5)
+e = lazy(lambda: 5)
f = f() ** 5.0
g = a.b**c.d
h = 5.0 ** funcs.f()
```
## Ruff Output
```py
def function(**kwargs):
t = a**2 + b**3
return t**2
def function_replace_spaces(**kwargs):
t = a**2 + b**3 + c**4
def function_dont_replace_spaces():
{**a, **b, **c}
a = 5**~4
b = 5 ** f()
c = -(5**2)
d = 5 ** f["hi"]
e = lazy(lambda: 5)
f = f() ** 5
g = a.b**c.d
h = 5 ** funcs.f()
i = funcs.f() ** 5
j = super().name ** 5
k = [(2**idx, value) for idx, value in pairs]
l = mod.weights_[0] == pytest.approx(0.95**100, abs=0.001)
m = [([2**63], [1, 2**63])]
n = count <= 10**5
o = settings(max_examples=10**6)
p = {(k, k**2): v**2 for k, v in pairs}
q = [10**i for i in range(6)]
r = x**y
a = 5.0**~4.0
b = 5.0 ** f()
c = -(5.0**2.0)
d = 5.0 ** f["hi"]
e = lazy(lambda: 5)
f = f() ** 5.0
g = a.b**c.d
h = 5.0 ** funcs.f()
i = funcs.f() ** 5.0
j = super().name ** 5.0
k = [(2.0**idx, value) for idx, value in pairs]
l = mod.weights_[0] == pytest.approx(0.95**100, abs=0.001)
m = [([2.0**63.0], [1.0, 2**63.0])]
n = count <= 10**5.0
o = settings(max_examples=10**6.0)
p = {(k, k**2): v**2.0 for k, v in pairs}
q = [10.5**i for i in range(6)]
# WE SHOULD DEFINITELY NOT EAT THESE COMMENTS (https://github.com/psf/black/issues/2873)
if hasattr(view, "sum_of_weights"):
return np.divide( # type: ignore[no-any-return]
view.variance, # type: ignore[union-attr]
view.sum_of_weights, # type: ignore[union-attr]
out=np.full(view.sum_of_weights.shape, np.nan), # type: ignore[union-attr]
where=view.sum_of_weights**2 > view.sum_of_weights_squared, # type: ignore[union-attr]
)
return np.divide(
where=view.sum_of_weights_of_weight_long**2 > view.sum_of_weights_squared, # type: ignore
)
```
## Black Output
```py
def function(**kwargs):
t = a**2 + b**3
return t**2
def function_replace_spaces(**kwargs):
t = a**2 + b**3 + c**4
def function_dont_replace_spaces():
{**a, **b, **c}
a = 5**~4
b = 5 ** f()
c = -(5**2)
d = 5 ** f["hi"]
e = lazy(lambda **kwargs: 5)
f = f() ** 5
g = a.b**c.d
h = 5 ** funcs.f()
i = funcs.f() ** 5
j = super().name ** 5
k = [(2**idx, value) for idx, value in pairs]
l = mod.weights_[0] == pytest.approx(0.95**100, abs=0.001)
m = [([2**63], [1, 2**63])]
n = count <= 10**5
o = settings(max_examples=10**6)
p = {(k, k**2): v**2 for k, v in pairs}
q = [10**i for i in range(6)]
r = x**y
a = 5.0**~4.0
b = 5.0 ** f()
c = -(5.0**2.0)
d = 5.0 ** f["hi"]
e = lazy(lambda **kwargs: 5)
f = f() ** 5.0
g = a.b**c.d
h = 5.0 ** funcs.f()
i = funcs.f() ** 5.0
j = super().name ** 5.0
k = [(2.0**idx, value) for idx, value in pairs]
l = mod.weights_[0] == pytest.approx(0.95**100, abs=0.001)
m = [([2.0**63.0], [1.0, 2**63.0])]
n = count <= 10**5.0
o = settings(max_examples=10**6.0)
p = {(k, k**2): v**2.0 for k, v in pairs}
q = [10.5**i for i in range(6)]
# WE SHOULD DEFINITELY NOT EAT THESE COMMENTS (https://github.com/psf/black/issues/2873)
if hasattr(view, "sum_of_weights"):
return np.divide( # type: ignore[no-any-return]
view.variance, # type: ignore[union-attr]
view.sum_of_weights, # type: ignore[union-attr]
out=np.full(view.sum_of_weights.shape, np.nan), # type: ignore[union-attr]
where=view.sum_of_weights**2 > view.sum_of_weights_squared, # type: ignore[union-attr]
)
return np.divide(
where=view.sum_of_weights_of_weight_long**2 > view.sum_of_weights_squared, # type: ignore
)
```

View file

@ -81,6 +81,22 @@ def f(
# ...but we do preserve a trailing comma after the arguments
a = lambda b,: 0
lambda a,: 0
lambda *args,: 0
lambda **kwds,: 0
lambda a, *args,: 0
lambda a, **kwds,: 0
lambda *args, b,: 0
lambda *, b,: 0
lambda *args, **kwds,: 0
lambda a, *args, b,: 0
lambda a, *, b,: 0
lambda a, *args, **kwds,: 0
lambda *args, b, **kwds,: 0
lambda *, b, **kwds,: 0
lambda a, *args, b, **kwds,: 0
lambda a, *, b, **kwds,: 0
```
## Output
@ -162,6 +178,22 @@ def f(
# ...but we do preserve a trailing comma after the arguments
a = lambda b,: 0
lambda a,: 0
lambda *args,: 0
lambda **kwds,: 0
lambda a, *args,: 0
lambda a, **kwds,: 0
lambda *args, b,: 0
lambda: 0
lambda *args, **kwds,: 0
lambda a, *args, b,: 0
lambda a, *, b,: 0
lambda a, *args, **kwds,: 0
lambda *args, b, **kwds,: 0
lambda *, b, **kwds,: 0
lambda a, *args, b, **kwds,: 0
lambda a, *, b, **kwds,: 0
```

View file

@ -237,6 +237,24 @@ def f42(
c,
):
pass
# Check trailing commas are permitted in funcdef argument list.
def f(a, ): pass
def f(*args, ): pass
def f(**kwds, ): pass
def f(a, *args, ): pass
def f(a, **kwds, ): pass
def f(*args, b, ): pass
def f(*, b, ): pass
def f(*args, **kwds, ): pass
def f(a, *args, b, ): pass
def f(a, *, b, ): pass
def f(a, *args, **kwds, ): pass
def f(*args, b, **kwds, ): pass
def f(*, b, **kwds, ): pass
def f(a, *args, b, **kwds, ): pass
def f(a, *, b, **kwds, ): pass
```
## Output
@ -513,6 +531,118 @@ def f42(
c,
):
pass
# Check trailing commas are permitted in funcdef argument list.
def f(
a,
):
pass
def f(
*args,
):
pass
def f(
**kwds,
):
pass
def f(
a,
*args,
):
pass
def f(
a,
**kwds,
):
pass
def f(
*args,
b,
):
pass
def f(
*,
b,
):
pass
def f(
*args,
**kwds,
):
pass
def f(
a,
*args,
b,
):
pass
def f(
a,
*,
b,
):
pass
def f(
a,
*args,
**kwds,
):
pass
def f(
*args,
b,
**kwds,
):
pass
def f(
*,
b,
**kwds,
):
pass
def f(
a,
*args,
b,
**kwds,
):
pass
def f(
a,
*,
b,
**kwds,
):
pass
```