Add nodedefs (no parsing yet) needed for Tuple

This ended up being pretty complicated, so the parser stuff will come in
another diff.

Hopefully this should set up up nicely for dicts, sets, and lists too.
This commit is contained in:
Benjamin Woodruff 2019-06-05 15:23:22 -07:00
parent 21ace9df33
commit ca15eb2685
6 changed files with 421 additions and 8 deletions

View file

@ -1,4 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

View file

@ -22,6 +22,7 @@ from libcst.nodes._expression import (
Attribute,
Await,
BaseAtom,
BaseElement,
BaseExpression,
BaseFormattedStringContent,
BinaryOperation,
@ -30,6 +31,7 @@ from libcst.nodes._expression import (
Comparison,
ComparisonTarget,
ConcatenatedString,
Element,
Ellipses,
ExtSlice,
Float,
@ -54,7 +56,9 @@ from libcst.nodes._expression import (
SimpleString,
Slice,
Starred,
StarredElement,
Subscript,
Tuple,
UnaryOperation,
Yield,
)

View file

@ -6,7 +6,7 @@
# pyre-strict
import re
from abc import ABC
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
@ -2048,3 +2048,187 @@ class Yield(BaseExpression):
value._codegen(state, default_space="")
elif value is not None:
value._codegen(state)
class BaseElement(CSTNode, ABC):
"""
An element of a literal list, tuple, or set. For elements of a literal dict, see
BaseMappingElement. (TODO)
"""
# pyre-fixme[13]: Attribute `value` is never initialized.
value: BaseExpression
comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT
# Used if we don't have a comma, otherwise the parser will attach the whitespace to
# the comma.
whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("")
@abstractmethod
def _codegen(
self,
state: CodegenState,
default_comma: bool = False,
default_comma_whitespace: bool = False, # False for a single-item tuple
) -> None:
...
@add_slots
@dataclass(frozen=True)
class Element(BaseElement):
value: BaseExpression
# Any trailing comma
comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT
# Whitespace
whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("")
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Element":
return Element(
value=visit_required("value", self.value, visitor),
comma=visit_sentinel("comma", self.comma, visitor),
whitespace_after=visit_required(
"whitespace_after", self.whitespace_after, visitor
),
)
def _codegen(
self,
state: CodegenState,
default_comma: bool = False,
default_comma_whitespace: bool = False,
) -> None:
self.value._codegen(state)
self.whitespace_after._codegen(state)
comma = self.comma
if comma is MaybeSentinel.DEFAULT and default_comma:
if default_comma_whitespace:
state.tokens.append(", ")
else:
state.tokens.append(",")
elif isinstance(comma, Comma):
comma._codegen(state)
@add_slots
@dataclass(frozen=True)
class StarredElement(BaseElement, _BaseParenthesizedNode):
value: BaseExpression
# Any trailing comma
comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT
# Parentheses around the leading asterisk and the value. Functionally equivalent to
# parentheses around the value, but in a different position.
lpar: Sequence[LeftParen] = ()
rpar: Sequence[RightParen] = ()
# Whitespace
whitespace_before_value: BaseParenthesizableWhitespace = SimpleWhitespace("")
whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("")
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "StarredElement":
return StarredElement(
lpar=visit_sequence("lpar", self.lpar, visitor),
whitespace_before_value=visit_required(
"whitespace_before_value", self.whitespace_before_value, visitor
),
value=visit_required("value", self.value, visitor),
whitespace_after=visit_required(
"whitespace_after", self.whitespace_after, visitor
),
rpar=visit_sequence("rpar", self.rpar, visitor),
comma=visit_sentinel("comma", self.comma, visitor),
)
def _codegen(
self,
state: CodegenState,
default_comma: bool = False,
default_comma_whitespace: bool = False,
) -> None:
with self._parenthesize(state):
state.tokens.append("*")
self.whitespace_before_value._codegen(state)
self.value._codegen(state)
comma = self.comma
if comma is MaybeSentinel.DEFAULT and default_comma:
if default_comma_whitespace:
state.tokens.append(", ")
else:
state.tokens.append(",")
elif isinstance(comma, Comma):
comma._codegen(state)
self.whitespace_after._codegen(state)
@add_slots
@dataclass(frozen=True)
class Tuple(BaseExpression):
elements: Sequence[Union[Element, StarredElement]]
# Sequence of open parenthesis for precedence dictation.
lpar: Sequence[LeftParen] = (LeftParen(),)
# Sequence of close parenthesis for precedence dictation.
rpar: Sequence[RightParen] = (RightParen(),)
def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
if super(Tuple, self)._safe_to_use_with_word_operator(position):
# if we have parenthesis, we're safe.
return True
# elements[-1] and elements[0] must exist past this point, because
# we're not parenthesized, meaning we must have at least one element.
elements = self.elements
if position == ExpressionPosition.LEFT:
last_element = elements[-1]
return (
not last_element.whitespace_after.empty
or isinstance(last_element.comma, Comma)
or (
isinstance(last_element, StarredElement)
and len(last_element.rpar) > 0
)
or last_element.value._safe_to_use_with_word_operator(position)
)
else: # ExpressionPosition.RIGHT
first_element = elements[0]
# starred elements are always safe because they begin with ( or *
return isinstance(
first_element, StarredElement
) or first_element.value._safe_to_use_with_word_operator(position)
def _validate(self) -> None:
# Paren validation and such
super(Tuple, self)._validate()
if len(self.elements) == 0:
if len(self.lpar) == 0: # assumes len(lpar) == len(rpar), via superclass
raise CSTValidationError(
"A zero-length tuple must be wrapped in parentheses."
)
# Invalid commas aren't possible, because MaybeSentinel will ensure that there
# is a comma where required.
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Tuple":
return Tuple(
lpar=visit_sequence("lpar", self.lpar, visitor),
elements=visit_sequence("elements", self.elements, visitor),
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
with self._parenthesize(state):
elements = self.elements
if len(elements) == 1:
elements[0]._codegen(
state, default_comma=True, default_comma_whitespace=False
)
else:
for idx, el in enumerate(elements):
el._codegen(
state,
default_comma=(idx < len(elements) - 1),
default_comma_whitespace=True,
)

View file

@ -166,8 +166,13 @@ class Colon(_BaseOneTokenOp):
@dataclass(frozen=True)
class Comma(_BaseOneTokenOp):
"""
Used by ImportAlias as a separator between subsequent ImportAliases contained
within a Import or ImportFrom.
Syntactic trivia used as a separator between subsequent items in various parts of
the grammar.
Some use-cases are:
- Import or ImportFrom
- Function arguments
- Tuple/list/set/dict elements
"""
whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace("")

View file

@ -27,6 +27,7 @@ from libcst.nodes._expression import (
Name,
Parameters,
RightParen,
Tuple,
)
from libcst.nodes._internal import (
CodegenState,
@ -1578,7 +1579,9 @@ class For(BaseCompoundStatement):
"""
# The target of the iterator in the for statement.
target: Name # TODO: Should be a Union[Name, Tuple, List] once we support this.
target: Union[
Name, Tuple
] # TODO: Should be a Union[Name, Tuple, List] once we support this.
# The iterable expression we will loop over.
iter: BaseExpression

View file

@ -0,0 +1,221 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Callable
import libcst.nodes as cst
from libcst.nodes.tests.base import CSTNodeTest
from libcst.testing.utils import data_provider
class TupleTest(CSTNodeTest):
@data_provider(
(
# zero-element tuple
(cst.Tuple([]), "()"),
# one-element tuple
(cst.Tuple([cst.Element(cst.Name("single_element"))]), "(single_element,)"),
# two-element tuple
(
cst.Tuple([cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))]),
"(one, two)",
),
# remove parenthesis
(
cst.Tuple(
[cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))],
lpar=[],
rpar=[],
),
"one, two",
),
# add extra parenthesis
(
cst.Tuple(
[cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))],
lpar=[cst.LeftParen(), cst.LeftParen()],
rpar=[cst.RightParen(), cst.RightParen()],
),
"((one, two))",
),
# starred element
(
cst.Tuple(
[cst.Element(cst.Name("one")), cst.StarredElement(cst.Name("two"))]
),
"(one, *two)",
),
# custom comma on Element
(
cst.Tuple(
[
cst.Element(cst.Name("one")),
cst.Element(cst.Name("two"), comma=cst.Comma()),
]
),
"(one, two,)",
),
# custom comma on StarredElement
(
cst.Tuple(
[
cst.Element(cst.Name("one")),
cst.StarredElement(cst.Name("two"), comma=cst.Comma()),
]
),
"(one, *two,)",
),
# custom parenthesis on StarredElement
(
cst.Tuple(
[
cst.StarredElement(
cst.Name("abc"),
lpar=[cst.LeftParen()],
rpar=[cst.RightParen()],
)
]
),
"((*abc),)",
),
# custom whitespace on Element
(
cst.Tuple(
[
cst.Element(cst.Name("one")),
cst.Element(
cst.Name("two"), whitespace_after=cst.SimpleWhitespace(" ")
),
],
lpar=[],
rpar=[], # rpar can't own the trailing whitespace if it's not there
),
"one, two ",
),
# custom whitespace on StarredElement
(
cst.Tuple(
[
cst.Element(cst.Name("one")),
cst.StarredElement(
cst.Name("two"),
whitespace_before_value=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
lpar=[cst.LeftParen()],
rpar=[cst.RightParen()],
),
],
lpar=[],
rpar=[], # rpar can't own the trailing whitespace if it's not there
),
"one, (* two) ",
),
# missing spaces around tuple, okay with parenthesis
(
cst.For(
target=cst.Tuple(
[cst.Element(cst.Name("k")), cst.Element(cst.Name("v"))]
),
iter=cst.Name("abc"),
body=cst.SimpleStatementSuite([cst.Pass()]),
whitespace_after_for=cst.SimpleWhitespace(""),
whitespace_before_in=cst.SimpleWhitespace(""),
),
"for(k, v)in abc: pass\n",
),
# no spaces around tuple, but using values that are parenthesized
(
cst.For(
target=cst.Tuple(
[
cst.Element(
cst.Name(
"k", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]
)
),
cst.Element(
cst.Name(
"v", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]
)
),
],
lpar=[],
rpar=[],
),
iter=cst.Name("abc"),
body=cst.SimpleStatementSuite([cst.Pass()]),
whitespace_after_for=cst.SimpleWhitespace(""),
whitespace_before_in=cst.SimpleWhitespace(""),
),
"for(k), (v)in abc: pass\n",
),
# starred elements are safe to use without a space before them
(
cst.For(
target=cst.Tuple(
[cst.StarredElement(cst.Name("foo"))], lpar=[], rpar=[]
),
iter=cst.Name("bar"),
body=cst.SimpleStatementSuite([cst.Pass()]),
whitespace_after_for=cst.SimpleWhitespace(""),
),
"for*foo, in bar: pass\n",
),
)
)
def test_valid(self, node: cst.CSTNode, code: str) -> None:
self.validate_node(node, code)
@data_provider(
(
(
lambda: cst.Tuple([], lpar=[], rpar=[]),
"A zero-length tuple must be wrapped in parentheses.",
),
(
lambda: cst.Tuple(
[cst.Element(cst.Name("mismatched"))],
lpar=[cst.LeftParen(), cst.LeftParen()],
rpar=[cst.RightParen()],
),
"unbalanced parens",
),
(
lambda: cst.For(
target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]),
iter=cst.Name("it"),
body=cst.SimpleStatementSuite([cst.Pass()]),
whitespace_after_for=cst.SimpleWhitespace(""),
),
"Must have at least one space after 'for' keyword.",
),
(
lambda: cst.For(
target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]),
iter=cst.Name("it"),
body=cst.SimpleStatementSuite([cst.Pass()]),
whitespace_before_in=cst.SimpleWhitespace(""),
),
"Must have at least one space before 'in' keyword.",
),
# an additional check for StarredElement, since it's a separate codepath
(
lambda: cst.For(
target=cst.Tuple(
[cst.StarredElement(cst.Name("el"))], lpar=[], rpar=[]
),
iter=cst.Name("it"),
body=cst.SimpleStatementSuite([cst.Pass()]),
whitespace_before_in=cst.SimpleWhitespace(""),
),
"Must have at least one space before 'in' keyword.",
),
)
)
def test_invalid(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
self.assert_invalid(get_node, expected_re)