Add node classes for ListComp

This adds support for

```
[a for b in c if d]
```

I added support for GeneratorExp in a previous commit, so this is a
simple extension of that.
This commit is contained in:
Benjamin Woodruff 2019-07-16 14:35:59 -07:00
parent 6557122db6
commit 2295d062fc
3 changed files with 131 additions and 7 deletions

View file

@ -52,6 +52,7 @@ from libcst.nodes._expression import (
LeftParen,
LeftSquareBracket,
List,
ListComp,
Name,
Number,
Param,

View file

@ -2241,10 +2241,11 @@ class Tuple(BaseAtom, BaseAssignTargetExpression, BaseDelTargetExpression):
)
@add_slots
@dataclass(frozen=True)
class List(BaseAtom, BaseAssignTargetExpression, BaseDelTargetExpression):
elements: Sequence[Union[Element, StarredElement]]
class BaseList(BaseAtom, ABC):
"""
A Base class for List and ListComp, which both result in a list object when
evaluated.
"""
# Open bracket surrounding the list
lbracket: LeftSquareBracket = LeftSquareBracket()
@ -2261,6 +2262,22 @@ class List(BaseAtom, BaseAssignTargetExpression, BaseDelTargetExpression):
def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
return True
@contextmanager
def _bracketize(self, state: CodegenState) -> Generator[None, None, None]:
self.lbracket._codegen(state)
yield
self.rbracket._codegen(state)
@add_slots
@dataclass(frozen=True)
class List(BaseList, BaseAssignTargetExpression, BaseDelTargetExpression):
elements: Sequence[Union[Element, StarredElement]]
lbracket: LeftSquareBracket = LeftSquareBracket()
rbracket: RightSquareBracket = RightSquareBracket()
lpar: Sequence[LeftParen] = ()
rpar: Sequence[RightParen] = ()
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "List":
return List(
lpar=visit_sequence("lpar", self.lpar, visitor),
@ -2271,8 +2288,7 @@ class List(BaseAtom, BaseAssignTargetExpression, BaseDelTargetExpression):
)
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.lbracket._codegen(state)
with self._parenthesize(state), self._bracketize(state):
elements = self.elements
for idx, el in enumerate(elements):
el._codegen(
@ -2280,7 +2296,6 @@ class List(BaseAtom, BaseAssignTargetExpression, BaseDelTargetExpression):
default_comma=(idx < len(elements) - 1),
default_comma_whitespace=True,
)
self.rbracket._codegen(state)
@add_slots
@ -2532,3 +2547,33 @@ class GeneratorExp(BaseSimpleComp):
with self._parenthesize(state):
self.elt._codegen(state)
self.for_in._codegen(state)
@add_slots
@dataclass(frozen=True)
class ListComp(BaseList, BaseSimpleComp):
elt: BaseAssignTargetExpression
for_in: CompFor
lbracket: LeftSquareBracket = LeftSquareBracket()
rbracket: RightSquareBracket = RightSquareBracket()
lpar: Sequence[LeftParen] = ()
rpar: Sequence[RightParen] = ()
def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
# ListComp is always surrounded in square brackets
return True
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ListComp":
return ListComp(
lpar=visit_sequence("lpar", self.lpar, visitor),
lbracket=visit_required("lbracket", self.lbracket, visitor),
elt=visit_required("elt", self.elt, visitor),
for_in=visit_required("for_in", self.for_in, visitor),
rbracket=visit_required("rbracket", self.rbracket, visitor),
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state), self._bracketize(state):
self.elt._codegen(state)
self.for_in._codegen(state)

View file

@ -21,6 +21,13 @@ class SimpleCompTest(CSTNodeTest):
),
"code": "(a for b in c)",
},
# simple ListComp
{
"node": cst.ListComp(
cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))
),
"code": "[a for b in c]",
},
# async GeneratorExp
{
"node": cst.GeneratorExp(
@ -121,6 +128,24 @@ class SimpleCompTest(CSTNodeTest):
),
"code": "(\fa for b in c\tif\t\td\f\f)",
},
# custom whitespace around ListComp's brackets
{
"node": cst.ListComp(
cst.Name("a"),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
lbracket=cst.LeftSquareBracket(
whitespace_after=cst.SimpleWhitespace("\t")
),
rbracket=cst.RightSquareBracket(
whitespace_before=cst.SimpleWhitespace("\t\t")
),
lpar=[cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f"))],
rpar=[
cst.RightParen(whitespace_before=cst.SimpleWhitespace("\f\f"))
],
),
"code": "(\f[\ta for b in c\t\t]\f\f)",
},
# no whitespace between elements
{
"node": cst.GeneratorExp(
@ -163,6 +188,50 @@ class SimpleCompTest(CSTNodeTest):
),
"code": "((a)for(b)in(c)if(d)for(e)in(f))",
},
# no whitespace before/after GeneratorExp is valid
{
"node": cst.Comparison(
cst.GeneratorExp(
cst.Name("a"),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
),
[
cst.ComparisonTarget(
cst.Is(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
),
cst.GeneratorExp(
cst.Name("d"),
cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")),
),
)
],
),
"code": "(a for b in c)is(d for e in f)",
},
# no whitespace before/after ListComp is valid
{
"node": cst.Comparison(
cst.ListComp(
cst.Name("a"),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
),
[
cst.ComparisonTarget(
cst.Is(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
),
cst.ListComp(
cst.Name("d"),
cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")),
),
)
],
),
"code": "[a for b in c]is[d for e in f]",
},
]
)
def test_valid(self, **kwargs: Any) -> None:
@ -179,6 +248,15 @@ class SimpleCompTest(CSTNodeTest):
),
"unbalanced parens",
),
(
lambda: cst.ListComp(
cst.Name("a"),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
lpar=[cst.LeftParen(), cst.LeftParen()],
rpar=[cst.RightParen()],
),
"unbalanced parens",
),
(
lambda: cst.GeneratorExp(
cst.Name("a"),