diff --git a/libcst/nodes/__init__.py b/libcst/nodes/__init__.py index 60259719..13438a70 100644 --- a/libcst/nodes/__init__.py +++ b/libcst/nodes/__init__.py @@ -24,6 +24,8 @@ from libcst.nodes._expression import ( # noqa: F401 Await, BaseAtom, BaseComp, + BaseDict, + BaseDictElement, BaseElement, BaseExpression, BaseFormattedStringContent, @@ -39,6 +41,8 @@ from libcst.nodes._expression import ( # noqa: F401 CompFor, CompIf, ConcatenatedString, + Dict, + DictElement, Element, Ellipses, ExtSlice, @@ -69,6 +73,7 @@ from libcst.nodes._expression import ( # noqa: F401 SetComp, SimpleString, Slice, + StarredDictElement, StarredElement, Subscript, Tuple, diff --git a/libcst/nodes/_expression.py b/libcst/nodes/_expression.py index e0cd6da5..833f895c 100644 --- a/libcst/nodes/_expression.py +++ b/libcst/nodes/_expression.py @@ -2125,32 +2125,70 @@ class Yield(BaseExpression): value._codegen(state) -class BaseElement(CSTNode, ABC): +class _BaseElementImpl(CSTNode, ABC): """ - An element of a literal list, tuple, or set. For elements of a literal dict, see - BaseMappingElement. (TODO) + An internal base class for :class:`.Element` and :class:`DictElement`. """ # pyre-fixme[13]: Attribute `value` is never initialized. value: BaseExpression comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + def _codegen_comma( + self, + state: CodegenState, + default_comma: bool = False, + default_comma_whitespace: bool = False, # False for a single-item collection + ) -> None: + """ + Called by `_codegen_impl` in subclasses to generate the comma. + """ + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + if default_comma_whitespace: + state.add_token(", ") + else: + state.add_token(",") + elif isinstance(comma, Comma): + comma._codegen(state) + @abstractmethod def _codegen_impl( self, state: CodegenState, default_comma: bool = False, - default_comma_whitespace: bool = False, # False for a single-item tuple + default_comma_whitespace: bool = False, # False for a single-item collection ) -> None: ... +class BaseElement(_BaseElementImpl, ABC): + """ + An element of a literal list, tuple, or set. For elements of a literal dict, see + BaseDictElement. + """ + + +class BaseDictElement(_BaseElementImpl, ABC): + """ + An element of a literal dict. For elements of a list, tuple, or set, see + BaseElement. + """ + + @add_slots @dataclass(frozen=True) class Element(BaseElement): + """ + A simple value in a literal :class:`.List`, :class:`.Tuple`, or :class:`.Set`. + These a literal collection may also contain a :class:`.StarredElement`. + + If you're using a literal :class:`.Dict`, see :class:`.DictElement` instead. + """ + value: BaseExpression - # Any trailing comma + #: A trailing comma. By default, we'll only insert a comma if one is required. comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Element": @@ -2167,31 +2205,95 @@ class Element(BaseElement): ) -> None: with state.record_syntactic_position(self): self.value._codegen(state) + self._codegen_comma(state, default_comma, default_comma_whitespace) - comma = self.comma - if comma is MaybeSentinel.DEFAULT and default_comma: - if default_comma_whitespace: - state.add_token(", ") - else: - state.add_token(",") - elif isinstance(comma, Comma): - comma._codegen(state) + +@add_slots +@dataclass(frozen=True) +class DictElement(BaseDictElement): + """ + A simple ``key: value`` pair that represents a single entry in a literal + :class:`.Dict`. :class:`.Dict` nodes may also contain a + :class:`.StarredDictElement`. + + If you're using a literal :class:`.List`, :class:`.Tuple`, or :class:`.Set`, + see :class:`.Element` instead. + """ + + key: BaseExpression + value: BaseExpression + + #: A trailing comma. By default, we'll only insert a comma if one is required. + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + #: Whitespace after the key, but before the colon in ``key : value``. + whitespace_before_colon: BaseParenthesizableWhitespace = SimpleWhitespace("") + #: Whitespace after the colon, but before the value in ``key : value``. + whitespace_after_colon: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "DictElement": + return DictElement( + key=visit_required("key", self.key, visitor), + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + whitespace_after_colon=visit_required( + "whitespace_after_colon", self.whitespace_after_colon, visitor + ), + value=visit_required("value", self.value, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + ) + + def _codegen_impl( + self, + state: CodegenState, + default_comma: bool = False, + default_comma_whitespace: bool = False, + ) -> None: + with state.record_syntactic_position(self): + self.key._codegen(state) + self.whitespace_before_colon._codegen(state) + state.add_token(":") + self.whitespace_after_colon._codegen(state) + self.value._codegen(state) + self._codegen_comma(state, default_comma, default_comma_whitespace) @add_slots @dataclass(frozen=True) class StarredElement(BaseElement, _BaseParenthesizedNode): + """ + A starred ``*value`` element that expands to represent multiple values in a literal + :class:`.List`, :class:`.Tuple`, or :class:`.Set`. + + If you're using a literal :class:`.Dict`, see :class:`.StarredDictElement` instead. + + If this node owns parenthesis, those parenthesis wrap the leading asterisk, but not + the trailing comma. For example:: + + StarredElement( + cst.Name("el"), + comma=cst.Comma(), + lpar=[cst.LeftParen()], + rpar=[cst.RightParen()], + ) + + will generate:: + + (*el), + """ + value: BaseExpression - #: Any trailing comma + #: A trailing comma. By default, we'll only insert a comma if one is required. comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT - # Parentheses around the leading asterisk and the value. Functionally equivalent to - # parentheses around the value, but in a different position. + #: Parenthesis at the beginning of the node, before the leading asterisk. lpar: Sequence[LeftParen] = () + #: Parentheses after the value, but before a comma (if there is one). rpar: Sequence[RightParen] = () - # Whitespace + #: Whitespace between the leading asterisk and the value expression. whitespace_before_value: BaseParenthesizableWhitespace = SimpleWhitespace("") def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "StarredElement": @@ -2215,21 +2317,75 @@ class StarredElement(BaseElement, _BaseParenthesizedNode): state.add_token("*") self.whitespace_before_value._codegen(state) self.value._codegen(state) + self._codegen_comma(state, default_comma, default_comma_whitespace) - comma = self.comma - if comma is MaybeSentinel.DEFAULT and default_comma: - if default_comma_whitespace: - state.add_token(", ") - else: - state.add_token(",") - elif isinstance(comma, Comma): - comma._codegen(state) + +@add_slots +@dataclass(frozen=True) +class StarredDictElement(BaseDictElement, _BaseParenthesizedNode): + """ + A starred ``**value`` element that expands to represent multiple values in a literal + :class:`.Dict`. + + If you're using a literal :class:`.List`, :class:`.Tuple`, or :class:`.Set`, + see :class:`.StarredElement` instead. + + If this node owns parenthesis, those parenthesis wrap the leading asterisks, but not + the trailing comma. For example:: + + StarredDictElement( + cst.Name("el"), + comma=cst.Comma(), + lpar=[cst.LeftParen()], + rpar=[cst.RightParen()], + ) + + will generate:: + + (**el), + """ + + value: BaseExpression + + #: A trailing comma. By default, we'll only insert a comma if one is required. + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + #: Parenthesis at the beginning of the node, before the leading asterisk. + lpar: Sequence[LeftParen] = () + #: Parentheses after the value, but before a comma (if there is one). + rpar: Sequence[RightParen] = () + + #: Whitespace between the leading asterisks and the value expression. + whitespace_before_value: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "StarredDictElement": + return StarredDictElement( + lpar=visit_sequence("lpar", self.lpar, visitor), + whitespace_before_value=visit_required( + "whitespace_before_value", self.whitespace_before_value, visitor + ), + value=visit_required("value", self.value, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + ) + + def _codegen_impl( + self, + state: CodegenState, + default_comma: bool = False, + default_comma_whitespace: bool = False, + ) -> None: + with self._parenthesize(state): + state.add_token("**") + self.whitespace_before_value._codegen(state) + self.value._codegen(state) + self._codegen_comma(state, default_comma, default_comma_whitespace) @add_slots @dataclass(frozen=True) class Tuple(BaseAtom, BaseAssignTargetExpression, BaseDelTargetExpression): - elements: Sequence[Union[Element, StarredElement]] + elements: Sequence[BaseElement] #: Sequence of open parenthesis for precedence dictation. lpar: Sequence[LeftParen] = (LeftParen(),) @@ -2327,7 +2483,7 @@ class BaseList(BaseAtom, ABC): @add_slots @dataclass(frozen=True) class List(BaseList, BaseAssignTargetExpression, BaseDelTargetExpression): - elements: Sequence[Union[Element, StarredElement]] + elements: Sequence[BaseElement] lbracket: LeftSquareBracket = LeftSquareBracket() rbracket: RightSquareBracket = RightSquareBracket() lpar: Sequence[LeftParen] = () @@ -2353,7 +2509,15 @@ class List(BaseList, BaseAssignTargetExpression, BaseDelTargetExpression): ) -class BaseSet(BaseAtom, ABC): +class _BaseSetOrDict(BaseAtom, ABC): + """ + An abstract base class for :class:`.BaseSet` and :class:`.BaseDict`. + + Literal sets and dicts are syntactically similar (hence this shared base class), but + are semantically different. This base class is an implementation detail and + shouldn't be exported. + """ + #: Open brace surrounding the list lbrace: LeftCurlyBrace = LeftCurlyBrace() @@ -2377,10 +2541,16 @@ class BaseSet(BaseAtom, ABC): self.rbrace._codegen(state) +class BaseSet(_BaseSetOrDict, ABC): + """ + An abstract base class for :class:`.Set` and :class:`.SetComp`. + """ + + @add_slots @dataclass(frozen=True) class Set(BaseSet): - elements: Sequence[Union[Element, StarredElement]] + elements: Sequence[BaseElement] lbrace: LeftCurlyBrace = LeftCurlyBrace() rbrace: RightCurlyBrace = RightCurlyBrace() lpar: Sequence[LeftParen] = () @@ -2415,6 +2585,61 @@ class Set(BaseSet): ) +class BaseDict(_BaseSetOrDict, ABC): + """ + An abstract base class for :class:`.Dict` and :class:`.DictComp`. + """ + + +@add_slots +@dataclass(frozen=True) +class Dict(BaseDict): + """ + A dictionary. Key-value pairs are stored in ``elements`` using :class:`.DictElement` + nodes. + + It's possible to expand one dictionary into another, as in ``{k: v, **expanded}``. + Expanded elements are stored as :class:`.StarredDictElement` nodes. + + :: + + Dict([ + DictElement(Name("k1"), Name("v1")), + DictElement(Name("k2"), Name("v2")), + StarredDictElement(Name("expanded")), + ]) + + generates the following code:: + + {k1: v1, k2: v2, **expanded} + """ + + elements: Sequence[BaseDictElement] + lbrace: LeftCurlyBrace = LeftCurlyBrace() + rbrace: RightCurlyBrace = RightCurlyBrace() + lpar: Sequence[LeftParen] = () + rpar: Sequence[RightParen] = () + + def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Dict": + return Dict( + lpar=visit_sequence("lpar", self.lpar, visitor), + lbrace=visit_required("lbrace", self.lbrace, visitor), + elements=visit_sequence("elements", self.elements, visitor), + rbrace=visit_required("rbrace", self.rbrace, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen_impl(self, state: CodegenState) -> None: + with self._parenthesize(state), self._braceize(state): + elements = self.elements + for idx, el in enumerate(elements): + el._codegen( + state, + default_comma=(idx < len(elements) - 1), + default_comma_whitespace=True, + ) + + @add_slots @dataclass(frozen=True) class CompFor(CSTNode): diff --git a/libcst/nodes/tests/test_dict.py b/libcst/nodes/tests/test_dict.py new file mode 100644 index 00000000..005fb59d --- /dev/null +++ b/libcst/nodes/tests/test_dict.py @@ -0,0 +1,189 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Any + +import libcst.nodes as cst +from libcst.nodes._internal import CodeRange +from libcst.nodes.tests.base import CSTNodeTest +from libcst.testing.utils import data_provider + + +class DictTest(CSTNodeTest): + @data_provider( + [ + # zero-element dict + { + "node": cst.Dict([]), + "code": "{}", + "expected_position": CodeRange.create((1, 0), (1, 2)), + }, + # one-element dict, sentinel comma value + { + "node": cst.Dict([cst.DictElement(cst.Name("k"), cst.Name("v"))]), + "code": "{k: v}", + "expected_position": CodeRange.create((1, 0), (1, 6)), + }, + { + "node": cst.Dict([cst.StarredDictElement(cst.Name("expanded"))]), + "code": "{**expanded}", + "expected_position": CodeRange.create((1, 0), (1, 12)), + }, + # two-element dict, sentinel comma value + { + "node": cst.Dict( + [ + cst.DictElement(cst.Name("k1"), cst.Name("v1")), + cst.DictElement(cst.Name("k2"), cst.Name("v2")), + ] + ), + "code": "{k1: v1, k2: v2}", + "expected_position": CodeRange.create((1, 0), (1, 16)), + }, + # custom whitespace between brackets + { + "node": cst.Dict( + [cst.DictElement(cst.Name("k"), cst.Name("v"))], + lbrace=cst.LeftCurlyBrace( + whitespace_after=cst.SimpleWhitespace("\t") + ), + rbrace=cst.RightCurlyBrace( + whitespace_before=cst.SimpleWhitespace("\t\t") + ), + ), + "code": "{\tk: v\t\t}", + "expected_position": CodeRange.create((1, 0), (1, 9)), + }, + # with parenthesis + { + "node": cst.Dict( + [cst.DictElement(cst.Name("k"), cst.Name("v"))], + lpar=[cst.LeftParen()], + rpar=[cst.RightParen()], + ), + "code": "({k: v})", + "expected_position": CodeRange.create((1, 1), (1, 7)), + }, + # starred element + { + "node": cst.Dict( + [ + cst.StarredDictElement(cst.Name("one")), + cst.StarredDictElement(cst.Name("two")), + ] + ), + "code": "{**one, **two}", + "expected_position": CodeRange.create((1, 0), (1, 14)), + }, + # custom comma on DictElement + { + "node": cst.Dict( + [cst.DictElement(cst.Name("k"), cst.Name("v"), comma=cst.Comma())] + ), + "code": "{k: v,}", + "expected_position": CodeRange.create((1, 0), (1, 7)), + }, + # custom comma on StarredDictElement + { + "node": cst.Dict( + [cst.StarredDictElement(cst.Name("expanded"), comma=cst.Comma())] + ), + "code": "{**expanded,}", + "expected_position": CodeRange.create((1, 0), (1, 13)), + }, + # custom whitespace on DictElement + { + "node": cst.Dict( + [ + cst.DictElement( + cst.Name("k"), + cst.Name("v"), + whitespace_before_colon=cst.SimpleWhitespace("\t"), + whitespace_after_colon=cst.SimpleWhitespace("\t\t"), + ) + ] + ), + "code": "{k\t:\t\tv}", + "expected_position": CodeRange.create((1, 0), (1, 8)), + }, + # custom parenthesis on StarredDictElement + { + "node": cst.Dict( + [ + cst.StarredDictElement( + cst.Name("abc"), + lpar=[cst.LeftParen()], + rpar=[cst.RightParen()], + comma=cst.Comma(), + ) + ] + ), + "code": "{(**abc),}", + "expected_position": CodeRange.create((1, 0), (1, 10)), + }, + # custom whitespace on StarredDictElement + { + "node": cst.Dict( + [ + cst.DictElement( + cst.Name("k"), cst.Name("v"), comma=cst.Comma() + ), + cst.StarredDictElement( + cst.Name("expanded"), + whitespace_before_value=cst.SimpleWhitespace(" "), + lpar=[cst.LeftParen()], + rpar=[cst.RightParen()], + ), + ] + ), + "code": "{k: v,(** expanded)}", + "expected_position": CodeRange.create((1, 0), (1, 21)), + }, + # missing spaces around dict is always okay + { + "node": cst.GeneratorExp( + cst.Name("a"), + cst.CompFor( + cst.Name("b"), + cst.Dict([cst.DictElement(cst.Name("k"), cst.Name("v"))]), + ifs=[ + cst.CompIf( + cst.Name("c"), + whitespace_before=cst.SimpleWhitespace(""), + ) + ], + whitespace_after_in=cst.SimpleWhitespace(""), + ), + ), + "code": "(a for b in{k: v}if c)", + }, + ] + ) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) + + @data_provider( + [ + # unbalanced Dict + { + "get_node": lambda: cst.Dict([], lpar=[cst.LeftParen()]), + "expected_re": "left paren without right paren", + }, + # unbalanced StarredDictElement + { + "get_node": lambda: cst.Dict( + [ + cst.StarredDictElement( + cst.Name("unbalanced"), lpar=[cst.LeftParen()] + ) + ] + ), + "expected_re": "left paren without right paren", + }, + ] + ) + def test_invalid(self, **kwargs: Any) -> None: + self.assert_invalid(**kwargs)