From ca15eb2685f381e755de0f373e63828fb2b83df1 Mon Sep 17 00:00:00 2001 From: Benjamin Woodruff Date: Wed, 5 Jun 2019 15:23:22 -0700 Subject: [PATCH] Add nodedefs (no parsing yet) needed for Tuple This ended up being pretty complicated, so the parser stuff will come in another diff. Hopefully this should set up up nicely for dicts, sets, and lists too. --- libcst/__init__.py | 4 - libcst/nodes/__init__.py | 4 + libcst/nodes/_expression.py | 186 +++++++++++++++++++++++++- libcst/nodes/_op.py | 9 +- libcst/nodes/_statement.py | 5 +- libcst/nodes/tests/test_tuple.py | 221 +++++++++++++++++++++++++++++++ 6 files changed, 421 insertions(+), 8 deletions(-) delete mode 100644 libcst/__init__.py create mode 100644 libcst/nodes/tests/test_tuple.py diff --git a/libcst/__init__.py b/libcst/__init__.py deleted file mode 100644 index 62642369..00000000 --- a/libcst/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. diff --git a/libcst/nodes/__init__.py b/libcst/nodes/__init__.py index 74658912..43411dc3 100644 --- a/libcst/nodes/__init__.py +++ b/libcst/nodes/__init__.py @@ -22,6 +22,7 @@ from libcst.nodes._expression import ( Attribute, Await, BaseAtom, + BaseElement, BaseExpression, BaseFormattedStringContent, BinaryOperation, @@ -30,6 +31,7 @@ from libcst.nodes._expression import ( Comparison, ComparisonTarget, ConcatenatedString, + Element, Ellipses, ExtSlice, Float, @@ -54,7 +56,9 @@ from libcst.nodes._expression import ( SimpleString, Slice, Starred, + StarredElement, Subscript, + Tuple, UnaryOperation, Yield, ) diff --git a/libcst/nodes/_expression.py b/libcst/nodes/_expression.py index e037a906..3a4b24da 100644 --- a/libcst/nodes/_expression.py +++ b/libcst/nodes/_expression.py @@ -6,7 +6,7 @@ # pyre-strict import re -from abc import ABC +from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass from enum import Enum, auto @@ -2048,3 +2048,187 @@ class Yield(BaseExpression): value._codegen(state, default_space="") elif value is not None: value._codegen(state) + + +class BaseElement(CSTNode, ABC): + """ + An element of a literal list, tuple, or set. For elements of a literal dict, see + BaseMappingElement. (TODO) + """ + + # pyre-fixme[13]: Attribute `value` is never initialized. + value: BaseExpression + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + # Used if we don't have a comma, otherwise the parser will attach the whitespace to + # the comma. + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + @abstractmethod + def _codegen( + self, + state: CodegenState, + default_comma: bool = False, + default_comma_whitespace: bool = False, # False for a single-item tuple + ) -> None: + ... + + +@add_slots +@dataclass(frozen=True) +class Element(BaseElement): + value: BaseExpression + + # Any trailing comma + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Whitespace + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Element": + return Element( + value=visit_required("value", self.value, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ), + ) + + def _codegen( + self, + state: CodegenState, + default_comma: bool = False, + default_comma_whitespace: bool = False, + ) -> None: + self.value._codegen(state) + self.whitespace_after._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + if default_comma_whitespace: + state.tokens.append(", ") + else: + state.tokens.append(",") + elif isinstance(comma, Comma): + comma._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class StarredElement(BaseElement, _BaseParenthesizedNode): + value: BaseExpression + + # Any trailing comma + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Parentheses around the leading asterisk and the value. Functionally equivalent to + # parentheses around the value, but in a different position. + lpar: Sequence[LeftParen] = () + rpar: Sequence[RightParen] = () + + # Whitespace + whitespace_before_value: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "StarredElement": + return StarredElement( + lpar=visit_sequence("lpar", self.lpar, visitor), + whitespace_before_value=visit_required( + "whitespace_before_value", self.whitespace_before_value, visitor + ), + value=visit_required("value", self.value, visitor), + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ), + rpar=visit_sequence("rpar", self.rpar, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + ) + + def _codegen( + self, + state: CodegenState, + default_comma: bool = False, + default_comma_whitespace: bool = False, + ) -> None: + with self._parenthesize(state): + state.tokens.append("*") + self.whitespace_before_value._codegen(state) + self.value._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + if default_comma_whitespace: + state.tokens.append(", ") + else: + state.tokens.append(",") + elif isinstance(comma, Comma): + comma._codegen(state) + self.whitespace_after._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Tuple(BaseExpression): + elements: Sequence[Union[Element, StarredElement]] + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = (LeftParen(),) + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = (RightParen(),) + + def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: + if super(Tuple, self)._safe_to_use_with_word_operator(position): + # if we have parenthesis, we're safe. + return True + # elements[-1] and elements[0] must exist past this point, because + # we're not parenthesized, meaning we must have at least one element. + elements = self.elements + if position == ExpressionPosition.LEFT: + last_element = elements[-1] + return ( + not last_element.whitespace_after.empty + or isinstance(last_element.comma, Comma) + or ( + isinstance(last_element, StarredElement) + and len(last_element.rpar) > 0 + ) + or last_element.value._safe_to_use_with_word_operator(position) + ) + else: # ExpressionPosition.RIGHT + first_element = elements[0] + # starred elements are always safe because they begin with ( or * + return isinstance( + first_element, StarredElement + ) or first_element.value._safe_to_use_with_word_operator(position) + + def _validate(self) -> None: + # Paren validation and such + super(Tuple, self)._validate() + + if len(self.elements) == 0: + if len(self.lpar) == 0: # assumes len(lpar) == len(rpar), via superclass + raise CSTValidationError( + "A zero-length tuple must be wrapped in parentheses." + ) + # Invalid commas aren't possible, because MaybeSentinel will ensure that there + # is a comma where required. + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Tuple": + return Tuple( + lpar=visit_sequence("lpar", self.lpar, visitor), + elements=visit_sequence("elements", self.elements, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + elements = self.elements + if len(elements) == 1: + elements[0]._codegen( + state, default_comma=True, default_comma_whitespace=False + ) + else: + for idx, el in enumerate(elements): + el._codegen( + state, + default_comma=(idx < len(elements) - 1), + default_comma_whitespace=True, + ) diff --git a/libcst/nodes/_op.py b/libcst/nodes/_op.py index 373b4552..0bcffb7d 100644 --- a/libcst/nodes/_op.py +++ b/libcst/nodes/_op.py @@ -166,8 +166,13 @@ class Colon(_BaseOneTokenOp): @dataclass(frozen=True) class Comma(_BaseOneTokenOp): """ - Used by ImportAlias as a separator between subsequent ImportAliases contained - within a Import or ImportFrom. + Syntactic trivia used as a separator between subsequent items in various parts of + the grammar. + + Some use-cases are: + - Import or ImportFrom + - Function arguments + - Tuple/list/set/dict elements """ whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace("") diff --git a/libcst/nodes/_statement.py b/libcst/nodes/_statement.py index cfb1fba2..38bdaa11 100644 --- a/libcst/nodes/_statement.py +++ b/libcst/nodes/_statement.py @@ -27,6 +27,7 @@ from libcst.nodes._expression import ( Name, Parameters, RightParen, + Tuple, ) from libcst.nodes._internal import ( CodegenState, @@ -1578,7 +1579,9 @@ class For(BaseCompoundStatement): """ # The target of the iterator in the for statement. - target: Name # TODO: Should be a Union[Name, Tuple, List] once we support this. + target: Union[ + Name, Tuple + ] # TODO: Should be a Union[Name, Tuple, List] once we support this. # The iterable expression we will loop over. iter: BaseExpression diff --git a/libcst/nodes/tests/test_tuple.py b/libcst/nodes/tests/test_tuple.py new file mode 100644 index 00000000..ff2b80a0 --- /dev/null +++ b/libcst/nodes/tests/test_tuple.py @@ -0,0 +1,221 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.testing.utils import data_provider + + +class TupleTest(CSTNodeTest): + @data_provider( + ( + # zero-element tuple + (cst.Tuple([]), "()"), + # one-element tuple + (cst.Tuple([cst.Element(cst.Name("single_element"))]), "(single_element,)"), + # two-element tuple + ( + cst.Tuple([cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))]), + "(one, two)", + ), + # remove parenthesis + ( + cst.Tuple( + [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))], + lpar=[], + rpar=[], + ), + "one, two", + ), + # add extra parenthesis + ( + cst.Tuple( + [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))], + lpar=[cst.LeftParen(), cst.LeftParen()], + rpar=[cst.RightParen(), cst.RightParen()], + ), + "((one, two))", + ), + # starred element + ( + cst.Tuple( + [cst.Element(cst.Name("one")), cst.StarredElement(cst.Name("two"))] + ), + "(one, *two)", + ), + # custom comma on Element + ( + cst.Tuple( + [ + cst.Element(cst.Name("one")), + cst.Element(cst.Name("two"), comma=cst.Comma()), + ] + ), + "(one, two,)", + ), + # custom comma on StarredElement + ( + cst.Tuple( + [ + cst.Element(cst.Name("one")), + cst.StarredElement(cst.Name("two"), comma=cst.Comma()), + ] + ), + "(one, *two,)", + ), + # custom parenthesis on StarredElement + ( + cst.Tuple( + [ + cst.StarredElement( + cst.Name("abc"), + lpar=[cst.LeftParen()], + rpar=[cst.RightParen()], + ) + ] + ), + "((*abc),)", + ), + # custom whitespace on Element + ( + cst.Tuple( + [ + cst.Element(cst.Name("one")), + cst.Element( + cst.Name("two"), whitespace_after=cst.SimpleWhitespace(" ") + ), + ], + lpar=[], + rpar=[], # rpar can't own the trailing whitespace if it's not there + ), + "one, two ", + ), + # custom whitespace on StarredElement + ( + cst.Tuple( + [ + cst.Element(cst.Name("one")), + cst.StarredElement( + cst.Name("two"), + whitespace_before_value=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + lpar=[cst.LeftParen()], + rpar=[cst.RightParen()], + ), + ], + lpar=[], + rpar=[], # rpar can't own the trailing whitespace if it's not there + ), + "one, (* two) ", + ), + # missing spaces around tuple, okay with parenthesis + ( + cst.For( + target=cst.Tuple( + [cst.Element(cst.Name("k")), cst.Element(cst.Name("v"))] + ), + iter=cst.Name("abc"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_after_for=cst.SimpleWhitespace(""), + whitespace_before_in=cst.SimpleWhitespace(""), + ), + "for(k, v)in abc: pass\n", + ), + # no spaces around tuple, but using values that are parenthesized + ( + cst.For( + target=cst.Tuple( + [ + cst.Element( + cst.Name( + "k", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] + ) + ), + cst.Element( + cst.Name( + "v", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] + ) + ), + ], + lpar=[], + rpar=[], + ), + iter=cst.Name("abc"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_after_for=cst.SimpleWhitespace(""), + whitespace_before_in=cst.SimpleWhitespace(""), + ), + "for(k), (v)in abc: pass\n", + ), + # starred elements are safe to use without a space before them + ( + cst.For( + target=cst.Tuple( + [cst.StarredElement(cst.Name("foo"))], lpar=[], rpar=[] + ), + iter=cst.Name("bar"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_after_for=cst.SimpleWhitespace(""), + ), + "for*foo, in bar: pass\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + ( + lambda: cst.Tuple([], lpar=[], rpar=[]), + "A zero-length tuple must be wrapped in parentheses.", + ), + ( + lambda: cst.Tuple( + [cst.Element(cst.Name("mismatched"))], + lpar=[cst.LeftParen(), cst.LeftParen()], + rpar=[cst.RightParen()], + ), + "unbalanced parens", + ), + ( + lambda: cst.For( + target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]), + iter=cst.Name("it"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_after_for=cst.SimpleWhitespace(""), + ), + "Must have at least one space after 'for' keyword.", + ), + ( + lambda: cst.For( + target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]), + iter=cst.Name("it"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_before_in=cst.SimpleWhitespace(""), + ), + "Must have at least one space before 'in' keyword.", + ), + # an additional check for StarredElement, since it's a separate codepath + ( + lambda: cst.For( + target=cst.Tuple( + [cst.StarredElement(cst.Name("el"))], lpar=[], rpar=[] + ), + iter=cst.Name("it"), + body=cst.SimpleStatementSuite([cst.Pass()]), + whitespace_before_in=cst.SimpleWhitespace(""), + ), + "Must have at least one space before 'in' keyword.", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re)