diff --git a/libcst/__init__.py b/libcst/__init__.py deleted file mode 100644 index 62642369..00000000 --- a/libcst/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. diff --git a/libcst/nodes/__init__.py b/libcst/nodes/__init__.py index 74658912..43411dc3 100644 --- a/libcst/nodes/__init__.py +++ b/libcst/nodes/__init__.py @@ -22,6 +22,7 @@ from libcst.nodes._expression import ( Attribute, Await, BaseAtom, + BaseElement, BaseExpression, BaseFormattedStringContent, BinaryOperation, @@ -30,6 +31,7 @@ from libcst.nodes._expression import ( Comparison, ComparisonTarget, ConcatenatedString, + Element, Ellipses, ExtSlice, Float, @@ -54,7 +56,9 @@ from libcst.nodes._expression import ( SimpleString, Slice, Starred, + StarredElement, Subscript, + Tuple, UnaryOperation, Yield, ) diff --git a/libcst/nodes/_expression.py b/libcst/nodes/_expression.py index e037a906..3a4b24da 100644 --- a/libcst/nodes/_expression.py +++ b/libcst/nodes/_expression.py @@ -6,7 +6,7 @@ # pyre-strict import re -from abc import ABC +from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass from enum import Enum, auto @@ -2048,3 +2048,187 @@ class Yield(BaseExpression): value._codegen(state, default_space="") elif value is not None: value._codegen(state) + + +class BaseElement(CSTNode, ABC): + """ + An element of a literal list, tuple, or set. For elements of a literal dict, see + BaseMappingElement. (TODO) + """ + + # pyre-fixme[13]: Attribute `value` is never initialized. + value: BaseExpression + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + # Used if we don't have a comma, otherwise the parser will attach the whitespace to + # the comma. + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + @abstractmethod + def _codegen( + self, + state: CodegenState, + default_comma: bool = False, + default_comma_whitespace: bool = False, # False for a single-item tuple + ) -> None: + ... + + +@add_slots +@dataclass(frozen=True) +class Element(BaseElement): + value: BaseExpression + + # Any trailing comma + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Whitespace + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Element": + return Element( + value=visit_required("value", self.value, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ), + ) + + def _codegen( + self, + state: CodegenState, + default_comma: bool = False, + default_comma_whitespace: bool = False, + ) -> None: + self.value._codegen(state) + self.whitespace_after._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + if default_comma_whitespace: + state.tokens.append(", ") + else: + state.tokens.append(",") + elif isinstance(comma, Comma): + comma._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class StarredElement(BaseElement, _BaseParenthesizedNode): + value: BaseExpression + + # Any trailing comma + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Parentheses around the leading asterisk and the value. Functionally equivalent to + # parentheses around the value, but in a different position. + lpar: Sequence[LeftParen] = () + rpar: Sequence[RightParen] = () + + # Whitespace + whitespace_before_value: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "StarredElement": + return StarredElement( + lpar=visit_sequence("lpar", self.lpar, visitor), + whitespace_before_value=visit_required( + "whitespace_before_value", self.whitespace_before_value, visitor + ), + value=visit_required("value", self.value, visitor), + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ), + rpar=visit_sequence("rpar", self.rpar, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + ) + + def _codegen( + self, + state: CodegenState, + default_comma: bool = False, + default_comma_whitespace: bool = False, + ) -> None: + with self._parenthesize(state): + state.tokens.append("*") + self.whitespace_before_value._codegen(state) + self.value._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + if default_comma_whitespace: + state.tokens.append(", ") + else: + state.tokens.append(",") + elif isinstance(comma, Comma): + comma._codegen(state) + self.whitespace_after._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Tuple(BaseExpression): + elements: Sequence[Union[Element, StarredElement]] + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = (LeftParen(),) + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = (RightParen(),) + + def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: + if super(Tuple, self)._safe_to_use_with_word_operator(position): + # if we have parenthesis, we're safe. + return True + # elements[-1] and elements[0] must exist past this point, because + # we're not parenthesized, meaning we must have at least one element. + elements = self.elements + if position == ExpressionPosition.LEFT: + last_element = elements[-1] + return ( + not last_element.whitespace_after.empty + or isinstance(last_element.comma, Comma) + or ( + isinstance(last_element, StarredElement) + and len(last_element.rpar) > 0 + ) + or last_element.value._safe_to_use_with_word_operator(position) + ) + else: # ExpressionPosition.RIGHT + first_element = elements[0] + # starred elements are always safe because they begin with ( or * + return isinstance( + first_element, StarredElement + ) or first_element.value._safe_to_use_with_word_operator(position) + + def _validate(self) -> None: + # Paren validation and such + super(Tuple, self)._validate() + + if len(self.elements) == 0: + if len(self.lpar) == 0: # assumes len(lpar) == len(rpar), via superclass + raise CSTValidationError( + "A zero-length tuple must be wrapped in parentheses." + ) + # Invalid commas aren't possible, because MaybeSentinel will ensure that there + # is a comma where required. + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Tuple": + return Tuple( + lpar=visit_sequence("lpar", self.lpar, visitor), + elements=visit_sequence("elements", self.elements, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + elements = self.elements + if len(elements) == 1: + elements[0]._codegen( + state, default_comma=True, default_comma_whitespace=False + ) + else: + for idx, el in enumerate(elements): + el._codegen( + state, + default_comma=(idx < len(elements) - 1), + default_comma_whitespace=True, + ) diff --git a/libcst/nodes/_op.py b/libcst/nodes/_op.py index 373b4552..0bcffb7d 100644 --- a/libcst/nodes/_op.py +++ b/libcst/nodes/_op.py @@ -166,8 +166,13 @@ class Colon(_BaseOneTokenOp): @dataclass(frozen=True) class Comma(_BaseOneTokenOp): """ - Used by ImportAlias as a separator between subsequent ImportAliases contained - within a Import or ImportFrom. + Syntactic trivia used as a separator between subsequent items in various parts of + the grammar. + + Some use-cases are: + - Import or ImportFrom + - Function arguments + - Tuple/list/set/dict elements """ whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace("") diff --git a/libcst/nodes/_statement.py b/libcst/nodes/_statement.py index cfb1fba2..38bdaa11 100644 --- a/libcst/nodes/_statement.py +++ b/libcst/nodes/_statement.py @@ -27,6 +27,7 @@ from libcst.nodes._expression import ( Name, Parameters, RightParen, + Tuple, ) from libcst.nodes._internal import ( CodegenState, @@ -1578,7 +1579,9 @@ class For(BaseCompoundStatement): """ # The target of the iterator in the for statement. - target: Name # TODO: Should be a Union[Name, Tuple, List] once we support this. + target: Union[ + Name, Tuple + ] # TODO: Should be a Union[Name, Tuple, List] once we support this. # The iterable expression we will loop over. iter: BaseExpression diff --git a/libcst/nodes/tests/test_tuple.py b/libcst/nodes/tests/test_tuple.py new file mode 100644 index 00000000..ff2b80a0 --- /dev/null +++ b/libcst/nodes/tests/test_tuple.py @@ -0,0 +1,221 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.testing.utils import data_provider + + +class TupleTest(CSTNodeTest): + @data_provider( + ( + # zero-element tuple + (cst.Tuple([]), "()"), + # one-element tuple + (cst.Tuple([cst.Element(cst.Name("single_element"))]), "(single_element,)"), + # two-element tuple + ( + cst.Tuple([cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))]), + "(one, two)", + ), + # remove parenthesis + ( + cst.Tuple( + [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))], + lpar=[], + rpar=[], + ), + "one, two", + ), + # add extra parenthesis + ( + cst.Tuple( + [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))], + lpar=[cst.LeftParen(), cst.LeftParen()], + rpar=[cst.RightParen(), cst.RightParen()], + ), + "((one, two))", + ), + # starred element + ( + cst.Tuple( + [cst.Element(cst.Name("one")), cst.StarredElement(cst.Name("two"))] + ), + "(one, *two)", + ), + # custom comma on Element + ( + cst.Tuple( + [ + cst.Element(cst.Name("one")), + cst.Element(cst.Name("two"), comma=cst.Comma()), + ] + ), + "(one, two,)", + ), + # custom comma on StarredElement + ( + cst.Tuple( + [ + cst.Element(cst.Name("one")), + cst.StarredElement(cst.Name("two"), comma=cst.Comma()), + ] + ), + "(one, *two,)", + ), + # custom parenthesis on StarredElement + ( + cst.Tuple( + [ + cst.StarredElement( + cst.Name("abc"), + lpar=[cst.LeftParen()], + rpar=[cst.RightParen()], + ) + ] + ), + "((*abc),)", + ), + # custom whitespace on Element + ( + cst.Tuple( + [ + cst.Element(cst.Name("one")), + cst.Element( + cst.Name("two"), whitespace_after=cst.SimpleWhitespace(" ") + ), + ], + lpar=[], + rpar=[], # rpar can't own the trailing whitespace if it's not there + ), + "one, two ", + ), + # custom whitespace on StarredElement + ( + cst.Tuple( + [ + cst.Element(cst.Name("one")), + cst.StarredElement( + cst.Name("two"), + whitespace_before_value=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + lpar=[cst.LeftParen()], + rpar=[cst.RightParen()], + ), + ], + lpar=[], + rpar=[], # rpar can't own the trailing whitespace if it's not there + ), + "one, (* two) ", + ), + # missing spaces around tuple, okay with parenthesis + ( + cst.For( + target=cst.Tuple( + [cst.Element(cst.Name("k")), cst.Element(cst.Name("v"))] + ), + iter=cst.Name("abc"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_after_for=cst.SimpleWhitespace(""), + whitespace_before_in=cst.SimpleWhitespace(""), + ), + "for(k, v)in abc: pass\n", + ), + # no spaces around tuple, but using values that are parenthesized + ( + cst.For( + target=cst.Tuple( + [ + cst.Element( + cst.Name( + "k", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] + ) + ), + cst.Element( + cst.Name( + "v", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] + ) + ), + ], + lpar=[], + rpar=[], + ), + iter=cst.Name("abc"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_after_for=cst.SimpleWhitespace(""), + whitespace_before_in=cst.SimpleWhitespace(""), + ), + "for(k), (v)in abc: pass\n", + ), + # starred elements are safe to use without a space before them + ( + cst.For( + target=cst.Tuple( + [cst.StarredElement(cst.Name("foo"))], lpar=[], rpar=[] + ), + iter=cst.Name("bar"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_after_for=cst.SimpleWhitespace(""), + ), + "for*foo, in bar: pass\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + ( + lambda: cst.Tuple([], lpar=[], rpar=[]), + "A zero-length tuple must be wrapped in parentheses.", + ), + ( + lambda: cst.Tuple( + [cst.Element(cst.Name("mismatched"))], + lpar=[cst.LeftParen(), cst.LeftParen()], + rpar=[cst.RightParen()], + ), + "unbalanced parens", + ), + ( + lambda: cst.For( + target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]), + iter=cst.Name("it"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_after_for=cst.SimpleWhitespace(""), + ), + "Must have at least one space after 'for' keyword.", + ), + ( + lambda: cst.For( + target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]), + iter=cst.Name("it"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_before_in=cst.SimpleWhitespace(""), + ), + "Must have at least one space before 'in' keyword.", + ), + # an additional check for StarredElement, since it's a separate codepath + ( + lambda: cst.For( + target=cst.Tuple( + [cst.StarredElement(cst.Name("el"))], lpar=[], rpar=[] + ), + iter=cst.Name("it"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_before_in=cst.SimpleWhitespace(""), + ), + "Must have at least one space before 'in' keyword.", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re)