diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..e4e112f2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +*.swp +*.swo +*.pyc +*.pyo +.pyre/ +__pycache__/ diff --git a/README.md b/README.md index 7f90b67d..280604e1 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,29 @@ # LibCST TODO: Add documentation. + +## Auto-formatting code with isort and Black + +We use isort and black to format code. To format changes to be conformant, run the following in the root: + +` +isort -q -w 88 -m 3 -tc -fgw 0 -lai 2 -ca -ns __init__.py -y ; black --target-version py36 libcst/ +` + +## Running tests + +To run all tests, do the following in the root: + +`find -name "test_*.py" -printf '%P\n' | xargs python3 -m unittest` + +## Verifying types with Pyre + +To verify types for the library, do the following in the root: + +`pyre --source-directory . --search-path stubs/ check` + +## Examining a sample tree + +To examine the tree that is parsed from a particular file, do the following: + +`python -m libcst.tool print ` diff --git a/libcst/__init__.py b/libcst/__init__.py new file mode 100644 index 00000000..62642369 --- /dev/null +++ b/libcst/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/libcst/_add_slots.py b/libcst/_add_slots.py new file mode 100644 index 00000000..d4a0f461 --- /dev/null +++ b/libcst/_add_slots.py @@ -0,0 +1,43 @@ +# This file is derived from github.com/ericvsmith/dataclasses, and is Apache 2 licensed. +# https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f188f452/LICENSE.txt +# https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f/dataclass_tools.py + +# pyre-strict +import dataclasses +from typing import Type, TypeVar + + +_T = TypeVar("_T") + + +def add_slots(cls: Type[_T]) -> Type[_T]: + # TODO: This doesn't work under Python3.7 due to the type(cls) call below. + return cls + + # Need to create a new class, since we can't set __slots__ + # after a class has been created. + + # Make sure __slots__ isn't already set. + if "__slots__" in cls.__dict__: + raise TypeError(f"{cls.__name__} already specifies __slots__") + + # Create a new dict for our new class. + cls_dict = dict(cls.__dict__) + field_names = tuple(f.name for f in dataclasses.fields(cls)) + cls_dict["__slots__"] = field_names + for field_name in field_names: + # Remove our attributes, if present. They'll still be + # available in _MARKER. + cls_dict.pop(field_name, None) + # Remove __dict__ itself. + cls_dict.pop("__dict__", None) + # And finally create the class. + qualname = getattr(cls, "__qualname__", None) + # GenericMeta in py3.6 requires us to track __orig_bases__. This is fixed in py3.7 + # by the removal of GenericMeta. We should just be able to use cls.__bases__ in the + # future. + bases = getattr(cls, "__orig_bases__", cls.__bases__) + cls = type(cls)(cls.__name__, bases, cls_dict) + if qualname is not None: + cls.__qualname__ = qualname + return cls diff --git a/libcst/_base_visitor.py b/libcst/_base_visitor.py new file mode 100644 index 00000000..23cabae1 --- /dev/null +++ b/libcst/_base_visitor.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from abc import ABC +from typing import TYPE_CHECKING, TypeVar, Union + +from libcst._removal_sentinel import RemovalSentinel + + +if TYPE_CHECKING: + # Circular dependency for typing reasons only + from libcst.nodes._base import CSTNode + +# RemovalSentinel is re-exported +__all__ = ["CSTVisitor", "RemovalSentinel"] + + +CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode") + + +class CSTVisitor(ABC): + """ + The low-level base visitor class. + + This shouldn't be used directly, instead we should provide a more user-friendly + subclass. + """ + + def on_visit(self, node: "CSTNode") -> bool: + """ + Called every time a node is visited, before we've visited its children. + + Returns `True` if children should be visited, and returns `False` otherwise. + """ + return True + + def on_leave( + self, original_node: CSTNodeT, updated_node: CSTNodeT + ) -> Union[CSTNodeT, RemovalSentinel]: + """ + Called every time we leave a node, after we've visited its children. + + A RemovalSentinel indicates that the node should be removed from its parent. + This is not always possible, and may raise an exception. + """ + return updated_node diff --git a/libcst/_maybe_sentinel.py b/libcst/_maybe_sentinel.py new file mode 100644 index 00000000..b2e03f8c --- /dev/null +++ b/libcst/_maybe_sentinel.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum, auto + + +class MaybeSentinel(Enum): + """ + A MaybeSentinal value is used as the default constructor for some attributes to + denote that on generating code we should optionally include this element in order + to generate valid code. + """ + + DEFAULT = auto() diff --git a/libcst/_removal_sentinel.py b/libcst/_removal_sentinel.py new file mode 100644 index 00000000..863933b9 --- /dev/null +++ b/libcst/_removal_sentinel.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +""" +Used and re-exported by visitor. This is hoisted into a separate module to avoid some +circular dependencies in the definition of CSTNode. +""" + +from enum import Enum, auto + + +class RemovalSentinel(Enum): + """ + A RemovalSentinel value should be returned by a `on_leave` method when we want to + remove that child from its parent. + + The parent's _visit_and_replace_children method should make a best-effort to remove + the child, but may raise an exception when removing the child doesn't make sense, or + could change the semantics in an unexpected way. E.g. a function with no name + doesn't make sense. + + In we can't automatically remove the child, the developer should instead remove the + child by constructing a new parent in the parent's on_leave call. + + We use this instead of `None` to force developers to be more explicit about + deletions, because `None` is the default return value for a function with no return + statement. + + In the future, we may extend this to support other forms of removal, but that's + unlikely. + """ + + REMOVE = auto() diff --git a/libcst/_tabs.py b/libcst/_tabs.py new file mode 100644 index 00000000..c18ecc74 --- /dev/null +++ b/libcst/_tabs.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +def expand_tabs(line: str) -> str: + """ + Tabs are treated as 1-8 spaces according to + https://docs.python.org/3/reference/lexical_analysis.html#indentation + + Given a string with tabs, this removes all tab characters and replaces them with the + appropriate number of spaces. + """ + result_list = [] + total = 0 + for ch in line: + if ch == "\t": + prev_total = total + total = ((total + 8) // 8) * 8 + result_list.append(" " * (total - prev_total)) + else: + total += 1 + result_list.append(ch) + + return "".join(result_list) diff --git a/libcst/exceptions.py b/libcst/exceptions.py new file mode 100644 index 00000000..c62580a3 --- /dev/null +++ b/libcst/exceptions.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from dataclasses import dataclass +from typing import Iterable, Optional, Sequence, Tuple + +from libcst._tabs import expand_tabs + + +_EOF_STR: str = "end of file (EOF)" +_NEWLINE_CHARS: str = "\r\n" + + +@dataclass(frozen=True, eq=False) +class ParserSyntaxError(Exception): + """ + Contains error information about the parser tree. + """ + + message: str + encountered: Optional[str] # None means EOF + expected: Optional[Iterable[str]] # None means EOF + pos: Tuple[int, int] # (one-indexed line, zero-indexed column) + lines: Sequence[str] # source code, used to generate a human-readable output + + def __str__(self) -> str: + """ + A human-readable error message of where the syntax error is in their code. + """ + if self.encountered is not None: + encountered_str = repr(self.encountered) + else: + encountered_str = _EOF_STR + + if self.expected is not None: + expected_str = f"one of {repr(list(self.expected))}" + else: + expected_str = _EOF_STR + + pos_line, pos_column = self.pos + editor_line = self.editor_line + editor_column = self.editor_column + + return ( + f"Syntax Error: {self.message} @ {editor_line}:{editor_column}.\n" + + f"Encountered {encountered_str}, but expected {expected_str}.\n\n" + + f"{expand_tabs(self.lines[pos_line - 1]).rstrip(_NEWLINE_CHARS)}\n" + + f"{' ' * (editor_column - 1)}^" + ) + + @property + def editor_line(self) -> int: + """ + The one-indexed line in the user's editor. + """ + return self.pos[0] # the line in pos is already one-indexed. + + @property + def editor_column(self) -> int: + """ + The one-indexed column in the user's editor, assuming tabs expand to 1-8 spaces. + """ + pos_line, pos_column = self.pos + tab_adjusted_column = len(expand_tabs(self.lines[pos_line - 1][:pos_column])) + # Text editors use a one-indexed column, so we need to add one to our + # zero-indexed column to get a human-readable result. + return tab_adjusted_column + 1 diff --git a/libcst/helpers.py b/libcst/helpers.py new file mode 100644 index 00000000..6470e892 --- /dev/null +++ b/libcst/helpers.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import libcst.nodes as libcst + + +def get_fully_qualified_name(node: libcst.BaseExpression) -> str: + if isinstance(node, libcst.Name): + return node.value + elif isinstance(node, libcst.Attribute): + return get_fully_qualified_name(node.value) + "." + node.attr.value + else: + raise Exception(f"Invalid node type {type(node)}!") diff --git a/libcst/nodes/__init__.py b/libcst/nodes/__init__.py new file mode 100644 index 00000000..74658912 --- /dev/null +++ b/libcst/nodes/__init__.py @@ -0,0 +1,164 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +This package contains CSTNode and all of the subclasses needed to express Python's full +grammar in a whitespace-sensitive fashion, forming a "Concrete" Syntax Tree (CST). +""" + +# We don't export BaseLeaf/BaseValueToken from _base, because we consider those +# implementation details. Those base classes shouldn't be useful outside of this +# package. + +from libcst.nodes._base import CSTNode, CSTValidationError +from libcst.nodes._dummy import DummyNode +from libcst.nodes._expression import ( + Annotation, + Arg, + Attribute, + Await, + BaseAtom, + BaseExpression, + BaseFormattedStringContent, + BinaryOperation, + BooleanOperation, + Call, + Comparison, + ComparisonTarget, + ConcatenatedString, + Ellipses, + ExtSlice, + Float, + FormattedString, + FormattedStringExpression, + FormattedStringText, + From, + IfExp, + Imaginary, + Index, + Integer, + Lambda, + LeftParen, + LeftSquareBracket, + Name, + Number, + Param, + Parameters, + ParamStar, + RightParen, + RightSquareBracket, + SimpleString, + Slice, + Starred, + Subscript, + UnaryOperation, + Yield, +) +from libcst.nodes._module import Module +from libcst.nodes._op import ( + Add, + AddAssign, + And, + AssignEqual, + BaseAugOp, + BaseBinaryOp, + BaseBooleanOp, + BaseCompOp, + BaseUnaryOp, + BitAnd, + BitAndAssign, + BitInvert, + BitOr, + BitOrAssign, + BitXor, + BitXorAssign, + Colon, + Comma, + Divide, + DivideAssign, + Dot, + Equal, + FloorDivide, + FloorDivideAssign, + GreaterThan, + GreaterThanEqual, + ImportStar, + In, + Is, + IsNot, + LeftShift, + LeftShiftAssign, + LessThan, + LessThanEqual, + MatrixMultiply, + MatrixMultiplyAssign, + Minus, + Modulo, + ModuloAssign, + Multiply, + MultiplyAssign, + Not, + NotEqual, + NotIn, + Or, + Plus, + Power, + PowerAssign, + RightShift, + RightShiftAssign, + Semicolon, + Subtract, + SubtractAssign, +) +from libcst.nodes._statement import ( + AnnAssign, + AsName, + Assert, + Assign, + AssignTarget, + Asynchronous, + AugAssign, + BaseCompoundStatement, + BaseSmallStatement, + BaseSuite, + Break, + ClassDef, + Continue, + Decorator, + Del, + Else, + ExceptHandler, + Expr, + Finally, + For, + FunctionDef, + Global, + If, + Import, + ImportAlias, + ImportFrom, + IndentedBlock, + NameItem, + Nonlocal, + Pass, + Raise, + Return, + SimpleStatementLine, + SimpleStatementSuite, + Try, + While, + With, + WithItem, +) +from libcst.nodes._whitespace import ( + Comment, + EmptyLine, + Newline, + ParenthesizedWhitespace, + SimpleWhitespace, + TrailingWhitespace, +) diff --git a/libcst/nodes/_base.py b/libcst/nodes/_base.py new file mode 100644 index 00000000..a9eb4d85 --- /dev/null +++ b/libcst/nodes/_base.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import fields, replace +from enum import Enum, auto +from typing import Any, List, Sequence, TypeVar, Union, cast + +from libcst._base_visitor import CSTVisitor +from libcst._removal_sentinel import RemovalSentinel +from libcst.nodes._internal import CodegenState + + +_CSTNodeSelfT = TypeVar("_CSTNodeSelfT", bound="CSTNode") +_EMPTY_SEQUENCE: Sequence["CSTNode"] = () + + +class CSTValidationError(SyntaxError): + pass + + +class CSTCodegenError(SyntaxError): + pass + + +class _ChildrenCollectionVisitor(CSTVisitor): + def __init__(self) -> None: + self.children: List[CSTNode] = [] + + def on_visit(self, node: "CSTNode") -> bool: + self.children.append(node) + return False # Don't include transitive children + + +def _pretty_repr(value: object) -> str: + if not isinstance(value, str) and isinstance(value, Sequence): + return _pretty_repr_sequence(value) + else: + return repr(value) + + +def _pretty_repr_sequence(seq: Sequence[object]) -> str: + if len(seq) == 0: + return "[]" + else: + return "\n".join(["[", *[f"{_indent(repr(el))}," for el in seq], "]"]) + + +def _indent(value: str) -> str: + return "\n".join(f" {l}" for l in value.split("\n")) + + +class CSTNode(ABC): + def __post_init__(self) -> None: + # PERF: It might make more sense to move validation work into the visitor, which + # would allow us to avoid validating the tree when parsing a file. + self._validate() + + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + """ + HACK: Add our implementation of `__repr__`, `__hash__`, and `__eq__` to the + class's __dict__ to prevent dataclass from generating it's own `__repr__`, + `__hash__`, and `__eq__`. + + The alternative is to require each implementation of a node to remember to add + `repr=False, eq=False`, which is more error-prone. + """ + super().__init_subclass__(**kwargs) + + if "__repr__" not in cls.__dict__: + cls.__repr__ = CSTNode.__repr__ + if "__eq__" not in cls.__dict__: + cls.__eq__ = CSTNode.__eq__ + if "__hash__" not in cls.__dict__: + cls.__hash__ = CSTNode.__hash__ + + def _validate(self) -> None: + """ + Override this to perform runtime validation of a newly created node. + + The function is called during `__init__`. It should check for possible mistakes + that wouldn't be caught by a static type checker. + """ + pass + + @property + def children(self) -> Sequence["CSTNode"]: + """ + The immediate (not transitive) child CSTNodes of the current node. Various + properties on the nodes, such as string values, will not be visited if they are + not a subclass of CSTNode. + + Iterable properties of the node (e.g. an IndentedBlock's body) will be flattened + into the children's sequence. + + The children will always be returned in the same order that they appear + lexically in the code. + """ + + # We're hooking into _visit_and_replace_children, which means that our current + # implementation is slow. We may need to rethink and/or cache this if it becomes + # a frequently accessed property. + # + # This probably won't be called frequently, because most child access will + # probably through visit, or directly through named property access, not through + # children. + + visitor = _ChildrenCollectionVisitor() + self._visit_and_replace_children(visitor) + return visitor.children + + def visit( + self: _CSTNodeSelfT, visitor: CSTVisitor + ) -> Union[_CSTNodeSelfT, RemovalSentinel]: + """ + Visits the current node, its children, and all transitive children using the + given CSTVisitor's callbacks. + """ + # visit self + should_visit_children = visitor.on_visit(self) + + # visit children (optionally) + if should_visit_children: + # It's not possible to define `_visit_and_replace_children` with the correct + # return type in any sane way, so we're using this cast. See the + # explanation above the declaration of `_visit_and_replace_children`. + with_updated_children = cast( + _CSTNodeSelfT, self._visit_and_replace_children(visitor) + ) + else: + with_updated_children = self + + leave_result = visitor.on_leave(self, with_updated_children) + + # validate return type of the user-defined `visitor.on_leave` method + if not isinstance(leave_result, (CSTNode, RemovalSentinel)): + raise Exception( + f"Expected a node of type CSTNode or a RemovalSentinel, " + + f"but got a return value of {type(leave_result).__name__}" + ) + + # TODO: Run runtime typechecks against updated nodes + + return leave_result + + # The return type of `_visit_and_replace_children` is `CSTNode`, not + # `_CSTNodeSelfT`. This is because pyre currently doesn't have a way to annotate + # classes as final. https://mypy.readthedocs.io/en/latest/final_attrs.html + # + # The issue is that any reasonable implementation of `_visit_and_replace_children` + # needs to refer to the class' own constructor: + # + # class While(CSTNode): + # def _visit_and_replace_children(self, visitor: CSTVisitor) -> While: + # return While(...) + # + # You'll notice that because this implementation needs to call the `While` + # constructor, the return type is also `While`. This function is a valid subtype of + # `Callable[[CSTVisitor], CSTNode]`. + # + # It is not a valid subtype of `Callable[[CSTVisitor], _CSTNodeSelfT]`. That's + # because the return type of this function wouldn't be valid for any subclasses. + # In practice, that's not an issue, because we don't have any subclasses of `While`, + # but there's no way to tell pyre that without a `@final` annotation. + # + # Instead, we're just relying on an unchecked call to `cast()` in the `visit` + # method. + @abstractmethod + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "CSTNode": + """ + Intended to be overridden by subclasses to provide a low-level hook for the + visitor API. + + Don't call this directly. Instead, use `visitor.visit_and_replace_node` or + `visitor.visit_and_replace_module`. If you need list of children, access the + `children` property instead. + + The general expectation is that children should be visited in the order in which + they appear lexically. + """ + ... + + @abstractmethod + def _codegen(self, state: CodegenState) -> None: + ... + + def with_changes(self: _CSTNodeSelfT, **changes: Any) -> _CSTNodeSelfT: + """ + A convenience method for performing mutation-like operations on immutable nodes. + Creates a new object of the same type, replacing fields with values from the + supplied keyword arguments. + + For example, to update the test of an if conditional, you could do: + + def leave_If(self, old_node: cst.If) -> cst.If: + new_node = old_node.with_changes(test=new_conditional) + return new_node + + `new_node` will have the same `body`, `orelse`, and whitespace fields as + `old_node`, but with the updated `test` field. + + The accepted arguments match the arguments given to `__init__`, however there + are no required or positional arguments. + + TODO: This API is untyped. There's probably no sane way to type it using pyre's + current feature-set, but we should still think about ways to type this or a + similar API in the future. + """ + return replace(self, **changes) + + def deep_equals(self: _CSTNodeSelfT, other: _CSTNodeSelfT) -> bool: + """ + Recursively inspects the entire tree under `self` and `other` to determine if + the two trees are equal by value. + """ + from libcst.nodes._deep_equals import deep_equals as deep_equals_impl + + return deep_equals_impl(self, other) + + def __eq__(self: _CSTNodeSelfT, other: _CSTNodeSelfT) -> bool: + """ + CSTNodes are only treated as equal by identity. This matches the behavior of + CPython's AST nodes. + + If you actually want to compare the value instead of the identity of the current + node with another, use `node.deep_equals`. Because `deep_equals` must traverse + the entire tree, it can have an unexpectedly large time complexity. + + We're not exposing value equality as the default behavior because of + `deep_equals`'s large time complexity. + """ + return self is other + + def __hash__(self) -> int: + # Equality of nodes is based on identity, so the hash should be too. + return id(self) + + def __repr__(self) -> str: + if len(fields(self)) == 0: + return f"{type(self).__name__}()" + + lines = [] + lines.append(f"{type(self).__name__}(") + for field in fields(self): + key = field.name + value = getattr(self, key) + lines.append(_indent(f"{key}={_pretty_repr(value)},")) + lines.append(")") + return "\n".join(lines) + + +class BaseLeaf(CSTNode, ABC): + @property + def children(self) -> Sequence[CSTNode]: + # override this with an optimized implementation + return _EMPTY_SEQUENCE + + def _visit_and_replace_children( + self: _CSTNodeSelfT, visitor: CSTVisitor + ) -> _CSTNodeSelfT: + return self + + +class BaseValueToken(BaseLeaf, ABC): + """ + Represents the subset of nodes that only contain a value. Not all tokens from the + tokenizer will exist as BaseValueTokens. In places where the token is always a + constant value (e.g. a COLON token), the token's value will be implicitly folded + into the parent CSTNode, and hard-coded into the implementation of _codegen. + """ + + value: str + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append(self.value) + + +class AnnotationIndicatorSentinel(Enum): + """ + An AnnotationIndicatorSentinel indicates that the underlying codegen should choose + the correct annotation indicator (":" or "->") based on where the annotation is + used. + """ + + DEFAULT = auto() diff --git a/libcst/nodes/_deep_equals.py b/libcst/nodes/_deep_equals.py new file mode 100644 index 00000000..c32145b3 --- /dev/null +++ b/libcst/nodes/_deep_equals.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +""" +Provides the implementation of `CSTNode.deep_equals`. +""" + +from dataclasses import fields +from typing import Sequence + +from libcst.nodes._base import CSTNode + + +def deep_equals(a: object, b: object) -> bool: + if isinstance(a, CSTNode) and isinstance(b, CSTNode): + return _deep_equals_cst_node(a, b) + elif ( + isinstance(a, Sequence) + and not isinstance(a, (str, bytes)) + and isinstance(b, Sequence) + and not isinstance(b, (str, bytes)) + ): + return _deep_equals_sequence(a, b) + else: + return a == b + + +def _deep_equals_sequence(a: Sequence[object], b: Sequence[object]) -> bool: + """ + A helper function for `CSTNode.deep_equals`. + + Normalizes and compares sequences. Because we only ever expose `Sequence[]` + types, and not `List[]`, `Tuple[]`, or `Iterable[]` values, all sequences should + be treated as equal if they have the same values. + """ + if a is b: # short-circuit + return True + if len(a) != len(b): + return False + return all(deep_equals(a_el, b_el) for (a_el, b_el) in zip(a, b)) + + +def _deep_equals_cst_node(a: "CSTNode", b: "CSTNode") -> bool: + if type(a) is not type(b): + return False + if a is b: # short-circuit + return True + for field in fields(a): + a_value = getattr(a, field.name) + b_value = getattr(b, field.name) + if not deep_equals(a_value, b_value): + return False + return True diff --git a/libcst/nodes/_dummy.py b/libcst/nodes/_dummy.py new file mode 100644 index 00000000..cd358bd5 --- /dev/null +++ b/libcst/nodes/_dummy.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass +from typing import Any, List, Sequence, Union + +from libcst._add_slots import add_slots +from libcst._base_visitor import CSTVisitor +from libcst._removal_sentinel import RemovalSentinel +from libcst.nodes._base import CSTNode, CSTValidationError +from libcst.nodes._expression import LeftParen, RightParen +from libcst.nodes._internal import CodegenState, visit_sequence +from libcst.nodes._whitespace import EmptyLine, TrailingWhitespace + + +@add_slots +@dataclass(frozen=True) +class DummyNode(CSTNode): + + children: Sequence[Union[CSTNode, str]] = () + + # HACK: So that we can support being used as an expression + lpar: Sequence[LeftParen] = () + + # HACK: So that we can support being used as an expression + rpar: Sequence[RightParen] = () + + def _validate(self) -> None: + if self.lpar and not self.rpar: + raise CSTValidationError("Cannot have left paren without right paren.") + if not self.lpar and self.rpar: + raise CSTValidationError("Cannot have right paren without left paren.") + if len(self.lpar) != len(self.rpar): + raise CSTValidationError("Cannot have unbalanced parens.") + if len(self.children) < 1: + raise CSTValidationError("Must have at least one child for dummy node.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "DummyNode": + # Preserve traversal order + lpar = visit_sequence("lpar", self.lpar, visitor) + + new_children: List[Union[CSTNode, str]] = [] + for child in self.children: + if isinstance(child, CSTNode): + new_child = child.visit(visitor) + if not isinstance(new_child, RemovalSentinel): + new_children.append(new_child) + else: + new_children.append(child) + return DummyNode( + lpar=lpar, + children=new_children, + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState, **kwargs: Any) -> None: + for lpar in self.lpar: + lpar._codegen(state) + for child in self.children: + if isinstance(child, CSTNode): + child._codegen(state) + else: + state.tokens.append(child) + for rpar in self.rpar: + rpar._codegen(state) diff --git a/libcst/nodes/_expression.py b/libcst/nodes/_expression.py new file mode 100644 index 00000000..d77b6730 --- /dev/null +++ b/libcst/nodes/_expression.py @@ -0,0 +1,2070 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import re +from abc import ABC +from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum, auto +from tokenize import ( + Floatnumber as FLOATNUMBER_RE, + Imagnumber as IMAGNUMBER_RE, + Intnumber as INTNUMBER_RE, +) +from typing import Callable, Generator, List, Optional, Sequence, Union + +from typing_extensions import Literal + +from libcst._add_slots import add_slots +from libcst._base_visitor import CSTVisitor +from libcst._maybe_sentinel import MaybeSentinel +from libcst.nodes._base import ( + AnnotationIndicatorSentinel, + CSTCodegenError, + CSTNode, + CSTValidationError, +) +from libcst.nodes._internal import ( + CodegenState, + visit_optional, + visit_required, + visit_sentinel, + visit_sequence, +) +from libcst.nodes._op import ( + AssignEqual, + BaseBinaryOp, + BaseBooleanOp, + BaseCompOp, + BaseUnaryOp, + Colon, + Comma, + Dot, + In, + Is, + IsNot, + Minus, + Not, + NotIn, + Plus, +) +from libcst.nodes._whitespace import BaseParenthesizableWhitespace, SimpleWhitespace + + +@add_slots +@dataclass(frozen=True) +class LeftSquareBracket(CSTNode): + """ + Used by various nodes to denote a subscript or list section. This doesn't own + the whitespace to the left of it since this is owned by the parent node. + """ + + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "LeftSquareBracket": + return LeftSquareBracket( + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ) + ) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append("[") + self.whitespace_after._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class RightSquareBracket(CSTNode): + """ + Used by various nodes to denote a subscript or list section. This doesn't own + the whitespace to the right of it since this is owned by the parent node. + """ + + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "RightSquareBracket": + return RightSquareBracket( + whitespace_before=visit_required( + "whitespace_before", self.whitespace_before, visitor + ) + ) + + def _codegen(self, state: CodegenState) -> None: + self.whitespace_before._codegen(state) + state.tokens.append("]") + + +@add_slots +@dataclass(frozen=True) +class LeftParen(CSTNode): + """ + Used by various nodes to denote a parenthesized section. This doesn't own + the whitespace to the left of it since this is owned by the parent node. + """ + + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "LeftParen": + return LeftParen( + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ) + ) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append("(") + self.whitespace_after._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class RightParen(CSTNode): + """ + Used by various nodes to denote a parenthesized section. This doesn't own + the whitespace to the right of it since this is owned by the parent node. + """ + + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "RightParen": + return RightParen( + whitespace_before=visit_required( + "whitespace_before", self.whitespace_before, visitor + ) + ) + + def _codegen(self, state: CodegenState) -> None: + self.whitespace_before._codegen(state) + state.tokens.append(")") + + +class _BaseParenthesizedNode(CSTNode, ABC): + """ + We don't want to have another level of indirection for parenthesis in + our tree, since that makes us more of a CST than an AST. So, all the + expressions or atoms that can be wrapped in parenthesis will subclass + this to get that functionality. + """ + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _validate(self) -> None: + if self.lpar and not self.rpar: + raise CSTValidationError("Cannot have left paren without right paren.") + if not self.lpar and self.rpar: + raise CSTValidationError("Cannot have right paren without left paren.") + if len(self.lpar) != len(self.rpar): + raise CSTValidationError("Cannot have unbalanced parens.") + + @contextmanager + def _parenthesize(self, state: CodegenState) -> Generator[None, None, None]: + for lpar in self.lpar: + lpar._codegen(state) + yield + for rpar in self.rpar: + rpar._codegen(state) + + +class ExpressionPosition(Enum): + LEFT = auto() + RIGHT = auto() + + +class BaseExpression(_BaseParenthesizedNode, ABC): + def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: + """ + Returns true if this expression is safe to be use with a word operator + such as "not" without space between the operator an ourselves. Examples + where this is true are "not(True)", "(1)in[1,2,3]", etc. This base + function handles parenthesized nodes, but certain nodes such as tuples, + dictionaries and lists will override this to signifiy that they're always + safe. + """ + + return len(self.lpar) > 0 and len(self.rpar) > 0 + + +class BaseAtom(BaseExpression, ABC): + """ + > Atoms are the most basic elements of expressions. The simplest atoms are + > identifiers or literals. Forms enclosed in parentheses, brackets or braces are + > also categorized syntactically as atoms. + + -- https://docs.python.org/3/reference/expressions.html#atoms + """ + + pass + + +class BaseAssignTargetExpression(BaseExpression, ABC): + """ + An expression that's valid on the left side of an assign statement. + + Python's grammar defines all expression as valid in this position, but the AST + compiler further restricts the allowed types, which is what this type attempts to + express. + + See also: https://github.com/python/cpython/blob/v3.8.0a4/Python/ast.c#L1120 + """ + + pass + + +class BaseDelTargetExpression(BaseExpression, ABC): + """ + An expression that's valid on the right side of a 'del' statement. + + Python's grammar defines all expression as valid in this position, but the AST + compiler further restricts the allowed types, which is what this type attempts to + express. + + This is similar to a BaseAssignTargetExpression, but excludes `Starred`. + + See also: https://github.com/python/cpython/blob/v3.8.0a4/Python/ast.c#L1120 + and: https://github.com/python/cpython/blob/v3.8.0a4/Python/compile.c#L4854 + """ + + pass + + +@add_slots +@dataclass(frozen=True) +class Name(BaseAssignTargetExpression, BaseDelTargetExpression, BaseAtom): + # The actual identifier string + value: str + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Name": + return Name( + lpar=visit_sequence("lpar", self.lpar, visitor), + value=self.value, + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _validate(self) -> None: + super(Name, self)._validate() + if len(self.value) == 0: + raise CSTValidationError("Cannot have empty name identifier.") + if not self.value.isidentifier(): + raise CSTValidationError("Name is not a valid identifier.") + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append(self.value) + + +@add_slots +@dataclass(frozen=True) +class Ellipses(BaseAtom): + """ + An ellipses "..." + """ + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Ellipses": + return Ellipses( + lpar=visit_sequence("lpar", self.lpar, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append("...") + + +@add_slots +@dataclass(frozen=True) +class Integer(_BaseParenthesizedNode): + value: str + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Integer": + return Integer( + lpar=visit_sequence("lpar", self.lpar, visitor), + value=self.value, + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _validate(self) -> None: + super(Integer, self)._validate() + if not re.fullmatch(INTNUMBER_RE, self.value): + raise CSTValidationError("Number is not a valid integer.") + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append(self.value) + + +@add_slots +@dataclass(frozen=True) +class Float(_BaseParenthesizedNode): + value: str + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Float": + return Float( + lpar=visit_sequence("lpar", self.lpar, visitor), + value=self.value, + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _validate(self) -> None: + super(Float, self)._validate() + if not re.fullmatch(FLOATNUMBER_RE, self.value): + raise CSTValidationError("Number is not a valid float.") + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append(self.value) + + +@add_slots +@dataclass(frozen=True) +class Imaginary(_BaseParenthesizedNode): + value: str + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Imaginary": + return Imaginary( + lpar=visit_sequence("lpar", self.lpar, visitor), + value=self.value, + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _validate(self) -> None: + super(Imaginary, self)._validate() + if not re.fullmatch(IMAGNUMBER_RE, self.value): + raise CSTValidationError("Number is not a valid imaginary.") + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append(self.value) + + +@add_slots +@dataclass(frozen=True) +class Number(BaseAtom): + # The actual number component + number: Union[Integer, Float, Imaginary] + + # Any unary operator applied to the number + operator: Optional[Union[Plus, Minus]] = None + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: + """ + Numbers are funny. The expression "5in [1,2,3,4,5]" is a valid expression + which evaluates to "True". So, encapsulate that here by allowing zero spacing + with the left hand side of an expression with a comparison operator. + """ + if position == ExpressionPosition.LEFT: + return True + return super(Number, self)._safe_to_use_with_word_operator(position) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Number": + return Number( + lpar=visit_sequence("lpar", self.lpar, visitor), + operator=visit_optional("operator", self.operator, visitor), + number=visit_required("number", self.number, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + operator = self.operator + if operator is not None: + operator._codegen(state) + self.number._codegen(state) + + +class BaseString(BaseAtom, ABC): + """ + A type that can be used anywhere that you need to explicitly take any + string. + """ + + pass + + +@add_slots +@dataclass(frozen=True) +class SimpleString(BaseString): + value: str + + # Sequence of open parenthesis for precidence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precidence dictation. + rpar: Sequence[RightParen] = () + + def _validate(self) -> None: + super(SimpleString, self)._validate() + + # Validate any prefix + prefix = self._get_prefix() + if prefix not in ("", "r", "u", "b", "br", "rb"): + raise CSTValidationError("Invalid string prefix.") + prefixlen = len(prefix) + # Validate wrapping quotes + if len(self.value) < (prefixlen + 2): + raise CSTValidationError("String must have enclosing quotes.") + if ( + self.value[prefixlen] not in ['"', "'"] + or self.value[prefixlen] != self.value[-1] + ): + raise CSTValidationError("String must have matching enclosing quotes.") + # Check validity of triple-quoted strings + if len(self.value) >= (prefixlen + 6): + if self.value[prefixlen] == self.value[prefixlen + 1]: + # We know this isn't an empty string, so there needs to be a third + # identical enclosing token. + if ( + self.value[prefixlen] != self.value[prefixlen + 2] + or self.value[prefixlen] != self.value[-2] + or self.value[prefixlen] != self.value[-3] + ): + raise CSTValidationError( + "String must have matching enclosing quotes." + ) + # We should check the contents as well, but this is pretty complicated, + # partially due to triple-quoted strings. + + def _get_prefix(self) -> str: + prefix = "" + for c in self.value: + if c in ['"', "'"]: + break + prefix += c + return prefix.lower() + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "SimpleString": + return SimpleString( + lpar=visit_sequence("lpar", self.lpar, visitor), + value=self.value, + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append(self.value) + + +class BaseFormattedStringContent(CSTNode, ABC): + """ + A type that can be used anywhere that you need to take any part of a f-string. + """ + + pass + + +@add_slots +@dataclass(frozen=True) +class FormattedStringText(BaseFormattedStringContent): + # The raw string value. + value: str + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "FormattedStringText": + return FormattedStringText(value=self.value) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append(self.value) + + +@add_slots +@dataclass(frozen=True) +class FormattedStringExpression(BaseFormattedStringContent): + # The expression we will render when printing the string + expression: BaseExpression + + # An optional conversion specifier + conversion: Optional[str] = None + + # An optional format specifier + format_spec: Optional[Sequence[BaseFormattedStringContent]] = None + + # Whitespace + whitespace_before_expression: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_after_expression: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + if self.conversion is not None and self.conversion not in ("s", "r", "a"): + raise CSTValidationError("Invalid f-string conversion.") + + def _visit_and_replace_children( + self, visitor: CSTVisitor + ) -> "FormattedStringExpression": + format_spec = self.format_spec + return FormattedStringExpression( + whitespace_before_expression=visit_required( + "whitespace_before_expression", + self.whitespace_before_expression, + visitor, + ), + expression=visit_required("expression", self.expression, visitor), + whitespace_after_expression=visit_required( + "whitespace_after_expression", self.whitespace_after_expression, visitor + ), + conversion=self.conversion, + format_spec=( + visit_sequence("format_spec", format_spec, visitor) + if format_spec is not None + else None + ), + ) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append("{") + self.whitespace_before_expression._codegen(state) + self.expression._codegen(state) + self.whitespace_after_expression._codegen(state) + conversion = self.conversion + if conversion is not None: + state.tokens.append("!") + state.tokens.append(conversion) + format_spec = self.format_spec + if format_spec is not None: + state.tokens.append(":") + for spec in format_spec: + spec._codegen(state) + state.tokens.append("}") + + +@add_slots +@dataclass(frozen=True) +class FormattedString(BaseString): + # Sequence of formatted string parts + parts: Sequence[BaseFormattedStringContent] + + # String start indicator + start: str = 'f"' + + # String end indicator + end: str = '"' + + # Sequence of open parenthesis for precidence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precidence dictation. + rpar: Sequence[RightParen] = () + + def _validate(self) -> None: + super(FormattedString, self)._validate() + + # Validate any prefix + prefix = self._get_prefix() + if prefix not in ("f", "fr", "rf"): + raise CSTValidationError("Invalid f-string prefix.") + + # Validate wrapping quotes + starttoken = self.start[len(prefix) :] + if starttoken != self.end: + raise CSTValidationError("f-string must have matching enclosing quotes.") + + # Validate valid wrapping quote usage + if starttoken not in ('"', "'", '"""', "'''"): + raise CSTValidationError("Invalid f-string enclosing quotes.") + + def _get_prefix(self) -> str: + prefix = "" + for c in self.start: + if c in ['"', "'"]: + break + prefix += c + return prefix.lower() + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "FormattedString": + return FormattedString( + lpar=visit_sequence("lpar", self.lpar, visitor), + start=self.start, + parts=visit_sequence("parts", self.parts, visitor), + end=self.end, + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append(self.start) + for part in self.parts: + part._codegen(state) + state.tokens.append(self.end) + + +@add_slots +@dataclass(frozen=True) +class ConcatenatedString(BaseString): + # String on the left of the concatenation. + left: Union[SimpleString, FormattedString] + + # String on the right of the concatenation. + right: Union[SimpleString, FormattedString, "ConcatenatedString"] + + # Sequence of open parenthesis for precidence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precidence dictation. + rpar: Sequence[RightParen] = () + + # Whitespace between strings. + whitespace_between: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + super(ConcatenatedString, self)._validate() + + # Strings that are concatenated cannot have parens. + if bool(self.left.lpar) or bool(self.left.rpar): + raise CSTValidationError("Cannot concatenate parenthesized strings.") + if bool(self.right.lpar) or bool(self.right.rpar): + raise CSTValidationError("Cannot concatenate parenthesized strings.") + + # Cannot concatenate str and bytes + leftbytes = "b" in self.left._get_prefix() + if isinstance(self.right, ConcatenatedString): + rightbytes = "b" in self.right.left._get_prefix() + elif isinstance(self.right, SimpleString): + rightbytes = "b" in self.right._get_prefix() + elif isinstance(self.right, FormattedString): + rightbytes = "b" in self.right._get_prefix() + else: + raise Exception("Logic error!") + if leftbytes != rightbytes: + raise CSTValidationError("Cannot concatenate string and bytes.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ConcatenatedString": + return ConcatenatedString( + lpar=visit_sequence("lpar", self.lpar, visitor), + left=visit_required("left", self.left, visitor), + whitespace_between=visit_required( + "whitespace_between", self.whitespace_between, visitor + ), + right=visit_required("right", self.right, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + self.left._codegen(state) + self.whitespace_between._codegen(state) + self.right._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Starred(BaseAssignTargetExpression): + # The actual expression + expression: BaseExpression + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + # Whitespace nodes + whitespace_after_star: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Starred": + return Starred( + lpar=visit_sequence("lpar", self.lpar, visitor), + whitespace_after_star=visit_required( + "whitespace_after_star", self.whitespace_after_star, visitor + ), + expression=visit_required("expression", self.expression, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append("*") + self.whitespace_after_star._codegen(state) + self.expression._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class ComparisonTarget(CSTNode): + """ + A target for a comparison. Owns the comparison operator itself. + """ + + # The actual comparison operator + operator: BaseCompOp + + # The right hand side of the comparison operation + comparator: BaseExpression + + def _validate(self) -> None: + # Validate operator spacing rules + if ( + isinstance(self.operator, (In, NotIn, Is, IsNot)) + and isinstance(self.operator.whitespace_after, SimpleWhitespace) + and self.operator.whitespace_after.value == "" + ): + if not self.comparator._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError( + "Must have at least one space around comparison operator." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ComparisonTarget": + return ComparisonTarget( + operator=visit_required("operator", self.operator, visitor), + comparator=visit_required("comparator", self.comparator, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + self.operator._codegen(state) + self.comparator._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Comparison(BaseExpression): + """ + Any comparison such as "x < y < z" + """ + + # The left hand side of the comparison operation + left: BaseExpression + + # The actual comparison operator + comparisons: Sequence[ComparisonTarget] + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _validate(self) -> None: + # Perform any validation on base type + super(Comparison, self)._validate() + + if len(self.comparisons) == 0: + raise CSTValidationError("Must have at least one ComparisonTarget.") + + # Validate operator spacing rules + if ( + isinstance(self.comparisons[0].operator, (In, NotIn, Is, IsNot)) + and isinstance( + # pyre-fixme[16]: `BaseCompOp` has no attribute `whitespace_before`. + self.comparisons[0].operator.whitespace_before, + SimpleWhitespace, + ) + # pyre-fixme[16]: `BaseCompOp` has no attribute `whitespace_before`. + and self.comparisons[0].operator.whitespace_before.value == "" + ): + if not self.left._safe_to_use_with_word_operator(ExpressionPosition.LEFT): + raise CSTValidationError( + "Must have at least one space around comparison operator." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Comparison": + return Comparison( + lpar=visit_sequence("lpar", self.lpar, visitor), + left=visit_required("left", self.left, visitor), + comparisons=visit_sequence("comparisons", self.comparisons, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + self.left._codegen(state) + for comp in self.comparisons: + comp._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class UnaryOperation(BaseExpression): + """ + Any generic unary expression, such as "not x" or "-x". Note that this node + does not get used for immediate number negation such as "-5". For that, + the Number class is used. + """ + + # The unary operator applied to the expression + operator: BaseUnaryOp + + # The actual expression or atom + expression: BaseExpression + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _validate(self) -> None: + # Perform any validation on base type + super(UnaryOperation, self)._validate() + + if ( + isinstance(self.operator, Not) + and isinstance(self.operator.whitespace_after, SimpleWhitespace) + and self.operator.whitespace_after.value == "" + ): + if not self.expression._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError( + "Must have at least one space after not operator." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "UnaryOperation": + return UnaryOperation( + lpar=visit_sequence("lpar", self.lpar, visitor), + operator=visit_required("operator", self.operator, visitor), + expression=visit_required("expression", self.expression, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + self.operator._codegen(state) + self.expression._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class BinaryOperation(BaseExpression): + """ + Any binary operation such as "x << y" or "y + z". + """ + + # The left hand side of the operation + left: BaseExpression + + # The actual operator + operator: BaseBinaryOp + + # The right hand side of the operation + right: BaseExpression + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "BinaryOperation": + return BinaryOperation( + lpar=visit_sequence("lpar", self.lpar, visitor), + left=visit_required("left", self.left, visitor), + operator=visit_required("operator", self.operator, visitor), + right=visit_required("right", self.right, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + self.left._codegen(state) + self.operator._codegen(state) + self.right._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class BooleanOperation(BaseExpression): + """ + Any boolean operation such as "x or y" or "z and w" + """ + + # The left hand side of the operation + left: BaseExpression + + # The actual operator + operator: BaseBooleanOp + + # The right hand side of the operation + right: BaseExpression + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _validate(self) -> None: + # Paren validation and such + super(BooleanOperation, self)._validate() + # Validate spacing rules + if ( + isinstance(self.operator.whitespace_before, SimpleWhitespace) + and self.operator.whitespace_before.value == "" + ): + if not self.left._safe_to_use_with_word_operator(ExpressionPosition.LEFT): + raise CSTValidationError( + "Must have at least one space around boolean operator." + ) + if ( + isinstance(self.operator.whitespace_after, SimpleWhitespace) + and self.operator.whitespace_after.value == "" + ): + if not self.right._safe_to_use_with_word_operator(ExpressionPosition.RIGHT): + raise CSTValidationError( + "Must have at least one space around boolean operator." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "BooleanOperation": + return BooleanOperation( + lpar=visit_sequence("lpar", self.lpar, visitor), + left=visit_required("left", self.left, visitor), + operator=visit_required("operator", self.operator, visitor), + right=visit_required("right", self.right, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + self.left._codegen(state) + self.operator._codegen(state) + self.right._codegen(state) + + +@dataclass(frozen=True) +class Attribute(BaseAssignTargetExpression, BaseDelTargetExpression): + """ + An attribute reference, such as "x.y". Note that in the case of + "x.y.z", the outer attribute will have an attr of "z" and the + value will be another Attribute referencing the "y" attribute on + "x". + """ + + # Expression which, when evaluated, will have 'attr' as an attribute + value: BaseExpression + + # Name of the attribute being accessed. + attr: Name + + # Separating dot, with any whitespace it owns. + dot: Dot = Dot() + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Attribute": + return Attribute( + lpar=visit_sequence("lpar", self.lpar, visitor), + value=visit_required("value", self.value, visitor), + dot=visit_required("dot", self.dot, visitor), + attr=visit_required("attr", self.attr, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + self.value._codegen(state) + self.dot._codegen(state) + self.attr._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Index(CSTNode): + """ + Any index as passed to a subscript. + """ + + # The index value itself. + value: BaseExpression + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Index": + return Index(value=visit_required("value", self.value, visitor)) + + def _codegen(self, state: CodegenState) -> None: + self.value._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Slice(CSTNode): + """ + Any slice operation in a subscript, such as "1:", "2:3:4", etc. Note + that the grammar does NOT allow parenthesis around a slice so they + are not supported here. + """ + + # The lower bound in the slice, if present + lower: Optional[BaseExpression] + + # The upper bound in the slice, if present + upper: Optional[BaseExpression] + + # The step in the slice, if present + step: Optional[BaseExpression] = None + + # The first slice operator + first_colon: Colon = Colon() + + # The second slice operator, usually omitted + second_colon: Union[Colon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Slice": + return Slice( + lower=visit_optional("lower", self.lower, visitor), + first_colon=visit_required("first_colon", self.first_colon, visitor), + upper=visit_optional("upper", self.upper, visitor), + second_colon=visit_sentinel("second_colon", self.second_colon, visitor), + step=visit_optional("step", self.step, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + lower = self.lower + if lower is not None: + lower._codegen(state) + self.first_colon._codegen(state) + upper = self.upper + if upper is not None: + upper._codegen(state) + second_colon = self.second_colon + if second_colon is MaybeSentinel.DEFAULT and self.step is not None: + state.tokens.append(":") + elif isinstance(second_colon, Colon): + second_colon._codegen(state) + step = self.step + if step is not None: + step._codegen(state) + + +@dataclass(frozen=True) +class ExtSlice(CSTNode): + """ + A list of slices, such as "1:2, 3". Not used in the stdlib but still + valid. This also does not allow for wrapping parenthesis. + "x". + """ + + # A slice or index that is part of the extslice. + slice: Union[Index, Slice] + + # Separating comma, with any whitespace it owns. + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ExtSlice": + return ExtSlice( + slice=visit_required("slice", self.slice, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + ) + + def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + self.slice._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + state.tokens.append(", ") + elif isinstance(comma, Comma): + comma._codegen(state) + + +@dataclass(frozen=True) +class Subscript(BaseAssignTargetExpression, BaseDelTargetExpression): + """ + A subscript reference such as "x[2]". + """ + + # Expression which, when evaluated, will be subscripted. + value: BaseExpression + + # Subscript to take on the value. + slice: Union[Index, Slice, Sequence[ExtSlice]] + + # Open bracket surrounding the slice + lbracket: LeftSquareBracket = LeftSquareBracket() + + # Close bracket surrounding the slice + rbracket: RightSquareBracket = RightSquareBracket() + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + # Whitespace + whitespace_after_value: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + super(Subscript, self)._validate() + if isinstance(self.slice, Sequence): + # Validate valid commas + if len(self.slice) < 1: + raise CSTValidationError("Cannot have empty ExtSlice.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Subscript": + slice = self.slice + return Subscript( + lpar=visit_sequence("lpar", self.lpar, visitor), + value=visit_required("value", self.value, visitor), + whitespace_after_value=visit_required( + "whitespace_after_value", self.whitespace_after_value, visitor + ), + lbracket=visit_required("lbracket", self.lbracket, visitor), + slice=visit_required("slice", slice, visitor) + if isinstance(slice, (Index, Slice)) + else visit_sequence("slice", slice, visitor), + rbracket=visit_required("rbracket", self.rbracket, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + self.value._codegen(state) + self.whitespace_after_value._codegen(state) + self.lbracket._codegen(state) + if isinstance(self.slice, (Index, Slice)): + self.slice._codegen(state) + elif isinstance(self.slice, Sequence): + lastslice = len(self.slice) - 1 + for i, slice in enumerate(self.slice): + slice._codegen(state, default_comma=(i != lastslice)) + else: + # We can make pyre happy this way! + raise Exception("Logic error!") + self.rbracket._codegen(state) + + +@dataclass(frozen=True) +class Annotation(CSTNode): + """ + An annotation. + """ + + # The annotation itself. + annotation: Union[Name, Attribute, BaseString, Subscript] + + # The indicator token before the annotation. + indicator: Union[ + str, AnnotationIndicatorSentinel + ] = AnnotationIndicatorSentinel.DEFAULT + + # Whitespace + whitespace_before_indicator: Union[ + BaseParenthesizableWhitespace, MaybeSentinel + ] = MaybeSentinel.DEFAULT + whitespace_after_indicator: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _validate(self) -> None: + if isinstance(self.indicator, str) and self.indicator not in [":", "->"]: + raise CSTValidationError( + "An Annotation indicator must be one of ':', '->'." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Annotation": + return Annotation( + whitespace_before_indicator=visit_sentinel( + "whitespace_before_indicator", self.whitespace_before_indicator, visitor + ), + indicator=self.indicator, + whitespace_after_indicator=visit_required( + "whitespace_after_indicator", self.whitespace_after_indicator, visitor + ), + annotation=visit_required("annotation", self.annotation, visitor), + ) + + def _codegen( + self, state: CodegenState, default_indicator: Optional[str] = None + ) -> None: + # First, figure out the indicator which tells us default whitespace. + indicator = self.indicator + if isinstance(indicator, AnnotationIndicatorSentinel): + if default_indicator is None: + raise CSTCodegenError( + "Must specify a concrete default_indicator if default used on indicator." + ) + indicator = default_indicator + + # Now, output the whitespace + whitespace_before_indicator = self.whitespace_before_indicator + if isinstance(whitespace_before_indicator, BaseParenthesizableWhitespace): + whitespace_before_indicator._codegen(state) + elif isinstance(whitespace_before_indicator, MaybeSentinel): + if indicator == "->": + state.tokens.append(" ") + else: + raise Exception("Logic error!") + + # Now, output the indicator and the rest of the annotation + state.tokens.append(indicator) + self.whitespace_after_indicator._codegen(state) + self.annotation._codegen(state) + + +@dataclass(frozen=True) +class ParamStar(CSTNode): + """ + A sentinel indicator on a Parameter list to denote that the following params + are kwonly args. + """ + + # Comma that comes after the star. + comma: Comma = Comma(whitespace_after=SimpleWhitespace(" ")) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ParamStar": + return ParamStar(comma=visit_required("comma", self.comma, visitor)) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append("*") + self.comma._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Param(CSTNode): + """ + A single parameter in a Parameter list. May contain a type annotation and + in some cases a default. + """ + + # The parameter name itself + name: Name + + # Any optional annotation + annotation: Optional[Annotation] = None + + # The equals sign used to denote assignment if there is a default. + equal: Union[AssignEqual, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Any optional default + default: Optional[BaseExpression] = None + + # Any trailing comma + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Optional star appearing before name for star_arg and star_kwarg + star: Union[str, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Whitespace + whitespace_after_star: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_after_param: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + if self.default is None and isinstance(self.equal, AssignEqual): + raise CSTValidationError( + "Must have a default when specifying an AssignEqual." + ) + if isinstance(self.star, str) and self.star not in ("", "*", "**"): + raise CSTValidationError("Must specify either '', '*' or '**' for star.") + if ( + self.annotation is not None + and isinstance(self.annotation.indicator, str) + and self.annotation.indicator != ":" + ): + raise CSTValidationError("A param Annotation must be denoted with a ':'.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Param": + return Param( + star=self.star, + whitespace_after_star=visit_required( + "whitespace_after_star", self.whitespace_after_star, visitor + ), + name=visit_required("name", self.name, visitor), + annotation=visit_optional("annotation", self.annotation, visitor), + equal=visit_sentinel("equal", self.equal, visitor), + default=visit_optional("default", self.default, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + whitespace_after_param=visit_required( + "whitespace_after_param", self.whitespace_after_param, visitor + ), + ) + + def _codegen( + self, + state: CodegenState, + default_star: Optional[str] = None, + default_comma: bool = False, + ) -> None: + star = self.star + if isinstance(star, MaybeSentinel): + if default_star is None: + raise CSTCodegenError( + "Must specify a concrete default_star if default used on star." + ) + star = default_star + if isinstance(star, str): + state.tokens.append(star) + self.whitespace_after_star._codegen(state) + self.name._codegen(state) + annotation = self.annotation + if annotation is not None: + annotation._codegen(state, default_indicator=":") + equal = self.equal + if equal is MaybeSentinel.DEFAULT and self.default is not None: + state.tokens.append(" = ") + elif isinstance(equal, AssignEqual): + equal._codegen(state) + default = self.default + if default is not None: + default._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + state.tokens.append(", ") + elif isinstance(comma, Comma): + comma._codegen(state) + self.whitespace_after_param._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Parameters(CSTNode): + """ + A function or lambda parameter list. + """ + + # Positional parameters. + params: Sequence[Param] = () + + # Positional parameters with defaults. + default_params: Sequence[Param] = () + + # Optional parameter that captures unspecified positional arguments or a sentinel + # star that dictates parameters following are kwonly args. + star_arg: Union[Param, ParamStar, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Keyword-only params that may or may not have defaults. + kwonly_params: Sequence[Param] = () + + # Optional parameter that captures unspecified kwargs. + star_kwarg: Optional[Param] = None + + def _validate_stars_sequence(self, vals: Sequence[Param], *, section: str) -> None: + if len(vals) == 0: + return + for val in vals: + if isinstance(val.star, str) and val.star != "": + raise CSTValidationError( + f"Expecting a star prefix of '' for {section} Param." + ) + + def _validate_kwonlystar(self) -> None: + if isinstance(self.star_arg, ParamStar) and len(self.kwonly_params) == 0: + raise CSTValidationError( + "Must have at least one kwonly param if ParamStar is used." + ) + + def _validate_defaults(self) -> None: + for param in self.params: + if param.default is not None: + raise CSTValidationError( + "Cannot have defaults for params. Place them in default_params." + ) + for param in self.default_params: + if param.default is None: + raise CSTValidationError( + "Must have defaults for default_params. Place non-defaults in params." + ) + if isinstance(self.star_arg, Param) and self.star_arg.default is not None: + raise CSTValidationError("Cannot have default for star_arg.") + if self.star_kwarg is not None and self.star_kwarg.default is not None: + raise CSTValidationError("Cannot have default for star_kwarg.") + + def _validate_stars(self) -> None: + if len(self.params) > 0: + self._validate_stars_sequence(self.params, section="params") + if len(self.default_params) > 0: + self._validate_stars_sequence(self.default_params, section="default_params") + star_arg = self.star_arg + if ( + isinstance(star_arg, Param) + and isinstance(star_arg.star, str) + and star_arg.star != "*" + ): + raise CSTValidationError( + "Expecting a star prefix of '*' for star_arg Param." + ) + if len(self.kwonly_params) > 0: + self._validate_stars_sequence(self.kwonly_params, section="kwonly_params") + star_kwarg = self.star_kwarg + if ( + star_kwarg is not None + and isinstance(star_kwarg.star, str) + and star_kwarg.star != "**" + ): + raise CSTValidationError( + "Expecting a star prefix of '**' for star_kwarg Param." + ) + + def _validate(self) -> None: + # Validate kwonly_param star placement semantics. + self._validate_kwonlystar() + # Validate defaults semantics for params, default_params and star_arg/star_kwarg. + self._validate_defaults() + # Validate that we don't have random stars on non star_kwarg. + self._validate_stars() + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Parameters": + return Parameters( + params=visit_sequence("params", self.params, visitor), + default_params=visit_sequence( + "default_params", self.default_params, visitor + ), + star_arg=visit_sentinel("star_arg", self.star_arg, visitor), + kwonly_params=visit_sequence("kwonly_params", self.kwonly_params, visitor), + star_kwarg=visit_optional("star_kwarg", self.star_kwarg, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + # Compute the star existence first so we can ask about whether + # each element is the last in the list or not. + star_arg = self.star_arg + if isinstance(star_arg, MaybeSentinel): + starincluded = len(self.kwonly_params) > 0 + elif isinstance(star_arg, (Param, ParamStar)): + starincluded = True + else: + starincluded = False + # Render out the params first, computing necessary trailing commas. + lastparam = len(self.params) - 1 + more_values = ( + len(self.default_params) > 0 + or starincluded + or len(self.kwonly_params) > 0 + or self.star_kwarg is not None + ) + for i, param in enumerate(self.params): + param._codegen( + state, default_star="", default_comma=(i < lastparam or more_values) + ) + # Render out the default_params next, computing necessary trailing commas. + lastparam = len(self.default_params) - 1 + more_values = ( + starincluded or len(self.kwonly_params) > 0 or self.star_kwarg is not None + ) + for i, param in enumerate(self.default_params): + param._codegen( + state, default_star="", default_comma=(i < lastparam or more_values) + ) + # Render out optional star sentinel if its explicitly included or + # if we are inferring it from kwonly_params. Otherwise, render out the + # optional star_arg. + if isinstance(star_arg, MaybeSentinel): + if starincluded: + state.tokens.append("*, ") + elif isinstance(star_arg, Param): + more_values = len(self.kwonly_params) > 0 or self.star_kwarg is not None + star_arg._codegen(state, default_star="*", default_comma=more_values) + elif isinstance(star_arg, ParamStar): + star_arg._codegen(state) + # Render out the kwonly_args next, computing necessary trailing commas. + lastparam = len(self.kwonly_params) - 1 + more_values = self.star_kwarg is not None + for i, param in enumerate(self.kwonly_params): + param._codegen( + state, default_star="", default_comma=(i < lastparam or more_values) + ) + # Finally, render out any optional star_kwarg + star_kwarg = self.star_kwarg + if star_kwarg is not None: + star_kwarg._codegen(state, default_star="**", default_comma=False) + + +@add_slots +@dataclass(frozen=True) +class Lambda(BaseExpression): + # The parameters to the lambda + params: Parameters + + # The body of the lambda + body: BaseExpression + + # The colon separating the parameters from the body + colon: Colon = Colon(whitespace_after=SimpleWhitespace(" ")) + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + # Whitespace + whitespace_after_lambda: Union[ + BaseParenthesizableWhitespace, MaybeSentinel + ] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + # Validate parents + super(Lambda, self)._validate() + # Sum up all parameters + all_params: List[Param] = [ + *self.params.params, + *self.params.default_params, + *self.params.kwonly_params, + ] + if isinstance(self.params.star_arg, Param): + all_params.append(self.params.star_arg) + if self.params.star_kwarg is not None: + all_params.append(self.params.star_kwarg) + # Check for nonzero parameters because several checks care + # about this. + if len(all_params) > 0: + for param in all_params: + if param.annotation is not None: + raise CSTValidationError( + "Lambda params cannot have type annotations." + ) + if ( + isinstance(self.whitespace_after_lambda, SimpleWhitespace) + and len(self.whitespace_after_lambda.value) == 0 + ): + raise CSTValidationError( + "Must have at least one space after lambda when specifying params" + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Lambda": + return Lambda( + lpar=visit_sequence("lpar", self.lpar, visitor), + whitespace_after_lambda=visit_sentinel( + "whitespace_after_lambda", self.whitespace_after_lambda, visitor + ), + params=visit_required("params", self.params, visitor), + colon=visit_required("colon", self.colon, visitor), + body=visit_required("body", self.body, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append("lambda") + whitespace_after_lambda = self.whitespace_after_lambda + if isinstance(whitespace_after_lambda, MaybeSentinel): + if not ( + len(self.params.params) == 0 + and len(self.params.default_params) == 0 + and not isinstance(self.params.star_arg, Param) + and len(self.params.kwonly_params) == 0 + and self.params.star_kwarg is None + ): + # We have one or more params, provide a space + state.tokens.append(" ") + elif isinstance(whitespace_after_lambda, BaseParenthesizableWhitespace): + whitespace_after_lambda._codegen(state) + self.params._codegen(state) + self.colon._codegen(state) + self.body._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Arg(CSTNode): + """ + A single argument to a Call. It may be a * or a ** expansion, or it may be in + the form of "keyword=expression" for named arguments. + """ + + # The argument expression itself + value: BaseExpression + + # Optional keyword for the argument + keyword: Optional[Name] = None + + # The equals sign used to denote assignment if there is a keyword. + equal: Union[AssignEqual, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Any trailing comma + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Optional star appearing before name for * and ** expansion + star: Literal["", "*", "**"] = "" + + # Whitespace + whitespace_after_star: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_after_arg: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + if self.keyword is None and isinstance(self.equal, AssignEqual): + raise CSTValidationError( + "Must have a keyword when specifying an AssignEqual." + ) + if self.star not in ("", "*", "**"): + raise CSTValidationError("Must specify either '', '*' or '**' for star.") + if self.star in ("*", "**") and self.keyword is not None: + raise CSTValidationError("Cannot specify a star and a keyword together.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Arg": + return Arg( + star=self.star, + whitespace_after_star=visit_required( + "whitespace_after_star", self.whitespace_after_star, visitor + ), + keyword=visit_optional("keyword", self.keyword, visitor), + equal=visit_sentinel("equal", self.equal, visitor), + value=visit_required("value", self.value, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + whitespace_after_arg=visit_required( + "whitespace_after_arg", self.whitespace_after_arg, visitor + ), + ) + + def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + state.tokens.append(self.star) + self.whitespace_after_star._codegen(state) + keyword = self.keyword + if keyword is not None: + keyword._codegen(state) + equal = self.equal + if equal is MaybeSentinel.DEFAULT and self.keyword is not None: + state.tokens.append(" = ") + elif isinstance(equal, AssignEqual): + equal._codegen(state) + self.value._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + state.tokens.append(", ") + elif isinstance(comma, Comma): + comma._codegen(state) + self.whitespace_after_arg._codegen(state) + + +class _BaseExpressionWithArgs(BaseExpression, ABC): + """ + Arguments are complicated enough that we can't represent them easily + in typing. So, we have common validation functions here. + """ + + # Sequence of arguments that will be passed to the functgion call + args: Sequence[Arg] = () # TODO This can also be a single Generator. + + def _check_kwargs_or_keywords( + self, arg: Arg + ) -> Optional[Callable[[Arg], Callable]]: + """ + Validates that we only have a mix of "keyword=arg" and "**arg" expansion. + """ + + if arg.keyword is not None: + # Valid, keyword argument + return None + elif arg.star == "**": + # Valid, kwargs + return None + elif arg.star == "*": + # Invalid, cannot have "*" follow "**" + raise CSTValidationError( + "Cannot have iterable argument unpacking after keyword argument unpacking." + ) + else: + # Invalid, cannot have positional argument follow **/keyword + raise CSTValidationError( + "Cannot have positional argument after keyword argument unpacking." + ) + + def _check_starred_or_keywords( + self, arg: Arg + ) -> Optional[Callable[[Arg], Callable]]: + """ + Validates that we only have a mix of "*arg" expansion and "keyword=arg". + """ + + if arg.keyword is not None: + # Valid, keyword argument + return None + elif arg.star == "**": + # Valid, but we now no longer allow "*" args + # pyre-fixme[7]: Expected `Optional[Callable[[Arg], Callable[..., + # Any]]]` but got `Callable[[Arg], Optional[Callable[[Arg], Callable[..., + # Any]]]]`. + return self._check_kwargs_or_keywords + elif arg.star == "*": + # Valid, iterable unpacking + return None + else: + # Invalid, cannot have positional argument follow **/keyword + raise CSTValidationError( + "Cannot have positional argument after keyword argument." + ) + + def _check_positional(self, arg: Arg) -> Optional[Callable[[Arg], Callable]]: + """ + Validates that we only have a mix of positional args and "*arg" expansion. + """ + + if arg.keyword is not None: + # Valid, but this puts us into starred/keyword state + # pyre-fixme[7]: Expected `Optional[Callable[[Arg], Callable[..., + # Any]]]` but got `Callable[[Arg], Optional[Callable[[Arg], Callable[..., + # Any]]]]`. + return self._check_starred_or_keywords + elif arg.star == "**": + # Valid, but we skip states to kwargs/keywords + # pyre-fixme[7]: Expected `Optional[Callable[[Arg], Callable[..., + # Any]]]` but got `Callable[[Arg], Optional[Callable[[Arg], Callable[..., + # Any]]]]`. + return self._check_kwargs_or_keywords + elif arg.star == "*": + # Valid, iterator expansion + return None + else: + # Valid, allowed to have positional arguments here + return None + + def _validate(self) -> None: + # Validate any super-class stuff, whatever it may be. + super()._validate() + # Now, validate the weird intermingling rules for arguments by running + # a small validator state machine. This works by passing each argument + # to a validator function which can either raise an exception if it + # detects an invalid sequence, return a new validator to be used for the + # next arg, or return None to use the same validator. We could enforce + # always returning ourselves instead of None but it ends up making the + # functions themselves less readable. In this way, the current validator + # function encodes the state we're in (positional state, iterable + # expansion state, or dictionary expansion state). + validator = self._check_positional + for arg in self.args: + # pyre-fixme[29]: `Union[Callable[[Arg], Callable[..., Any]], + # Callable[..., Any]]` is not a function. + validator = validator(arg) or validator + + +@add_slots +@dataclass(frozen=True) +class Call(_BaseExpressionWithArgs): + # The expression resulting in a callable that we are to call + func: Union[BaseAtom, Attribute, Subscript, "Call"] + + # The arguments to pass to the resulting callable + args: Sequence[Arg] = () # TODO This can also be a single Generator. + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + # Whitespace nodes + whitespace_after_func: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_before_args: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: + """ + Calls have a close paren on the right side regardless of whether they're + parenthesized as a whole. As a result, they are safe to use directly against + an adjacent node to the right. + """ + if position == ExpressionPosition.LEFT: + return True + return super(Call, self)._safe_to_use_with_word_operator(position) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Call": + return Call( + lpar=visit_sequence("lpar", self.lpar, visitor), + func=visit_required("func", self.func, visitor), + whitespace_after_func=visit_required( + "whitespace_after_func", self.whitespace_after_func, visitor + ), + whitespace_before_args=visit_required( + "whitespace_before_args", self.whitespace_before_args, visitor + ), + args=visit_sequence("args", self.args, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + self.func._codegen(state) + self.whitespace_after_func._codegen(state) + state.tokens.append("(") + self.whitespace_before_args._codegen(state) + lastarg = len(self.args) - 1 + for i, arg in enumerate(self.args): + arg._codegen(state, default_comma=(i != lastarg)) + state.tokens.append(")") + + +@add_slots +@dataclass(frozen=True) +class Await(BaseExpression): + # The actual expression we need to await on + expression: BaseExpression + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + # Whitespace nodes + whitespace_after_await: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _validate(self) -> None: + # Validate any super-class stuff, whatever it may be. + super(Await, self)._validate() + # Make sure we don't run identifiers together. + if ( + isinstance(self.whitespace_after_await, SimpleWhitespace) + and len(self.whitespace_after_await.value) == 0 + ): + raise CSTValidationError("Must have at least one space after await") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Await": + return Await( + lpar=visit_sequence("lpar", self.lpar, visitor), + whitespace_after_await=visit_required( + "whitespace_after_await", self.whitespace_after_await, visitor + ), + expression=visit_required("expression", self.expression, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append("await") + self.whitespace_after_await._codegen(state) + self.expression._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class IfExp(BaseExpression): + """ + An if expression similar to "body if test else orelse". + """ + + # The test to perform. + test: BaseExpression + + # The expression to evaluate if the test is true. + body: BaseExpression + + # The expression to evaluate if the test is false. + orelse: BaseExpression + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + # Whitespace nodes + whitespace_before_if: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after_if: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_before_else: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after_else: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _validate(self) -> None: + # Paren validation and such + super(IfExp, self)._validate() + # Validate spacing rules + if ( + isinstance(self.whitespace_before_if, SimpleWhitespace) + and self.whitespace_before_if.value == "" + ): + if not self.body._safe_to_use_with_word_operator(ExpressionPosition.LEFT): + raise CSTValidationError( + "Must have at least one space before 'if' keyword." + ) + if ( + isinstance(self.whitespace_after_if, SimpleWhitespace) + and self.whitespace_after_if.value == "" + ): + if not self.test._safe_to_use_with_word_operator(ExpressionPosition.RIGHT): + raise CSTValidationError( + "Must have at least one space after 'if' keyword." + ) + if ( + isinstance(self.whitespace_before_else, SimpleWhitespace) + and self.whitespace_before_else.value == "" + ): + if not self.test._safe_to_use_with_word_operator(ExpressionPosition.LEFT): + raise CSTValidationError( + "Must have at least one space before 'else' keyword." + ) + if ( + isinstance(self.whitespace_after_else, SimpleWhitespace) + and self.whitespace_after_else.value == "" + ): + if not self.orelse._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError( + "Must have at least one space after 'else' keyword." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "IfExp": + return IfExp( + lpar=visit_sequence("lpar", self.lpar, visitor), + body=visit_required("body", self.body, visitor), + whitespace_before_if=visit_required( + "whitespace_before_if", self.whitespace_before_if, visitor + ), + whitespace_after_if=visit_required( + "whitespace_after_if", self.whitespace_after_if, visitor + ), + test=visit_required("test", self.test, visitor), + whitespace_before_else=visit_required( + "whitespace_before_else", self.whitespace_before_else, visitor + ), + whitespace_after_else=visit_required( + "whitespace_after_else", self.whitespace_after_else, visitor + ), + orelse=visit_required("orelse", self.orelse, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + self.body._codegen(state) + self.whitespace_before_if._codegen(state) + state.tokens.append("if") + self.whitespace_after_if._codegen(state) + self.test._codegen(state) + self.whitespace_before_else._codegen(state) + state.tokens.append("else") + self.whitespace_after_else._codegen(state) + self.orelse._codegen(state) + + +@dataclass(frozen=True) +class From(CSTNode): + """ + A 'from x' stanza in a Yield or Raise. + """ + + # Expression that we are yielding/raising from. + item: BaseExpression + + whitespace_before_from: Union[ + BaseParenthesizableWhitespace, MaybeSentinel + ] = MaybeSentinel.DEFAULT + whitespace_after_from: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _validate(self) -> None: + if ( + isinstance(self.whitespace_after_from, SimpleWhitespace) + and self.whitespace_after_from.value == "" + and not self.item._safe_to_use_with_word_operator(ExpressionPosition.RIGHT) + ): + raise CSTValidationError( + "Must have at least one space after 'from' keyword." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "From": + return From( + whitespace_before_from=visit_sentinel( + "whitespace_before_from", self.whitespace_before_from, visitor + ), + item=visit_required("item", self.item, visitor), + whitespace_after_from=visit_required( + "whitespace_after_from", self.whitespace_after_from, visitor + ), + ) + + def _codegen(self, state: CodegenState, default_space: str = "") -> None: + whitespace_before_from = self.whitespace_before_from + if isinstance(whitespace_before_from, BaseParenthesizableWhitespace): + whitespace_before_from._codegen(state) + else: + state.tokens.append(default_space) + state.tokens.append("from") + self.whitespace_after_from._codegen(state) + self.item._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Yield(BaseExpression): + """ + A yield expression similar to "yield x" or "yield from fun()" + """ + + # The test to perform. + value: Optional[Union[BaseExpression, From]] = None + + # Sequence of open parenthesis for precedence dictation. + lpar: Sequence[LeftParen] = () + + # Sequence of close parenthesis for precedence dictation. + rpar: Sequence[RightParen] = () + + # Whitespace nodes + whitespace_after_yield: Union[ + BaseParenthesizableWhitespace, MaybeSentinel + ] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + # Paren rules and such + super(Yield, self)._validate() + # Our own rules + if ( + isinstance(self.whitespace_after_yield, SimpleWhitespace) + and self.whitespace_after_yield.value == "" + ): + if isinstance(self.value, From): + raise CSTValidationError( + "Must have at least one space after 'yield' keyword." + ) + if isinstance( + self.value, BaseExpression + ) and not self.value._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError( + "Must have at least one space after 'yield' keyword." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Yield": + return Yield( + lpar=visit_sequence("lpar", self.lpar, visitor), + whitespace_after_yield=visit_sentinel( + "whitespace_after_yield", self.whitespace_after_yield, visitor + ), + value=visit_optional("value", self.value, visitor), + rpar=visit_sequence("rpar", self.rpar, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + with self._parenthesize(state): + state.tokens.append("yield") + whitespace_after_yield = self.whitespace_after_yield + if isinstance(whitespace_after_yield, BaseParenthesizableWhitespace): + whitespace_after_yield._codegen(state) + else: + # Only need a space after yield if there is a value to yield. + if self.value is not None: + state.tokens.append(" ") + value = self.value + if isinstance(value, From): + value._codegen(state, default_space="") + elif value is not None: + value._codegen(state) diff --git a/libcst/nodes/_internal.py b/libcst/nodes/_internal.py new file mode 100644 index 00000000..7274413c --- /dev/null +++ b/libcst/nodes/_internal.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, TypeVar, Union + +from libcst._add_slots import add_slots +from libcst._maybe_sentinel import MaybeSentinel +from libcst._removal_sentinel import RemovalSentinel + + +if TYPE_CHECKING: + # These are circular dependencies only used for typing purposes + from libcst.nodes._base import CSTNode + from libcst._base_visitor import CSTVisitor + + +_CSTNodeT = TypeVar("_CSTNodeT", bound="CSTNode") + + +@add_slots +@dataclass(frozen=False) +class CodegenState: + # These are derived from a Module + default_indent: str + default_newline: str + + indent: List[str] = field(default_factory=list) + tokens: List[str] = field(default_factory=list) + + +def visit_required(fieldname: str, node: _CSTNodeT, visitor: "CSTVisitor") -> _CSTNodeT: + """ + Given a node, visits the node using `visitor`. If removal is attempted by the + visitor, an exception is raised. + """ + result = node.visit(visitor) + if isinstance(result, RemovalSentinel): + raise TypeError( + f"We got a RemovalSentinel while visiting a {type(node).__name__}. This " + + "node's parent does not allow it to be removed." + ) + return result + + +def visit_optional( + fieldname: str, node: Optional[_CSTNodeT], visitor: "CSTVisitor" +) -> Optional[_CSTNodeT]: + """ + Given an optional node, visits the node if it exists with `visitor`. If the node is + removed, returns None. + """ + if node is None: + return None + result = node.visit(visitor) + return None if isinstance(result, RemovalSentinel) else result + + +def visit_sentinel( + fieldname: str, node: Union[_CSTNodeT, MaybeSentinel], visitor: "CSTVisitor" +) -> Union[_CSTNodeT, MaybeSentinel]: + """ + Given a node that can be a real value or a sentinel value, visits the node if it + is real with `visitor`. If the node is removed, returns MaybeSentinel. + """ + if isinstance(node, MaybeSentinel): + return MaybeSentinel.DEFAULT + result = node.visit(visitor) + return MaybeSentinel.DEFAULT if isinstance(result, RemovalSentinel) else result + + +def visit_iterable( + fieldname: str, children: Iterable[_CSTNodeT], visitor: "CSTVisitor" +) -> Iterable[_CSTNodeT]: + """ + Given an iterable of children, visits each child with `visitor`, and yields the new + children with any `RemovalSentinel` values removed. + """ + for child in children: + new_child = child.visit(visitor) + if not isinstance(new_child, RemovalSentinel): + yield new_child + + +def visit_sequence( + fieldname: str, children: Sequence[_CSTNodeT], visitor: "CSTVisitor" +) -> Sequence[_CSTNodeT]: + """ + A convenience wrapper for `visit_iterable` that returns a sequence instead of an + iterable. + """ + return tuple(visit_iterable(fieldname, children, visitor)) diff --git a/libcst/nodes/_module.py b/libcst/nodes/_module.py new file mode 100644 index 00000000..6d042f3f --- /dev/null +++ b/libcst/nodes/_module.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Sequence, TypeVar, Union + +from libcst._add_slots import add_slots +from libcst._base_visitor import CSTVisitor +from libcst._removal_sentinel import RemovalSentinel +from libcst.nodes._base import CSTNode +from libcst.nodes._internal import CodegenState, visit_sequence +from libcst.nodes._statement import BaseCompoundStatement, SimpleStatementLine +from libcst.nodes._whitespace import EmptyLine + + +_ModuleSelfT = TypeVar("_ModuleSelfT", bound="Module") + +# type alias needed for scope overlap in type definition +builtin_bytes = bytes + + +@add_slots +@dataclass(frozen=True) +class Module(CSTNode): + """ + Contains some top-level information inferred from the file letting us set correct + defaults when printing the tree about global formatting rules. + """ + + body: Sequence[Union[SimpleStatementLine, BaseCompoundStatement]] + # Normally any whitespace/comments are assigned to the next node visited, but Module + # is a special case, and comments at the top of the file tend to refer to the module + # itself, so we assign them to the Module. + header: Sequence[EmptyLine] = () + footer: Sequence[EmptyLine] = () + + encoding: str = "utf-8" + default_indent: str = " " * 4 + default_newline: str = "\n" + has_trailing_newline: bool = True + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Module": + return Module( + header=visit_sequence("header", self.header, visitor), + body=visit_sequence("body", self.body, visitor), + footer=visit_sequence("footer", self.footer, visitor), + encoding=self.encoding, + default_indent=self.default_indent, + default_newline=self.default_newline, + has_trailing_newline=self.has_trailing_newline, + ) + + def visit(self: _ModuleSelfT, visitor: CSTVisitor) -> _ModuleSelfT: + result = CSTNode.visit(self, visitor) + if isinstance(result, RemovalSentinel): + return self.with_changes(body=(), header=(), footer=()) + else: # is a Module + return result + + def _codegen(self, state: CodegenState) -> None: + for h in self.header: + h._codegen(state) + for stmt in self.body: + stmt._codegen(state) + for f in self.footer: + f._codegen(state) + if self.has_trailing_newline: + if len(state.tokens) == 0: + # There was nothing in the header, footer, or body. Just add a newline + # to preserve the trailing newline. + state.tokens.append(state.default_newline) + else: # has_trailing_newline is false + if len(state.tokens) > 0: + # EmptyLine and all statements generate newlines, so we can be sure that + # the last token (if we're not an empty file) is a newline. + state.tokens.pop() + + @property + def code(self) -> str: + return self.code_for_node(self) + + @property + def bytes(self) -> builtin_bytes: + return self.code.encode(self.encoding) + + def code_for_node(self, node: CSTNode) -> str: + """ + Generates the code for the given node in the context of this module. This is a + method of Module, not CSTNode, because we need to know the module's default + indentation and newline formats. + """ + state = CodegenState( + default_indent=self.default_indent, default_newline=self.default_newline + ) + node._codegen(state) + return "".join(state.tokens) diff --git a/libcst/nodes/_op.py b/libcst/nodes/_op.py new file mode 100644 index 00000000..ce6be65e --- /dev/null +++ b/libcst/nodes/_op.py @@ -0,0 +1,635 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Tuple + +from libcst._add_slots import add_slots +from libcst._base_visitor import CSTVisitor +from libcst.nodes._base import BaseLeaf, CSTNode, CSTValidationError +from libcst.nodes._internal import CodegenState, visit_required +from libcst.nodes._whitespace import BaseParenthesizableWhitespace, SimpleWhitespace + + +class _BaseOneTokenOp(CSTNode, ABC): + """ + Any node that has a static value and needs to own whitespace on both sides. + """ + + whitespace_before: BaseParenthesizableWhitespace + whitespace_after: BaseParenthesizableWhitespace + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "_BaseOneTokenOp": + return self.__class__( + whitespace_before=visit_required( + "whitespace_before", self.whitespace_before, visitor + ), + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ), + ) + + def _codegen(self, state: CodegenState) -> None: + self.whitespace_before._codegen(state) + state.tokens.append(self._get_token()) + self.whitespace_after._codegen(state) + + @abstractmethod + def _get_token(self) -> str: + ... + + +class _BaseTwoTokenOp(CSTNode, ABC): + """ + This node ends up as two tokens, so we must preserve the whitespace + in beteween them. + """ + + whitespace_before: BaseParenthesizableWhitespace + whitespace_between: BaseParenthesizableWhitespace + whitespace_after: BaseParenthesizableWhitespace + + def _validate(self) -> None: + if ( + isinstance(self.whitespace_between, SimpleWhitespace) + and len(self.whitespace_between.value) == 0 + ): + raise CSTValidationError("Must have at least one space between not and in.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "_BaseTwoTokenOp": + return self.__class__( + whitespace_before=visit_required( + "whitespace_before", self.whitespace_before, visitor + ), + whitespace_between=visit_required( + "whitespace_between", self.whitespace_between, visitor + ), + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ), + ) + + def _codegen(self, state: CodegenState) -> None: + self.whitespace_before._codegen(state) + state.tokens.append(self._get_tokens()[0]) + self.whitespace_between._codegen(state) + state.tokens.append(self._get_tokens()[1]) + self.whitespace_after._codegen(state) + + @abstractmethod + def _get_tokens(self) -> Tuple[str, str]: + ... + + +class BaseUnaryOp(CSTNode, ABC): + """ + Any node that has a static value used in a Unary expression. + """ + + whitespace_after: BaseParenthesizableWhitespace + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "BaseUnaryOp": + return self.__class__( + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ) + ) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append(self._get_token()) + self.whitespace_after._codegen(state) + + @abstractmethod + def _get_token(self) -> str: + ... + + +class BaseBooleanOp(_BaseOneTokenOp, ABC): + """ + Any node that has a static value used in a Binary expression. This node + is purely for typing. + """ + + +class BaseBinaryOp(CSTNode, ABC): + """ + Any node that has a static value used in a Binary expression. This node + is purely for typing. + """ + + +class BaseCompOp(CSTNode, ABC): + """ + Any node that has a static value used in a CompExpression. This node + is purely for typing. + """ + + +class BaseAugOp(CSTNode, ABC): + """ + Any node that has a static value used in an AugAssign. This node is purely + for typing. + """ + + +@add_slots +@dataclass(frozen=True) +class Semicolon(_BaseOneTokenOp): + """ + Used by SmallStatement as a separator between subsequent SmallStatements contained + within a SimpleStatementLine or SimpleStatementSuite. + """ + + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _get_token(self) -> str: + return ";" + + +@add_slots +@dataclass(frozen=True) +class Colon(_BaseOneTokenOp): + """ + Used by Slice as a separator between subsequent Expressions, and in Lambda + to separate arguments and body. + """ + + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _get_token(self) -> str: + return ":" + + +@add_slots +@dataclass(frozen=True) +class Comma(_BaseOneTokenOp): + """ + Used by ImportAlias as a separator between subsequent ImportAliases contained + within a Import or ImportFrom. + """ + + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _get_token(self) -> str: + return "," + + +@add_slots +@dataclass(frozen=True) +class Dot(_BaseOneTokenOp): + """ + Used by Attribute and DottedName as a separator between subsequent + Name nodes. + """ + + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace("") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _get_token(self) -> str: + return "." + + +@add_slots +@dataclass(frozen=True) +class ImportStar(BaseLeaf): + """ + Used by ImportFrom to denote a star import. + """ + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append("*") + + +@add_slots +@dataclass(frozen=True) +class AssignEqual(_BaseOneTokenOp): + """ + Used by AnnAssign to denote a single equal character when doing an + assignment on top of a type annotation. + """ + + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "=" + + +@dataclass(frozen=True) +class Plus(BaseUnaryOp): + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _get_token(self) -> str: + return "+" + + +@dataclass(frozen=True) +class Minus(BaseUnaryOp): + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _get_token(self) -> str: + return "-" + + +@dataclass(frozen=True) +class BitInvert(BaseUnaryOp): + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") + + def _get_token(self) -> str: + return "~" + + +@dataclass(frozen=True) +class Not(BaseUnaryOp): + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "not" + + +@dataclass(frozen=True) +class And(BaseBooleanOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "and" + + +@dataclass(frozen=True) +class Or(BaseBooleanOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "or" + + +@dataclass(frozen=True) +class Add(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "+" + + +@dataclass(frozen=True) +class Subtract(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "-" + + +@dataclass(frozen=True) +class Multiply(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "*" + + +@dataclass(frozen=True) +class Divide(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "/" + + +@dataclass(frozen=True) +class FloorDivide(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "//" + + +@dataclass(frozen=True) +class Modulo(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "%" + + +@dataclass(frozen=True) +class Power(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "**" + + +@dataclass(frozen=True) +class LeftShift(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "<<" + + +@dataclass(frozen=True) +class RightShift(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return ">>" + + +@dataclass(frozen=True) +class BitOr(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "|" + + +@dataclass(frozen=True) +class BitAnd(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "&" + + +@dataclass(frozen=True) +class BitXor(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "^" + + +@dataclass(frozen=True) +class MatrixMultiply(BaseBinaryOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "@" + + +@dataclass(frozen=True) +class LessThan(BaseCompOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "<" + + +@dataclass(frozen=True) +class GreaterThan(BaseCompOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return ">" + + +@dataclass(frozen=True) +class Equal(BaseCompOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "==" + + +@dataclass(frozen=True) +class LessThanEqual(BaseCompOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "<=" + + +@dataclass(frozen=True) +class GreaterThanEqual(BaseCompOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return ">=" + + +@dataclass(frozen=True) +class NotEqual(BaseCompOp): + """ + This node defines a static value for convenience, but in reality due to + PEP 401 it can be one of two values, both of which should be a NotEqual + CompOp. + """ + + value: str = "!=" + + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _validate(self) -> None: + if self.value not in ["!=", "<>"]: + raise CSTValidationError("Invalid value for NotEqual node.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "BaseCompOp": + return self.__class__( + whitespace_before=visit_required( + "whitespace_before", self.whitespace_before, visitor + ), + value=self.value, + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ), + ) + + def _codegen(self, state: CodegenState) -> None: + self.whitespace_before._codegen(state) + state.tokens.append(self.value) + self.whitespace_after._codegen(state) + + +@dataclass(frozen=True) +class In(BaseCompOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "in" + + +@dataclass(frozen=True) +class NotIn(BaseCompOp, _BaseTwoTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_between: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_tokens(self) -> Tuple[str, str]: + return ("not", "in") + + +@dataclass(frozen=True) +class Is(BaseCompOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "is" + + +@dataclass(frozen=True) +class IsNot(BaseCompOp, _BaseTwoTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_between: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_tokens(self) -> Tuple[str, str]: + return ("is", "not") + + +@add_slots +@dataclass(frozen=True) +class AddAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "+=" + + +@add_slots +@dataclass(frozen=True) +class SubtractAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "-=" + + +@add_slots +@dataclass(frozen=True) +class MultiplyAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "*=" + + +@add_slots +@dataclass(frozen=True) +class MatrixMultiplyAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "@=" + + +@add_slots +@dataclass(frozen=True) +class DivideAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "/=" + + +@add_slots +@dataclass(frozen=True) +class ModuloAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "%=" + + +@add_slots +@dataclass(frozen=True) +class BitAndAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "&=" + + +@add_slots +@dataclass(frozen=True) +class BitOrAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "|=" + + +@add_slots +@dataclass(frozen=True) +class BitXorAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "^=" + + +@add_slots +@dataclass(frozen=True) +class LeftShiftAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "<<=" + + +@add_slots +@dataclass(frozen=True) +class RightShiftAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return ">>=" + + +@add_slots +@dataclass(frozen=True) +class PowerAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "**=" + + +@add_slots +@dataclass(frozen=True) +class FloorDivideAssign(BaseAugOp, _BaseOneTokenOp): + whitespace_before: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _get_token(self) -> str: + return "//=" diff --git a/libcst/nodes/_statement.py b/libcst/nodes/_statement.py new file mode 100644 index 00000000..e54b7071 --- /dev/null +++ b/libcst/nodes/_statement.py @@ -0,0 +1,2009 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Sequence, Union + +from libcst._add_slots import add_slots +from libcst._base_visitor import CSTVisitor +from libcst._maybe_sentinel import MaybeSentinel +from libcst.nodes._base import CSTNode, CSTValidationError +from libcst.nodes._expression import ( + Annotation, + Arg, + Attribute, + BaseAssignTargetExpression, + BaseAtom, + BaseDelTargetExpression, + BaseExpression, + Call, + ExpressionPosition, + From, + LeftParen, + Name, + Parameters, + RightParen, +) +from libcst.nodes._internal import ( + CodegenState, + visit_optional, + visit_required, + visit_sentinel, + visit_sequence, +) +from libcst.nodes._op import AssignEqual, BaseAugOp, Comma, Dot, ImportStar, Semicolon +from libcst.nodes._whitespace import ( + BaseParenthesizableWhitespace, + EmptyLine, + SimpleWhitespace, + TrailingWhitespace, +) + + +_INDENT_WHITESPACE_RE = re.compile(r"[ \f\t]+", re.UNICODE) + + +class BaseSuite(CSTNode, ABC): + """ + A dummy base-class for both SmallStatementLine and IndentedBlock. This exists to + simplify type definitions and isinstance checks. + + > A suite is a group of statements controlled by a clause. A suite can be one or + > more semicolon-separated simple statements on the same line as the header, + > following the header’s colon, or it can be one or more indented statements on + > subsequent lines. + + -- https://docs.python.org/3/reference/compound_stmts.html + """ + + body: Union[ + Sequence[Union["SimpleStatementLine", "BaseCompoundStatement"]], + Sequence["BaseSmallStatement"], + ] + + +class BaseSmallStatement(CSTNode, ABC): + """ + Encapsulates a small statement, like "del" or "pass", and optionally adds a trailing + semicolon. A SmallStatement is always contained inside a SimpleStatementLine or + SimpleStatementSuite. + """ + + # This is optional for the last SmallStatement in a SimpleStatementLine or + # SimpleStatementSuite, but all other SmallStatements inside a simple statement must + # contain a semicolon to disambiguate multiple small statements on the same line. + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + @abstractmethod + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + ... + + +@add_slots +@dataclass(frozen=True) +class Del(BaseSmallStatement): + """ + Represents a `del` statement. `del` is always followed by a target. + """ + + target: BaseDelTargetExpression + whitespace_after_del: SimpleWhitespace = SimpleWhitespace(" ") + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + has_no_gap = len(self.whitespace_after_del.value) == 0 + if has_no_gap and not self.target._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError("Must have at least one space after 'del'.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Del": + return Del( + target=visit_required("target", self.target, visitor), + whitespace_after_del=visit_required( + "whitespace_after_del", self.whitespace_after_del, visitor + ), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + state.tokens.append("del") + self.whitespace_after_del._codegen(state) + self.target._codegen(state) + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Pass(BaseSmallStatement): + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Pass": + return Pass(semicolon=visit_sentinel("semicolon", self.semicolon, visitor)) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + state.tokens.append("pass") + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Break(BaseSmallStatement): + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Break": + return Break(semicolon=visit_sentinel("semicolon", self.semicolon, visitor)) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + state.tokens.append("break") + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Continue(BaseSmallStatement): + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Continue": + return Continue(semicolon=visit_sentinel("semicolon", self.semicolon, visitor)) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + state.tokens.append("continue") + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Return(BaseSmallStatement): + value: Optional[BaseExpression] = None + + whitespace_after_return: Union[ + SimpleWhitespace, MaybeSentinel + ] = MaybeSentinel.DEFAULT + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + value = self.value + if value is not None: + whitespace_after_return = self.whitespace_after_return + has_no_gap = ( + not isinstance(whitespace_after_return, MaybeSentinel) + and len(whitespace_after_return.value) == 0 + ) + if has_no_gap and not value._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError("Must have at least one space after 'return'.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Return": + return Return( + whitespace_after_return=visit_sentinel( + "whitespace_after_return", self.whitespace_after_return, visitor + ), + value=visit_optional("value", self.value, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + value = self.value + + state.tokens.append("return") + + whitespace_after_return = self.whitespace_after_return + if isinstance(whitespace_after_return, MaybeSentinel): + if value is not None: + state.tokens.append(" ") + else: + whitespace_after_return._codegen(state) + + if value is not None: + value._codegen(state) + + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Expr(BaseSmallStatement): + """ + An expression used as a statement, where the result is unused and unassigned. + """ + + # The expression itself + value: BaseExpression + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Expr": + return Expr( + value=visit_required("value", self.value, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + self.value._codegen(state) + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +class _BaseSimpleStatement(CSTNode, ABC): + """ + A simple statement is a series of small statements joined together by semicolons. + + simple_stmt: small_stmt (';' small_stmt)* [';'] NEWLINE + + Whitespace between each small statement is owned by the small statements themselves. + """ + + body: Sequence[BaseSmallStatement] + # a NEWLINE token is actually part of simple_stmt's grammar + trailing_whitespace: TrailingWhitespace + + def _validate(self) -> None: + body = self.body + if len(body) == 0: + raise CSTValidationError( + "An empty StatementLine is useless, and should be removed." + ) + for small_stmt in body[:-1]: + if small_stmt.semicolon is None: + raise CSTValidationError( + "All but the last SmallStatement in a SimpleStatementLine or " + + "SimpleStatementSuite must have a trailing semicolon. Otherwise, " + + "there's no way to syntatically disambiguate each SmallStatement " + + "on the same line." + ) + + def _codegen(self, state: CodegenState) -> None: + body = self.body + laststmt = len(body) - 1 + for idx, stmt in enumerate(body): + stmt._codegen(state, default_semicolon=(idx != laststmt)) + self.trailing_whitespace._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class SimpleStatementLine(_BaseSimpleStatement): + """ + A simple statement that's part of an IndentedBlock or Module. A simple statement is + a series of small statements joined together by semicolons. + + This isn't differentiated from a SimpleStatementSuite in the grammar, but because a + SimpleStatementLine can own additional whitespace that a SimpleStatementSuite + doesn't have, we're differentiating it in the CST. + """ + + body: Sequence[BaseSmallStatement] + leading_lines: Sequence[EmptyLine] = () + trailing_whitespace: TrailingWhitespace = TrailingWhitespace() + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "SimpleStatementLine": + leading_lines = visit_sequence("leading_lines", self.leading_lines, visitor) + new_body = visit_sequence("body", self.body, visitor) + return SimpleStatementLine( + leading_lines=leading_lines, # hoisted above to preserve order + # replace the body with a pass statement if it's empty + body=(Pass(),) if len(new_body) == 0 else new_body, + trailing_whitespace=visit_required( + "trailing_whitespace", self.trailing_whitespace, visitor + ), + ) + + def _codegen(self, state: CodegenState) -> None: + for ll in self.leading_lines: + ll._codegen(state) + state.tokens.extend(state.indent) + _BaseSimpleStatement._codegen(self, state) + + +@add_slots +@dataclass(frozen=True) +class SimpleStatementSuite(_BaseSimpleStatement, BaseSuite): + """ + A simple statement that's used as a suite. A simple statement is a series of small + statements joined together by semicolons. A suite is the thing that follows the + colon in a compound statement. + + if test: + + This isn't differentiated from a SimpleStatementLine in the grammar, but because the + two classes need to track different whitespace, we're differentiating it in the CST. + """ + + body: Sequence[BaseSmallStatement] + leading_whitespace: SimpleWhitespace = SimpleWhitespace(" ") + trailing_whitespace: TrailingWhitespace = TrailingWhitespace() + + def _visit_and_replace_children( + self, visitor: CSTVisitor + ) -> "SimpleStatementSuite": + leading_whitespace = visit_required( + "leading_whitespace", self.leading_whitespace, visitor + ) + new_body = visit_sequence("body", self.body, visitor) + return SimpleStatementSuite( + leading_whitespace=leading_whitespace, # hoisted above to preserve order + # replace the body with a pass statement if it's empty + body=(Pass(),) if len(new_body) == 0 else new_body, + trailing_whitespace=visit_required( + "trailing_whitespace", self.trailing_whitespace, visitor + ), + ) + + def _codegen(self, state: CodegenState) -> None: + self.leading_whitespace._codegen(state) + _BaseSimpleStatement._codegen(self, state) + + +@add_slots +@dataclass(frozen=True) +class Else(CSTNode): + """ + An `else` clause that appears optionally after an `If`, `While`, `Try`, or `For` + statement. + + This node does not match `elif` clauses in `If` statements. It also does not match + the required `else` clause in an `if` expression (`a = if b else c`). + """ + + body: BaseSuite + leading_lines: Sequence[EmptyLine] = () + whitespace_before_colon: SimpleWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Else": + return Else( + leading_lines=visit_sequence("leading_lines", self.leading_lines, visitor), + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + body=visit_required("body", self.body, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + for ll in self.leading_lines: + ll._codegen(state) + state.tokens.extend(state.indent) + state.tokens.append("else") + self.whitespace_before_colon._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + + +class BaseCompoundStatement(CSTNode, ABC): + """ + > Compound statements contain (groups of) other statements; they affect or control + > the execution of those other statements in some way. In general, compound + > statements span multiple lines, although in simple incarnations a whole compound + > statement may be contained in one line. + + -- https://docs.python.org/3/reference/compound_stmts.html + """ + + body: BaseSuite + leading_lines: Sequence[EmptyLine] + + +@add_slots +@dataclass(frozen=True) +class If(BaseCompoundStatement): + """ + An `if` statement. `test` holds a single test expression. + + `elif` clauses don’t have a special representation in the AST, but rather appear as + extra `If` nodes within the `orelse` section of the previous one. + """ + + test: BaseExpression # TODO: should be a test_nocond + body: BaseSuite + # A value of orelse with the type of: + # - If signifies an elif block. + # - Else signifies an else block. + # - None signifies no else or elif block. + orelse: Union["If", Else, None] = None + + # Whitespace: + leading_lines: Sequence[EmptyLine] = () + whitespace_before_test: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_after_test: SimpleWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "If": + return If( + leading_lines=visit_sequence("leading_lines", self.leading_lines, visitor), + whitespace_before_test=visit_required( + "whitespace_before_test", self.whitespace_before_test, visitor + ), + test=visit_required("test", self.test, visitor), + whitespace_after_test=visit_required( + "whitespace_after_test", self.whitespace_after_test, visitor + ), + body=visit_required("body", self.body, visitor), + orelse=visit_optional("orelse", self.orelse, visitor), + ) + + def _codegen(self, state: CodegenState, is_elif: bool = False) -> None: + for ll in self.leading_lines: + ll._codegen(state) + state.tokens.extend(state.indent) + state.tokens.append("elif" if is_elif else "if") + self.whitespace_before_test._codegen(state) + self.test._codegen(state) + self.whitespace_after_test._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + orelse = self.orelse + if orelse is not None: + if isinstance(orelse, If): # special-case elif + orelse._codegen(state, is_elif=True) + else: # is an Else clause + orelse._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class IndentedBlock(BaseSuite): + """ + Represents a block of statements beginning with an INDENT token and ending in a + DEDENT token. Used as the body of compound statements, such as an if statement's + body. + + A common alternative to an IndentedBlock is a SimpleStatement, which can also be + used as a BaseSuite, meaning that it can be used as the body of many compound + statements. + """ + + body: Sequence[Union[SimpleStatementLine, BaseCompoundStatement]] + + # An IndentedBlock always occurs after a colon in a BaseCompoundStatement, so it + # owns the trailing whitespace for the compound statement's clause. + # + # if test: # IndentedBlock's header + # body + header: TrailingWhitespace = TrailingWhitespace() + + # A str represents a specific indentation. A None value uses the modules's default + # indentation. + # + # This is because indentation is allowed to be inconsistent across a file, just not + # ambiguously. + indent: Optional[str] = None + + # There may be some trailing comments or lines after the dedent. Statements own + # preceeding and same-line trailing comments, but not trailing lines, so it falls on + # IndentedBlock to own it. + footer: Sequence[EmptyLine] = () + + def _validate(self) -> None: + if len(self.body) == 0: + raise CSTValidationError( + "An indented block must have at least one StatementLine in the body." + ) + indent = self.indent + if indent is not None: + if len(indent) == 0: + raise CSTValidationError( + "An indented block must have a non-zero width indent." + ) + if _INDENT_WHITESPACE_RE.fullmatch(indent) is None: + raise CSTValidationError( + "An indent must be composed of only whitespace characters." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "IndentedBlock": + header = visit_required("header", self.header, visitor) + body = visit_sequence("body", self.body, visitor) + if len(body) == 0: + # replace the body with a pass statement if it's empty + body = (SimpleStatementLine((Pass(),)),) + return IndentedBlock( + header=header, + indent=self.indent, + body=body, + footer=visit_sequence("footer", self.footer, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + self.header._codegen(state) + + indent = self.indent + state.indent.append(state.default_indent if indent is None else indent) + + for stmt in self.body: + # IndentedBlock is responsible for adjusting the current indentation level, + # but its children are responsible for actually adding that indentation to + # the token list. + stmt._codegen(state) + + state.indent.pop() + + for f in self.footer: + f._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class AsName(CSTNode): + """ + An `as name` clause inside an `ExceptHandler`, `ImportAlias` or `WithItem` node. + """ + + # Identifier that the parent node will be aliased to. + name: Name # TODO: This should be Union[Name, Tuple, List] once we support those + + # Whitespace nodes + whitespace_before_as: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + whitespace_after_as: BaseParenthesizableWhitespace = SimpleWhitespace(" ") + + def _validate(self) -> None: + if self.whitespace_after_as.empty: + raise CSTValidationError( + "There must be at least one space between 'as' and name." + ) + if self.whitespace_before_as.empty: + raise CSTValidationError("There must be at least one space before 'as'.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "AsName": + return AsName( + whitespace_before_as=visit_required( + "whitespace_before_as", self.whitespace_before_as, visitor + ), + name=visit_required("name", self.name, visitor), + whitespace_after_as=visit_required( + "whitespace_after_as", self.whitespace_after_as, visitor + ), + ) + + def _codegen(self, state: CodegenState) -> None: + self.whitespace_before_as._codegen(state) + state.tokens.append("as") + self.whitespace_after_as._codegen(state) + self.name._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class ExceptHandler(CSTNode): + """ + An `except` clause that appears optionally after a `Try` statement. + """ + + # The body of the except + body: BaseSuite + + # The type of exception this catches. Can be a tuple in some cases, + # or none for an empty exception. + type: Optional[BaseExpression] = None + + # The name that a caught exception is assigned to + name: Optional[AsName] = None + + # Whitespace nodes + leading_lines: Sequence[EmptyLine] = () + whitespace_after_except: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_before_colon: SimpleWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + if self.type is None and self.name is not None: + raise CSTValidationError("Cannot have a name for an empty type.") + if self.name is not None and not isinstance(self.name.name, Name): + raise CSTValidationError( + "Must use a Name node for AsName name inside ExceptHandler." + ) + if self.type is not None and len(self.whitespace_after_except.value) == 0: + raise CSTValidationError( + "Must have at least one space after except when ExceptHandler has a type." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ExceptHandler": + return ExceptHandler( + whitespace_after_except=visit_required( + "whitespace_after_except", self.whitespace_after_except, visitor + ), + type=visit_optional("type", self.type, visitor), + name=visit_optional("name", self.name, visitor), + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + body=visit_required("body", self.body, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.extend(state.indent) + state.tokens.append("except") + self.whitespace_after_except._codegen(state) + typenode = self.type + if typenode is not None: + typenode._codegen(state) + namenode = self.name + if namenode is not None: + namenode._codegen(state) + self.whitespace_before_colon._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Finally(CSTNode): + """ + A `finally` clause that appears optionally after a `Try` statement. + """ + + body: BaseSuite + leading_lines: Sequence[EmptyLine] = () + whitespace_before_colon: SimpleWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Finally": + return Finally( + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + body=visit_required("body", self.body, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.extend(state.indent) + state.tokens.append("finally") + self.whitespace_before_colon._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Try(BaseCompoundStatement): + """ + A `try` statement. + """ + + # The suite that is wrapped with a try statement. + body: BaseSuite + + # A list of zero or more exception handlers. + handlers: Sequence[ExceptHandler] = () + + # An optional else case. + orelse: Optional[Else] = None + + # An optional finally case. + finalbody: Optional[Finally] = None + + # Whitespace + leading_lines: Sequence[EmptyLine] = () + whitespace_before_colon: SimpleWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + if len(self.handlers) == 0 and self.finalbody is None: + raise CSTValidationError( + "A Try statement must have at least one ExceptHandler or Finally" + ) + if len(self.handlers) == 0 and self.orelse is not None: + raise CSTValidationError( + "A Try statement must have at least one ExceptHandler in order " + + "to have an Else" + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Try": + return Try( + leading_lines=visit_sequence("leading_lines", self.leading_lines, visitor), + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + body=visit_required("body", self.body, visitor), + handlers=visit_sequence("handlers", self.handlers, visitor), + orelse=visit_optional("orelse", self.orelse, visitor), + finalbody=visit_optional("finalbody", self.finalbody, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + for ll in self.leading_lines: + ll._codegen(state) + state.tokens.extend(state.indent) + state.tokens.append("try") + self.whitespace_before_colon._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + for handler in self.handlers: + handler._codegen(state) + orelse = self.orelse + if orelse is not None: + orelse._codegen(state) + finalbody = self.finalbody + if finalbody is not None: + finalbody._codegen(state) + + +@dataclass(frozen=True) +class ImportAlias(CSTNode): + """ + An import, with an optional AsName. + """ + + # Name or Attribute node representing the module + name: Union[Attribute, Name] + + # Alias if it exists + asname: Optional[AsName] = None + + # This is optional for the last ImportAlias in a Import or ImportFrom, but all + # other ImportAliases inside an import must contain a comma to disambiguate + # multiple small statements on the same line. + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + if self.asname is not None and not isinstance(self.asname.name, Name): + raise CSTValidationError( + "Must use a Name node for AsName name inside ImportAlias." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ImportAlias": + return ImportAlias( + name=visit_required("name", self.name, visitor), + asname=visit_optional("asname", self.asname, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + ) + + def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + self.name._codegen(state) + asname = self.asname + if asname is not None: + asname._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + state.tokens.append(", ") + elif isinstance(comma, Comma): + comma._codegen(state) + + +@dataclass(frozen=True) +class Import(BaseSmallStatement): + """ + An `import` statement. + """ + + # One or more names that are being imported + names: Sequence[ImportAlias] + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Whitespace + whitespace_after_import: SimpleWhitespace = SimpleWhitespace(" ") + + def _validate(self) -> None: + if len(self.names) == 0: + raise CSTValidationError( + "An ImportStatement must have at least one ImportAlias" + ) + if isinstance(self.names[-1].comma, Comma): + raise CSTValidationError( + "An ImportStatement does not allow a trailing comma" + ) + if len(self.whitespace_after_import.value) == 0: + raise CSTValidationError("Must have at least one space after import.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Import": + return Import( + whitespace_after_import=visit_required( + "whitespace_after_import", self.whitespace_after_import, visitor + ), + names=visit_sequence("names", self.names, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + state.tokens.append("import") + self.whitespace_after_import._codegen(state) + lastname = len(self.names) - 1 + for i, name in enumerate(self.names): + name._codegen(state, default_comma=(i != lastname)) + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@dataclass(frozen=True) +class ImportFrom(BaseSmallStatement): + """ + A `from x import y` statement. + """ + + # Name or Attribute node representing the module + module: Optional[Union[Attribute, Name]] + + # One or more names that are being imported from the module + names: Union[Sequence[ImportAlias], ImportStar] + + # Sequence of Dot nodes indicating relative import. + relative: Sequence[Dot] = () + + # Optional open parenthesis for multi-line import continuation. + lpar: Optional[LeftParen] = None + + # Optional open parenthesis for multi-line import continuation. + rpar: Optional[RightParen] = None + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Whitespace nodes owned by ImportFrom. + whitespace_after_from: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_before_import: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_after_import: SimpleWhitespace = SimpleWhitespace(" ") + + def _validate_module(self) -> None: + if self.module is None and len(self.relative) == 0: + raise CSTValidationError( + "Must have a module specified if there is no relative import." + ) + + def _validate_names(self) -> None: + if isinstance(self.names, Sequence): + if len(self.names) == 0: + raise CSTValidationError( + "An ImportFrom must have at least one ImportAlias" + ) + for name in self.names[:-1]: + if name.comma is None: + raise CSTValidationError("Non-final ImportAliases require a comma") + if self.lpar is not None and self.rpar is None: + raise CSTValidationError("Cannot have left paren without right paren.") + if self.lpar is None and self.rpar is not None: + raise CSTValidationError("Cannot have right paren without left paren.") + if isinstance(self.names, ImportStar): + if self.lpar is not None or self.rpar is not None: + raise CSTValidationError( + "An ImportFrom using ImportStar cannot have parens" + ) + + def _validate_whitespace(self) -> None: + if len(self.whitespace_after_from.value) == 0: + raise CSTValidationError("Must have at least one space after from.") + if len(self.whitespace_before_import.value) == 0: + raise CSTValidationError("Must have at least one space before import.") + if len(self.whitespace_after_import.value) == 0 and self.lpar is None: + raise CSTValidationError("Must have at least one space after import.") + + def _validate(self) -> None: + self._validate_module() + self._validate_names() + self._validate_whitespace() + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ImportFrom": + names = self.names + return ImportFrom( + whitespace_after_from=visit_required( + "whitespace_after_from", self.whitespace_after_from, visitor + ), + relative=visit_sequence("relative", self.relative, visitor), + module=visit_optional("module", self.module, visitor), + whitespace_before_import=visit_required( + "whitespace_before_import", self.whitespace_before_import, visitor + ), + whitespace_after_import=visit_required( + "whitespace_after_import", self.whitespace_after_import, visitor + ), + lpar=visit_optional("lpar", self.lpar, visitor), + names=( + visit_required("names", names, visitor) + if isinstance(names, ImportStar) + else visit_sequence("names", names, visitor) + ), + rpar=visit_optional("rpar", self.rpar, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + state.tokens.append("from") + self.whitespace_after_from._codegen(state) + for dot in self.relative: + dot._codegen(state) + module = self.module + if module is not None: + module._codegen(state) + self.whitespace_before_import._codegen(state) + state.tokens.append("import") + self.whitespace_after_import._codegen(state) + lpar = self.lpar + if lpar is not None: + lpar._codegen(state) + if isinstance(self.names, Sequence): + lastname = len(self.names) - 1 + for i, name in enumerate(self.names): + name._codegen(state, default_comma=(i != lastname)) + if isinstance(self.names, ImportStar): + self.names._codegen(state) + rpar = self.rpar + if rpar is not None: + rpar._codegen(state) + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@dataclass(frozen=True) +class AssignTarget(CSTNode): + """ + A target for an assignment. Owns the equals. + """ + + # The target being assigned to. + target: BaseAssignTargetExpression + + # Whitespace + whitespace_before_equal: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_after_equal: SimpleWhitespace = SimpleWhitespace(" ") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "AssignTarget": + return AssignTarget( + target=visit_required("target", self.target, visitor), + whitespace_before_equal=visit_required( + "whitespace_before_equal", self.whitespace_before_equal, visitor + ), + whitespace_after_equal=visit_required( + "whitespace_after_equal", self.whitespace_after_equal, visitor + ), + ) + + def _codegen(self, state: CodegenState) -> None: + self.target._codegen(state) + self.whitespace_before_equal._codegen(state) + state.tokens.append("=") + self.whitespace_after_equal._codegen(state) + + +@dataclass(frozen=True) +class Assign(BaseSmallStatement): + """ + An assignment statement. + """ + + # One or more targets that are being assigned to. + targets: Sequence[AssignTarget] + + # The expression being assigned to the targets. + value: BaseExpression + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + if len(self.targets) == 0: + raise CSTValidationError( + "An Assign statement must have at least one AssignTarget" + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Assign": + return Assign( + targets=visit_sequence("targets", self.targets, visitor), + value=visit_required("value", self.value, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + for target in self.targets: + target._codegen(state) + self.value._codegen(state) + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@dataclass(frozen=True) +class AnnAssign(BaseSmallStatement): + """ + An assignment statement. + """ + + # One or more targets that are being assigned to. + target: BaseExpression + + # The annotation for the target. + annotation: Annotation + + # The optional expression being assigned to the target. + value: Optional[BaseExpression] = None + + # The equals sign used to denote assignment if there is a value. + equal: Union[AssignEqual, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + if ( + isinstance(self.annotation.indicator, str) + and self.annotation.indicator != ":" + ): + raise CSTValidationError("An Annotation must be denoted with a ':'.") + if self.value is None and isinstance(self.equal, AssignEqual): + raise CSTValidationError( + "Must have a value when specifying an AssignEqual." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "AnnAssign": + return AnnAssign( + target=visit_required("target", self.target, visitor), + annotation=visit_required("annotation", self.annotation, visitor), + equal=visit_sentinel("equal", self.equal, visitor), + value=visit_optional("value", self.value, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + self.target._codegen(state) + self.annotation._codegen(state, default_indicator=":") + equal = self.equal + if equal is MaybeSentinel.DEFAULT and self.value is not None: + state.tokens.append(" = ") + elif isinstance(equal, AssignEqual): + equal._codegen(state) + value = self.value + if value is not None: + value._codegen(state) + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@dataclass(frozen=True) +class AugAssign(BaseSmallStatement): + """ + An augmented assignment statement. + """ + + # Target that is being assigned to + target: BaseExpression + + # The augmented assignment operation being performed + operator: BaseAugOp + + # The being assigned to the target. + value: BaseExpression + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "AugAssign": + return AugAssign( + target=visit_required("target", self.target, visitor), + operator=visit_required("operator", self.operator, visitor), + value=visit_required("value", self.value, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + self.target._codegen(state) + self.operator._codegen(state) + self.value._codegen(state) + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Asynchronous(CSTNode): + """ + Used by asynchronous function definitions, as well as async for and async with + """ + + whitespace_after: SimpleWhitespace = SimpleWhitespace(" ") + + def _validate(self) -> None: + if len(self.whitespace_after.value) < 1: + raise CSTValidationError("Must have at least one space after Asynchronous.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Asynchronous": + return Asynchronous( + whitespace_after=visit_required( + "whitespace_after", self.whitespace_after, visitor + ) + ) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append("async") + self.whitespace_after._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Decorator(CSTNode): + """ + A single decorator that decorates a FunctionDef or a ClassDef. + """ + + # The decorator that will return a new function wrapping the parent + # of this decorator. + decorator: Union[Name, Attribute, Call] + + # Line comments and empty lines before this decorator. The parent FunctionDef + # or ClassDef node owns leading lines before the comments of the first + # decorator so that if the first decorator is removed, spacing is preserved. + leading_lines: Sequence[EmptyLine] = () + + # Whitespace between various tokens making up the decorator + whitespace_after_at: SimpleWhitespace = SimpleWhitespace("") + + # Whitespace following the decorator before the next line + trailing_whitespace: TrailingWhitespace = TrailingWhitespace() + + def _validate(self) -> None: + if len(self.decorator.lpar) > 0 or len(self.decorator.rpar) > 0: + raise CSTValidationError( + "Cannot have parens around decorator in a Decorator." + ) + if isinstance(self.decorator, Call) and not isinstance( + self.decorator.func, (BaseAtom, Attribute) + ): + raise CSTValidationError( + "Decorator call function must be an atom or attribute." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Decorator": + return Decorator( + leading_lines=visit_sequence("leading_lines", self.leading_lines, visitor), + whitespace_after_at=visit_required( + "whitespace_after_at", self.whitespace_after_at, visitor + ), + decorator=visit_required("decorator", self.decorator, visitor), + trailing_whitespace=visit_required( + "trailing_whitespace", self.trailing_whitespace, visitor + ), + ) + + def _codegen(self, state: CodegenState) -> None: + for ll in self.leading_lines: + ll._codegen(state) + state.tokens.extend(state.indent) + state.tokens.append("@") + self.whitespace_after_at._codegen(state) + self.decorator._codegen(state) + self.trailing_whitespace._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class FunctionDef(BaseCompoundStatement): + """ + A function definition. + """ + + # The function name. + name: Name + + # The function parameters. Present even if there are no params. + params: Parameters + + # The function body. + body: BaseSuite + + # List of decorators applied to this function. + decorators: Sequence[Decorator] = () + + # An optional return type annotation + returns: Optional[Annotation] = None + + # Optional async modifier. + asynchronous: Optional[Asynchronous] = None + + # Leading empty lines and comments before the first decorator. We + # assume any comments before the first decorator are owned by the + # function definition itself. If there are no decorators, this will + # still contain all of the empty lines and comments before the + # function definition. + leading_lines: Sequence[EmptyLine] = () + + # Empty lines and comments between the final decorator and the + # FunctionDef node. In the case of no decorators, this will be empty. + lines_after_decorators: Sequence[EmptyLine] = () + + # Whitespace between various tokens making up the functiondef + whitespace_after_def: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_after_name: SimpleWhitespace = SimpleWhitespace("") + whitespace_before_params: SimpleWhitespace = SimpleWhitespace("") + whitespace_before_colon: SimpleWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: + raise CSTValidationError("Cannot have parens around Name in a FunctionDef.") + if len(self.whitespace_after_def.value) == 0: + raise CSTValidationError( + "There must be at least one space between 'def' and name." + ) + if ( + self.returns is not None + and isinstance(self.returns.indicator, str) + and self.returns.indicator != "->" + ): + raise CSTValidationError("A return Annotation must be denoted with a '->'.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "FunctionDef": + return FunctionDef( + leading_lines=visit_sequence("leading_lines", self.leading_lines, visitor), + decorators=visit_sequence("decorators", self.decorators, visitor), + lines_after_decorators=visit_sequence( + "lines_after_decorators", self.lines_after_decorators, visitor + ), + asynchronous=visit_optional("asynchronous", self.asynchronous, visitor), + whitespace_after_def=visit_required( + "whitespace_after_def", self.whitespace_after_def, visitor + ), + name=visit_required("name", self.name, visitor), + whitespace_after_name=visit_required( + "whitespace_after_name", self.whitespace_after_name, visitor + ), + whitespace_before_params=visit_required( + "whitespace_before_params", self.whitespace_before_params, visitor + ), + params=visit_required("params", self.params, visitor), + returns=visit_optional("returns", self.returns, visitor), + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + body=visit_required("body", self.body, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + for ll in self.leading_lines: + ll._codegen(state) + for decorator in self.decorators: + decorator._codegen(state) + for lad in self.lines_after_decorators: + lad._codegen(state) + state.tokens.extend(state.indent) + asynchronous = self.asynchronous + if asynchronous is not None: + asynchronous._codegen(state) + state.tokens.append("def") + self.whitespace_after_def._codegen(state) + self.name._codegen(state) + self.whitespace_after_name._codegen(state) + state.tokens.append("(") + self.whitespace_before_params._codegen(state) + self.params._codegen(state) + state.tokens.append(")") + returns = self.returns + if returns is not None: + returns._codegen(state, default_indicator="->") + self.whitespace_before_colon._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class ClassDef(BaseCompoundStatement): + """ + A class definition. + """ + + # The class name. + name: Name + + # The class body. + body: BaseSuite + + # The base classes this class inherits from + bases: Sequence[Arg] = () + + # Any keywords, such as "metaclass" + keywords: Sequence[Arg] = () + + # List of decorators applied to this function. + decorators: Sequence[Decorator] = () + + # Optional open parenthesis used when there are bases or keywords. + lpar: Union[LeftParen, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Optional close parenthesis used when there are bases or keywords. + rpar: Union[RightParen, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Leading empty lines and comments before the first decorator. We + # assume any comments before the first decorator are owned by the + # class definition itself. If there are no decorators, this will + # still contain all of the empty lines and comments before the + # class definition. + leading_lines: Sequence[EmptyLine] = () + + # Empty lines and comments between the final decorator and the + # ClassDef node. In the case of no decorators, this will be empty. + lines_after_decorators: Sequence[EmptyLine] = () + + # Whitespace between various tokens making up the functiondef + whitespace_after_class: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_after_name: SimpleWhitespace = SimpleWhitespace("") + whitespace_before_colon: SimpleWhitespace = SimpleWhitespace("") + + def _validate_whitespace(self) -> None: + if len(self.whitespace_after_class.value) == 0: + raise CSTValidationError( + "There must be at least one space between 'class' and name." + ) + + def _validate_parens(self) -> None: + if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: + raise CSTValidationError("Cannot have parens around Name in a ClassDef.") + if isinstance(self.lpar, MaybeSentinel) and isinstance(self.rpar, RightParen): + raise CSTValidationError( + "Do not mix concrete LeftParen/RightParen with MaybeSentinel." + ) + if isinstance(self.lpar, LeftParen) and isinstance(self.rpar, MaybeSentinel): + raise CSTValidationError( + "Do not mix concrete LeftParen/RightParen with MaybeSentinel." + ) + + def _validate_args(self) -> None: + if any((arg.keyword is not None) for arg in self.bases): + raise CSTValidationError("Bases must be arguments without keywords.") + if any((arg.keyword is None and arg.star != "**") for arg in self.keywords): + raise CSTValidationError( + "Keywords must be arguments with keywords or dictionary expansions." + ) + + def _validate(self) -> None: + self._validate_whitespace() + self._validate_parens() + self._validate_args() + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ClassDef": + return ClassDef( + leading_lines=visit_sequence("leading_lines", self.leading_lines, visitor), + decorators=visit_sequence("decorators", self.decorators, visitor), + lines_after_decorators=visit_sequence( + "lines_after_decorators", self.lines_after_decorators, visitor + ), + whitespace_after_class=visit_required( + "whitespace_after_class", self.whitespace_after_class, visitor + ), + name=visit_required("name", self.name, visitor), + whitespace_after_name=visit_required( + "whitespace_after_name", self.whitespace_after_name, visitor + ), + lpar=visit_sentinel("lpar", self.lpar, visitor), + bases=visit_sequence("bases", self.bases, visitor), + keywords=visit_sequence("keywords", self.keywords, visitor), + rpar=visit_sentinel("rpar", self.rpar, visitor), + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + body=visit_required("body", self.body, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + for ll in self.leading_lines: + ll._codegen(state) + for decorator in self.decorators: + decorator._codegen(state) + for lad in self.lines_after_decorators: + lad._codegen(state) + state.tokens.extend(state.indent) + state.tokens.append("class") + self.whitespace_after_class._codegen(state) + self.name._codegen(state) + self.whitespace_after_name._codegen(state) + lpar = self.lpar + if isinstance(lpar, MaybeSentinel): + if self.bases or self.keywords: + state.tokens.append("(") + elif isinstance(lpar, LeftParen): + lpar._codegen(state) + args = [*self.bases, *self.keywords] + last_arg = len(args) - 1 + for i, arg in enumerate(args): + arg._codegen(state, default_comma=(i != last_arg)) + rpar = self.rpar + if isinstance(rpar, MaybeSentinel): + if self.bases or self.keywords: + state.tokens.append(")") + elif isinstance(rpar, RightParen): + rpar._codegen(state) + self.whitespace_before_colon._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + + +@dataclass(frozen=True) +class WithItem(CSTNode): + """ + A single context manager in a with block, with an optional variable name. + """ + + # Expression that evaluates to a context manager. + item: BaseExpression + + # Variable to assign the context manager to. + asname: Optional[AsName] = None + + # This is forbidden for the last WithItem in a With, but all other WithItems + # inside a with block must contain a comma to separate them. + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "WithItem": + return WithItem( + item=visit_required("item", self.item, visitor), + asname=visit_optional("asname", self.asname, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + ) + + def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + self.item._codegen(state) + asname = self.asname + if asname is not None: + asname._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + state.tokens.append(", ") + elif isinstance(comma, Comma): + comma._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class With(BaseCompoundStatement): + """ + A `with` statement. + """ + + # A list of one or more WithItems. + items: Sequence[WithItem] + + # The suite that is wrapped with this statement. + body: BaseSuite + + # Optional async modifier. + asynchronous: Optional[Asynchronous] = None + + # Whitespace + leading_lines: Sequence[EmptyLine] = () + whitespace_after_with: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_before_colon: SimpleWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + if len(self.items) == 0: + raise CSTValidationError( + "A With statement must have at least one WithItem." + ) + if self.items[-1].comma != MaybeSentinel.DEFAULT: + raise CSTValidationError( + "The last WithItem in a With cannot have a trailing comma." + ) + has_no_gap = len(self.whitespace_after_with.value) == 0 + if has_no_gap and not self.items[0].item._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError("Must have at least one space after with keyword.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "With": + return With( + leading_lines=visit_sequence("leading_lines", self.leading_lines, visitor), + asynchronous=visit_optional("asynchronous", self.asynchronous, visitor), + whitespace_after_with=visit_required( + "whitespace_after_with", self.whitespace_after_with, visitor + ), + items=visit_sequence("items", self.items, visitor), + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + body=visit_required("body", self.body, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + for ll in self.leading_lines: + ll._codegen(state) + state.tokens.extend(state.indent) + asynchronous = self.asynchronous + if asynchronous is not None: + asynchronous._codegen(state) + state.tokens.append("with") + self.whitespace_after_with._codegen(state) + last_item = len(self.items) - 1 + for i, item in enumerate(self.items): + item._codegen(state, default_comma=(i != last_item)) + self.whitespace_before_colon._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class For(BaseCompoundStatement): + """ + A `for` statement. + """ + + # The target of the iterator in the for statement. + target: Name # TODO: Should be a Union[Name, Tuple, List] once we support this. + + # The iterable expression we will loop over. + iter: BaseExpression + + # The suite that is wrapped with this statement. + body: BaseSuite + + # An optional else case. + orelse: Optional[Else] = None + + # Optional async modifier. + asynchronous: Optional[Asynchronous] = None + + # Whitespace + leading_lines: Sequence[EmptyLine] = () + whitespace_after_for: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_before_in: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_after_in: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_before_colon: SimpleWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + has_no_gap = len(self.whitespace_after_for.value) == 0 + if has_no_gap and not self.target._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError( + "Must have at least one space after 'for' keyword." + ) + has_no_gap = len(self.whitespace_before_in.value) == 0 + if has_no_gap and not self.target._safe_to_use_with_word_operator( + ExpressionPosition.LEFT + ): + raise CSTValidationError( + "Must have at least one space before 'in' keyword." + ) + has_no_gap = len(self.whitespace_after_in.value) == 0 + if has_no_gap and not self.iter._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError("Must have at least one space after 'in' keyword.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "For": + return For( + leading_lines=visit_sequence("leading_lines", self.leading_lines, visitor), + asynchronous=visit_optional("asynchronous", self.asynchronous, visitor), + whitespace_after_for=visit_required( + "whitespace_after_for", self.whitespace_after_for, visitor + ), + target=visit_required("target", self.target, visitor), + whitespace_before_in=visit_required( + "whitespace_before_in", self.whitespace_before_in, visitor + ), + whitespace_after_in=visit_required( + "whitespace_after_in", self.whitespace_after_in, visitor + ), + iter=visit_required("iter", self.iter, visitor), + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + body=visit_required("body", self.body, visitor), + orelse=visit_optional("orelse", self.orelse, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + for ll in self.leading_lines: + ll._codegen(state) + state.tokens.extend(state.indent) + asynchronous = self.asynchronous + if asynchronous is not None: + asynchronous._codegen(state) + state.tokens.append("for") + self.whitespace_after_for._codegen(state) + self.target._codegen(state) + self.whitespace_before_in._codegen(state) + state.tokens.append("in") + self.whitespace_after_in._codegen(state) + self.iter._codegen(state) + self.whitespace_before_colon._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + orelse = self.orelse + if orelse is not None: + orelse._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class While(BaseCompoundStatement): + """ + A `while` statement. + """ + + # The test we will loop against. + test: BaseExpression + + # The suite that is wrapped with this statement. + body: BaseSuite + + # An optional else case. + orelse: Optional[Else] = None + + # Whitespace + leading_lines: Sequence[EmptyLine] = () + whitespace_after_while: SimpleWhitespace = SimpleWhitespace(" ") + whitespace_before_colon: SimpleWhitespace = SimpleWhitespace("") + + def _validate(self) -> None: + has_no_gap = len(self.whitespace_after_while.value) == 0 + if has_no_gap and not self.test._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError( + "Must have at least one space after 'while' keyword." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "While": + return While( + leading_lines=visit_sequence("leading_lines", self.leading_lines, visitor), + whitespace_after_while=visit_required( + "whitespace_after_while", self.whitespace_after_while, visitor + ), + test=visit_required("test", self.test, visitor), + whitespace_before_colon=visit_required( + "whitespace_before_colon", self.whitespace_before_colon, visitor + ), + body=visit_required("body", self.body, visitor), + orelse=visit_optional("orelse", self.orelse, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + for ll in self.leading_lines: + ll._codegen(state) + state.tokens.extend(state.indent) + state.tokens.append("while") + self.whitespace_after_while._codegen(state) + self.test._codegen(state) + self.whitespace_before_colon._codegen(state) + state.tokens.append(":") + self.body._codegen(state) + orelse = self.orelse + if orelse is not None: + orelse._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Raise(BaseSmallStatement): + exc: Optional[BaseExpression] = None + + cause: Optional[From] = None + + whitespace_after_raise: Union[ + SimpleWhitespace, MaybeSentinel + ] = MaybeSentinel.DEFAULT + + # Optional semicolon when this is used in a statement line + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + # Validate correct construction + if self.exc is None and self.cause is not None: + raise CSTValidationError( + "Must have an 'exc' when specifying 'clause'. on Raise." + ) + + # Validate spacing between "raise" and "exc" + exc = self.exc + if exc is not None: + whitespace_after_raise = self.whitespace_after_raise + has_no_gap = ( + not isinstance(whitespace_after_raise, MaybeSentinel) + and len(whitespace_after_raise.value) == 0 + ) + if has_no_gap and not exc._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError("Must have at least one space after 'raise'.") + + # Validate spacing between "exc" and "from" + cause = self.cause + if exc is not None and cause is not None: + whitespace_before_from = self.cause.whitespace_before_from + has_no_gap = ( + isinstance(whitespace_before_from, SimpleWhitespace) + and len(whitespace_before_from.value) == 0 + ) + if has_no_gap and not exc._safe_to_use_with_word_operator( + ExpressionPosition.LEFT + ): + raise CSTValidationError("Must have at least one space before 'from'.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Raise": + return Raise( + whitespace_after_raise=visit_sentinel( + "whitespace_after_raise", self.whitespace_after_raise, visitor + ), + exc=visit_optional("exc", self.exc, visitor), + cause=visit_optional("cause", self.cause, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + exc = self.exc + cause = self.cause + + state.tokens.append("raise") + + whitespace_after_raise = self.whitespace_after_raise + if isinstance(whitespace_after_raise, MaybeSentinel): + if exc is not None: + state.tokens.append(" ") + else: + whitespace_after_raise._codegen(state) + + if exc is not None: + exc._codegen(state) + if cause is not None: + cause._codegen(state, default_space=" ") + + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Assert(BaseSmallStatement): + """ + An assert statement such as "assert x > 5" or "assert x > 5, 'Uh oh!'" + """ + + # The test we are going to assert on. + test: BaseExpression + + # The optional message to display if the assert fails. + msg: Optional[BaseExpression] = None + + # A comma separating test and message, if there is a message. + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + # Whitespace nodes. + whitespace_after_assert: SimpleWhitespace = SimpleWhitespace(" ") + + # Optional semicolon when this is used in a statement line. + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + # Validate whitespace + has_no_gap = len(self.whitespace_after_assert.value) == 0 + if has_no_gap and not self.test._safe_to_use_with_word_operator( + ExpressionPosition.RIGHT + ): + raise CSTValidationError("Must have at least one space after 'assert'.") + + # Validate comma rules + if self.msg is None and isinstance(self.comma, Comma): + raise CSTValidationError("Cannot have trailing comma after 'test'.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Assert": + return Assert( + whitespace_after_assert=visit_required( + "whitespace_after_assert", self.whitespace_after_assert, visitor + ), + test=visit_required("test", self.test, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + msg=visit_optional("msg", self.msg, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + state.tokens.append("assert") + self.whitespace_after_assert._codegen(state) + self.test._codegen(state) + + comma = self.comma + msg = self.msg + if isinstance(comma, MaybeSentinel): + if msg is not None: + state.tokens.append(", ") + else: + comma._codegen(state) + if msg is not None: + msg._codegen(state) + + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@dataclass(frozen=True) +class NameItem(CSTNode): + """ + A single identifier name inside a Global or Nonlocal statement. + """ + + # Identifier name. + name: Name + + # This is forbidden for the last NameItem in a Global/Nonlocal, but all other + # NameItems inside a with block must contain a comma to separate them. + comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + # No parens around names here + if len(self.name.lpar) > 0 or len(self.name.rpar) > 0: + raise CSTValidationError("Cannot have parens around names in NameItem.") + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "NameItem": + return NameItem( + name=visit_required("name", self.name, visitor), + comma=visit_sentinel("comma", self.comma, visitor), + ) + + def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + self.name._codegen(state) + comma = self.comma + if comma is MaybeSentinel.DEFAULT and default_comma: + state.tokens.append(", ") + elif isinstance(comma, Comma): + comma._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Global(BaseSmallStatement): + """ + A `global` statement. + """ + + # A list of one or more NameItems. + names: Sequence[NameItem] + + # Whitespace + whitespace_after_global: SimpleWhitespace = SimpleWhitespace(" ") + + # Optional semicolon when this is used in a statement line. + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + if len(self.names) == 0: + raise CSTValidationError( + "A Global statement must have at least one NameItem." + ) + if self.names[-1].comma != MaybeSentinel.DEFAULT: + raise CSTValidationError( + "The last NameItem in a Global cannot have a trailing comma." + ) + if len(self.whitespace_after_global.value) == 0: + raise CSTValidationError( + "Must have at least one space after 'global' keyword." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Global": + return Global( + whitespace_after_global=visit_required( + "whitespace_after_global", self.whitespace_after_global, visitor + ), + names=visit_sequence("names", self.names, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + state.tokens.append("global") + self.whitespace_after_global._codegen(state) + last_name = len(self.names) - 1 + for i, name in enumerate(self.names): + name._codegen(state, default_comma=(i != last_name)) + + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class Nonlocal(BaseSmallStatement): + """ + A `nonlocal` statement. + """ + + # A list of one or more NameItems. + names: Sequence[NameItem] + + # Whitespace + whitespace_after_nonlocal: SimpleWhitespace = SimpleWhitespace(" ") + + # Optional semicolon when this is used in a statement line. + semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT + + def _validate(self) -> None: + if len(self.names) == 0: + raise CSTValidationError( + "A Nonlocal statement must have at least one NameItem." + ) + if self.names[-1].comma != MaybeSentinel.DEFAULT: + raise CSTValidationError( + "The last NameItem in a Nonlocal cannot have a trailing comma." + ) + if len(self.whitespace_after_nonlocal.value) == 0: + raise CSTValidationError( + "Must have at least one space after 'nonlocal' keyword." + ) + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Nonlocal": + return Nonlocal( + whitespace_after_nonlocal=visit_required( + "whitespace_after_nonlocal", self.whitespace_after_nonlocal, visitor + ), + names=visit_sequence("names", self.names, visitor), + semicolon=visit_sentinel("semicolon", self.semicolon, visitor), + ) + + def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + state.tokens.append("nonlocal") + self.whitespace_after_nonlocal._codegen(state) + last_name = len(self.names) - 1 + for i, name in enumerate(self.names): + name._codegen(state, default_comma=(i != last_name)) + + semicolon = self.semicolon + if isinstance(semicolon, MaybeSentinel): + if default_semicolon: + state.tokens.append("; ") + elif isinstance(semicolon, Semicolon): + semicolon._codegen(state) diff --git a/libcst/nodes/_whitespace.py b/libcst/nodes/_whitespace.py new file mode 100644 index 00000000..e7ffc4eb --- /dev/null +++ b/libcst/nodes/_whitespace.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Pattern, Sequence + +from libcst._add_slots import add_slots +from libcst._base_visitor import CSTVisitor +from libcst.nodes._base import BaseLeaf, BaseValueToken, CSTNode, CSTValidationError +from libcst.nodes._internal import ( + CodegenState, + visit_optional, + visit_required, + visit_sequence, +) + + +# SimpleWhitespace includes continuation characters, which must be followed immediately +# by a newline. SimpleWhitespace does not include other kinds of newlines, because those +# may have semantic significance. +SIMPLE_WHITESPACE_RE: Pattern[str] = re.compile(r"([ \f\t]|\\(\r\n?|\n))*", re.UNICODE) +NEWLINE_RE: Pattern[str] = re.compile(r"\r\n?|\n", re.UNICODE) +COMMENT_RE: Pattern[str] = re.compile(r"#[^\r\n]*", re.UNICODE) + + +class BaseParenthesizableWhitespace(CSTNode, ABC): + """ + This is the kind of whitespace you might see inside the body of a statement or + expression between two tokens. This is the most common type of whitespace. + + The list of allowed characters in a whitespace depends on whether it is found + inside a parentesized expression or not. This class allows nodes which can be + found inside or outside a (), [] or {} section to accept either whitespace + form. + + https://docs.python.org/3/reference/lexical_analysis.html#implicit-line-joining + + ParenthesizableWhitespace may contain a backslash character (`\`), when used as a + line-continuation character. While the continuation character isn't technically + "whitespace", it serves the same purpose. + + ParenthesizableWhitespace is often non-semantic (optional), but in cases where whitespace + solves a grammar ambiguity between tokens (e.g. `if test`, versus `iftest`), it has + some semantic value. + """ + + # TODO: Should we somehow differentiate places where we require non-zero whitespace + # with a separate type? + + @property + @abstractmethod + def empty(self) -> bool: + ... + + +@add_slots +@dataclass(frozen=True) +class SimpleWhitespace(BaseParenthesizableWhitespace, BaseValueToken): + + value: str + + def _validate(self) -> None: + if SIMPLE_WHITESPACE_RE.fullmatch(self.value) is None: + raise CSTValidationError( + f"Got non-whitespace value for whitespace node: {repr(self.value)}" + ) + + @property + def empty(self) -> bool: + return len(self.value) == 0 + + +@add_slots +@dataclass(frozen=True) +class Newline(BaseLeaf): + """ + Represents the newline that ends an EmptyLine or a statement (as part of + TrailingWhitespace). + + Other newlines may occur in the document after continuation characters (the + backslash, `\`), but those newlines are treated as part of the SimpleWhitespace. + """ + + # A value of 'None' indicates that the module's default newline sequence should be + # used. A value is allowed only because python modules are permitted to mix multiple + # unambiguous newline markers. + value: Optional[str] = None + + def _validate(self) -> None: + if self.value and NEWLINE_RE.fullmatch(self.value) is None: + raise CSTValidationError( + f"Got an invalid value for newline node: {repr(self.value)}" + ) + + def _codegen(self, state: CodegenState) -> None: + state.tokens.append(state.default_newline if self.value is None else self.value) + + +@add_slots +@dataclass(frozen=True) +class Comment(BaseValueToken): + """ + A comment including the leading pound (`#`) character. + + The leading pound character is included in the 'value' property (instead of being + stripped) to help re-enforce the idea that whitespace immediately after the pound + character may be significant. E.g: + + # comment with whitespace at the start (usually preferred), versus + #comment without whitespace at the start (usually not desirable) + + Usually wrapped in a TrailingWhitespace or EmptyLine node. + """ + + value: str + + def _validate(self) -> None: + if COMMENT_RE.fullmatch(self.value) is None: + raise CSTValidationError( + f"Got non-comment value for comment node: {repr(self.value)}" + ) + + +@add_slots +@dataclass(frozen=True) +class TrailingWhitespace(CSTNode): + """ + The whitespace at the end of a line after a statement. If a line contains only + whitespace, EmptyLine should be used instead. + """ + + whitespace: SimpleWhitespace = SimpleWhitespace("") + comment: Optional[Comment] = None + newline: Newline = Newline() + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "TrailingWhitespace": + return TrailingWhitespace( + whitespace=visit_required("whitespace", self.whitespace, visitor), + comment=visit_optional("comment", self.comment, visitor), + newline=visit_required("newline", self.newline, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + self.whitespace._codegen(state) + if self.comment is not None: + self.comment._codegen(state) + self.newline._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class EmptyLine(CSTNode): + """ + Represents a line with only whitespace/comments. Usually statements will own any + EmptyLine nodes above themselves, and a Module will own the document's header/footer + EmptyLine nodes. + """ + + # An empty line doesn't have to correspond to the current indentation level. For + # example, this happens when all trailing whitespace is stripped. + indent: bool = True + # Extra whitespace after the indent, but before the comment + whitespace: SimpleWhitespace = SimpleWhitespace("") + comment: Optional[Comment] = None + newline: Newline = Newline() + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "EmptyLine": + return EmptyLine( + indent=self.indent, + whitespace=visit_required("whitespace", self.whitespace, visitor), + comment=visit_optional("comment", self.comment, visitor), + newline=visit_required("newline", self.newline, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + if self.indent: + state.tokens.extend(state.indent) + self.whitespace._codegen(state) + if self.comment is not None: + self.comment._codegen(state) + self.newline._codegen(state) + + +@add_slots +@dataclass(frozen=True) +class ParenthesizedWhitespace(BaseParenthesizableWhitespace): + + # The whitespace that comes after the previous node, up to and including + # the end-of-line comment. + first_line: TrailingWhitespace = TrailingWhitespace() + + # Any lines that contain only indentation and/or comments + empty_lines: Sequence[EmptyLine] = () + + # Whether or not the final comment is indented regularly + indent: bool = False + + # Extra whitespace after the indent, but before the next node + last_line: SimpleWhitespace = SimpleWhitespace("") + + def _visit_and_replace_children( + self, visitor: CSTVisitor + ) -> "ParenthesizedWhitespace": + return ParenthesizedWhitespace( + first_line=visit_required("first_line", self.first_line, visitor), + empty_lines=visit_sequence("empty_lines", self.empty_lines, visitor), + indent=self.indent, + last_line=visit_required("last_line", self.last_line, visitor), + ) + + def _codegen(self, state: CodegenState) -> None: + self.first_line._codegen(state) + for line in self.empty_lines: + line._codegen(state) + if self.indent: + state.tokens.extend(state.indent) + self.last_line._codegen(state) + + @property + def empty(self) -> bool: + # Its not possible to have a ParenthesizedWhitespace with zero characers. + # If we did, the TrailingWhitespace would not have parsed. + return False diff --git a/libcst/nodes/tests/base.py b/libcst/nodes/tests/base.py new file mode 100644 index 00000000..d4375349 --- /dev/null +++ b/libcst/nodes/tests/base.py @@ -0,0 +1,213 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +from contextlib import ExitStack +from dataclasses import dataclass +from typing import Callable, Iterable, List, Optional, Sequence, Type, TypeVar +from unittest.mock import patch + +import libcst.nodes as cst +from libcst._base_visitor import CSTVisitor +from libcst.nodes._internal import CodegenState, visit_required +from libcst.testing.utils import UnitTest + + +_CSTNodeT = TypeVar("_CSTNodeT", bound="cst.CSTNode") + + +@dataclass(frozen=True) +class _CSTCodegenPatchTarget: + type: Type[cst.CSTNode] + name: str + old_codegen: Callable[..., None] + + +class _NOOPVisitor(CSTVisitor): + pass + + +def _cst_node_equality_func(a: cst.CSTNode, b: cst.CSTNode, msg=None) -> None: + """ + For use with addTypeEqualityFunc. + """ + if not a.deep_equals(b): + suffix = "" if msg is None else f"\n{msg}" + raise AssertionError(f"\n{a!r}\nis not deeply equal to \n{b!r}{suffix}") + + +# We can't use an ABCMeta here, because of metaclass conflicts +class CSTNodeTest(UnitTest): + def setUp(self) -> None: + # Fix `self.assertEqual` for CSTNode subclasses. We should compare equality by + # value instead of identity (what `CSTNode.__eq__` does) for tests. + # + # The time complexity of CSTNode.deep_equals doesn't matter much inside tests. + for v in cst.__dict__.values(): + if isinstance(v, type) and issubclass(v, cst.CSTNode): + self.addTypeEqualityFunc(v, _cst_node_equality_func) + self.addTypeEqualityFunc(DummyIndentedBlock, _cst_node_equality_func) + + def validate_node( + self, + node: _CSTNodeT, + code: str, + parser: Optional[Callable[[str], _CSTNodeT]] = None, + ) -> None: + self.__assert_codegen(node, code) + + if parser is not None: + parsed_node = parser(code) + self.assertEqual(parsed_node, node) + + # Tests of children should unwrap DummyIndentedBlock first, because we don't + # want to test DummyIndentedBlock's behavior. + unwrapped_node = node + while isinstance(unwrapped_node, DummyIndentedBlock): + unwrapped_node = unwrapped_node.child + self.__assert_children_match_codegen(unwrapped_node) + self.__assert_children_match_fields(unwrapped_node) + self.__assert_visit_returns_identity(unwrapped_node) + + def assert_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + with self.assertRaisesRegex(cst.CSTValidationError, expected_re): + get_node() + + def __assert_codegen(self, node: cst.CSTNode, expected: str) -> None: + """ + Verifies that the given node's `_codegen` method is correct. + """ + self.assertEqual(cst.Module([]).code_for_node(node), expected) + + def __assert_children_match_codegen(self, node: cst.CSTNode) -> None: + children = node.children + codegen_children = self.__derive_children_from_codegen(node) + self.assertSequenceEqual( + children, + codegen_children, + msg=( + "The list of children we got from `node.children` differs from the " + + "children that were visited by `node._codegen`." + ), + ) + + def __derive_children_from_codegen( + self, node: cst.CSTNode + ) -> Sequence[cst.CSTNode]: + """ + Patches all subclasses of `CSTNode` exported by the `cst` module to track which + `_codegen` methods get called, generating a list of children. + + Because all children must be rendered out into lexical order, this should be + equivalent to `node.children`. + + `node.children` uses `_visit_and_replace_children` under the hood, not + `_codegen`, so this helps us verify that both of those two method's behaviors + are in sync. + """ + + patch_targets: Iterable[_CSTCodegenPatchTarget] = [ + _CSTCodegenPatchTarget(type=v, name=k, old_codegen=v._codegen) + for (k, v) in cst.__dict__.items() + if isinstance(v, type) + and issubclass(v, cst.CSTNode) + and hasattr(v, "_codegen") + ] + + children: List[cst.CSTNode] = [] + codegen_stack: List[cst.CSTNode] = [] + + def _get_codegen_override(target: _CSTCodegenPatchTarget): + def _codegen(self, *args, **kwargs) -> None: + should_pop = False + # Don't stick duplicates in the stack. This is needed so that we don't + # track calls to `super()._codegen()`. + if len(codegen_stack) == 0 or codegen_stack[-1] is not self: + # Check the stack to see that we're a direct child, not the root or + # a transitive child. + if len(codegen_stack) == 1: + children.append(self) + codegen_stack.append(self) + should_pop = True + target.old_codegen(self, *args, **kwargs) + # only pop if we pushed something to the stack earlier + if should_pop: + codegen_stack.pop() + + return _codegen + + with ExitStack() as patch_stack: + for t in patch_targets: + patch_stack.enter_context( + # pyre-ignore Incompatible parameter type [6]: Expected + # pyre-ignore `typing.ContextManager[Variable[contextlib._T]]` + # pyre-ignore for 1st anonymous parameter to call + # pyre-ignore `contextlib.ExitStack.enter_context` but got + # pyre-ignore `unittest.mock._patch`. + patch(f"libcst.nodes.{t.name}._codegen", _get_codegen_override(t)) + ) + # Execute `node._codegen()` + cst.Module([]).code_for_node(node) + + return children + + def __assert_children_match_fields(self, node: cst.CSTNode) -> None: + """ + We expect `node.children` to match everything we can extract from the node's + fields, but maybe in a different order. This asserts that those things match. + + If you want to verify order as well, use `assert_children_ordered`. + """ + node_children_ids = {id(child) for child in node.children} + fields = dataclasses.fields(node) + field_child_ids = set() + for f in fields: + value = getattr(node, f.name) + if isinstance(value, cst.CSTNode): + field_child_ids.add(id(value)) + elif isinstance(value, Iterable): + field_child_ids.update( + id(el) for el in value if isinstance(el, cst.CSTNode) + ) + + # order doesn't matter + self.assertSetEqual( + node_children_ids, + field_child_ids, + msg="`node.children` doesn't match what we found through introspection", + ) + + def __assert_visit_returns_identity(self, node: cst.CSTNode) -> None: + """ + When visit is called with a visitor that acts as a no-op, the visit method + should return the same node it started with. + """ + # TODO: We're only checking equality right now, because visit currently clones + # the node, since that was easier to implement. We should fix that behavior in a + # later version and tighten this check. + self.assertEqual(node, node.visit(_NOOPVisitor())) + + +@dataclass(frozen=True) +class DummyIndentedBlock(cst.CSTNode): + """ + A stripped-down version of cst.IndentedBlock that only sets/clears the indentation + state for the purpose of testing cst.IndentWhitespace in isolation. + """ + + value: str + child: cst.CSTNode + + def _codegen(self, state: CodegenState) -> None: + state.indent.append(self.value) + self.child._codegen(state) + state.indent.pop() + + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "DummyIndentedBlock": + return DummyIndentedBlock( + value=self.value, child=visit_required("child", self.child, visitor) + ) diff --git a/libcst/nodes/tests/test_assert.py b/libcst/nodes/tests/test_assert.py new file mode 100644 index 00000000..aef02f59 --- /dev/null +++ b/libcst/nodes/tests/test_assert.py @@ -0,0 +1,114 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class AssertConstructionTest(CSTNodeTest): + @data_provider( + ( + # Simple assert + (cst.Assert(cst.Name("True")), "assert True"), + # Assert with message + ( + cst.Assert( + cst.Name("True"), cst.SimpleString('"Value should be true"') + ), + 'assert True, "Value should be true"', + ), + # Whitespace oddities test + ( + cst.Assert( + cst.Name("True", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)), + whitespace_after_assert=cst.SimpleWhitespace(""), + ), + "assert(True)", + ), + # Whitespace rendering test + ( + cst.Assert( + whitespace_after_assert=cst.SimpleWhitespace(" "), + test=cst.Name("True"), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + msg=cst.SimpleString('"Value should be true"'), + ), + 'assert True , "Value should be true"', + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + # Validate whitespace handling + ( + lambda: cst.Assert( + cst.Name("True"), whitespace_after_assert=cst.SimpleWhitespace("") + ), + "Must have at least one space after 'assert'", + ), + # Validate comma handling + ( + lambda: cst.Assert(test=cst.Name("True"), comma=cst.Comma()), + "Cannot have trailing comma after 'test'", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class AssertParsingTest(CSTNodeTest): + @data_provider( + ( + # Simple assert + (cst.Assert(cst.Name("True")), "assert True"), + # Assert with message + ( + cst.Assert( + cst.Name("True"), + cst.SimpleString('"Value should be true"'), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + 'assert True, "Value should be true"', + ), + # Whitespace oddities test + ( + cst.Assert( + cst.Name("True", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)), + whitespace_after_assert=cst.SimpleWhitespace(""), + ), + "assert(True)", + ), + # Whitespace rendering test + ( + cst.Assert( + whitespace_after_assert=cst.SimpleWhitespace(" "), + test=cst.Name("True"), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + msg=cst.SimpleString('"Value should be true"'), + ), + 'assert True , "Value should be true"', + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. + self.validate_node(node, code, lambda code: parse_statement(code).body[0]) diff --git a/libcst/nodes/tests/test_assign.py b/libcst/nodes/tests/test_assign.py new file mode 100644 index 00000000..273dafc5 --- /dev/null +++ b/libcst/nodes/tests/test_assign.py @@ -0,0 +1,374 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class AssignTest(CSTNodeTest): + @data_provider( + ( + # Simple assignment creation case. + ( + cst.Assign( + (cst.AssignTarget(cst.Name("foo")),), cst.Number(cst.Integer("5")) + ), + "foo = 5", + None, + ), + # Multiple targets creation + ( + cst.Assign( + ( + cst.AssignTarget(cst.Name("foo")), + cst.AssignTarget(cst.Name("bar")), + ), + cst.Number(cst.Integer("5")), + ), + "foo = bar = 5", + None, + ), + # Whitespace test for creating nodes + ( + cst.Assign( + ( + cst.AssignTarget( + cst.Name("foo"), + whitespace_before_equal=cst.SimpleWhitespace(""), + whitespace_after_equal=cst.SimpleWhitespace(""), + ), + ), + cst.Number(cst.Integer("5")), + ), + "foo=5", + None, + ), + # Simple assignment parser case. + ( + cst.SimpleStatementLine( + ( + cst.Assign( + (cst.AssignTarget(cst.Name("foo")),), + cst.Number(cst.Integer("5")), + ), + ) + ), + "foo = 5\n", + parse_statement, + ), + # Multiple targets parser + ( + cst.SimpleStatementLine( + ( + cst.Assign( + ( + cst.AssignTarget(cst.Name("foo")), + cst.AssignTarget(cst.Name("bar")), + ), + cst.Number(cst.Integer("5")), + ), + ) + ), + "foo = bar = 5\n", + parse_statement, + ), + # Whitespace test parser + ( + cst.SimpleStatementLine( + ( + cst.Assign( + ( + cst.AssignTarget( + cst.Name("foo"), + whitespace_before_equal=cst.SimpleWhitespace(""), + whitespace_after_equal=cst.SimpleWhitespace(""), + ), + ), + cst.Number(cst.Integer("5")), + ), + ) + ), + "foo=5\n", + parse_statement, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + ( + lambda: cst.Assign(targets=(), value=cst.Number(cst.Integer("5"))), + "at least one AssignTarget", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class AnnAssignTest(CSTNodeTest): + @data_provider( + ( + # Simple assignment creation case. + ( + cst.AnnAssign( + cst.Name("foo"), + cst.Annotation(cst.Name("str")), + cst.Number(cst.Integer("5")), + ), + "foo: str = 5", + None, + ), + # Annotation creation without assignment + ( + cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str"))), + "foo: str", + None, + ), + # Complex annotation creation + ( + cst.AnnAssign( + cst.Name("foo"), + cst.Annotation( + cst.Subscript(cst.Name("Optional"), cst.Index(cst.Name("str"))) + ), + cst.Number(cst.Integer("5")), + ), + "foo: Optional[str] = 5", + None, + ), + # Simple assignment parser case. + ( + cst.SimpleStatementLine( + ( + cst.AnnAssign( + target=cst.Name("foo"), + annotation=cst.Annotation( + annotation=cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace(""), + ), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("5")), + ), + ) + ), + "foo: str = 5\n", + parse_statement, + ), + # Annotation without assignment + ( + cst.SimpleStatementLine( + ( + cst.AnnAssign( + target=cst.Name("foo"), + annotation=cst.Annotation( + annotation=cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace(""), + ), + value=None, + ), + ) + ), + "foo: str\n", + parse_statement, + ), + # Complex annotation + ( + cst.SimpleStatementLine( + ( + cst.AnnAssign( + target=cst.Name("foo"), + annotation=cst.Annotation( + annotation=cst.Subscript( + cst.Name("Optional"), cst.Index(cst.Name("str")) + ), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace(""), + ), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("5")), + ), + ) + ), + "foo: Optional[str] = 5\n", + parse_statement, + ), + # Whitespace test + ( + cst.AnnAssign( + target=cst.Name("foo"), + annotation=cst.Annotation( + annotation=cst.Subscript( + cst.Name("Optional"), cst.Index(cst.Name("str")) + ), + whitespace_before_indicator=cst.SimpleWhitespace(" "), + whitespace_after_indicator=cst.SimpleWhitespace(" "), + ), + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + value=cst.Number(cst.Integer("5")), + ), + "foo : Optional[str] = 5", + None, + ), + ( + cst.SimpleStatementLine( + ( + cst.AnnAssign( + target=cst.Name("foo"), + annotation=cst.Annotation( + annotation=cst.Subscript( + cst.Name("Optional"), cst.Index(cst.Name("str")) + ), + whitespace_before_indicator=cst.SimpleWhitespace(" "), + indicator=":", + whitespace_after_indicator=cst.SimpleWhitespace(" "), + ), + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + value=cst.Number(cst.Integer("5")), + ), + ) + ), + "foo : Optional[str] = 5\n", + parse_statement, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + ( + lambda: cst.AnnAssign( + target=cst.Name("foo"), + annotation=cst.Annotation(cst.Name("str")), + equal=cst.AssignEqual(), + value=None, + ), + "Must have a value when specifying an AssignEqual.", + ), + ( + lambda: cst.AnnAssign( + target=cst.Name("foo"), + annotation=cst.Annotation(cst.Name("str"), "->"), + value=cst.Number(cst.Integer("5")), + ), + "must be denoted with a ':'", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class AugAssignTest(CSTNodeTest): + @data_provider( + ( + # Simple assignment constructor case. + ( + cst.AugAssign( + cst.Name("foo"), cst.AddAssign(), cst.Number(cst.Integer("5")) + ), + "foo += 5", + None, + ), + ( + cst.AugAssign(cst.Name("bar"), cst.MultiplyAssign(), cst.Name("foo")), + "bar *= foo", + None, + ), + # Whitespace constructor test + ( + cst.AugAssign( + target=cst.Name("foo"), + operator=cst.LeftShiftAssign( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + value=cst.Number(cst.Integer("5")), + ), + "foo <<= 5", + None, + ), + # Simple assignment parser case. + ( + cst.SimpleStatementLine( + ( + cst.AugAssign( + cst.Name("foo"), + cst.AddAssign(), + cst.Number(cst.Integer("5")), + ), + ) + ), + "foo += 5\n", + parse_statement, + ), + ( + cst.SimpleStatementLine( + ( + cst.AugAssign( + cst.Name("bar"), cst.MultiplyAssign(), cst.Name("foo") + ), + ) + ), + "bar *= foo\n", + parse_statement, + ), + # Whitespace parser test + ( + cst.SimpleStatementLine( + ( + cst.AugAssign( + target=cst.Name("foo"), + operator=cst.LeftShiftAssign( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + value=cst.Number(cst.Integer("5")), + ), + ) + ), + "foo <<= 5\n", + parse_statement, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) diff --git a/libcst/nodes/tests/test_atom.py b/libcst/nodes/tests/test_atom.py new file mode 100644 index 00000000..0daa4058 --- /dev/null +++ b/libcst/nodes/tests/test_atom.py @@ -0,0 +1,422 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class AtomTest(CSTNodeTest): + @data_provider( + ( + # Simple identifier + (cst.Name("test"), "test"), + # Parenthesized identifier + ( + cst.Name("test", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)), + "(test)", + ), + # Decimal integers + (cst.Number(cst.Integer("12345")), "12345"), + (cst.Number(cst.Integer("0000")), "0000"), + (cst.Number(cst.Integer("1_234_567")), "1_234_567"), + (cst.Number(cst.Integer("0_000")), "0_000"), + # Binary integers + (cst.Number(cst.Integer("0b0000")), "0b0000"), + (cst.Number(cst.Integer("0B1011_0100")), "0B1011_0100"), + # Octal integers + (cst.Number(cst.Integer("0o12345")), "0o12345"), + (cst.Number(cst.Integer("0O12_345")), "0O12_345"), + # Hex numbers + (cst.Number(cst.Integer("0x123abc")), "0x123abc"), + (cst.Number(cst.Integer("0X12_3ABC")), "0X12_3ABC"), + # Parenthesized integers + ( + cst.Number( + cst.Integer( + "123", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ) + ), + "(123)", + ), + # Non-exponent floats + (cst.Number(cst.Float("12345.")), "12345."), + (cst.Number(cst.Float("00.00")), "00.00"), + (cst.Number(cst.Float("12.21")), "12.21"), + (cst.Number(cst.Float(".321")), ".321"), + (cst.Number(cst.Float("1_234_567.")), "1_234_567."), + (cst.Number(cst.Float("0.000_000")), "0.000_000"), + # Exponent floats + (cst.Number(cst.Float("12345.e10")), "12345.e10"), + (cst.Number(cst.Float("00.00e10")), "00.00e10"), + (cst.Number(cst.Float("12.21e10")), "12.21e10"), + (cst.Number(cst.Float(".321e10")), ".321e10"), + (cst.Number(cst.Float("1_234_567.e10")), "1_234_567.e10"), + (cst.Number(cst.Float("0.000_000e10")), "0.000_000e10"), + (cst.Number(cst.Float("1e+10")), "1e+10"), + (cst.Number(cst.Float("1e-10")), "1e-10"), + # Parenthesized floats + ( + cst.Number( + cst.Float( + "123.4", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ) + ), + "(123.4)", + ), + # Imaginary numbers + (cst.Number(cst.Imaginary("12345j")), "12345j"), + (cst.Number(cst.Imaginary("1_234_567J")), "1_234_567J"), + (cst.Number(cst.Imaginary("12345.e10j")), "12345.e10j"), + (cst.Number(cst.Imaginary(".321J")), ".321J"), + # Parenthesized imaginary + ( + cst.Number( + cst.Imaginary( + "123.4j", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ) + ), + "(123.4j)", + ), + # Simple elipses + (cst.Ellipses(), "..."), + # Parenthesized elipses + (cst.Ellipses(lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)), "(...)"), + # Simple strings + (cst.SimpleString('""'), '""'), + (cst.SimpleString("''"), "''"), + (cst.SimpleString('"test"'), '"test"'), + (cst.SimpleString('b"test"'), 'b"test"'), + (cst.SimpleString('r"test"'), 'r"test"'), + (cst.SimpleString('"""test"""'), '"""test"""'), + # Validate parens + ( + cst.SimpleString( + '"test"', lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + '("test")', + ), + ( + cst.SimpleString( + 'rb"test"', lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + '(rb"test")', + ), + # Empty formatted strings + (cst.FormattedString(start='f"', parts=(), end='"'), 'f""'), + (cst.FormattedString(start="f'", parts=(), end="'"), "f''"), + (cst.FormattedString(start='f"""', parts=(), end='"""'), 'f""""""'), + (cst.FormattedString(start="f'''", parts=(), end="'''"), "f''''''"), + # Non-empty formatted strings + (cst.FormattedString(parts=(cst.FormattedStringText("foo"),)), 'f"foo"'), + ( + cst.FormattedString( + parts=(cst.FormattedStringExpression(cst.Name("foo")),) + ), + 'f"{foo}"', + ), + ( + cst.FormattedString( + parts=( + cst.FormattedStringText("foo "), + cst.FormattedStringExpression(cst.Name("bar")), + cst.FormattedStringText(" baz"), + ) + ), + 'f"foo {bar} baz"', + ), + ( + cst.FormattedString( + parts=( + cst.FormattedStringText("foo "), + cst.FormattedStringExpression(cst.Call(cst.Name("bar"))), + cst.FormattedStringText(" baz"), + ) + ), + 'f"foo {bar()} baz"', + ), + # Formatted strings with conversions and format specifiers + ( + cst.FormattedString( + parts=( + cst.FormattedStringExpression(cst.Name("foo"), conversion="s"), + ) + ), + 'f"{foo!s}"', + ), + ( + cst.FormattedString( + parts=( + cst.FormattedStringExpression(cst.Name("foo"), format_spec=()), + ) + ), + 'f"{foo:}"', + ), + ( + cst.FormattedString( + parts=( + cst.FormattedStringExpression( + cst.Name("today"), + format_spec=(cst.FormattedStringText("%B %d, %Y"),), + ), + ) + ), + 'f"{today:%B %d, %Y}"', + ), + ( + cst.FormattedString( + parts=( + cst.FormattedStringExpression( + cst.Name("foo"), + format_spec=( + cst.FormattedStringExpression(cst.Name("bar")), + ), + ), + ) + ), + 'f"{foo:{bar}}"', + ), + ( + cst.FormattedString( + parts=( + cst.FormattedStringExpression( + cst.Name("foo"), + format_spec=( + cst.FormattedStringExpression(cst.Name("bar")), + cst.FormattedStringText("."), + cst.FormattedStringExpression(cst.Name("baz")), + ), + ), + ) + ), + 'f"{foo:{bar}.{baz}}"', + ), + ( + cst.FormattedString( + parts=( + cst.FormattedStringExpression( + cst.Name("foo"), + conversion="s", + format_spec=( + cst.FormattedStringExpression(cst.Name("bar")), + ), + ), + ) + ), + 'f"{foo!s:{bar}}"', + ), + # Validate parens + ( + cst.FormattedString( + start='f"', + parts=(), + end='"', + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ), + '(f"")', + ), + # Concatenated strings + ( + cst.ConcatenatedString( + cst.SimpleString('"ab"'), cst.SimpleString('"c"') + ), + '"ab""c"', + ), + # Concatenated parenthesized strings + ( + cst.ConcatenatedString( + lpar=(cst.LeftParen(),), + left=cst.SimpleString('"ab"'), + right=cst.SimpleString('"c"'), + rpar=(cst.RightParen(),), + ), + '("ab""c")', + ), + # Validate spacing + ( + cst.ConcatenatedString( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + left=cst.SimpleString('"ab"'), + whitespace_between=cst.SimpleWhitespace(" "), + right=cst.SimpleString('"c"'), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + '( "ab" "c" )', + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + # We don't have sentinel nodes for atoms, so we know that 100% of atoms + # can be parsed identically to their creation. + self.validate_node(node, code, parse_expression) + + @data_provider( + ( + # Expression wrapping parenthesis rules + ( + lambda: cst.Name("foo", lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Name("foo", rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ( + lambda: cst.Ellipses(lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Ellipses(rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ( + lambda: cst.Integer("5", lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Integer("5", rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ( + lambda: cst.Float("5.5", lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Float("5.5", rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ( + lambda: cst.Imaginary("5j", lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Imaginary("5j", rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ( + lambda: cst.Number(cst.Integer("5"), lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Number(cst.Integer("5"), rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ( + lambda: cst.SimpleString("'foo'", lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.SimpleString("'foo'", rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ( + # pyre-fixme[6]: Expected `Sequence[BaseFormattedStringContent]` for + # 1st param but got `str`. + lambda: cst.FormattedString("f''", lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + # pyre-fixme[6]: Expected `Sequence[BaseFormattedStringContent]` for + # 1st param but got `str`. + lambda: cst.FormattedString("f''", rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ( + lambda: cst.ConcatenatedString( + cst.SimpleString("'foo'"), + cst.SimpleString("'foo'"), + lpar=(cst.LeftParen(),), + ), + "left paren without right paren", + ), + ( + lambda: cst.ConcatenatedString( + cst.SimpleString("'foo'"), + cst.SimpleString("'foo'"), + rpar=(cst.RightParen(),), + ), + "right paren without left paren", + ), + # Node-specific rules + (lambda: cst.Name(""), "empty name identifier"), + (lambda: cst.Name(r"\/"), "not a valid identifier"), + (lambda: cst.Integer(""), "not a valid integer"), + (lambda: cst.Integer("012345"), "not a valid integer"), + (lambda: cst.Integer("012345"), "not a valid integer"), + (lambda: cst.Integer("_12345"), "not a valid integer"), + (lambda: cst.Integer("0b2"), "not a valid integer"), + (lambda: cst.Integer("0o8"), "not a valid integer"), + (lambda: cst.Integer("0xg"), "not a valid integer"), + (lambda: cst.Integer("123.45"), "not a valid integer"), + (lambda: cst.Integer("12345j"), "not a valid integer"), + (lambda: cst.Float("12.3.45"), "not a valid float"), + (lambda: cst.Float("12"), "not a valid float"), + (lambda: cst.Float("12.3j"), "not a valid float"), + (lambda: cst.Imaginary("_12345j"), "not a valid imaginary"), + (lambda: cst.Imaginary("0b0j"), "not a valid imaginary"), + (lambda: cst.Imaginary("0o0j"), "not a valid imaginary"), + (lambda: cst.Imaginary("0x0j"), "not a valid imaginary"), + (lambda: cst.SimpleString('wee""'), "Invalid string prefix"), + (lambda: cst.SimpleString(""), "must have enclosing quotes"), + (lambda: cst.SimpleString("'"), "must have enclosing quotes"), + (lambda: cst.SimpleString('"'), "must have enclosing quotes"), + (lambda: cst.SimpleString("\"'"), "must have matching enclosing quotes"), + (lambda: cst.SimpleString("'bla"), "must have matching enclosing quotes"), + (lambda: cst.SimpleString("f''"), "Invalid string prefix"), + ( + lambda: cst.SimpleString("'''bla''"), + "must have matching enclosing quotes", + ), + ( + lambda: cst.SimpleString("'''bla\"\"\""), + "must have matching enclosing quotes", + ), + ( + lambda: cst.FormattedString(start="'", parts=(), end="'"), + "Invalid f-string prefix", + ), + ( + lambda: cst.FormattedString(start="f'", parts=(), end='"'), + "must have matching enclosing quotes", + ), + ( + lambda: cst.FormattedString(start="f'''", parts=(), end="''"), + "must have matching enclosing quotes", + ), + ( + lambda: cst.ConcatenatedString( + cst.SimpleString( + '"ab"', lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + cst.SimpleString('"c"'), + ), + "Cannot concatenate parenthesized", + ), + ( + lambda: cst.ConcatenatedString( + cst.SimpleString('"ab"'), + cst.SimpleString( + '"c"', lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + ), + "Cannot concatenate parenthesized", + ), + ( + lambda: cst.ConcatenatedString( + cst.SimpleString('"ab"'), cst.SimpleString('b"c"') + ), + "Cannot concatenate string and bytes", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_attribute.py b/libcst/nodes/tests/test_attribute.py new file mode 100644 index 00000000..01192055 --- /dev/null +++ b/libcst/nodes/tests/test_attribute.py @@ -0,0 +1,68 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class AttributeTest(CSTNodeTest): + @data_provider( + ( + # Simple attribute access + (cst.Attribute(cst.Name("foo"), cst.Name("bar")), "foo.bar"), + # Parenthesized attribute access + ( + cst.Attribute( + lpar=(cst.LeftParen(),), + value=cst.Name("foo"), + attr=cst.Name("bar"), + rpar=(cst.RightParen(),), + ), + "(foo.bar)", + ), + # Make sure that spacing works + ( + cst.Attribute( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + value=cst.Name("foo"), + dot=cst.Dot( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + attr=cst.Name("bar"), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( foo . bar )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_expression) + + @data_provider( + ( + ( + lambda: cst.Attribute( + cst.Name("foo"), cst.Name("bar"), lpar=(cst.LeftParen(),) + ), + "left paren without right paren", + ), + ( + lambda: cst.Attribute( + cst.Name("foo"), cst.Name("bar"), rpar=(cst.RightParen(),) + ), + "right paren without left paren", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_await.py b/libcst/nodes/tests/test_await.py new file mode 100644 index 00000000..8a14be67 --- /dev/null +++ b/libcst/nodes/tests/test_await.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class AwaitTest(CSTNodeTest): + @data_provider( + ( + # Some simple calls + (cst.Await(cst.Name("test")), "await test"), + (cst.Await(cst.Call(cst.Name("test"))), "await test()"), + # Whitespace + ( + cst.Await( + cst.Name("test"), + whitespace_after_await=cst.SimpleWhitespace(" "), + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( await test )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + # We don't have sentinel nodes for atoms, so we know that 100% of atoms + # can be parsed identically to their creation. + self.validate_node(node, code, parse_expression) + + @data_provider( + ( + # Expression wrapping parenthesis rules + ( + lambda: cst.Await(cst.Name("foo"), lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Await(cst.Name("foo"), rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ( + lambda: cst.Await( + cst.Name("foo"), whitespace_after_await=cst.SimpleWhitespace("") + ), + "at least one space after await", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_binary_op.py b/libcst/nodes/tests/test_binary_op.py new file mode 100644 index 00000000..d03b3340 --- /dev/null +++ b/libcst/nodes/tests/test_binary_op.py @@ -0,0 +1,144 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class BinaryOperationTest(CSTNodeTest): + @data_provider( + ( + # Simple binary operations + ( + cst.BinaryOperation( + cst.Name("foo"), cst.Add(), cst.Number(cst.Float("5.5")) + ), + "foo + 5.5", + ), + ( + cst.BinaryOperation( + cst.Name("foo"), cst.Subtract(), cst.Number(cst.Float("5.5")) + ), + "foo - 5.5", + ), + ( + cst.BinaryOperation( + cst.Name("foo"), cst.LeftShift(), cst.Number(cst.Integer("5")) + ), + "foo << 5", + ), + ( + cst.BinaryOperation( + cst.Name("foo"), cst.RightShift(), cst.Number(cst.Integer("5")) + ), + "foo >> 5", + ), + ( + cst.BinaryOperation(cst.Name("foo"), cst.BitAnd(), cst.Name("bar")), + "foo & bar", + ), + ( + cst.BinaryOperation(cst.Name("foo"), cst.BitXor(), cst.Name("bar")), + "foo ^ bar", + ), + ( + cst.BinaryOperation(cst.Name("foo"), cst.BitOr(), cst.Name("bar")), + "foo | bar", + ), + ( + cst.BinaryOperation( + cst.Name("foo"), cst.Multiply(), cst.Number(cst.Float("5.5")) + ), + "foo * 5.5", + ), + ( + cst.BinaryOperation( + cst.Name("foo"), cst.MatrixMultiply(), cst.Number(cst.Float("5.5")) + ), + "foo @ 5.5", + ), + ( + cst.BinaryOperation( + cst.Name("foo"), cst.Divide(), cst.Number(cst.Float("5.5")) + ), + "foo / 5.5", + ), + ( + cst.BinaryOperation( + cst.Name("foo"), cst.Modulo(), cst.Number(cst.Float("5.5")) + ), + "foo % 5.5", + ), + ( + cst.BinaryOperation( + cst.Name("foo"), cst.FloorDivide(), cst.Number(cst.Float("5.5")) + ), + "foo // 5.5", + ), + # Parenthesized binary operation + ( + cst.BinaryOperation( + lpar=(cst.LeftParen(),), + left=cst.Name("foo"), + operator=cst.LeftShift(), + right=cst.Number(cst.Integer("5")), + rpar=(cst.RightParen(),), + ), + "(foo << 5)", + ), + # Make sure that spacing works + ( + cst.BinaryOperation( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + left=cst.Name("foo"), + operator=cst.Multiply( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + right=cst.Name("bar"), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( foo * bar )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_expression) + + @data_provider( + ( + ( + lambda: cst.BinaryOperation( + cst.Name("foo"), + # pyre-fixme[6]: Expected `BaseBinaryOp` for 2nd param but got + # `Plus`. + cst.Plus(), + cst.Name("bar"), + lpar=(cst.LeftParen(),), + ), + "left paren without right paren", + ), + ( + lambda: cst.BinaryOperation( + cst.Name("foo"), + # pyre-fixme[6]: Expected `BaseBinaryOp` for 2nd param but got + # `Plus`. + cst.Plus(), + cst.Name("bar"), + rpar=(cst.RightParen(),), + ), + "right paren without left paren", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_boolean_op.py b/libcst/nodes/tests/test_boolean_op.py new file mode 100644 index 00000000..1555e1f0 --- /dev/null +++ b/libcst/nodes/tests/test_boolean_op.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class BooleanOperationTest(CSTNodeTest): + @data_provider( + ( + # Simple boolean operations + ( + cst.BooleanOperation(cst.Name("foo"), cst.And(), cst.Name("bar")), + "foo and bar", + ), + ( + cst.BooleanOperation(cst.Name("foo"), cst.Or(), cst.Name("bar")), + "foo or bar", + ), + # Parenthesized boolean operation + ( + cst.BooleanOperation( + lpar=(cst.LeftParen(),), + left=cst.Name("foo"), + operator=cst.Or(), + right=cst.Name("bar"), + rpar=(cst.RightParen(),), + ), + "(foo or bar)", + ), + ( + cst.BooleanOperation( + left=cst.Name( + "foo", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + operator=cst.Or( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + right=cst.Name( + "bar", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + ), + "(foo)or(bar)", + ), + # Make sure that spacing works + ( + cst.BooleanOperation( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + left=cst.Name("foo"), + operator=cst.And( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + right=cst.Name("bar"), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( foo and bar )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_expression) + + @data_provider( + ( + ( + lambda: cst.BooleanOperation( + cst.Name("foo"), cst.And(), cst.Name("bar"), lpar=(cst.LeftParen(),) + ), + "left paren without right paren", + ), + ( + lambda: cst.BooleanOperation( + cst.Name("foo"), + cst.And(), + cst.Name("bar"), + rpar=(cst.RightParen(),), + ), + "right paren without left paren", + ), + ( + lambda: cst.BooleanOperation( + left=cst.Name("foo"), + operator=cst.Or( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + right=cst.Name("bar"), + ), + "at least one space around boolean operator", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_call.py b/libcst/nodes/tests/test_call.py new file mode 100644 index 00000000..010fe9c9 --- /dev/null +++ b/libcst/nodes/tests/test_call.py @@ -0,0 +1,524 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class CallTest(CSTNodeTest): + @data_provider( + ( + # Simple call + (cst.Call(cst.Name("foo")), "foo()", parse_expression), + ( + cst.Call( + cst.Name("foo"), whitespace_before_args=cst.SimpleWhitespace(" ") + ), + "foo( )", + parse_expression, + ), + # Call with attribute dereference + ( + cst.Call(cst.Attribute(cst.Name("foo"), cst.Name("bar"))), + "foo.bar()", + parse_expression, + ), + # Positional arguments render test + ( + cst.Call(cst.Name("foo"), (cst.Arg(cst.Number(cst.Integer("1"))),)), + "foo(1)", + None, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg(cst.Number(cst.Integer("1"))), + cst.Arg(cst.Number(cst.Integer("2"))), + cst.Arg(cst.Number(cst.Integer("3"))), + ), + ), + "foo(1, 2, 3)", + None, + ), + # Positional arguments parse test + ( + cst.Call( + cst.Name("foo"), (cst.Arg(value=cst.Number(cst.Integer("1"))),) + ), + "foo(1)", + parse_expression, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + value=cst.Number(cst.Integer("1")), + whitespace_after_arg=cst.SimpleWhitespace(" "), + ), + ), + whitespace_after_func=cst.SimpleWhitespace(" "), + whitespace_before_args=cst.SimpleWhitespace(" "), + ), + "foo ( 1 )", + parse_expression, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + value=cst.Number(cst.Integer("1")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + ), + whitespace_after_func=cst.SimpleWhitespace(" "), + whitespace_before_args=cst.SimpleWhitespace(" "), + ), + "foo ( 1, )", + parse_expression, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + value=cst.Number(cst.Integer("1")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + value=cst.Number(cst.Integer("2")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg(value=cst.Number(cst.Integer("3"))), + ), + ), + "foo(1, 2, 3)", + parse_expression, + ), + # Keyword arguments render test + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + keyword=cst.Name("one"), value=cst.Number(cst.Integer("1")) + ), + ), + ), + "foo(one = 1)", + None, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + keyword=cst.Name("one"), value=cst.Number(cst.Integer("1")) + ), + cst.Arg( + keyword=cst.Name("two"), value=cst.Number(cst.Integer("2")) + ), + cst.Arg( + keyword=cst.Name("three"), + value=cst.Number(cst.Integer("3")), + ), + ), + ), + "foo(one = 1, two = 2, three = 3)", + None, + ), + # Keyword arguments parser test + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + keyword=cst.Name("one"), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("1")), + ), + ), + ), + "foo(one = 1)", + parse_expression, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + keyword=cst.Name("one"), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("1")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + keyword=cst.Name("two"), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("2")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + keyword=cst.Name("three"), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("3")), + ), + ), + ), + "foo(one = 1, two = 2, three = 3)", + parse_expression, + ), + # Iterator expansion render test + ( + cst.Call(cst.Name("foo"), (cst.Arg(star="*", value=cst.Name("one")),)), + "foo(*one)", + None, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg(star="*", value=cst.Name("one")), + cst.Arg(star="*", value=cst.Name("two")), + cst.Arg(star="*", value=cst.Name("three")), + ), + ), + "foo(*one, *two, *three)", + None, + ), + # Iterator expansion parser test + ( + cst.Call(cst.Name("foo"), (cst.Arg(star="*", value=cst.Name("one")),)), + "foo(*one)", + parse_expression, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + star="*", + value=cst.Name("one"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="*", + value=cst.Name("two"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg(star="*", value=cst.Name("three")), + ), + ), + "foo(*one, *two, *three)", + parse_expression, + ), + # Dictionary expansion render test + ( + cst.Call(cst.Name("foo"), (cst.Arg(star="**", value=cst.Name("one")),)), + "foo(**one)", + None, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg(star="**", value=cst.Name("one")), + cst.Arg(star="**", value=cst.Name("two")), + cst.Arg(star="**", value=cst.Name("three")), + ), + ), + "foo(**one, **two, **three)", + None, + ), + # Dictionary expansion parser test + ( + cst.Call(cst.Name("foo"), (cst.Arg(star="**", value=cst.Name("one")),)), + "foo(**one)", + parse_expression, + ), + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + star="**", + value=cst.Name("one"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="**", + value=cst.Name("two"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg(star="**", value=cst.Name("three")), + ), + ), + "foo(**one, **two, **three)", + parse_expression, + ), + # Complicated mingling rules render test + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg(value=cst.Name("pos1")), + cst.Arg(star="*", value=cst.Name("list1")), + cst.Arg(value=cst.Name("pos2")), + cst.Arg(value=cst.Name("pos3")), + cst.Arg(star="*", value=cst.Name("list2")), + cst.Arg(value=cst.Name("pos4")), + cst.Arg(star="*", value=cst.Name("list3")), + cst.Arg( + keyword=cst.Name("kw1"), value=cst.Number(cst.Integer("1")) + ), + cst.Arg(star="*", value=cst.Name("list4")), + cst.Arg( + keyword=cst.Name("kw2"), value=cst.Number(cst.Integer("2")) + ), + cst.Arg(star="*", value=cst.Name("list5")), + cst.Arg( + keyword=cst.Name("kw3"), value=cst.Number(cst.Integer("3")) + ), + cst.Arg(star="**", value=cst.Name("dict1")), + cst.Arg( + keyword=cst.Name("kw4"), value=cst.Number(cst.Integer("4")) + ), + cst.Arg(star="**", value=cst.Name("dict2")), + ), + ), + "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)", + None, + ), + # Complicated mingling rules parser test + ( + cst.Call( + cst.Name("foo"), + ( + cst.Arg( + value=cst.Name("pos1"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="*", + value=cst.Name("list1"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + value=cst.Name("pos2"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + value=cst.Name("pos3"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="*", + value=cst.Name("list2"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + value=cst.Name("pos4"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="*", + value=cst.Name("list3"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + keyword=cst.Name("kw1"), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("1")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="*", + value=cst.Name("list4"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + keyword=cst.Name("kw2"), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("2")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="*", + value=cst.Name("list5"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + keyword=cst.Name("kw3"), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("3")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="**", + value=cst.Name("dict1"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + keyword=cst.Name("kw4"), + equal=cst.AssignEqual(), + value=cst.Number(cst.Integer("4")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg(star="**", value=cst.Name("dict2")), + ), + ), + "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)", + parse_expression, + ), + # Test whitespace + ( + cst.Call( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + func=cst.Name("foo"), + whitespace_after_func=cst.SimpleWhitespace(" "), + whitespace_before_args=cst.SimpleWhitespace(" "), + args=( + cst.Arg( + keyword=None, + value=cst.Name("pos1"), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.Arg( + star="*", + whitespace_after_star=cst.SimpleWhitespace(" "), + keyword=None, + value=cst.Name("list1"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + keyword=cst.Name("kw1"), + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + value=cst.Number(cst.Integer("1")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="**", + keyword=None, + whitespace_after_star=cst.SimpleWhitespace(" "), + value=cst.Name("dict1"), + whitespace_after_arg=cst.SimpleWhitespace(" "), + ), + ), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( foo ( pos1 , * list1, kw1=1, ** dict1 ) )", + parse_expression, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + # Basic expression parenthesizing tests. + ( + lambda: cst.Call(func=cst.Name("foo"), lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Call(func=cst.Name("foo"), rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + # Test that we handle keyword stuff correctly. + ( + lambda: cst.Call( + func=cst.Name("foo"), + args=( + cst.Arg( + equal=cst.AssignEqual(), value=cst.SimpleString("'baz'") + ), + ), + ), + "Must have a keyword when specifying an AssignEqual", + ), + # Test that we separate *, ** and keyword args correctly + ( + lambda: cst.Call( + func=cst.Name("foo"), + args=( + cst.Arg( + star="*", + keyword=cst.Name("bar"), + value=cst.SimpleString("'baz'"), + ), + ), + ), + "Cannot specify a star and a keyword together", + ), + # Test for expected star inputs only + ( + lambda: cst.Call( + func=cst.Name("foo"), + # pyre-fixme[6]: Expected `Union[typing_extensions.Literal[''], + # typing_extensions.Literal['*'], + # typing_extensions.Literal['**']]` for 1st param but got + # `typing_extensions.Literal['***']`. + args=(cst.Arg(star="***", value=cst.SimpleString("'baz'")),), + ), + r"Must specify either '', '\*' or '\*\*' for star", + ), + # Test ordering exceptions + ( + lambda: cst.Call( + func=cst.Name("foo"), + args=( + cst.Arg(star="**", value=cst.Name("bar")), + cst.Arg(star="*", value=cst.Name("baz")), + ), + ), + "Cannot have iterable argument unpacking after keyword argument unpacking", + ), + ( + lambda: cst.Call( + func=cst.Name("foo"), + args=( + cst.Arg(star="**", value=cst.Name("bar")), + cst.Arg(value=cst.Name("baz")), + ), + ), + "Cannot have positional argument after keyword argument unpacking", + ), + ( + lambda: cst.Call( + func=cst.Name("foo"), + args=( + cst.Arg( + keyword=cst.Name("arg"), value=cst.SimpleString("'baz'") + ), + cst.Arg(value=cst.SimpleString("'bar'")), + ), + ), + "Cannot have positional argument after keyword argument", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_classdef.py b/libcst/nodes/tests/test_classdef.py new file mode 100644 index 00000000..f12fbaa7 --- /dev/null +++ b/libcst/nodes/tests/test_classdef.py @@ -0,0 +1,327 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class ClassDefCreationTest(CSTNodeTest): + @data_provider( + ( + # Simple classdef + ( + cst.ClassDef(cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(),))), + "class Foo: pass\n", + ), + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + rpar=cst.RightParen(), + ), + "class Foo(): pass\n", + ), + # Positional arguments render test + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + bases=(cst.Arg(cst.Name("obj")),), + ), + "class Foo(obj): pass\n", + ), + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + bases=( + cst.Arg(cst.Name("Bar")), + cst.Arg(cst.Name("Baz")), + cst.Arg(cst.Name("object")), + ), + ), + "class Foo(Bar, Baz, object): pass\n", + ), + # Keyword arguments render test + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + keywords=( + cst.Arg(keyword=cst.Name("metaclass"), value=cst.Name("Bar")), + ), + ), + "class Foo(metaclass = Bar): pass\n", + ), + # Iterator expansion render test + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + bases=(cst.Arg(star="*", value=cst.Name("one")),), + ), + "class Foo(*one): pass\n", + ), + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + bases=( + cst.Arg(star="*", value=cst.Name("one")), + cst.Arg(star="*", value=cst.Name("two")), + cst.Arg(star="*", value=cst.Name("three")), + ), + ), + "class Foo(*one, *two, *three): pass\n", + ), + # Dictionary expansion render test + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + keywords=(cst.Arg(star="**", value=cst.Name("one")),), + ), + "class Foo(**one): pass\n", + ), + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + keywords=( + cst.Arg(star="**", value=cst.Name("one")), + cst.Arg(star="**", value=cst.Name("two")), + cst.Arg(star="**", value=cst.Name("three")), + ), + ), + "class Foo(**one, **two, **three): pass\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + # Basic parenthesis tests. + ( + lambda: cst.ClassDef( + name=cst.Name("Foo"), + body=cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + ), + "Do not mix concrete LeftParen/RightParen with MaybeSentinel", + ), + ( + lambda: cst.ClassDef( + name=cst.Name("Foo"), + body=cst.SimpleStatementSuite((cst.Pass(),)), + rpar=cst.RightParen(), + ), + "Do not mix concrete LeftParen/RightParen with MaybeSentinel", + ), + # Whitespace validation + ( + lambda: cst.ClassDef( + name=cst.Name("Foo"), + body=cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_class=cst.SimpleWhitespace(""), + ), + "at least one space between 'class' and name", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class ClassDefParserTest(CSTNodeTest): + @data_provider( + ( + # Simple classdef + ( + cst.ClassDef(cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(),))), + "class Foo: pass\n", + ), + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + rpar=cst.RightParen(), + ), + "class Foo(): pass\n", + ), + # Positional arguments render test + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + bases=(cst.Arg(cst.Name("obj")),), + rpar=cst.RightParen(), + ), + "class Foo(obj): pass\n", + ), + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + bases=( + cst.Arg( + cst.Name("Bar"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + cst.Name("Baz"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg(cst.Name("object")), + ), + rpar=cst.RightParen(), + ), + "class Foo(Bar, Baz, object): pass\n", + ), + # Keyword arguments render test + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + keywords=( + cst.Arg( + keyword=cst.Name("metaclass"), + equal=cst.AssignEqual(), + value=cst.Name("Bar"), + ), + ), + rpar=cst.RightParen(), + ), + "class Foo(metaclass = Bar): pass\n", + ), + # Iterator expansion render test + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + bases=(cst.Arg(star="*", value=cst.Name("one")),), + rpar=cst.RightParen(), + ), + "class Foo(*one): pass\n", + ), + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + bases=( + cst.Arg( + star="*", + value=cst.Name("one"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="*", + value=cst.Name("two"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg(star="*", value=cst.Name("three")), + ), + rpar=cst.RightParen(), + ), + "class Foo(*one, *two, *three): pass\n", + ), + # Dictionary expansion render test + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + keywords=(cst.Arg(star="**", value=cst.Name("one")),), + rpar=cst.RightParen(), + ), + "class Foo(**one): pass\n", + ), + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + keywords=( + cst.Arg( + star="**", + value=cst.Name("one"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg( + star="**", + value=cst.Name("two"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.Arg(star="**", value=cst.Name("three")), + ), + rpar=cst.RightParen(), + ), + "class Foo(**one, **two, **three): pass\n", + ), + # Decorator render tests + ( + cst.ClassDef( + cst.Name("Foo"), + cst.SimpleStatementSuite((cst.Pass(),)), + decorators=(cst.Decorator(cst.Name("foo")),), + lpar=cst.LeftParen(), + rpar=cst.RightParen(), + ), + "@foo\nclass Foo(): pass\n", + ), + ( + cst.ClassDef( + leading_lines=( + cst.EmptyLine(), + cst.EmptyLine(comment=cst.Comment("# leading comment 1")), + ), + decorators=( + cst.Decorator(cst.Name("foo"), leading_lines=()), + cst.Decorator( + cst.Name("bar"), + leading_lines=( + cst.EmptyLine( + comment=cst.Comment("# leading comment 2") + ), + ), + ), + cst.Decorator( + cst.Name("baz"), + leading_lines=( + cst.EmptyLine( + comment=cst.Comment("# leading comment 3") + ), + ), + ), + ), + lines_after_decorators=( + cst.EmptyLine(comment=cst.Comment("# class comment")), + ), + name=cst.Name("Foo"), + body=cst.SimpleStatementSuite((cst.Pass(),)), + lpar=cst.LeftParen(), + rpar=cst.RightParen(), + ), + "\n# leading comment 1\n@foo\n# leading comment 2\n@bar\n# leading comment 3\n@baz\n# class comment\nclass Foo(): pass\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_statement) diff --git a/libcst/nodes/tests/test_comment.py b/libcst/nodes/tests/test_comment.py new file mode 100644 index 00000000..d492551b --- /dev/null +++ b/libcst/nodes/tests/test_comment.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.testing.utils import data_provider + + +class CommentTest(CSTNodeTest): + @data_provider( + ( + (cst.Comment("#"), "#"), + (cst.Comment("#comment text"), "#comment text"), + (cst.Comment("# comment text"), "# comment text"), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + (lambda: cst.Comment(" bad input"), "non-comment"), + (lambda: cst.Comment("# newline shouldn't be here\n"), "non-comment"), + (lambda: cst.Comment(" # Leading space is wrong"), "non-comment"), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_comparison.py b/libcst/nodes/tests/test_comparison.py new file mode 100644 index 00000000..c05c1ff9 --- /dev/null +++ b/libcst/nodes/tests/test_comparison.py @@ -0,0 +1,234 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class ComparisonTest(CSTNodeTest): + @data_provider( + ( + # Simple comparison statements + ( + cst.Comparison( + cst.Name("foo"), + ( + cst.ComparisonTarget( + cst.LessThan(), cst.Number(cst.Integer("5")) + ), + ), + ), + "foo < 5", + ), + ( + cst.Comparison( + cst.Name("foo"), + ( + cst.ComparisonTarget( + cst.NotEqual(), cst.Number(cst.Integer("5")) + ), + ), + ), + "foo != 5", + ), + ( + cst.Comparison( + cst.Name("foo"), (cst.ComparisonTarget(cst.Is(), cst.Name("True")),) + ), + "foo is True", + ), + ( + cst.Comparison( + cst.Name("foo"), + (cst.ComparisonTarget(cst.IsNot(), cst.Name("False")),), + ), + "foo is not False", + ), + ( + cst.Comparison( + cst.Name("foo"), (cst.ComparisonTarget(cst.In(), cst.Name("bar")),) + ), + "foo in bar", + ), + ( + cst.Comparison( + cst.Name("foo"), + (cst.ComparisonTarget(cst.NotIn(), cst.Name("bar")),), + ), + "foo not in bar", + ), + # Comparison with parens + ( + cst.Comparison( + lpar=(cst.LeftParen(),), + left=cst.Name("foo"), + comparisons=( + cst.ComparisonTarget( + operator=cst.NotIn(), comparator=cst.Name("bar") + ), + ), + rpar=(cst.RightParen(),), + ), + "(foo not in bar)", + ), + ( + cst.Comparison( + left=cst.Name( + "foo", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + comparisons=( + cst.ComparisonTarget( + operator=cst.NotIn( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + comparator=cst.Name( + "bar", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + ), + ), + ), + "(foo)not in(bar)", + ), + # Valid expressions that look like they shouldn't parse + ( + cst.Comparison( + left=cst.Number(cst.Integer("5")), + comparisons=( + cst.ComparisonTarget( + operator=cst.NotIn( + whitespace_before=cst.SimpleWhitespace("") + ), + comparator=cst.Name("bar"), + ), + ), + ), + "5not in bar", + ), + # Validate that spacing works properly + ( + cst.Comparison( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + left=cst.Name("foo"), + comparisons=( + cst.ComparisonTarget( + operator=cst.NotIn( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_between=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + comparator=cst.Name("bar"), + ), + ), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( foo not in bar )", + ), + # Do some complex nodes + ( + cst.Comparison( + left=cst.Name("baz"), + comparisons=( + cst.ComparisonTarget( + operator=cst.Equal(), + comparator=cst.Comparison( + lpar=(cst.LeftParen(),), + left=cst.Name("foo"), + comparisons=( + cst.ComparisonTarget( + operator=cst.NotIn(), comparator=cst.Name("bar") + ), + ), + rpar=(cst.RightParen(),), + ), + ), + ), + ), + "baz == (foo not in bar)", + ), + ( + cst.Comparison( + left=cst.Name("a"), + comparisons=( + cst.ComparisonTarget( + operator=cst.GreaterThan(), comparator=cst.Name("b") + ), + cst.ComparisonTarget( + operator=cst.GreaterThan(), comparator=cst.Name("c") + ), + ), + ), + "a > b > c", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_expression) + + @data_provider( + ( + ( + lambda: cst.Comparison( + cst.Name("foo"), + # pyre-fixme[6]: Expected `BaseExpression` for 2nd param but got + # `Integer`. + (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")),), + lpar=(cst.LeftParen(),), + ), + "left paren without right paren", + ), + ( + lambda: cst.Comparison( + cst.Name("foo"), + # pyre-fixme[6]: Expected `BaseExpression` for 2nd param but got + # `Integer`. + (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")),), + rpar=(cst.RightParen(),), + ), + "right paren without left paren", + ), + ( + lambda: cst.Comparison(cst.Name("foo"), ()), + "at least one ComparisonTarget", + ), + ( + lambda: cst.Comparison( + left=cst.Name("foo"), + comparisons=( + cst.ComparisonTarget( + operator=cst.NotIn( + whitespace_before=cst.SimpleWhitespace("") + ), + comparator=cst.Name("bar"), + ), + ), + ), + "at least one space around comparison operator", + ), + ( + lambda: cst.Comparison( + left=cst.Name("foo"), + comparisons=( + cst.ComparisonTarget( + operator=cst.NotIn( + whitespace_after=cst.SimpleWhitespace("") + ), + comparator=cst.Name("bar"), + ), + ), + ), + "at least one space around comparison operator", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_cst_node.py b/libcst/nodes/tests/test_cst_node.py new file mode 100644 index 00000000..0013a2f0 --- /dev/null +++ b/libcst/nodes/tests/test_cst_node.py @@ -0,0 +1,199 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from textwrap import dedent +from typing import TypeVar, Union + +import libcst.nodes as cst +from libcst._base_visitor import CSTVisitor +from libcst._removal_sentinel import RemovalSentinel +from libcst.testing.utils import UnitTest, data_provider, none_throws + + +_CSTNodeT = TypeVar("_CSTNodeT", bound="cst.CSTNode") +_EMPTY_SIMPLE_WHITESPACE = cst.SimpleWhitespace("") + + +class _TestVisitor(CSTVisitor): + def __init__(self, test: UnitTest) -> None: + self.counter = 0 + self.test = test + + def assert_counter(self, expected: int) -> None: + self.test.assertEqual(self.counter, expected) + self.counter += 1 + + def on_visit(self, node: cst.CSTNode) -> bool: + if isinstance(node, cst.Module): + self.assert_counter(0) + elif isinstance(node, cst.SimpleStatementLine): + self.assert_counter(1) + elif isinstance(node, cst.Pass): + self.assert_counter(2) + elif isinstance(node, cst.Newline): + self.assert_counter(4) + return True + + def on_leave( + self, original_node: _CSTNodeT, updated_node: _CSTNodeT + ) -> Union[_CSTNodeT, RemovalSentinel]: + self.test.assertTrue(original_node.deep_equals(updated_node)) + if isinstance(updated_node, cst.Pass): + self.assert_counter(3) + elif isinstance(updated_node, cst.Newline): + self.assert_counter(5) + elif isinstance(updated_node, cst.SimpleStatementLine): + self.assert_counter(6) + elif isinstance(updated_node, cst.Module): + self.assert_counter(7) + # pyre: Expected `Union[RemovalSentinel, Variable[_CSTNodeT (bound to + # pyre-ignore[7]: cst._base.CSTNode)]]` but got `cst._statement.Pass`. + return updated_node + + +class CSTNodeTest(UnitTest): + def test_with_changes(self) -> None: + initial = cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" \\\n "), + comment=cst.Comment("# initial"), + newline=cst.Newline("\r\n"), + ) + changed = initial.with_changes(comment=cst.Comment("# new comment")) + + # see that we have the updated fields + self.assertEqual(none_throws(changed.comment).value, "# new comment") + # and that the old fields are still there + self.assertEqual(changed.whitespace.value, " \\\n ") + self.assertEqual(changed.newline.value, "\r\n") + + # ensure no mutation actually happened + self.assertEqual(none_throws(initial.comment).value, "# initial") + + def test_default_eq(self) -> None: + sw1 = cst.SimpleWhitespace("") + sw2 = cst.SimpleWhitespace("") + self.assertNotEqual(sw1, sw2) + self.assertEqual(sw1, sw1) + self.assertEqual(sw2, sw2) + self.assertTrue(sw1.deep_equals(sw2)) + self.assertTrue(sw2.deep_equals(sw1)) + + def test_hash(self) -> None: + sw1 = cst.SimpleWhitespace("") + sw2 = cst.SimpleWhitespace("") + self.assertNotEqual(hash(sw1), hash(sw2)) + self.assertEqual(hash(sw1), hash(sw1)) + self.assertEqual(hash(sw2), hash(sw2)) + + @data_provider( + { + "simple": (cst.SimpleWhitespace(""), cst.SimpleWhitespace("")), + "identity": (_EMPTY_SIMPLE_WHITESPACE, _EMPTY_SIMPLE_WHITESPACE), + "nested": ( + cst.EmptyLine(whitespace=cst.SimpleWhitespace("")), + cst.EmptyLine(whitespace=cst.SimpleWhitespace("")), + ), + "tuple_versus_list": ( + cst.SimpleStatementLine(body=[cst.Pass()]), + cst.SimpleStatementLine(body=(cst.Pass(),)), + ), + } + ) + def test_deep_equals_success(self, a: cst.CSTNode, b: cst.CSTNode) -> None: + self.assertTrue(a.deep_equals(b)) + + @data_provider( + { + "simple": (cst.SimpleWhitespace(" "), cst.SimpleWhitespace(" ")), + "nested": ( + cst.EmptyLine(whitespace=cst.SimpleWhitespace(" ")), + cst.EmptyLine(whitespace=cst.SimpleWhitespace(" ")), + ), + "list": ( + cst.SimpleStatementLine(body=[cst.Pass(semicolon=cst.Semicolon())]), + cst.SimpleStatementLine(body=[cst.Pass(semicolon=cst.Semicolon())] * 2), + ), + } + ) + def test_deep_equals_fails(self, a: cst.CSTNode, b: cst.CSTNode) -> None: + self.assertFalse(a.deep_equals(b)) + + def test_repr(self) -> None: + self.assertEqual( + repr( + cst.SimpleStatementLine( + body=[cst.Pass()], + # tuple with multiple items + leading_lines=( + cst.EmptyLine( + indent=True, + whitespace=cst.SimpleWhitespace(""), + comment=None, + newline=cst.Newline(), + ), + cst.EmptyLine( + indent=True, + whitespace=cst.SimpleWhitespace(""), + comment=None, + newline=cst.Newline(), + ), + ), + trailing_whitespace=cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# comment"), + newline=cst.Newline(), + ), + ) + ), + dedent( + """ + SimpleStatementLine( + body=[ + Pass( + semicolon=, + ), + ], + leading_lines=[ + EmptyLine( + indent=True, + whitespace=SimpleWhitespace( + value='', + ), + comment=None, + newline=Newline( + value=None, + ), + ), + EmptyLine( + indent=True, + whitespace=SimpleWhitespace( + value='', + ), + comment=None, + newline=Newline( + value=None, + ), + ), + ], + trailing_whitespace=TrailingWhitespace( + whitespace=SimpleWhitespace( + value=' ', + ), + comment=Comment( + value='# comment', + ), + newline=Newline( + value=None, + ), + ), + ) + """ + ).strip(), + ) + + def test_visit(self) -> None: + tree = cst.Module((cst.SimpleStatementLine((cst.Pass(),)),)) + tree.visit(_TestVisitor(self)) diff --git a/libcst/nodes/tests/test_del.py b/libcst/nodes/tests/test_del.py new file mode 100644 index 00000000..6e357862 --- /dev/null +++ b/libcst/nodes/tests/test_del.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class DelTest(CSTNodeTest): + @data_provider( + ( + (cst.SimpleStatementLine([cst.Del(cst.Name("abc"))]), "del abc\n"), + ( + cst.SimpleStatementLine( + [ + cst.Del( + cst.Name("abc"), + whitespace_after_del=cst.SimpleWhitespace(" "), + ) + ] + ), + "del abc\n", + ), + ( + cst.SimpleStatementLine( + [ + cst.Del( + cst.Name( + "abc", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] + ), + whitespace_after_del=cst.SimpleWhitespace(""), + ) + ] + ), + "del(abc)\n", + ), + ( + cst.SimpleStatementLine( + [cst.Del(cst.Name("abc"), semicolon=cst.Semicolon())] + ), + "del abc;\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_statement) + + @data_provider( + ( + ( + lambda: cst.Del( + cst.Name("abc"), whitespace_after_del=cst.SimpleWhitespace("") + ), + "Must have at least one space after 'del'.", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_else.py b/libcst/nodes/tests/test_else.py new file mode 100644 index 00000000..c93935bf --- /dev/null +++ b/libcst/nodes/tests/test_else.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.testing.utils import data_provider + + +class ElseTest(CSTNodeTest): + @data_provider( + ( + (cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), "else: pass\n"), + ( + cst.Else( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + "else : pass\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) diff --git a/libcst/nodes/tests/test_empty_line.py b/libcst/nodes/tests/test_empty_line.py new file mode 100644 index 00000000..483d9f1c --- /dev/null +++ b/libcst/nodes/tests/test_empty_line.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest, DummyIndentedBlock +from libcst.testing.utils import data_provider + + +class EmptyLineTest(CSTNodeTest): + @data_provider( + ( + (cst.EmptyLine(), "\n"), + (cst.EmptyLine(whitespace=cst.SimpleWhitespace(" ")), " \n"), + (cst.EmptyLine(comment=cst.Comment("# comment")), "# comment\n"), + (cst.EmptyLine(newline=cst.Newline("\r\n")), "\r\n"), + (DummyIndentedBlock(" ", cst.EmptyLine()), " \n"), + (DummyIndentedBlock(" ", cst.EmptyLine(indent=False)), "\n"), + ( + DummyIndentedBlock( + "\t", + cst.EmptyLine( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# comment"), + newline=cst.Newline("\r\n"), + ), + ), + "\t # comment\r\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) diff --git a/libcst/nodes/tests/test_for.py b/libcst/nodes/tests/test_for.py new file mode 100644 index 00000000..6bfc6085 --- /dev/null +++ b/libcst/nodes/tests/test_for.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest, DummyIndentedBlock +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class ForTest(CSTNodeTest): + @data_provider( + ( + # Simple for block + ( + cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "for target in iter(): pass\n", + parse_statement, + ), + # Simple async for block + ( + cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + asynchronous=cst.Asynchronous(), + ), + "async for target in iter(): pass\n", + parse_statement, + ), + # For block with else + ( + cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + ), + "for target in iter(): pass\nelse: pass\n", + parse_statement, + ), + # indentation + ( + DummyIndentedBlock( + " ", + cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + ), + " for target in iter(): pass\n", + None, + ), + # for an indented body + ( + DummyIndentedBlock( + " ", + cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + ), + ), + " for target in iter():\n pass\n", + None, + ), + # leading_lines + ( + cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + cst.Else( + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# else comment")), + ), + ), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# leading comment")), + ), + ), + "# leading comment\nfor target in iter():\n pass\n# else comment\nelse:\n pass\n", + None, + ), + # Weird spacing rules + ( + cst.For( + cst.Name( + "target", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + cst.Call( + cst.Name("iter"), + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_for=cst.SimpleWhitespace(""), + whitespace_before_in=cst.SimpleWhitespace(""), + whitespace_after_in=cst.SimpleWhitespace(""), + ), + "for(target)in(iter()): pass\n", + parse_statement, + ), + # Whitespace + ( + cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_for=cst.SimpleWhitespace(" "), + whitespace_before_in=cst.SimpleWhitespace(" "), + whitespace_after_in=cst.SimpleWhitespace(" "), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + "for target in iter() : pass\n", + parse_statement, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + ( + lambda: cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_for=cst.SimpleWhitespace(""), + ), + "Must have at least one space after 'for' keyword", + ), + ( + lambda: cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_before_in=cst.SimpleWhitespace(""), + ), + "Must have at least one space before 'in' keyword", + ), + ( + lambda: cst.For( + cst.Name("target"), + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_in=cst.SimpleWhitespace(""), + ), + "Must have at least one space after 'in' keyword", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_funcdef.py b/libcst/nodes/tests/test_funcdef.py new file mode 100644 index 00000000..9e64910d --- /dev/null +++ b/libcst/nodes/tests/test_funcdef.py @@ -0,0 +1,1619 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest, DummyIndentedBlock +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class FunctionDefCreationTest(CSTNodeTest): + @data_provider( + ( + # Simple function definition without any arguments or return + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(): pass\n", + ), + # Functiondef with a return annotation + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + returns=cst.Annotation(cst.Name("str")), + ), + "def foo() -> str: pass\n", + ), + # Async function definition. + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + asynchronous=cst.Asynchronous(), + ), + "async def foo(): pass\n", + ), + # Async function definition with annotation. + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + asynchronous=cst.Asynchronous(), + returns=cst.Annotation(cst.Name("int")), + ), + "async def foo() -> int: pass\n", + ), + # Test basic positional params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=(cst.Param(cst.Name("bar")), cst.Param(cst.Name("baz"))) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(bar, baz): pass\n", + ), + # Typed positional params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param(cst.Name("bar"), cst.Annotation(cst.Name("str"))), + cst.Param(cst.Name("baz"), cst.Annotation(cst.Name("int"))), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(bar: str, baz: int): pass\n", + ), + # Test basic positional default params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + default_params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + cst.Param( + cst.Name("baz"), default=cst.Number(cst.Integer("5")) + ), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(bar = "one", baz = 5): pass\n', + ), + # Typed positional default params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + default_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"one"'), + ), + cst.Param( + cst.Name("baz"), + cst.Annotation(cst.Name("int")), + default=cst.Number(cst.Integer("5")), + ), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(bar: str = "one", baz: int = 5): pass\n', + ), + # Mixed positional and default params. + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param(cst.Name("bar"), cst.Annotation(cst.Name("str"))), + ), + default_params=( + cst.Param( + cst.Name("baz"), + cst.Annotation(cst.Name("int")), + default=cst.Number(cst.Integer("5")), + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(bar: str, baz: int = 5): pass\n", + ), + # Test kwonly params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + kwonly_params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + cst.Param(cst.Name("baz")), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(*, bar = "one", baz): pass\n', + ), + # Typed kwonly params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"one"'), + ), + cst.Param(cst.Name("baz"), cst.Annotation(cst.Name("int"))), + cst.Param( + cst.Name("biz"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"two"'), + ), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(*, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Mixed params and kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param(cst.Name("first")), + cst.Param(cst.Name("second")), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"one"'), + ), + cst.Param(cst.Name("baz"), cst.Annotation(cst.Name("int"))), + cst.Param( + cst.Name("biz"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"two"'), + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(first, second, *, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Mixed default_params and kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + default_params=( + cst.Param( + cst.Name("first"), default=cst.Number(cst.Float("1.0")) + ), + cst.Param( + cst.Name("second"), default=cst.Number(cst.Float("1.5")) + ), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"one"'), + ), + cst.Param(cst.Name("baz"), cst.Annotation(cst.Name("int"))), + cst.Param( + cst.Name("biz"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"two"'), + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(first = 1.0, second = 1.5, *, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Mixed params, default_params, and kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param(cst.Name("first")), + cst.Param(cst.Name("second")), + ), + default_params=( + cst.Param( + cst.Name("third"), default=cst.Number(cst.Float("1.0")) + ), + cst.Param( + cst.Name("fourth"), default=cst.Number(cst.Float("1.5")) + ), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"one"'), + ), + cst.Param(cst.Name("baz"), cst.Annotation(cst.Name("int"))), + cst.Param( + cst.Name("biz"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"two"'), + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(first, second, third = 1.0, fourth = 1.5, *, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Test star_arg + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(star_arg=cst.Param(cst.Name("params"))), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(*params): pass\n", + ), + # Typed star_arg, include kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_arg=cst.Param( + cst.Name("params"), cst.Annotation(cst.Name("str")) + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"one"'), + ), + cst.Param(cst.Name("baz"), cst.Annotation(cst.Name("int"))), + cst.Param( + cst.Name("biz"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"two"'), + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(*params: str, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Mixed params default_params, star_arg and kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param(cst.Name("first")), + cst.Param(cst.Name("second")), + ), + default_params=( + cst.Param( + cst.Name("third"), default=cst.Number(cst.Float("1.0")) + ), + cst.Param( + cst.Name("fourth"), default=cst.Number(cst.Float("1.5")) + ), + ), + star_arg=cst.Param( + cst.Name("params"), cst.Annotation(cst.Name("str")) + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"one"'), + ), + cst.Param(cst.Name("baz"), cst.Annotation(cst.Name("int"))), + cst.Param( + cst.Name("biz"), + cst.Annotation(cst.Name("str")), + default=cst.SimpleString('"two"'), + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(first, second, third = 1.0, fourth = 1.5, *params: str, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Test star_arg and star_kwarg + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(star_kwarg=cst.Param(cst.Name("kwparams"))), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(**kwparams): pass\n", + ), + # Test star_arg and kwarg + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_arg=cst.Param(cst.Name("params")), + star_kwarg=cst.Param(cst.Name("kwparams")), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(*params, **kwparams): pass\n", + ), + # Test typed star_arg and star_kwarg + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_arg=cst.Param( + cst.Name("params"), cst.Annotation(cst.Name("str")) + ), + star_kwarg=cst.Param( + cst.Name("kwparams"), cst.Annotation(cst.Name("int")) + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(*params: str, **kwparams: int): pass\n", + ), + # Test decorators + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + (cst.Decorator(cst.Name("bar")),), + ), + "@bar\ndef foo(): pass\n", + ), + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + ( + cst.Decorator( + cst.Call( + cst.Name("bar"), + ( + cst.Arg(cst.Name("baz")), + cst.Arg(cst.SimpleString("'123'")), + ), + ) + ), + ), + ), + "@bar(baz, '123')\ndef foo(): pass\n", + ), + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + ( + cst.Decorator( + cst.Call( + cst.Name("bar"), (cst.Arg(cst.SimpleString("'123'")),) + ) + ), + cst.Decorator( + cst.Call( + cst.Name("baz"), (cst.Arg(cst.SimpleString("'456'")),) + ) + ), + ), + ), + "@bar('123')\n@baz('456')\ndef foo(): pass\n", + ), + # Test indentation + ( + DummyIndentedBlock( + " ", + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + (cst.Decorator(cst.Name("bar")),), + ), + ), + " @bar\n def foo(): pass\n", + ), + # With an indented body + ( + DummyIndentedBlock( + " ", + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + (cst.Decorator(cst.Name("bar")),), + ), + ), + " @bar\n def foo():\n pass\n", + ), + # Leading lines + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# leading comment")), + ), + ), + "# leading comment\ndef foo(): pass\n", + ), + # Inner whitespace + ( + cst.FunctionDef( + leading_lines=( + cst.EmptyLine(), + cst.EmptyLine( + comment=cst.Comment("# What an amazing decorator") + ), + ), + decorators=( + cst.Decorator( + whitespace_after_at=cst.SimpleWhitespace(" "), + decorator=cst.Call( + func=cst.Name("bar"), + whitespace_after_func=cst.SimpleWhitespace(" "), + whitespace_before_args=cst.SimpleWhitespace(" "), + ), + ), + ), + lines_after_decorators=( + cst.EmptyLine(comment=cst.Comment("# What a great function")), + ), + asynchronous=cst.Asynchronous( + whitespace_after=cst.SimpleWhitespace(" ") + ), + whitespace_after_def=cst.SimpleWhitespace(" "), + name=cst.Name("foo"), + whitespace_after_name=cst.SimpleWhitespace(" "), + whitespace_before_params=cst.SimpleWhitespace(" "), + params=cst.Parameters(), + returns=cst.Annotation( + whitespace_before_indicator=cst.SimpleWhitespace(" "), + whitespace_after_indicator=cst.SimpleWhitespace(" "), + annotation=cst.Name("str"), + ), + whitespace_before_colon=cst.SimpleWhitespace(" "), + body=cst.SimpleStatementSuite((cst.Pass(),)), + ), + "\n# What an amazing decorator\n@ bar ( )\n# What a great function\nasync def foo ( ) -> str : pass\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + ( + lambda: cst.FunctionDef( + cst.Name("foo", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "Cannot have parens around Name", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + asynchronous=cst.Asynchronous( + whitespace_after=cst.SimpleWhitespace("") + ), + ), + "one space after Asynchronous", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_def=cst.SimpleWhitespace(""), + ), + "one space between 'def' and name", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_kwarg=cst.Param(cst.Name("bar"), equal=cst.AssignEqual()) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "Must have a default when specifying an AssignEqual.", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="***")), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + r"Must specify either '', '\*' or '\*\*' for star.", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "Cannot have defaults for params", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(default_params=(cst.Param(cst.Name("bar")),)), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "Must have defaults for default_params", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(star_arg=cst.ParamStar()), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "Must have at least one kwonly param if ParamStar is used.", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(params=(cst.Param(cst.Name("bar"), star="*"),)), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "Expecting a star prefix of ''", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + default_params=( + cst.Param( + cst.Name("bar"), + default=cst.SimpleString('"one"'), + star="*", + ), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "Expecting a star prefix of ''", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + kwonly_params=(cst.Param(cst.Name("bar"), star="*"),) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "Expecting a star prefix of ''", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(star_arg=cst.Param(cst.Name("bar"), star="**")), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + r"Expecting a star prefix of '\*'", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="*")), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + r"Expecting a star prefix of '\*\*'", + ), + # Validate decorator name semantics + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + ( + cst.Decorator( + cst.Name( + "bar", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ) + ), + ), + ), + "Cannot have parens around decorator in a Decorator", + ), + # Validate annotations + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + returns=cst.Annotation(cst.Name("str"), indicator=":"), + ), + "return Annotation must be denoted with a '->'", + ), + ( + lambda: cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param( + cst.Name("baz"), + cst.Annotation(cst.Name("int"), indicator="->"), + ), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "param Annotation must be denoted with a ':'", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class FunctionDefParserTest(CSTNodeTest): + @data_provider( + ( + # Simple function definition without any arguments or return + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(): pass\n", + ), + # Functiondef with a return annotation + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + returns=cst.Annotation( + cst.Name("str"), + indicator="->", + whitespace_before_indicator=cst.SimpleWhitespace(" "), + ), + ), + "def foo() -> str: pass\n", + ), + # Async function definition. + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + asynchronous=cst.Asynchronous(), + ), + "async def foo(): pass\n", + ), + # Async function definition with annotation. + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + asynchronous=cst.Asynchronous(), + returns=cst.Annotation( + cst.Name("int"), + indicator="->", + whitespace_before_indicator=cst.SimpleWhitespace(" "), + ), + ), + "async def foo() -> int: pass\n", + ), + # Test basic positional params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param( + cst.Name("bar"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param(cst.Name("baz"), star=""), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(bar, baz): pass\n", + ), + # Typed positional params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param( + cst.Name("bar"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + star="", + ), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(bar: str, baz: int): pass\n", + ), + # Test basic positional default params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + default_params=( + cst.Param( + cst.Name("bar"), + equal=cst.AssignEqual(), + default=cst.SimpleString('"one"'), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + equal=cst.AssignEqual(), + default=cst.Number(cst.Integer("5")), + star="", + ), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(bar = "one", baz = 5): pass\n', + ), + # Typed positional default params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + default_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"one"'), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.Number(cst.Integer("5")), + star="", + ), + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(bar: str = "one", baz: int = 5): pass\n', + ), + # Mixed positional and default params. + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param( + cst.Name("bar"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + default_params=( + cst.Param( + cst.Name("baz"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.Number(cst.Integer("5")), + star="", + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(bar: str, baz: int = 5): pass\n", + ), + # Test kwonly params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_arg=cst.ParamStar(), + kwonly_params=( + cst.Param( + cst.Name("bar"), + equal=cst.AssignEqual(), + default=cst.SimpleString('"one"'), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param(cst.Name("baz"), default=None, star=""), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(*, bar = "one", baz): pass\n', + ), + # Typed kwonly params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_arg=cst.ParamStar(), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"one"'), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"two"'), + star="", + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(*, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Mixed params and kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param( + cst.Name("first"), + annotation=None, + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("second"), + annotation=None, + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + star_arg=cst.ParamStar(), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"one"'), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"two"'), + star="", + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(first, second, *, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Mixed default_params and kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + default_params=( + cst.Param( + cst.Name("first"), + annotation=None, + equal=cst.AssignEqual(), + default=cst.Number(cst.Float("1.0")), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("second"), + annotation=None, + equal=cst.AssignEqual(), + default=cst.Number(cst.Float("1.5")), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + star_arg=cst.ParamStar(), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"one"'), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"two"'), + star="", + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(first = 1.0, second = 1.5, *, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Mixed params, default_params, and kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param( + cst.Name("first"), + annotation=None, + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("second"), + annotation=None, + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + default_params=( + cst.Param( + cst.Name("third"), + annotation=None, + equal=cst.AssignEqual(), + default=cst.Number(cst.Float("1.0")), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("fourth"), + annotation=None, + equal=cst.AssignEqual(), + default=cst.Number(cst.Float("1.5")), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + star_arg=cst.ParamStar(), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"one"'), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"two"'), + star="", + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(first, second, third = 1.0, fourth = 1.5, *, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Test star_arg + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_arg=cst.Param( + cst.Name("params"), annotation=None, default=None, star="*" + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(*params): pass\n", + ), + # Typed star_arg, include kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_arg=cst.Param( + cst.Name("params"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace(""), + ), + default=None, + star="*", + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"one"'), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"two"'), + star="", + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(*params: str, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Mixed params default_params, star_arg and kwonly_params + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + params=( + cst.Param( + cst.Name("first"), + annotation=None, + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("second"), + annotation=None, + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + default_params=( + cst.Param( + cst.Name("third"), + annotation=None, + equal=cst.AssignEqual(), + default=cst.Number(cst.Float("1.0")), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("fourth"), + annotation=None, + equal=cst.AssignEqual(), + default=cst.Number(cst.Float("1.5")), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + star_arg=cst.Param( + cst.Name("params"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace(""), + ), + default=None, + star="*", + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"one"'), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + default=None, + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace( + "" + ), + ), + equal=cst.AssignEqual(), + default=cst.SimpleString('"two"'), + star="", + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + 'def foo(first, second, third = 1.0, fourth = 1.5, *params: str, bar: str = "one", baz: int, biz: str = "two"): pass\n', + ), + # Test star_arg and star_kwarg + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_kwarg=cst.Param( + cst.Name("kwparams"), + annotation=None, + default=None, + star="**", + ) + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(**kwparams): pass\n", + ), + # Test star_arg and kwarg + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_arg=cst.Param( + cst.Name("params"), + annotation=None, + default=None, + star="*", + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + star_kwarg=cst.Param( + cst.Name("kwparams"), + annotation=None, + default=None, + star="**", + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(*params, **kwparams): pass\n", + ), + # Test typed star_arg and star_kwarg + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters( + star_arg=cst.Param( + cst.Name("params"), + cst.Annotation( + cst.Name("str"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace(""), + ), + default=None, + star="*", + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + star_kwarg=cst.Param( + cst.Name("kwparams"), + cst.Annotation( + cst.Name("int"), + indicator=":", + whitespace_before_indicator=cst.SimpleWhitespace(""), + ), + default=None, + star="**", + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "def foo(*params: str, **kwparams: int): pass\n", + ), + # Test decorators + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + (cst.Decorator(cst.Name("bar")),), + ), + "@bar\ndef foo(): pass\n", + ), + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + ( + cst.Decorator( + cst.Call( + cst.Name("bar"), + ( + cst.Arg( + cst.Name("baz"), + keyword=None, + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Arg(cst.SimpleString("'123'"), keyword=None), + ), + ) + ), + ), + ), + "@bar(baz, '123')\ndef foo(): pass\n", + ), + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + ( + cst.Decorator( + cst.Call( + cst.Name("bar"), + (cst.Arg(cst.SimpleString("'123'"), keyword=None),), + ) + ), + cst.Decorator( + cst.Call( + cst.Name("baz"), + (cst.Arg(cst.SimpleString("'456'"), keyword=None),), + ) + ), + ), + ), + "@bar('123')\n@baz('456')\ndef foo(): pass\n", + ), + # Leading lines + ( + cst.FunctionDef( + cst.Name("foo"), + cst.Parameters(), + cst.SimpleStatementSuite((cst.Pass(),)), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# leading comment")), + ), + ), + "# leading comment\ndef foo(): pass\n", + ), + # Inner whitespace + ( + cst.FunctionDef( + leading_lines=( + cst.EmptyLine(), + cst.EmptyLine( + comment=cst.Comment("# What an amazing decorator") + ), + ), + decorators=( + cst.Decorator( + whitespace_after_at=cst.SimpleWhitespace(" "), + decorator=cst.Call( + func=cst.Name("bar"), + whitespace_after_func=cst.SimpleWhitespace(" "), + whitespace_before_args=cst.SimpleWhitespace(" "), + ), + ), + ), + lines_after_decorators=( + cst.EmptyLine(comment=cst.Comment("# What a great function")), + ), + asynchronous=cst.Asynchronous( + whitespace_after=cst.SimpleWhitespace(" ") + ), + whitespace_after_def=cst.SimpleWhitespace(" "), + name=cst.Name("foo"), + whitespace_after_name=cst.SimpleWhitespace(" "), + whitespace_before_params=cst.SimpleWhitespace(" "), + params=cst.Parameters(), + returns=cst.Annotation( + whitespace_before_indicator=cst.SimpleWhitespace(" "), + whitespace_after_indicator=cst.SimpleWhitespace(" "), + annotation=cst.Name("str"), + indicator="->", + ), + whitespace_before_colon=cst.SimpleWhitespace(" "), + body=cst.SimpleStatementSuite((cst.Pass(),)), + ), + "\n# What an amazing decorator\n@ bar ( )\n# What a great function\nasync def foo ( ) -> str : pass\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_statement) diff --git a/libcst/nodes/tests/test_global.py b/libcst/nodes/tests/test_global.py new file mode 100644 index 00000000..784e6a66 --- /dev/null +++ b/libcst/nodes/tests/test_global.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class GlobalConstructionTest(CSTNodeTest): + @data_provider( + ( + # Single global statement + (cst.Global((cst.NameItem(cst.Name("a")),)), "global a"), + # Multiple entries in global statement + ( + cst.Global((cst.NameItem(cst.Name("a")), cst.NameItem(cst.Name("b")))), + "global a, b", + ), + # Whitespace rendering test + ( + cst.Global( + ( + cst.NameItem( + cst.Name("a"), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.NameItem(cst.Name("b")), + ), + whitespace_after_global=cst.SimpleWhitespace(" "), + ), + "global a , b", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + # Validate construction + ( + lambda: cst.Global(()), + "A Global statement must have at least one NameItem", + ), + # Validate whitespace handling + ( + lambda: cst.Global( + (cst.NameItem(cst.Name("a")),), + whitespace_after_global=cst.SimpleWhitespace(""), + ), + "Must have at least one space after 'global' keyword", + ), + # Validate comma handling + ( + lambda: cst.Global((cst.NameItem(cst.Name("a"), comma=cst.Comma()),)), + "The last NameItem in a Global cannot have a trailing comma", + ), + # Validate paren handling + ( + lambda: cst.Global( + ( + cst.NameItem( + cst.Name( + "a", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ) + ), + ) + ), + "Cannot have parens around names in NameItem", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class GlobalParsingTest(CSTNodeTest): + @data_provider( + ( + # Single global statement + (cst.Global((cst.NameItem(cst.Name("a")),)), "global a"), + # Multiple entries in global statement + ( + cst.Global( + ( + cst.NameItem( + cst.Name("a"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.NameItem(cst.Name("b")), + ) + ), + "global a, b", + ), + # Whitespace rendering test + ( + cst.Global( + ( + cst.NameItem( + cst.Name("a"), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.NameItem(cst.Name("b")), + ), + whitespace_after_global=cst.SimpleWhitespace(" "), + ), + "global a , b", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. + self.validate_node(node, code, lambda code: parse_statement(code).body[0]) diff --git a/libcst/nodes/tests/test_if.py b/libcst/nodes/tests/test_if.py new file mode 100644 index 00000000..fb9ff39d --- /dev/null +++ b/libcst/nodes/tests/test_if.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest, DummyIndentedBlock +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class IfTest(CSTNodeTest): + @data_provider( + ( + # Simple if without elif or else + ( + cst.If( + cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(),)) + ), + "if conditional: pass\n", + parse_statement, + ), + # else clause + ( + cst.If( + cst.Name("conditional"), + cst.SimpleStatementSuite((cst.Pass(),)), + orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + ), + "if conditional: pass\nelse: pass\n", + parse_statement, + ), + # elif clause + ( + cst.If( + cst.Name("conditional"), + cst.SimpleStatementSuite((cst.Pass(),)), + orelse=cst.If( + cst.Name("other_conditional"), + cst.SimpleStatementSuite((cst.Pass(),)), + orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + ), + ), + "if conditional: pass\nelif other_conditional: pass\nelse: pass\n", + parse_statement, + ), + # indentation + ( + DummyIndentedBlock( + " ", + cst.If( + cst.Name("conditional"), + cst.SimpleStatementSuite((cst.Pass(),)), + orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + ), + ), + " if conditional: pass\n else: pass\n", + None, + ), + # with an indented body + ( + DummyIndentedBlock( + " ", + cst.If( + cst.Name("conditional"), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + ), + ), + " if conditional:\n pass\n", + None, + ), + # leading_lines + ( + cst.If( + cst.Name("conditional"), + cst.SimpleStatementSuite((cst.Pass(),)), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# leading comment")), + ), + ), + "# leading comment\nif conditional: pass\n", + parse_statement, + ), + # whitespace before/after test and else + ( + cst.If( + cst.Name("conditional"), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_before_test=cst.SimpleWhitespace(" "), + whitespace_after_test=cst.SimpleWhitespace(" "), + orelse=cst.Else( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + ), + "if conditional : pass\nelse : pass\n", + parse_statement, + ), + # empty lines between if/elif/else clauses, not captured by the suite. + ( + cst.If( + cst.Name("test_a"), + cst.SimpleStatementSuite((cst.Pass(),)), + orelse=cst.If( + cst.Name("test_b"), + cst.SimpleStatementSuite((cst.Pass(),)), + leading_lines=(cst.EmptyLine(),), + orelse=cst.Else( + cst.SimpleStatementSuite((cst.Pass(),)), + leading_lines=(cst.EmptyLine(),), + ), + ), + ), + "if test_a: pass\n\nelif test_b: pass\n\nelse: pass\n", + parse_statement, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) diff --git a/libcst/nodes/tests/test_ifexp.py b/libcst/nodes/tests/test_ifexp.py new file mode 100644 index 00000000..6eb3d773 --- /dev/null +++ b/libcst/nodes/tests/test_ifexp.py @@ -0,0 +1,99 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class IfExpTest(CSTNodeTest): + @data_provider( + ( + # Simple if experessions + ( + cst.IfExp( + body=cst.Name("foo"), test=cst.Name("bar"), orelse=cst.Name("baz") + ), + "foo if bar else baz", + ), + # Parenthesized if expressions + ( + cst.IfExp( + lpar=(cst.LeftParen(),), + body=cst.Name("foo"), + test=cst.Name("bar"), + orelse=cst.Name("baz"), + rpar=(cst.RightParen(),), + ), + "(foo if bar else baz)", + ), + ( + cst.IfExp( + body=cst.Name( + "foo", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + whitespace_before_if=cst.SimpleWhitespace(""), + whitespace_after_if=cst.SimpleWhitespace(""), + test=cst.Name( + "bar", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + whitespace_before_else=cst.SimpleWhitespace(""), + whitespace_after_else=cst.SimpleWhitespace(""), + orelse=cst.Name( + "baz", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + ), + "(foo)if(bar)else(baz)", + ), + # Make sure that spacing works + ( + cst.IfExp( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + body=cst.Name("foo"), + whitespace_before_if=cst.SimpleWhitespace(" "), + whitespace_after_if=cst.SimpleWhitespace(" "), + test=cst.Name("bar"), + whitespace_before_else=cst.SimpleWhitespace(" "), + whitespace_after_else=cst.SimpleWhitespace(" "), + orelse=cst.Name("baz"), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( foo if bar else baz )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_expression) + + @data_provider( + ( + ( + lambda: cst.IfExp( + cst.Name("bar"), + cst.Name("foo"), + cst.Name("baz"), + lpar=(cst.LeftParen(),), + ), + "left paren without right paren", + ), + ( + lambda: cst.IfExp( + cst.Name("bar"), + cst.Name("foo"), + cst.Name("baz"), + rpar=(cst.RightParen(),), + ), + "right paren without left paren", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_import.py b/libcst/nodes/tests/test_import.py new file mode 100644 index 00000000..ec38de7e --- /dev/null +++ b/libcst/nodes/tests/test_import.py @@ -0,0 +1,688 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class ImportCreateTest(CSTNodeTest): + @data_provider( + ( + # Simple import statement + (cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "import foo"), + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")) + ), + ) + ), + "import foo.bar", + ), + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")) + ), + ) + ), + "import foo.bar", + ), + # Comma-separated list of imports + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")) + ), + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("baz")) + ), + ) + ), + "import foo.bar, foo.baz", + ), + # Import with an alias + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")), + asname=cst.AsName(cst.Name("baz")), + ), + ) + ), + "import foo.bar as baz", + ), + # Import with an alias, comma separated + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")), + asname=cst.AsName(cst.Name("baz")), + ), + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("baz")), + asname=cst.AsName(cst.Name("bar")), + ), + ) + ), + "import foo.bar as baz, foo.baz as bar", + ), + # Combine for fun and profit + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")), + asname=cst.AsName(cst.Name("baz")), + ), + cst.ImportAlias( + cst.Attribute(cst.Name("insta"), cst.Name("gram")) + ), + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("baz")) + ), + cst.ImportAlias( + cst.Name("unittest"), asname=cst.AsName(cst.Name("ut")) + ), + ) + ), + "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", + ), + # Verify whitespace works everywhere. + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute( + cst.Name("foo"), + cst.Name("bar"), + dot=cst.Dot( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + asname=cst.AsName( + cst.Name("baz"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.ImportAlias( + cst.Name("unittest"), + asname=cst.AsName( + cst.Name("ut"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + ), + ), + whitespace_after_import=cst.SimpleWhitespace(" "), + ), + "import foo . bar as baz , unittest as ut", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + (lambda: cst.Import(names=()), "at least one ImportAlias"), + ( + lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)), + "empty name identifier", + ), + ( + lambda: cst.Import( + names=( + cst.ImportAlias(cst.Attribute(cst.Name(""), cst.Name("bla"))), + ) + ), + "empty name identifier", + ), + ( + lambda: cst.Import( + names=( + cst.ImportAlias(cst.Attribute(cst.Name("bla"), cst.Name(""))), + ) + ), + "empty name identifier", + ), + ( + lambda: cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")), + comma=cst.Comma(), + ), + ) + ), + "trailing comma", + ), + ( + lambda: cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")) + ), + ), + whitespace_after_import=cst.SimpleWhitespace(""), + ), + "at least one space", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class ImportParseTest(CSTNodeTest): + @data_provider( + ( + # Simple import statement + (cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "import foo"), + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")) + ), + ) + ), + "import foo.bar", + ), + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")) + ), + ) + ), + "import foo.bar", + ), + # Comma-separated list of imports + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("baz")) + ), + ) + ), + "import foo.bar, foo.baz", + ), + # Import with an alias + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")), + asname=cst.AsName(cst.Name("baz")), + ), + ) + ), + "import foo.bar as baz", + ), + # Import with an alias, comma separated + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")), + asname=cst.AsName(cst.Name("baz")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("baz")), + asname=cst.AsName(cst.Name("bar")), + ), + ) + ), + "import foo.bar as baz, foo.baz as bar", + ), + # Combine for fun and profit + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("bar")), + asname=cst.AsName(cst.Name("baz")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.ImportAlias( + cst.Attribute(cst.Name("insta"), cst.Name("gram")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.ImportAlias( + cst.Attribute(cst.Name("foo"), cst.Name("baz")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.ImportAlias( + cst.Name("unittest"), asname=cst.AsName(cst.Name("ut")) + ), + ) + ), + "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", + ), + # Verify whitespace works everywhere. + ( + cst.Import( + names=( + cst.ImportAlias( + cst.Attribute( + cst.Name("foo"), + cst.Name("bar"), + dot=cst.Dot( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + asname=cst.AsName( + cst.Name("baz"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.ImportAlias( + cst.Name("unittest"), + asname=cst.AsName( + cst.Name("ut"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + ), + ), + whitespace_after_import=cst.SimpleWhitespace(" "), + ), + "import foo . bar as baz , unittest as ut", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. + self.validate_node(node, code, lambda code: parse_statement(code).body[0]) + + +class ImportFromCreateTest(CSTNodeTest): + @data_provider( + ( + # Simple from import statement + ( + cst.ImportFrom( + module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),) + ), + "from foo import bar", + ), + # From import statement with alias + ( + cst.ImportFrom( + module=cst.Name("foo"), + names=( + cst.ImportAlias( + cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) + ), + ), + ), + "from foo import bar as baz", + ), + # Multiple imports + ( + cst.ImportFrom( + module=cst.Name("foo"), + names=( + cst.ImportAlias(cst.Name("bar")), + cst.ImportAlias(cst.Name("baz")), + ), + ), + "from foo import bar, baz", + ), + # Trailing comma + ( + cst.ImportFrom( + module=cst.Name("foo"), + names=( + cst.ImportAlias(cst.Name("bar"), comma=cst.Comma()), + cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()), + ), + ), + "from foo import bar,baz,", + ), + # Star import statement + ( + cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), + "from foo import *", + ), + # Simple relative import statement + ( + cst.ImportFrom( + relative=(cst.Dot(),), + module=cst.Name("foo"), + names=(cst.ImportAlias(cst.Name("bar")),), + ), + "from .foo import bar", + ), + ( + cst.ImportFrom( + relative=(cst.Dot(), cst.Dot()), + module=cst.Name("foo"), + names=(cst.ImportAlias(cst.Name("bar")),), + ), + "from ..foo import bar", + ), + # Relative only import + ( + cst.ImportFrom( + relative=(cst.Dot(), cst.Dot()), + module=None, + names=(cst.ImportAlias(cst.Name("bar")),), + ), + "from .. import bar", + ), + # Parenthesis + ( + cst.ImportFrom( + module=cst.Name("foo"), + lpar=cst.LeftParen(), + names=( + cst.ImportAlias( + cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) + ), + ), + rpar=cst.RightParen(), + ), + "from foo import (bar as baz)", + ), + # Verify whitespace works everywhere. + ( + cst.ImportFrom( + relative=( + cst.Dot( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + cst.Dot( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + module=cst.Name("foo"), + lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), + names=( + cst.ImportAlias( + cst.Name("bar"), + asname=cst.AsName( + cst.Name("baz"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.ImportAlias( + cst.Name("unittest"), + asname=cst.AsName( + cst.Name("ut"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + ), + ), + rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), + whitespace_after_from=cst.SimpleWhitespace(" "), + whitespace_before_import=cst.SimpleWhitespace(" "), + whitespace_after_import=cst.SimpleWhitespace(" "), + ), + "from . . foo import ( bar as baz , unittest as ut )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + ( + # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]` + # for 2nd param but got `Tuple[Name]`. + lambda: cst.ImportFrom(module=None, names=(cst.Name("bar"),)), + "Must have a module specified", + ), + ( + lambda: cst.ImportFrom(module=cst.Name("foo"), names=()), + "at least one ImportAlias", + ), + ( + lambda: cst.ImportFrom( + module=cst.Name("foo"), + names=(cst.ImportAlias(cst.Name("bar")),), + lpar=cst.LeftParen(), + ), + "left paren without right paren", + ), + ( + lambda: cst.ImportFrom( + module=cst.Name("foo"), + names=(cst.ImportAlias(cst.Name("bar")),), + rpar=cst.RightParen(), + ), + "right paren without left paren", + ), + ( + lambda: cst.ImportFrom( + module=cst.Name("foo"), names=cst.ImportStar(), lpar=cst.LeftParen() + ), + "cannot have parens", + ), + ( + lambda: cst.ImportFrom( + module=cst.Name("foo"), + names=cst.ImportStar(), + rpar=cst.RightParen(), + ), + "cannot have parens", + ), + ( + lambda: cst.ImportFrom( + module=cst.Name("foo"), + names=(cst.ImportAlias(cst.Name("bar")),), + whitespace_after_from=cst.SimpleWhitespace(""), + ), + "one space after from", + ), + ( + lambda: cst.ImportFrom( + module=cst.Name("foo"), + names=(cst.ImportAlias(cst.Name("bar")),), + whitespace_before_import=cst.SimpleWhitespace(""), + ), + "one space before import", + ), + ( + lambda: cst.ImportFrom( + module=cst.Name("foo"), + names=(cst.ImportAlias(cst.Name("bar")),), + whitespace_after_import=cst.SimpleWhitespace(""), + ), + "one space after import", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class ImportFromParseTest(CSTNodeTest): + @data_provider( + ( + # Simple from import statement + ( + cst.ImportFrom( + module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),) + ), + "from foo import bar", + ), + # From import statement with alias + ( + cst.ImportFrom( + module=cst.Name("foo"), + names=( + cst.ImportAlias( + cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) + ), + ), + ), + "from foo import bar as baz", + ), + # Multiple imports + ( + cst.ImportFrom( + module=cst.Name("foo"), + names=( + cst.ImportAlias( + cst.Name("bar"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.ImportAlias(cst.Name("baz")), + ), + ), + "from foo import bar, baz", + ), + # Trailing comma + ( + cst.ImportFrom( + module=cst.Name("foo"), + names=( + cst.ImportAlias( + cst.Name("bar"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()), + ), + ), + "from foo import bar, baz,", + ), + # Star import statement + ( + cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), + "from foo import *", + ), + # Simple relative import statement + ( + cst.ImportFrom( + relative=(cst.Dot(),), + module=cst.Name("foo"), + names=(cst.ImportAlias(cst.Name("bar")),), + ), + "from .foo import bar", + ), + ( + cst.ImportFrom( + relative=(cst.Dot(), cst.Dot()), + module=cst.Name("foo"), + names=(cst.ImportAlias(cst.Name("bar")),), + ), + "from ..foo import bar", + ), + # Relative only import + ( + cst.ImportFrom( + relative=(cst.Dot(), cst.Dot()), + module=None, + names=(cst.ImportAlias(cst.Name("bar")),), + ), + "from .. import bar", + ), + # Parenthesis + ( + cst.ImportFrom( + module=cst.Name("foo"), + lpar=cst.LeftParen(), + names=( + cst.ImportAlias( + cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) + ), + ), + rpar=cst.RightParen(), + ), + "from foo import (bar as baz)", + ), + # Verify whitespace works everywhere. + ( + cst.ImportFrom( + relative=( + cst.Dot( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(" "), + ), + cst.Dot( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + module=cst.Name("foo"), + lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), + names=( + cst.ImportAlias( + cst.Name("bar"), + asname=cst.AsName( + cst.Name("baz"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.ImportAlias( + cst.Name("unittest"), + asname=cst.AsName( + cst.Name("ut"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + ), + ), + rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), + whitespace_after_from=cst.SimpleWhitespace(" "), + whitespace_before_import=cst.SimpleWhitespace(" "), + whitespace_after_import=cst.SimpleWhitespace(" "), + ), + "from . . foo import ( bar as baz , unittest as ut )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. + self.validate_node(node, code, lambda code: parse_statement(code).body[0]) diff --git a/libcst/nodes/tests/test_indented_block.py b/libcst/nodes/tests/test_indented_block.py new file mode 100644 index 00000000..0c025141 --- /dev/null +++ b/libcst/nodes/tests/test_indented_block.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional, TypeVar, Union + +import libcst.nodes as cst +from libcst._base_visitor import CSTVisitor +from libcst._removal_sentinel import RemovalSentinel +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +_CSTNodeT = TypeVar("_CSTNodeT", bound=cst.CSTNode) + + +class IfStatementRemovalVisitor(CSTVisitor): + def on_leave( + self, original_node: _CSTNodeT, updated_node: _CSTNodeT + ) -> Union[_CSTNodeT, RemovalSentinel]: + if isinstance(updated_node, cst.If): + return RemovalSentinel.REMOVE + else: + return updated_node + + +class IndentedBlockTest(CSTNodeTest): + @data_provider( + ( + ( + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + "\n pass\n", + None, + ), + ( + cst.IndentedBlock( + (cst.SimpleStatementLine((cst.Pass(),)),), indent="\t" + ), + "\n\tpass\n", + None, + ), + ( + cst.IndentedBlock( + (cst.SimpleStatementLine((cst.Pass(),)),), + header=cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# header comment"), + ), + ), + " # header comment\n pass\n", + None, + ), + ( + cst.IndentedBlock( + (cst.SimpleStatementLine((cst.Pass(),)),), + footer=(cst.EmptyLine(comment=cst.Comment("# footer comment")),), + ), + "\n pass\n# footer comment\n", + None, + ), + ( + cst.IndentedBlock( + (cst.SimpleStatementLine((cst.Pass(),)),), + footer=( + cst.EmptyLine( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# footer comment"), + ), + ), + ), + "\n pass\n # footer comment\n", + None, + ), + ( + cst.IndentedBlock( + ( + cst.SimpleStatementLine((cst.Continue(),)), + cst.SimpleStatementLine((cst.Pass(),)), + ) + ), + "\n continue\n pass\n", + None, + ), + # Basic parsing test + ( + cst.If( + cst.Name("conditional"), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + ), + "if conditional:\n pass\n", + parse_statement, + ), + # Multi-level parsing test + ( + cst.If( + cst.Name("conditional"), + cst.IndentedBlock( + ( + cst.SimpleStatementLine((cst.Pass(),)), + cst.If( + cst.Name("other_conditional"), + cst.IndentedBlock( + (cst.SimpleStatementLine((cst.Pass(),)),) + ), + ), + ) + ), + ), + "if conditional:\n pass\n if other_conditional:\n pass\n", + parse_statement, + ), + # Inconsistent indentation parsing test + ( + cst.If( + cst.Name("conditional"), + cst.IndentedBlock( + ( + cst.SimpleStatementLine((cst.Pass(),)), + cst.If( + cst.Name("other_conditional"), + cst.IndentedBlock( + (cst.SimpleStatementLine((cst.Pass(),)),), + indent=" ", + ), + ), + ) + ), + ), + "if conditional:\n pass\n if other_conditional:\n pass\n", + parse_statement, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + (lambda: cst.IndentedBlock(()), "at least one"), + ( + lambda: cst.IndentedBlock( + (cst.SimpleStatementLine((cst.Pass(),)),), indent="" + ), + "non-zero width indent", + ), + ( + lambda: cst.IndentedBlock( + (cst.SimpleStatementLine((cst.Pass(),)),), + indent="this isn't valid whitespace!", + ), + "only whitespace", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + def test_removal_creates_pass(self) -> None: + original = cst.IndentedBlock( + (cst.If(cst.Name("conditional"), cst.SimpleStatementSuite((cst.Break(),))),) + ) + expected = cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)) + + self.assertEqual(original.visit(IfStatementRemovalVisitor()), expected) diff --git a/libcst/nodes/tests/test_lambda.py b/libcst/nodes/tests/test_lambda.py new file mode 100644 index 00000000..f29fbdd4 --- /dev/null +++ b/libcst/nodes/tests/test_lambda.py @@ -0,0 +1,925 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class LambdaCreationTest(CSTNodeTest): + @data_provider( + ( + # Simple lambda + (cst.Lambda(cst.Parameters(), cst.Number(cst.Integer("5"))), "lambda: 5"), + # Test basic positional params + ( + cst.Lambda( + cst.Parameters( + params=(cst.Param(cst.Name("bar")), cst.Param(cst.Name("baz"))) + ), + cst.Number(cst.Integer("5")), + ), + "lambda bar, baz: 5", + ), + # Test basic positional default params + ( + cst.Lambda( + cst.Parameters( + default_params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + cst.Param( + cst.Name("baz"), default=cst.Number(cst.Integer("5")) + ), + ) + ), + cst.Number(cst.Integer("5")), + ), + 'lambda bar = "one", baz = 5: 5', + ), + # Mixed positional and default params. + ( + cst.Lambda( + cst.Parameters( + params=(cst.Param(cst.Name("bar")),), + default_params=( + cst.Param( + cst.Name("baz"), default=cst.Number(cst.Integer("5")) + ), + ), + ), + cst.Number(cst.Integer("5")), + ), + "lambda bar, baz = 5: 5", + ), + # Test kwonly params + ( + cst.Lambda( + cst.Parameters( + kwonly_params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + cst.Param(cst.Name("baz")), + ) + ), + cst.Number(cst.Integer("5")), + ), + 'lambda *, bar = "one", baz: 5', + ), + # Mixed params and kwonly_params + ( + cst.Lambda( + cst.Parameters( + params=( + cst.Param(cst.Name("first")), + cst.Param(cst.Name("second")), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + cst.Param(cst.Name("baz")), + cst.Param( + cst.Name("biz"), default=cst.SimpleString('"two"') + ), + ), + ), + cst.Number(cst.Integer("5")), + ), + 'lambda first, second, *, bar = "one", baz, biz = "two": 5', + ), + # Mixed default_params and kwonly_params + ( + cst.Lambda( + cst.Parameters( + default_params=( + cst.Param( + cst.Name("first"), default=cst.Number(cst.Float("1.0")) + ), + cst.Param( + cst.Name("second"), default=cst.Number(cst.Float("1.5")) + ), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + cst.Param(cst.Name("baz")), + cst.Param( + cst.Name("biz"), default=cst.SimpleString('"two"') + ), + ), + ), + cst.Number(cst.Integer("5")), + ), + 'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5', + ), + # Mixed params, default_params, and kwonly_params + ( + cst.Lambda( + cst.Parameters( + params=( + cst.Param(cst.Name("first")), + cst.Param(cst.Name("second")), + ), + default_params=( + cst.Param( + cst.Name("third"), default=cst.Number(cst.Float("1.0")) + ), + cst.Param( + cst.Name("fourth"), default=cst.Number(cst.Float("1.5")) + ), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + cst.Param(cst.Name("baz")), + cst.Param( + cst.Name("biz"), default=cst.SimpleString('"two"') + ), + ), + ), + cst.Number(cst.Integer("5")), + ), + 'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5', + ), + # Test star_arg + ( + cst.Lambda( + cst.Parameters(star_arg=cst.Param(cst.Name("params"))), + cst.Number(cst.Integer("5")), + ), + "lambda *params: 5", + ), + # Typed star_arg, include kwonly_params + ( + cst.Lambda( + cst.Parameters( + star_arg=cst.Param(cst.Name("params")), + kwonly_params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + cst.Param(cst.Name("baz")), + cst.Param( + cst.Name("biz"), default=cst.SimpleString('"two"') + ), + ), + ), + cst.Number(cst.Integer("5")), + ), + 'lambda *params, bar = "one", baz, biz = "two": 5', + ), + # Mixed params default_params, star_arg and kwonly_params + ( + cst.Lambda( + cst.Parameters( + params=( + cst.Param(cst.Name("first")), + cst.Param(cst.Name("second")), + ), + default_params=( + cst.Param( + cst.Name("third"), default=cst.Number(cst.Float("1.0")) + ), + cst.Param( + cst.Name("fourth"), default=cst.Number(cst.Float("1.5")) + ), + ), + star_arg=cst.Param(cst.Name("params")), + kwonly_params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + cst.Param(cst.Name("baz")), + cst.Param( + cst.Name("biz"), default=cst.SimpleString('"two"') + ), + ), + ), + cst.Number(cst.Integer("5")), + ), + 'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5', + ), + # Test star_arg and star_kwarg + ( + cst.Lambda( + cst.Parameters(star_kwarg=cst.Param(cst.Name("kwparams"))), + cst.Number(cst.Integer("5")), + ), + "lambda **kwparams: 5", + ), + # Test star_arg and kwarg + ( + cst.Lambda( + cst.Parameters( + star_arg=cst.Param(cst.Name("params")), + star_kwarg=cst.Param(cst.Name("kwparams")), + ), + cst.Number(cst.Integer("5")), + ), + "lambda *params, **kwparams: 5", + ), + # Inner whitespace + ( + cst.Lambda( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + params=cst.Parameters(), + colon=cst.Colon(whitespace_after=cst.SimpleWhitespace(" ")), + body=cst.Number(cst.Integer("5")), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( lambda : 5 )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + ( + lambda: cst.Lambda( + cst.Parameters(params=(cst.Param(cst.Name("arg")),)), + cst.Number(cst.Integer("5")), + lpar=(cst.LeftParen(),), + ), + "left paren without right paren", + ), + ( + lambda: cst.Lambda( + cst.Parameters(params=(cst.Param(cst.Name("arg")),)), + cst.Number(cst.Integer("5")), + rpar=(cst.RightParen(),), + ), + "right paren without left paren", + ), + ( + lambda: cst.Lambda( + cst.Parameters(params=(cst.Param(cst.Name("arg")),)), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "at least one space after lambda", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + default_params=( + cst.Param( + cst.Name("arg"), default=cst.Number(cst.Integer("5")) + ), + ) + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "at least one space after lambda", + ), + ( + lambda: cst.Lambda( + cst.Parameters(star_arg=cst.Param(cst.Name("arg"))), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "at least one space after lambda", + ), + ( + lambda: cst.Lambda( + cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")),)), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "at least one space after lambda", + ), + ( + lambda: cst.Lambda( + cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "at least one space after lambda", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + star_kwarg=cst.Param(cst.Name("bar"), equal=cst.AssignEqual()) + ), + cst.Number(cst.Integer("5")), + ), + "Must have a default when specifying an AssignEqual.", + ), + ( + lambda: cst.Lambda( + cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="***")), + cst.Number(cst.Integer("5")), + ), + r"Must specify either '', '\*' or '\*\*' for star.", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + params=( + cst.Param( + cst.Name("bar"), default=cst.SimpleString('"one"') + ), + ) + ), + cst.Number(cst.Integer("5")), + ), + "Cannot have defaults for params", + ), + ( + lambda: cst.Lambda( + cst.Parameters(default_params=(cst.Param(cst.Name("bar")),)), + cst.Number(cst.Integer("5")), + ), + "Must have defaults for default_params", + ), + ( + lambda: cst.Lambda( + cst.Parameters(star_arg=cst.ParamStar()), + cst.Number(cst.Integer("5")), + ), + "Must have at least one kwonly param if ParamStar is used.", + ), + ( + lambda: cst.Lambda( + cst.Parameters(params=(cst.Param(cst.Name("bar"), star="*"),)), + cst.Number(cst.Integer("5")), + ), + "Expecting a star prefix of ''", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + default_params=( + cst.Param( + cst.Name("bar"), + default=cst.SimpleString('"one"'), + star="*", + ), + ) + ), + cst.Number(cst.Integer("5")), + ), + "Expecting a star prefix of ''", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + kwonly_params=(cst.Param(cst.Name("bar"), star="*"),) + ), + cst.Number(cst.Integer("5")), + ), + "Expecting a star prefix of ''", + ), + ( + lambda: cst.Lambda( + cst.Parameters(star_arg=cst.Param(cst.Name("bar"), star="**")), + cst.Number(cst.Integer("5")), + ), + r"Expecting a star prefix of '\*'", + ), + ( + lambda: cst.Lambda( + cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="*")), + cst.Number(cst.Integer("5")), + ), + r"Expecting a star prefix of '\*\*'", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + params=( + cst.Param( + cst.Name("arg"), + annotation=cst.Annotation(cst.Name("str")), + ), + ) + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "Lambda params cannot have type annotations", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + default_params=( + cst.Param( + cst.Name("arg"), + default=cst.Number(cst.Integer("5")), + annotation=cst.Annotation(cst.Name("str")), + ), + ) + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "Lambda params cannot have type annotations", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + star_arg=cst.Param( + cst.Name("arg"), annotation=cst.Annotation(cst.Name("str")) + ) + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "Lambda params cannot have type annotations", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + kwonly_params=( + cst.Param( + cst.Name("arg"), + annotation=cst.Annotation(cst.Name("str")), + ), + ) + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "Lambda params cannot have type annotations", + ), + ( + lambda: cst.Lambda( + cst.Parameters( + star_kwarg=cst.Param( + cst.Name("arg"), annotation=cst.Annotation(cst.Name("str")) + ) + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(""), + ), + "Lambda params cannot have type annotations", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class LambdaParserTest(CSTNodeTest): + @data_provider( + ( + # Simple lambda + (cst.Lambda(cst.Parameters(), cst.Number(cst.Integer("5"))), "lambda: 5"), + # Test basic positional params + ( + cst.Lambda( + cst.Parameters( + params=( + cst.Param( + cst.Name("bar"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param(cst.Name("baz"), star=""), + ) + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + "lambda bar, baz: 5", + ), + # Test basic positional default params + ( + cst.Lambda( + cst.Parameters( + default_params=( + cst.Param( + cst.Name("bar"), + default=cst.SimpleString('"one"'), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + default=cst.Number(cst.Integer("5")), + equal=cst.AssignEqual(), + star="", + ), + ) + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + 'lambda bar = "one", baz = 5: 5', + ), + # Mixed positional and default params. + ( + cst.Lambda( + cst.Parameters( + params=( + cst.Param( + cst.Name("bar"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + default_params=( + cst.Param( + cst.Name("baz"), + default=cst.Number(cst.Integer("5")), + equal=cst.AssignEqual(), + star="", + ), + ), + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + "lambda bar, baz = 5: 5", + ), + # Test kwonly params + ( + cst.Lambda( + cst.Parameters( + star_arg=cst.ParamStar(), + kwonly_params=( + cst.Param( + cst.Name("bar"), + default=cst.SimpleString('"one"'), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param(cst.Name("baz"), star=""), + ), + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + 'lambda *, bar = "one", baz: 5', + ), + # Mixed params and kwonly_params + ( + cst.Lambda( + cst.Parameters( + params=( + cst.Param( + cst.Name("first"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("second"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + star_arg=cst.ParamStar(), + kwonly_params=( + cst.Param( + cst.Name("bar"), + default=cst.SimpleString('"one"'), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + default=cst.SimpleString('"two"'), + equal=cst.AssignEqual(), + star="", + ), + ), + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + 'lambda first, second, *, bar = "one", baz, biz = "two": 5', + ), + # Mixed default_params and kwonly_params + ( + cst.Lambda( + cst.Parameters( + default_params=( + cst.Param( + cst.Name("first"), + default=cst.Number(cst.Float("1.0")), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("second"), + default=cst.Number(cst.Float("1.5")), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + star_arg=cst.ParamStar(), + kwonly_params=( + cst.Param( + cst.Name("bar"), + default=cst.SimpleString('"one"'), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + default=cst.SimpleString('"two"'), + equal=cst.AssignEqual(), + star="", + ), + ), + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + 'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5', + ), + # Mixed params, default_params, and kwonly_params + ( + cst.Lambda( + cst.Parameters( + params=( + cst.Param( + cst.Name("first"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("second"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + default_params=( + cst.Param( + cst.Name("third"), + default=cst.Number(cst.Float("1.0")), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("fourth"), + default=cst.Number(cst.Float("1.5")), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + star_arg=cst.ParamStar(), + kwonly_params=( + cst.Param( + cst.Name("bar"), + default=cst.SimpleString('"one"'), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + default=cst.SimpleString('"two"'), + equal=cst.AssignEqual(), + star="", + ), + ), + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + 'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5', + ), + # Test star_arg + ( + cst.Lambda( + cst.Parameters(star_arg=cst.Param(cst.Name("params"), star="*")), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + "lambda *params: 5", + ), + # Typed star_arg, include kwonly_params + ( + cst.Lambda( + cst.Parameters( + star_arg=cst.Param( + cst.Name("params"), + star="*", + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), + default=cst.SimpleString('"one"'), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + default=cst.SimpleString('"two"'), + equal=cst.AssignEqual(), + star="", + ), + ), + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + 'lambda *params, bar = "one", baz, biz = "two": 5', + ), + # Mixed params default_params, star_arg and kwonly_params + ( + cst.Lambda( + cst.Parameters( + params=( + cst.Param( + cst.Name("first"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("second"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + default_params=( + cst.Param( + cst.Name("third"), + default=cst.Number(cst.Float("1.0")), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("fourth"), + default=cst.Number(cst.Float("1.5")), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + ), + star_arg=cst.Param( + cst.Name("params"), + star="*", + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + kwonly_params=( + cst.Param( + cst.Name("bar"), + default=cst.SimpleString('"one"'), + equal=cst.AssignEqual(), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("baz"), + star="", + comma=cst.Comma( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + cst.Param( + cst.Name("biz"), + default=cst.SimpleString('"two"'), + equal=cst.AssignEqual(), + star="", + ), + ), + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + 'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5', + ), + # Test star_arg and star_kwarg + ( + cst.Lambda( + cst.Parameters( + star_kwarg=cst.Param(cst.Name("kwparams"), star="**") + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + "lambda **kwparams: 5", + ), + # Test star_arg and kwarg + ( + cst.Lambda( + cst.Parameters( + star_arg=cst.Param( + cst.Name("params"), + star="*", + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + star_kwarg=cst.Param(cst.Name("kwparams"), star="**"), + ), + cst.Number(cst.Integer("5")), + whitespace_after_lambda=cst.SimpleWhitespace(" "), + ), + "lambda *params, **kwparams: 5", + ), + # Inner whitespace + ( + cst.Lambda( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + params=cst.Parameters(), + colon=cst.Colon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + body=cst.Number(cst.Integer("5")), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( lambda : 5 )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_expression) diff --git a/libcst/nodes/tests/test_leaf_small_statements.py b/libcst/nodes/tests/test_leaf_small_statements.py new file mode 100644 index 00000000..b174138e --- /dev/null +++ b/libcst/nodes/tests/test_leaf_small_statements.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.testing.utils import data_provider + + +class LeafSmallStatementsTest(CSTNodeTest): + @data_provider( + ((cst.Pass(), "pass"), (cst.Break(), "break"), (cst.Continue(), "continue")) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) diff --git a/libcst/nodes/tests/test_module.py b/libcst/nodes/tests/test_module.py new file mode 100644 index 00000000..0fd4d776 --- /dev/null +++ b/libcst/nodes/tests/test_module.py @@ -0,0 +1,114 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_module +from libcst.testing.utils import data_provider + + +class ModuleTest(CSTNodeTest): + @data_provider( + ( + # simplest possible program + (cst.Module((cst.SimpleStatementLine((cst.Pass(),)),)), "pass\n"), + # test default_newline + ( + cst.Module( + (cst.SimpleStatementLine((cst.Pass(),)),), default_newline="\r" + ), + "pass\r", + ), + # test header/footer + ( + cst.Module( + (cst.SimpleStatementLine((cst.Pass(),)),), + header=(cst.EmptyLine(comment=cst.Comment("# header")),), + footer=(cst.EmptyLine(comment=cst.Comment("# footer")),), + ), + "# header\npass\n# footer\n", + ), + # test has_trailing_newline + ( + cst.Module( + (cst.SimpleStatementLine((cst.Pass(),)),), + has_trailing_newline=False, + ), + "pass", + ), + # an empty file + (cst.Module((), has_trailing_newline=False), ""), + # a file with only comments + ( + cst.Module( + (), + header=( + cst.EmptyLine(comment=cst.Comment("# nothing to see here")), + ), + ), + "# nothing to see here\n", + ), + # TODO: test default_indent + ) + ) + def test_code_and_bytes_properties(self, module: cst.Module, expected: str) -> None: + self.assertEqual(module.code, expected) + self.assertEqual(module.bytes, expected.encode("utf-8")) + + @data_provider( + ( + (cst.Module(()), cst.Newline(), "\n"), + (cst.Module((), default_newline="\r\n"), cst.Newline(), "\r\n"), + # has_trailing_newline has no effect on code_for_node + (cst.Module((), has_trailing_newline=False), cst.Newline(), "\n"), + # TODO: test default_indent + ) + ) + def test_code_for_node( + self, module: cst.Module, node: cst.CSTNode, expected: str + ) -> None: + self.assertEqual(module.code_for_node(node), expected) + + @data_provider( + { + "empty_program": { + "code": "", + "expected": cst.Module([], has_trailing_newline=False), + }, + "empty_program_with_newline": { + "code": "\n", + "expected": cst.Module([], has_trailing_newline=True), + }, + "empty_program_with_comments": { + "code": "# some comment\n", + "expected": cst.Module( + [], header=[cst.EmptyLine(comment=cst.Comment("# some comment"))] + ), + }, + "simple_pass": { + "code": "pass\n", + "expected": cst.Module([cst.SimpleStatementLine([cst.Pass()])]), + }, + "simple_pass_with_header_footer": { + "code": "# header\npass # trailing\n# footer\n", + "expected": cst.Module( + [ + cst.SimpleStatementLine( + [cst.Pass()], + trailing_whitespace=cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# trailing"), + ), + ) + ], + header=[cst.EmptyLine(comment=cst.Comment("# header"))], + footer=[cst.EmptyLine(comment=cst.Comment("# footer"))], + ), + }, + } + ) + def test_parser(self, *, code: str, expected: cst.Module) -> None: + self.assertEqual(parse_module(code), expected) diff --git a/libcst/nodes/tests/test_newline.py b/libcst/nodes/tests/test_newline.py new file mode 100644 index 00000000..93ef63da --- /dev/null +++ b/libcst/nodes/tests/test_newline.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.testing.utils import data_provider + + +class NewlineTest(CSTNodeTest): + @data_provider( + ( + (cst.Newline("\r\n"), "\r\n"), + (cst.Newline("\r"), "\r"), + (cst.Newline("\n"), "\n"), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + (lambda: cst.Newline("bad input"), "invalid value"), + (lambda: cst.Newline("\nbad input\n"), "invalid value"), + (lambda: cst.Newline("\n\n"), "invalid value"), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_nonlocal.py b/libcst/nodes/tests/test_nonlocal.py new file mode 100644 index 00000000..2cc1dea3 --- /dev/null +++ b/libcst/nodes/tests/test_nonlocal.py @@ -0,0 +1,129 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class NonlocalConstructionTest(CSTNodeTest): + @data_provider( + ( + # Single nonlocal statement + (cst.Nonlocal((cst.NameItem(cst.Name("a")),)), "nonlocal a"), + # Multiple entries in nonlocal statement + ( + cst.Nonlocal( + (cst.NameItem(cst.Name("a")), cst.NameItem(cst.Name("b"))) + ), + "nonlocal a, b", + ), + # Whitespace rendering test + ( + cst.Nonlocal( + ( + cst.NameItem( + cst.Name("a"), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.NameItem(cst.Name("b")), + ), + whitespace_after_nonlocal=cst.SimpleWhitespace(" "), + ), + "nonlocal a , b", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + # Validate construction + ( + lambda: cst.Nonlocal(()), + "A Nonlocal statement must have at least one NameItem", + ), + # Validate whitespace handling + ( + lambda: cst.Nonlocal( + (cst.NameItem(cst.Name("a")),), + whitespace_after_nonlocal=cst.SimpleWhitespace(""), + ), + "Must have at least one space after 'nonlocal' keyword", + ), + # Validate comma handling + ( + lambda: cst.Nonlocal((cst.NameItem(cst.Name("a"), comma=cst.Comma()),)), + "The last NameItem in a Nonlocal cannot have a trailing comma", + ), + # Validate paren handling + ( + lambda: cst.Nonlocal( + ( + cst.NameItem( + cst.Name( + "a", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ) + ), + ) + ), + "Cannot have parens around names in NameItem", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class NonlocalParsingTest(CSTNodeTest): + @data_provider( + ( + # Single nonlocal statement + (cst.Nonlocal((cst.NameItem(cst.Name("a")),)), "nonlocal a"), + # Multiple entries in nonlocal statement + ( + cst.Nonlocal( + ( + cst.NameItem( + cst.Name("a"), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.NameItem(cst.Name("b")), + ) + ), + "nonlocal a, b", + ), + # Whitespace rendering test + ( + cst.Nonlocal( + ( + cst.NameItem( + cst.Name("a"), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.NameItem(cst.Name("b")), + ), + whitespace_after_nonlocal=cst.SimpleWhitespace(" "), + ), + "nonlocal a , b", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. + self.validate_node(node, code, lambda code: parse_statement(code).body[0]) diff --git a/libcst/nodes/tests/test_number.py b/libcst/nodes/tests/test_number.py new file mode 100644 index 00000000..0af3368b --- /dev/null +++ b/libcst/nodes/tests/test_number.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class NumberTest(CSTNodeTest): + @data_provider( + ( + # Simple number + (cst.Number(cst.Integer("5")), "5", parse_expression), + # Negted number + ( + cst.Number(operator=cst.Minus(), number=cst.Integer("5")), + "-5", + parse_expression, + ), + # In parenthesis + ( + cst.Number( + lpar=(cst.LeftParen(),), + operator=cst.Minus(), + number=cst.Integer("5"), + rpar=(cst.RightParen(),), + ), + "(-5)", + parse_expression, + ), + ( + cst.Number( + lpar=(cst.LeftParen(),), + operator=cst.Minus(), + number=cst.Integer( + "5", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + rpar=(cst.RightParen(),), + ), + "(-(5))", + parse_expression, + ), + ( + cst.UnaryOperation( + operator=cst.Minus(), + expression=cst.Number( + operator=cst.Minus(), number=cst.Integer("5") + ), + ), + "--5", + parse_expression, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + ( + lambda: cst.Number(cst.Integer("5"), lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Number(cst.Integer("5"), rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_raise.py b/libcst/nodes/tests/test_raise.py new file mode 100644 index 00000000..1816d824 --- /dev/null +++ b/libcst/nodes/tests/test_raise.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class RaiseConstructionTest(CSTNodeTest): + @data_provider( + ( + # Simple raise + (cst.Raise(), "raise"), + # Raise exception + (cst.Raise(cst.Call(cst.Name("Exception"))), "raise Exception()"), + # Raise exception from cause + ( + cst.Raise(cst.Call(cst.Name("Exception")), cst.From(cst.Name("cause"))), + "raise Exception() from cause", + ), + # Whitespace oddities test + ( + cst.Raise( + cst.Call( + cst.Name("Exception"), + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ), + cst.From( + cst.Name( + "cause", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + whitespace_before_from=cst.SimpleWhitespace(""), + whitespace_after_from=cst.SimpleWhitespace(""), + ), + whitespace_after_raise=cst.SimpleWhitespace(""), + ), + "raise(Exception())from(cause)", + ), + ( + cst.Raise( + cst.Call(cst.Name("Exception")), + cst.From( + cst.Name("cause"), + whitespace_before_from=cst.SimpleWhitespace(""), + ), + ), + "raise Exception()from cause", + ), + # Whitespace rendering test + ( + cst.Raise( + exc=cst.Call(cst.Name("Exception")), + cause=cst.From( + cst.Name("cause"), + whitespace_before_from=cst.SimpleWhitespace(" "), + whitespace_after_from=cst.SimpleWhitespace(" "), + ), + whitespace_after_raise=cst.SimpleWhitespace(" "), + ), + "raise Exception() from cause", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + # Validate construction + ( + lambda: cst.Raise(cause=cst.From(cst.Name("cause"))), + "Must have an 'exc' when specifying 'clause'. on Raise", + ), + # Validate whitespace handling + ( + lambda: cst.Raise( + cst.Call(cst.Name("Exception")), + whitespace_after_raise=cst.SimpleWhitespace(""), + ), + "Must have at least one space after 'raise'", + ), + ( + lambda: cst.Raise( + cst.Name("exc"), + cst.From( + cst.Name("cause"), + whitespace_before_from=cst.SimpleWhitespace(""), + ), + ), + "Must have at least one space before 'from'", + ), + ( + lambda: cst.Raise( + cst.Name("exc"), + cst.From( + cst.Name("cause"), + whitespace_after_from=cst.SimpleWhitespace(""), + ), + ), + "Must have at least one space after 'from'", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class RaiseParsingTest(CSTNodeTest): + @data_provider( + ( + # Simple raise + (cst.Raise(), "raise"), + # Raise exception + ( + cst.Raise( + cst.Call(cst.Name("Exception")), + whitespace_after_raise=cst.SimpleWhitespace(" "), + ), + "raise Exception()", + ), + # Raise exception from cause + ( + cst.Raise( + cst.Call(cst.Name("Exception")), + cst.From( + cst.Name("cause"), + whitespace_before_from=cst.SimpleWhitespace(" "), + whitespace_after_from=cst.SimpleWhitespace(" "), + ), + whitespace_after_raise=cst.SimpleWhitespace(" "), + ), + "raise Exception() from cause", + ), + # Whitespace oddities test + ( + cst.Raise( + cst.Call( + cst.Name("Exception"), + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ), + cst.From( + cst.Name( + "cause", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + whitespace_before_from=cst.SimpleWhitespace(""), + whitespace_after_from=cst.SimpleWhitespace(""), + ), + whitespace_after_raise=cst.SimpleWhitespace(""), + ), + "raise(Exception())from(cause)", + ), + ( + cst.Raise( + cst.Call(cst.Name("Exception")), + cst.From( + cst.Name("cause"), + whitespace_before_from=cst.SimpleWhitespace(""), + whitespace_after_from=cst.SimpleWhitespace(" "), + ), + whitespace_after_raise=cst.SimpleWhitespace(" "), + ), + "raise Exception()from cause", + ), + # Whitespace rendering test + ( + cst.Raise( + exc=cst.Call(cst.Name("Exception")), + cause=cst.From( + cst.Name("cause"), + whitespace_before_from=cst.SimpleWhitespace(" "), + whitespace_after_from=cst.SimpleWhitespace(" "), + ), + whitespace_after_raise=cst.SimpleWhitespace(" "), + ), + "raise Exception() from cause", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. + self.validate_node(node, code, lambda code: parse_statement(code).body[0]) diff --git a/libcst/nodes/tests/test_return.py b/libcst/nodes/tests/test_return.py new file mode 100644 index 00000000..05f2cf8a --- /dev/null +++ b/libcst/nodes/tests/test_return.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class ReturnCreateTest(CSTNodeTest): + @data_provider( + ( + (cst.SimpleStatementLine([cst.Return()]), "return\n"), + (cst.SimpleStatementLine([cst.Return(cst.Name("abc"))]), "return abc\n"), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + ( + lambda: cst.Return( + cst.Name("abc"), whitespace_after_return=cst.SimpleWhitespace("") + ), + "Must have at least one space after 'return'.", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class ReturnParseTest(CSTNodeTest): + @data_provider( + ( + ( + cst.SimpleStatementLine( + [cst.Return(whitespace_after_return=cst.SimpleWhitespace(""))] + ), + "return\n", + ), + ( + cst.SimpleStatementLine( + [ + cst.Return( + cst.Name("abc"), + whitespace_after_return=cst.SimpleWhitespace(" "), + ) + ] + ), + "return abc\n", + ), + ( + cst.SimpleStatementLine( + [ + cst.Return( + cst.Name("abc"), + whitespace_after_return=cst.SimpleWhitespace(" "), + ) + ] + ), + "return abc\n", + ), + ( + cst.SimpleStatementLine( + [ + cst.Return( + cst.Name( + "abc", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] + ), + whitespace_after_return=cst.SimpleWhitespace(""), + ) + ] + ), + "return(abc)\n", + ), + ( + cst.SimpleStatementLine( + [ + cst.Return( + cst.Name("abc"), + whitespace_after_return=cst.SimpleWhitespace(" "), + semicolon=cst.Semicolon(), + ) + ] + ), + "return abc;\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_statement) diff --git a/libcst/nodes/tests/test_simple_statement.py b/libcst/nodes/tests/test_simple_statement.py new file mode 100644 index 00000000..b52bb7b0 --- /dev/null +++ b/libcst/nodes/tests/test_simple_statement.py @@ -0,0 +1,362 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest, DummyIndentedBlock +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class SimpleStatementTest(CSTNodeTest): + @data_provider( + ( + # a single-element SimpleStatementLine + (cst.SimpleStatementLine((cst.Pass(),)), "pass\n", parse_statement), + # a multi-element SimpleStatementLine + ( + cst.SimpleStatementLine( + (cst.Pass(semicolon=cst.Semicolon()), cst.Continue()) + ), + "pass;continue\n", + parse_statement, + ), + # a multi-element SimpleStatementLine with whitespace + ( + cst.SimpleStatementLine( + ( + cst.Pass( + semicolon=cst.Semicolon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ) + ), + cst.Continue(), + ) + ), + "pass ; continue\n", + parse_statement, + ), + # A more complicated SimpleStatementLine + ( + cst.SimpleStatementLine( + ( + cst.Pass(semicolon=cst.Semicolon()), + cst.Continue(semicolon=cst.Semicolon()), + cst.Break(), + ) + ), + "pass;continue;break\n", + parse_statement, + ), + # a multi-element SimpleStatementLine, inferred semicolons + ( + cst.SimpleStatementLine((cst.Pass(), cst.Continue(), cst.Break())), + "pass; continue; break\n", + None, # No test for parsing, since we are using sentinels. + ), + # some expression statements + ( + cst.SimpleStatementLine((cst.Expr(cst.Name("None")),)), + "None\n", + parse_statement, + ), + ( + cst.SimpleStatementLine((cst.Expr(cst.Name("True")),)), + "True\n", + parse_statement, + ), + ( + cst.SimpleStatementLine((cst.Expr(cst.Name("False")),)), + "False\n", + parse_statement, + ), + ( + cst.SimpleStatementLine((cst.Expr(cst.Ellipses()),)), + "...\n", + parse_statement, + ), + # Test some numbers + ( + cst.SimpleStatementLine((cst.Expr(cst.Number(cst.Integer("5"))),)), + "5\n", + parse_statement, + ), + ( + cst.SimpleStatementLine((cst.Expr(cst.Number(cst.Float("5.5"))),)), + "5.5\n", + parse_statement, + ), + ( + cst.SimpleStatementLine((cst.Expr(cst.Number(cst.Imaginary("5j"))),)), + "5j\n", + parse_statement, + ), + # Test some numbers with parens + ( + cst.SimpleStatementLine( + ( + cst.Expr( + cst.Number( + cst.Integer( + "5", + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ) + ) + ), + ) + ), + "(5)\n", + parse_statement, + ), + ( + cst.SimpleStatementLine( + ( + cst.Expr( + cst.Number( + cst.Float( + "5.5", + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ) + ) + ), + ) + ), + "(5.5)\n", + parse_statement, + ), + ( + cst.SimpleStatementLine( + ( + cst.Expr( + cst.Number( + cst.Imaginary( + "5j", + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ) + ) + ), + ) + ), + "(5j)\n", + parse_statement, + ), + # Test some strings + ( + cst.SimpleStatementLine((cst.Expr(cst.SimpleString('"abc"')),)), + '"abc"\n', + parse_statement, + ), + ( + cst.SimpleStatementLine( + ( + cst.Expr( + cst.ConcatenatedString( + cst.SimpleString('"abc"'), cst.SimpleString('"def"') + ) + ), + ) + ), + '"abc""def"\n', + parse_statement, + ), + ( + cst.SimpleStatementLine( + ( + cst.Expr( + cst.ConcatenatedString( + left=cst.SimpleString('"abc"'), + whitespace_between=cst.SimpleWhitespace(" "), + right=cst.ConcatenatedString( + left=cst.SimpleString('"def"'), + whitespace_between=cst.SimpleWhitespace(" "), + right=cst.SimpleString('"ghi"'), + ), + ) + ), + ) + ), + '"abc" "def" "ghi"\n', + parse_statement, + ), + # Test parenthesis rules + ( + cst.SimpleStatementLine( + ( + cst.Expr( + cst.Ellipses( + lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ) + ), + ) + ), + "(...)\n", + parse_statement, + ), + # Test parenthesis with whitespace ownership + ( + cst.SimpleStatementLine( + ( + cst.Expr( + cst.Ellipses( + lpar=( + cst.LeftParen( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + rpar=( + cst.RightParen( + whitespace_before=cst.SimpleWhitespace(" ") + ), + ), + ) + ), + ) + ), + "( ... )\n", + parse_statement, + ), + ( + cst.SimpleStatementLine( + ( + cst.Expr( + cst.Ellipses( + lpar=( + cst.LeftParen( + whitespace_after=cst.SimpleWhitespace(" ") + ), + cst.LeftParen( + whitespace_after=cst.SimpleWhitespace(" ") + ), + cst.LeftParen( + whitespace_after=cst.SimpleWhitespace(" ") + ), + ), + rpar=( + cst.RightParen( + whitespace_before=cst.SimpleWhitespace(" ") + ), + cst.RightParen( + whitespace_before=cst.SimpleWhitespace(" ") + ), + cst.RightParen( + whitespace_before=cst.SimpleWhitespace(" ") + ), + ), + ) + ), + ) + ), + "( ( ( ... ) ) )\n", + parse_statement, + ), + # Test parenthesis rules with expressions + ( + cst.SimpleStatementLine( + ( + cst.Expr( + cst.Ellipses( + lpar=( + cst.LeftParen( + whitespace_after=cst.ParenthesizedWhitespace( + first_line=cst.TrailingWhitespace(), + empty_lines=( + cst.EmptyLine( + comment=cst.Comment( + "# Wow, a comment!" + ) + ), + ), + indent=True, + last_line=cst.SimpleWhitespace(" "), + ) + ), + ), + rpar=( + cst.RightParen( + whitespace_before=cst.ParenthesizedWhitespace( + first_line=cst.TrailingWhitespace(), + empty_lines=(), + indent=True, + last_line=cst.SimpleWhitespace(""), + ) + ), + ), + ) + ), + ) + ), + "(\n# Wow, a comment!\n ...\n)\n", + parse_statement, + ), + # test trailing whitespace + ( + cst.SimpleStatementLine( + (cst.Pass(),), + trailing_whitespace=cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# trailing comment"), + ), + ), + "pass # trailing comment\n", + parse_statement, + ), + # test leading comment + ( + cst.SimpleStatementLine( + (cst.Pass(),), + leading_lines=(cst.EmptyLine(comment=cst.Comment("# comment")),), + ), + "# comment\npass\n", + parse_statement, + ), + # test indentation + ( + DummyIndentedBlock( + " ", + cst.SimpleStatementLine( + (cst.Pass(),), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# comment")), + ), + ), + ), + " # comment\n pass\n", + None, + ), + # test suite variant + (cst.SimpleStatementSuite((cst.Pass(),)), " pass\n", None), + ( + cst.SimpleStatementSuite( + (cst.Pass(),), leading_whitespace=cst.SimpleWhitespace("") + ), + "pass\n", + None, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + (lambda: cst.SimpleStatementLine(()), "empty"), + (lambda: cst.SimpleStatementSuite(()), "empty"), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_simple_whitespace.py b/libcst/nodes/tests/test_simple_whitespace.py new file mode 100644 index 00000000..dbb50b1b --- /dev/null +++ b/libcst/nodes/tests/test_simple_whitespace.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest, DummyIndentedBlock +from libcst.testing.utils import data_provider + + +class SimpleWhitespaceTest(CSTNodeTest): + @data_provider( + ( + (cst.SimpleWhitespace(""), ""), + (cst.SimpleWhitespace(" "), " "), + (cst.SimpleWhitespace(" \t\f"), " \t\f"), + (cst.SimpleWhitespace("\\\n "), "\\\n "), + (cst.SimpleWhitespace("\\\r\n "), "\\\r\n "), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + (lambda: cst.SimpleWhitespace(" bad input"), "non-whitespace"), + (lambda: cst.SimpleWhitespace("\\"), "non-whitespace"), + (lambda: cst.SimpleWhitespace("\\\n\n "), "non-whitespace"), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class ParenthesizedWhitespaceTest(CSTNodeTest): + @data_provider( + ( + (cst.ParenthesizedWhitespace(), "\n"), + ( + cst.ParenthesizedWhitespace( + first_line=cst.TrailingWhitespace( + cst.SimpleWhitespace(" "), cst.Comment("# This is a comment") + ) + ), + " # This is a comment\n", + ), + ( + cst.ParenthesizedWhitespace( + first_line=cst.TrailingWhitespace( + cst.SimpleWhitespace(" "), cst.Comment("# This is a comment") + ), + empty_lines=(cst.EmptyLine(), cst.EmptyLine(), cst.EmptyLine()), + ), + " # This is a comment\n\n\n\n", + ), + ( + cst.ParenthesizedWhitespace( + first_line=cst.TrailingWhitespace( + cst.SimpleWhitespace(" "), cst.Comment("# This is a comment") + ), + empty_lines=(cst.EmptyLine(), cst.EmptyLine(), cst.EmptyLine()), + indent=False, + last_line=cst.SimpleWhitespace(" "), + ), + " # This is a comment\n\n\n\n ", + ), + ( + DummyIndentedBlock( + " ", + cst.ParenthesizedWhitespace( + first_line=cst.TrailingWhitespace( + cst.SimpleWhitespace(" "), + cst.Comment("# This is a comment"), + ), + empty_lines=(cst.EmptyLine(), cst.EmptyLine(), cst.EmptyLine()), + indent=True, + last_line=cst.SimpleWhitespace(" "), + ), + ), + " # This is a comment\n \n \n \n ", + ), + ( + DummyIndentedBlock( + " ", + cst.ParenthesizedWhitespace( + first_line=cst.TrailingWhitespace( + cst.SimpleWhitespace(" "), + cst.Comment("# This is a comment"), + ), + indent=True, + last_line=cst.SimpleWhitespace(""), + ), + ), + " # This is a comment\n ", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) diff --git a/libcst/nodes/tests/test_small_statement.py b/libcst/nodes/tests/test_small_statement.py new file mode 100644 index 00000000..c5efc634 --- /dev/null +++ b/libcst/nodes/tests/test_small_statement.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.testing.utils import data_provider + + +class SmallStatementTest(CSTNodeTest): + @data_provider( + ( + (cst.Pass(), "pass"), + (cst.Pass(semicolon=cst.Semicolon()), "pass;"), + ( + cst.Pass( + semicolon=cst.Semicolon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ) + ), + "pass ; ", + ), + (cst.Continue(), "continue"), + (cst.Continue(semicolon=cst.Semicolon()), "continue;"), + ( + cst.Continue( + semicolon=cst.Semicolon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ) + ), + "continue ; ", + ), + (cst.Break(), "break"), + (cst.Break(semicolon=cst.Semicolon()), "break;"), + ( + cst.Break( + semicolon=cst.Semicolon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ) + ), + "break ; ", + ), + ( + cst.Expr(cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y"))), + "x + y", + ), + ( + cst.Expr( + cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y")), + semicolon=cst.Semicolon(), + ), + "x + y;", + ), + ( + cst.Expr( + cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y")), + semicolon=cst.Semicolon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + "x + y ; ", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) diff --git a/libcst/nodes/tests/test_starred.py b/libcst/nodes/tests/test_starred.py new file mode 100644 index 00000000..48cacb45 --- /dev/null +++ b/libcst/nodes/tests/test_starred.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class StarredTest(CSTNodeTest): + @data_provider( + ( + # Simple starred expression + (cst.Starred(cst.Name("foo")), "*foo", parse_expression), + # In parenthesis + ( + cst.Starred( + lpar=(cst.LeftParen(),), + expression=cst.Name("foo"), + rpar=(cst.RightParen(),), + ), + "(*foo)", + None, + ), + # Verify spacing + ( + cst.Starred( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + expression=cst.Name("foo"), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + whitespace_after_star=cst.SimpleWhitespace(" "), + ), + "( * foo )", + None, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + ( + lambda: cst.Starred(cst.Name("foo"), lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Starred(cst.Name("foo"), rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_subscript.py b/libcst/nodes/tests/test_subscript.py new file mode 100644 index 00000000..ccfa8b5d --- /dev/null +++ b/libcst/nodes/tests/test_subscript.py @@ -0,0 +1,353 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class SubscriptTest(CSTNodeTest): + @data_provider( + ( + # Simple subscript expression + ( + cst.Subscript(cst.Name("foo"), cst.Index(cst.Number(cst.Integer("5")))), + "foo[5]", + True, + ), + # Test creation of subscript with slice/extslice. + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=cst.Number(cst.Integer("1")), + upper=cst.Number(cst.Integer("2")), + step=cst.Number(cst.Integer("3")), + ), + ), + "foo[1:2:3]", + False, + ), + ( + cst.Subscript( + cst.Name("foo"), + ( + cst.ExtSlice( + cst.Slice( + lower=cst.Number(cst.Integer("1")), + upper=cst.Number(cst.Integer("2")), + step=cst.Number(cst.Integer("3")), + ) + ), + cst.ExtSlice(cst.Index(cst.Number(cst.Integer("5")))), + ), + ), + "foo[1:2:3, 5]", + False, + ), + # Test parsing of subscript with slice/extslice. + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=cst.Number(cst.Integer("1")), + first_colon=cst.Colon(), + upper=cst.Number(cst.Integer("2")), + second_colon=cst.Colon(), + step=cst.Number(cst.Integer("3")), + ), + ), + "foo[1:2:3]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + ( + cst.ExtSlice( + cst.Slice( + lower=cst.Number(cst.Integer("1")), + first_colon=cst.Colon(), + upper=cst.Number(cst.Integer("2")), + second_colon=cst.Colon(), + step=cst.Number(cst.Integer("3")), + ), + comma=cst.Comma(), + ), + cst.ExtSlice(cst.Index(cst.Number(cst.Integer("5")))), + ), + ), + "foo[1:2:3,5]", + True, + ), + # Some more wild slice creations + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=cst.Number(cst.Integer("1")), + upper=cst.Number(cst.Integer("2")), + ), + ), + "foo[1:2]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice(lower=cst.Number(cst.Integer("1")), upper=None), + ), + "foo[1:]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice(lower=None, upper=cst.Number(cst.Integer("2"))), + ), + "foo[:2]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=cst.Number(cst.Integer("1")), + upper=None, + step=cst.Number(cst.Integer("3")), + ), + ), + "foo[1::3]", + False, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=None, upper=None, step=cst.Number(cst.Integer("3")) + ), + ), + "foo[::3]", + False, + ), + # Some more wild slice parsings + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=cst.Number(cst.Integer("1")), + upper=cst.Number(cst.Integer("2")), + ), + ), + "foo[1:2]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice(lower=cst.Number(cst.Integer("1")), upper=None), + ), + "foo[1:]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice(lower=None, upper=cst.Number(cst.Integer("2"))), + ), + "foo[:2]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=cst.Number(cst.Integer("1")), + upper=None, + second_colon=cst.Colon(), + step=cst.Number(cst.Integer("3")), + ), + ), + "foo[1::3]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=None, + upper=None, + second_colon=cst.Colon(), + step=cst.Number(cst.Integer("3")), + ), + ), + "foo[::3]", + True, + ), + # Valid list clone operations rendering + ( + cst.Subscript(cst.Name("foo"), cst.Slice(lower=None, upper=None)), + "foo[:]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=None, upper=None, second_colon=cst.Colon(), step=None + ), + ), + "foo[::]", + True, + ), + # Valid list clone operations parsing + ( + cst.Subscript(cst.Name("foo"), cst.Slice(lower=None, upper=None)), + "foo[:]", + True, + ), + ( + cst.Subscript( + cst.Name("foo"), + cst.Slice( + lower=None, upper=None, second_colon=cst.Colon(), step=None + ), + ), + "foo[::]", + True, + ), + # In parenthesis + ( + cst.Subscript( + lpar=(cst.LeftParen(),), + value=cst.Name("foo"), + slice=cst.Index(cst.Number(cst.Integer("5"))), + rpar=(cst.RightParen(),), + ), + "(foo[5])", + True, + ), + # Verify spacing + ( + cst.Subscript( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + value=cst.Name("foo"), + lbracket=cst.LeftSquareBracket( + whitespace_after=cst.SimpleWhitespace(" ") + ), + slice=cst.Index(cst.Number(cst.Integer("5"))), + rbracket=cst.RightSquareBracket( + whitespace_before=cst.SimpleWhitespace(" ") + ), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + whitespace_after_value=cst.SimpleWhitespace(" "), + ), + "( foo [ 5 ] )", + True, + ), + ( + cst.Subscript( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + value=cst.Name("foo"), + lbracket=cst.LeftSquareBracket( + whitespace_after=cst.SimpleWhitespace(" ") + ), + slice=cst.Slice( + lower=cst.Number(cst.Integer("1")), + first_colon=cst.Colon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + upper=cst.Number(cst.Integer("2")), + second_colon=cst.Colon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + step=cst.Number(cst.Integer("3")), + ), + rbracket=cst.RightSquareBracket( + whitespace_before=cst.SimpleWhitespace(" ") + ), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + whitespace_after_value=cst.SimpleWhitespace(" "), + ), + "( foo [ 1 : 2 : 3 ] )", + True, + ), + ( + cst.Subscript( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + value=cst.Name("foo"), + lbracket=cst.LeftSquareBracket( + whitespace_after=cst.SimpleWhitespace(" ") + ), + slice=( + cst.ExtSlice( + slice=cst.Slice( + lower=cst.Number(cst.Integer("1")), + first_colon=cst.Colon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + upper=cst.Number(cst.Integer("2")), + second_colon=cst.Colon( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + step=cst.Number(cst.Integer("3")), + ), + comma=cst.Comma( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + ), + cst.ExtSlice(slice=cst.Index(cst.Number(cst.Integer("5")))), + ), + rbracket=cst.RightSquareBracket( + whitespace_before=cst.SimpleWhitespace(" ") + ), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + whitespace_after_value=cst.SimpleWhitespace(" "), + ), + "( foo [ 1 : 2 : 3 , 5 ] )", + True, + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str, check_parsing: bool) -> None: + if check_parsing: + self.validate_node(node, code, parse_expression) + else: + self.validate_node(node, code) + + @data_provider( + ( + ( + lambda: cst.Subscript( + cst.Name("foo"), + cst.Index(cst.Number(cst.Integer("5"))), + lpar=(cst.LeftParen(),), + ), + "left paren without right paren", + ), + ( + lambda: cst.Subscript( + cst.Name("foo"), + cst.Index(cst.Number(cst.Integer("5"))), + rpar=(cst.RightParen(),), + ), + "right paren without left paren", + ), + (lambda: cst.Subscript(cst.Name("foo"), ()), "empty ExtSlice"), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_trailing_whitespace.py b/libcst/nodes/tests/test_trailing_whitespace.py new file mode 100644 index 00000000..00f2138d --- /dev/null +++ b/libcst/nodes/tests/test_trailing_whitespace.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.testing.utils import data_provider + + +class TrailingWhitespaceTest(CSTNodeTest): + @data_provider( + ( + (cst.TrailingWhitespace(), "\n"), + (cst.TrailingWhitespace(whitespace=cst.SimpleWhitespace(" ")), " \n"), + (cst.TrailingWhitespace(comment=cst.Comment("# comment")), "# comment\n"), + (cst.TrailingWhitespace(newline=cst.Newline("\r\n")), "\r\n"), + ( + cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# comment"), + newline=cst.Newline("\r\n"), + ), + " # comment\r\n", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) diff --git a/libcst/nodes/tests/test_try.py b/libcst/nodes/tests/test_try.py new file mode 100644 index 00000000..08a623b5 --- /dev/null +++ b/libcst/nodes/tests/test_try.py @@ -0,0 +1,328 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest, DummyIndentedBlock +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class TryTest(CSTNodeTest): + @data_provider( + ( + # Simple try/except block + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_except=cst.SimpleWhitespace(""), + ), + ), + ), + "try: pass\nexcept: pass\n", + parse_statement, + ), + # Try/except with a class + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("Exception"), + ), + ), + ), + "try: pass\nexcept Exception: pass\n", + parse_statement, + ), + # Try/except with a named class + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("Exception"), + name=cst.AsName(cst.Name("exc")), + ), + ), + ), + "try: pass\nexcept Exception as exc: pass\n", + parse_statement, + ), + # Try/except with multiple clauses + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("TypeError"), + name=cst.AsName(cst.Name("e")), + ), + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("KeyError"), + name=cst.AsName(cst.Name("e")), + ), + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_except=cst.SimpleWhitespace(""), + ), + ), + ), + "try: pass\n" + + "except TypeError as e: pass\n" + + "except KeyError as e: pass\n" + + "except: pass\n", + parse_statement, + ), + # Simple try/finally block + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), + ), + "try: pass\nfinally: pass\n", + parse_statement, + ), + # Simple try/except/finally block + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_except=cst.SimpleWhitespace(""), + ), + ), + finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), + ), + "try: pass\nexcept: pass\nfinally: pass\n", + parse_statement, + ), + # Simple try/except/else block + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_except=cst.SimpleWhitespace(""), + ), + ), + orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + ), + "try: pass\nexcept: pass\nelse: pass\n", + parse_statement, + ), + # Simple try/except/else block/finally + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_except=cst.SimpleWhitespace(""), + ), + ), + orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), + ), + "try: pass\nexcept: pass\nelse: pass\nfinally: pass\n", + parse_statement, + ), + # Verify whitespace in various locations + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("TypeError"), + name=cst.AsName( + cst.Name("e"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + whitespace_after_except=cst.SimpleWhitespace(" "), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + ), + orelse=cst.Else( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + finalbody=cst.Finally( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + leading_lines=(cst.EmptyLine(),), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + "\ntry : pass\nexcept TypeError as e : pass\nelse : pass\nfinally : pass\n", + parse_statement, + ), + # Please don't write code like this + ( + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("TypeError"), + name=cst.AsName(cst.Name("e")), + ), + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("KeyError"), + name=cst.AsName(cst.Name("e")), + ), + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_except=cst.SimpleWhitespace(""), + ), + ), + orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), + ), + "try: pass\n" + + "except TypeError as e: pass\n" + + "except KeyError as e: pass\n" + + "except: pass\n" + + "else: pass\n" + + "finally: pass\n", + parse_statement, + ), + # Verify indentation + ( + DummyIndentedBlock( + " ", + cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + handlers=( + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("TypeError"), + name=cst.AsName(cst.Name("e")), + ), + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("KeyError"), + name=cst.AsName(cst.Name("e")), + ), + cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_except=cst.SimpleWhitespace(""), + ), + ), + orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), + ), + ), + " try: pass\n" + + " except TypeError as e: pass\n" + + " except KeyError as e: pass\n" + + " except: pass\n" + + " else: pass\n" + + " finally: pass\n", + None, + ), + # Verify indentation in bodies + ( + DummyIndentedBlock( + " ", + cst.Try( + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + handlers=( + cst.ExceptHandler( + cst.IndentedBlock( + (cst.SimpleStatementLine((cst.Pass(),)),) + ), + whitespace_after_except=cst.SimpleWhitespace(""), + ), + ), + orelse=cst.Else( + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)) + ), + finalbody=cst.Finally( + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)) + ), + ), + ), + " try:\n" + + " pass\n" + + " except:\n" + + " pass\n" + + " else:\n" + + " pass\n" + + " finally:\n" + + " pass\n", + None, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + (lambda: cst.AsName(cst.Name("")), "empty name identifier"), + ( + lambda: cst.AsName( + cst.Name("bla"), whitespace_after_as=cst.SimpleWhitespace("") + ), + "between 'as'", + ), + ( + lambda: cst.AsName( + cst.Name("bla"), whitespace_before_as=cst.SimpleWhitespace("") + ), + "before 'as'", + ), + ( + lambda: cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + name=cst.AsName(cst.Name("bla")), + ), + "name for an empty type", + ), + ( + lambda: cst.ExceptHandler( + cst.SimpleStatementSuite((cst.Pass(),)), + type=cst.Name("TypeError"), + whitespace_after_except=cst.SimpleWhitespace(""), + ), + "at least one space after except", + ), + ( + lambda: cst.Try(cst.SimpleStatementSuite((cst.Pass(),))), + "at least one ExceptHandler or Finally", + ), + ( + lambda: cst.Try( + cst.SimpleStatementSuite((cst.Pass(),)), + orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), + ), + "at least one ExceptHandler in order to have an Else", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_unary_op.py b/libcst/nodes/tests/test_unary_op.py new file mode 100644 index 00000000..bd1a82c0 --- /dev/null +++ b/libcst/nodes/tests/test_unary_op.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_expression +from libcst.testing.utils import data_provider + + +class UnaryOperationTest(CSTNodeTest): + @data_provider( + ( + # Simple unary operations + (cst.UnaryOperation(cst.Plus(), cst.Name("foo")), "+foo"), + (cst.UnaryOperation(cst.Minus(), cst.Name("foo")), "-foo"), + (cst.UnaryOperation(cst.BitInvert(), cst.Name("foo")), "~foo"), + (cst.UnaryOperation(cst.Not(), cst.Name("foo")), "not foo"), + # Parenthesized unary operation + ( + cst.UnaryOperation( + lpar=(cst.LeftParen(),), + operator=cst.Not(), + expression=cst.Name("foo"), + rpar=(cst.RightParen(),), + ), + "(not foo)", + ), + ( + cst.UnaryOperation( + operator=cst.Not(whitespace_after=cst.SimpleWhitespace("")), + expression=cst.Name( + "foo", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) + ), + ), + "not(foo)", + ), + # Make sure that spacing works + ( + cst.UnaryOperation( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + operator=cst.Not(whitespace_after=cst.SimpleWhitespace(" ")), + expression=cst.Name("foo"), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( not foo )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code, parse_expression) + + @data_provider( + ( + ( + lambda: cst.UnaryOperation( + cst.Plus(), cst.Name("foo"), lpar=(cst.LeftParen(),) + ), + "left paren without right paren", + ), + ( + lambda: cst.UnaryOperation( + cst.Plus(), cst.Name("foo"), rpar=(cst.RightParen(),) + ), + "right paren without left paren", + ), + ( + lambda: cst.UnaryOperation( + operator=cst.Not(whitespace_after=cst.SimpleWhitespace("")), + expression=cst.Name("foo"), + ), + "at least one space after not operator", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_while.py b/libcst/nodes/tests/test_while.py new file mode 100644 index 00000000..2fc3712b --- /dev/null +++ b/libcst/nodes/tests/test_while.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest, DummyIndentedBlock +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class WhileTest(CSTNodeTest): + @data_provider( + ( + # Simple while block + ( + cst.While( + cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(),)) + ), + "while iter(): pass\n", + parse_statement, + ), + # While block with else + ( + cst.While( + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + ), + "while iter(): pass\nelse: pass\n", + parse_statement, + ), + # indentation + ( + DummyIndentedBlock( + " ", + cst.While( + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + ), + " while iter(): pass\n", + None, + ), + # while an indented body + ( + DummyIndentedBlock( + " ", + cst.While( + cst.Call(cst.Name("iter")), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + ), + ), + " while iter():\n pass\n", + None, + ), + # leading_lines + ( + cst.While( + cst.Call(cst.Name("iter")), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# leading comment")), + ), + ), + "# leading comment\nwhile iter():\n pass\n", + parse_statement, + ), + ( + cst.While( + cst.Call(cst.Name("iter")), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + cst.Else( + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# else comment")), + ), + ), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# leading comment")), + ), + ), + "# leading comment\nwhile iter():\n pass\n# else comment\nelse:\n pass\n", + None, + ), + # Weird spacing rules + ( + cst.While( + cst.Call( + cst.Name("iter"), + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_while=cst.SimpleWhitespace(""), + ), + "while(iter()): pass\n", + parse_statement, + ), + # Whitespace + ( + cst.While( + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_while=cst.SimpleWhitespace(" "), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + "while iter() : pass\n", + parse_statement, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + ( + lambda: cst.While( + cst.Call(cst.Name("iter")), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_while=cst.SimpleWhitespace(""), + ), + "Must have at least one space after 'while' keyword", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_with.py b/libcst/nodes/tests/test_with.py new file mode 100644 index 00000000..d49e1215 --- /dev/null +++ b/libcst/nodes/tests/test_with.py @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Optional + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest, DummyIndentedBlock +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class WithTest(CSTNodeTest): + @data_provider( + ( + # Simple with block + ( + cst.With( + (cst.WithItem(cst.Call(cst.Name("context_mgr"))),), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "with context_mgr(): pass\n", + parse_statement, + ), + # Simple async with block + ( + cst.With( + (cst.WithItem(cst.Call(cst.Name("context_mgr"))),), + cst.SimpleStatementSuite((cst.Pass(),)), + asynchronous=cst.Asynchronous(), + ), + "async with context_mgr(): pass\n", + parse_statement, + ), + # Multiple context managers + ( + cst.With( + ( + cst.WithItem(cst.Call(cst.Name("foo"))), + cst.WithItem(cst.Call(cst.Name("bar"))), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "with foo(), bar(): pass\n", + None, + ), + ( + cst.With( + ( + cst.WithItem( + cst.Call(cst.Name("foo")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.WithItem(cst.Call(cst.Name("bar"))), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "with foo(), bar(): pass\n", + parse_statement, + ), + # With block containing variable for context manager. + ( + cst.With( + ( + cst.WithItem( + cst.Call(cst.Name("context_mgr")), + cst.AsName(cst.Name("ctx")), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + "with context_mgr() as ctx: pass\n", + parse_statement, + ), + # indentation + ( + DummyIndentedBlock( + " ", + cst.With( + (cst.WithItem(cst.Call(cst.Name("context_mgr"))),), + cst.SimpleStatementSuite((cst.Pass(),)), + ), + ), + " with context_mgr(): pass\n", + None, + ), + # with an indented body + ( + DummyIndentedBlock( + " ", + cst.With( + (cst.WithItem(cst.Call(cst.Name("context_mgr"))),), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + ), + ), + " with context_mgr():\n pass\n", + None, + ), + # leading_lines + ( + cst.With( + (cst.WithItem(cst.Call(cst.Name("context_mgr"))),), + cst.SimpleStatementSuite((cst.Pass(),)), + leading_lines=( + cst.EmptyLine(comment=cst.Comment("# leading comment")), + ), + ), + "# leading comment\nwith context_mgr(): pass\n", + parse_statement, + ), + # Weird spacing rules + ( + cst.With( + ( + cst.WithItem( + cst.Call( + cst.Name("context_mgr"), + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ) + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_with=cst.SimpleWhitespace(""), + ), + "with(context_mgr()): pass\n", + parse_statement, + ), + # Whitespace + ( + cst.With( + ( + cst.WithItem( + cst.Call(cst.Name("context_mgr")), + cst.AsName( + cst.Name("ctx"), + whitespace_before_as=cst.SimpleWhitespace(" "), + whitespace_after_as=cst.SimpleWhitespace(" "), + ), + ), + ), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_with=cst.SimpleWhitespace(" "), + whitespace_before_colon=cst.SimpleWhitespace(" "), + ), + "with context_mgr() as ctx : pass\n", + parse_statement, + ), + ) + ) + def test_valid( + self, + node: cst.CSTNode, + code: str, + parser: Optional[Callable[[str], cst.CSTNode]], + ) -> None: + self.validate_node(node, code, parser) + + @data_provider( + ( + ( + lambda: cst.With( + (), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)) + ), + "A With statement must have at least one WithItem", + ), + ( + lambda: cst.With( + ( + cst.WithItem( + cst.Call(cst.Name("foo")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + ), + cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), + ), + "The last WithItem in a With cannot have a trailing comma", + ), + ( + lambda: cst.With( + (cst.WithItem(cst.Call(cst.Name("context_mgr"))),), + cst.SimpleStatementSuite((cst.Pass(),)), + whitespace_after_with=cst.SimpleWhitespace(""), + ), + "Must have at least one space after with keyword", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) diff --git a/libcst/nodes/tests/test_yield.py b/libcst/nodes/tests/test_yield.py new file mode 100644 index 00000000..16d6a6e0 --- /dev/null +++ b/libcst/nodes/tests/test_yield.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +import libcst.nodes as cst +from libcst.nodes.tests.base import CSTNodeTest +from libcst.parser import parse_statement +from libcst.testing.utils import data_provider + + +class YieldConstructionTest(CSTNodeTest): + @data_provider( + ( + # Simple yield + (cst.Yield(), "yield"), + # yield expression + (cst.Yield(cst.Name("a")), "yield a"), + # yield from expression + (cst.Yield(cst.From(cst.Call(cst.Name("a")))), "yield from a()"), + # Parenthesizing tests + ( + cst.Yield( + lpar=(cst.LeftParen(),), + value=cst.Number(cst.Integer("5")), + rpar=(cst.RightParen(),), + ), + "(yield 5)", + ), + # Whitespace oddities tests + ( + cst.Yield( + cst.Name("a", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)), + whitespace_after_yield=cst.SimpleWhitespace(""), + ), + "yield(a)", + ), + ( + cst.Yield( + cst.From( + cst.Call( + cst.Name("a"), + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ), + whitespace_after_from=cst.SimpleWhitespace(""), + ) + ), + "yield from(a())", + ), + # Whitespace rendering/parsing tests + ( + cst.Yield( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + value=cst.Number(cst.Integer("5")), + whitespace_after_yield=cst.SimpleWhitespace(" "), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( yield 5 )", + ), + ( + cst.Yield( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + value=cst.From( + cst.Call(cst.Name("bla")), + whitespace_after_from=cst.SimpleWhitespace(" "), + ), + whitespace_after_yield=cst.SimpleWhitespace(" "), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( yield from bla() )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + self.validate_node(node, code) + + @data_provider( + ( + # Paren validation + ( + lambda: cst.Yield(lpar=(cst.LeftParen(),)), + "left paren without right paren", + ), + ( + lambda: cst.Yield(rpar=(cst.RightParen(),)), + "right paren without left paren", + ), + # Make sure we have adequate space after yield + ( + lambda: cst.Yield( + cst.Name("a"), whitespace_after_yield=cst.SimpleWhitespace("") + ), + "Must have at least one space after 'yield' keyword", + ), + ( + lambda: cst.Yield( + cst.From(cst.Call(cst.Name("a"))), + whitespace_after_yield=cst.SimpleWhitespace(""), + ), + "Must have at least one space after 'yield' keyword", + ), + # MAke sure we have adequate space after from + ( + lambda: cst.Yield( + cst.From( + cst.Call(cst.Name("a")), + whitespace_after_from=cst.SimpleWhitespace(""), + ) + ), + "Must have at least one space after 'from' keyword", + ), + ) + ) + def test_invalid( + self, get_node: Callable[[], cst.CSTNode], expected_re: str + ) -> None: + self.assert_invalid(get_node, expected_re) + + +class YieldParsingTest(CSTNodeTest): + @data_provider( + ( + # Simple yield + (cst.Yield(), "yield"), + # yield expression + ( + cst.Yield( + cst.Name("a"), whitespace_after_yield=cst.SimpleWhitespace(" ") + ), + "yield a", + ), + # yield from expression + ( + cst.Yield( + cst.From( + cst.Call(cst.Name("a")), + whitespace_after_from=cst.SimpleWhitespace(" "), + ), + whitespace_after_yield=cst.SimpleWhitespace(" "), + ), + "yield from a()", + ), + # Parenthesizing tests + ( + cst.Yield( + lpar=(cst.LeftParen(),), + whitespace_after_yield=cst.SimpleWhitespace(" "), + value=cst.Number(cst.Integer("5")), + rpar=(cst.RightParen(),), + ), + "(yield 5)", + ), + # Whitespace oddities tests + ( + cst.Yield( + cst.Name("a", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)), + whitespace_after_yield=cst.SimpleWhitespace(""), + ), + "yield(a)", + ), + ( + cst.Yield( + cst.From( + cst.Call( + cst.Name("a"), + lpar=(cst.LeftParen(),), + rpar=(cst.RightParen(),), + ), + whitespace_after_from=cst.SimpleWhitespace(""), + ), + whitespace_after_yield=cst.SimpleWhitespace(" "), + ), + "yield from(a())", + ), + # Whitespace rendering/parsing tests + ( + cst.Yield( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + value=cst.Number(cst.Integer("5")), + whitespace_after_yield=cst.SimpleWhitespace(" "), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( yield 5 )", + ), + ( + cst.Yield( + lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), + value=cst.From( + cst.Call(cst.Name("bla")), + whitespace_after_from=cst.SimpleWhitespace(" "), + ), + whitespace_after_yield=cst.SimpleWhitespace(" "), + rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), + ), + "( yield from bla() )", + ), + ) + ) + def test_valid(self, node: cst.CSTNode, code: str) -> None: + # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. + self.validate_node(node, code, lambda code: parse_statement(code).body[0].value) diff --git a/libcst/parser/__init__.py b/libcst/parser/__init__.py new file mode 100644 index 00000000..4fc8f8b7 --- /dev/null +++ b/libcst/parser/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from libcst.parser._entrypoints import parse_expression, parse_module, parse_statement + + +__all__ = ["parse_module", "parse_expression", "parse_statement"] diff --git a/libcst/parser/_base_parser.py b/libcst/parser/_base_parser.py new file mode 100644 index 00000000..efca8fbe --- /dev/null +++ b/libcst/parser/_base_parser.py @@ -0,0 +1,187 @@ +# Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +# Modifications: +# Copyright David Halter and Contributors +# Modifications are dual-licensed: MIT and PSF. +# 99% of the code is different from pgen2, now. + +# A fork of `parso.parser`. +# https://github.com/davidhalter/parso/blob/v0.3.4/parso/parser.py +# +# The following changes were made: +# - Typing was added. +# - Error recovery is removed. +# - The Jedi-specific _allowed_transition_names_and_token_types API is removed. +# - Improved error messages by using our exceptions module. +# - node_map/leaf_map were removed in favor of just calling convert_*. +# - convert_node/convert_leaf were renamed to convert_nonterminal/convert_terminal +# - convert_nonterminal is called regardless of the number of children. Parso avoids +# calling it in some cases to avoid creating extra nodes. +# - The parser is constructed with the tokens to allow us to track a bit more state. As +# As a consequence parser may only be used once. +# - Supports our custom Token class, instead of `parso.python.tokenize.Token`. + +# pyre-strict + +from dataclasses import dataclass, field +from typing import Generic, Iterable, List, Sequence, TypeVar, Union + +from parso.pgen2.generator import DFAState, Grammar, ReservedString +from parso.python.token import TokenType + +from libcst.exceptions import ParserSyntaxError +from libcst.parser._types.token import Token + + +_NodeT = TypeVar("_NodeT") +_LeafT = TypeVar("_LeafT") +_TokenTypeT = TypeVar("_TokenTypeT", bound=TokenType) +_TokenT = TypeVar("_TokenT", bound=Token) + + +@dataclass(frozen=False) +class StackNode(Generic[_TokenTypeT, _NodeT]): + dfa: "DFAState[_TokenTypeT]" + nodes: List[_NodeT] = field(default_factory=list) + + @property + def nonterminal(self) -> str: + return self.dfa.from_rule + + +def _token_to_transition( + grammar: "Grammar[_TokenTypeT]", type_: _TokenTypeT, value: str +) -> Union[ReservedString, _TokenTypeT]: + # Map from token to label + if type_.contains_syntax: + # Check for reserved words (keywords) + try: + return grammar.reserved_syntax_strings[value] + except KeyError: + pass + + return type_ + + +# TODO: This should be an ABC, but there's a metaclass conflict between Generic and ABC +# that's fixed in Python 3.7. +class BaseParser(Generic[_TokenT, _TokenTypeT, _NodeT]): + """Parser engine. + + A Parser instance contains state pertaining to the current token + sequence, and should not be used concurrently by different threads + to parse separate token sequences. + + See python/tokenize.py for how to get input tokens by a string. + """ + + tokens: Iterable[_TokenT] + lines: Sequence[str] # used when generating parse errors + _pgen_grammar: "Grammar[_TokenTypeT]" + stack: List[StackNode[_TokenTypeT, _NodeT]] + # Keep track of if parse was called. Because a parser may keep global mutable state, + # each BaseParser instance should only be used once. + __was_parse_called: bool + + def __init__( + self, + *, + tokens: Iterable[_TokenT], + lines: Sequence[str], + pgen_grammar: "Grammar[_TokenTypeT]", + start_nonterminal: str, + ) -> None: + self.tokens = tokens + self.lines = lines + self._pgen_grammar = pgen_grammar + first_dfa = pgen_grammar.nonterminal_to_dfas[start_nonterminal][0] + self.stack = [StackNode(first_dfa)] + self.__was_parse_called = False + + def parse(self) -> _NodeT: + # Ensure that we don't re-use parsers. + if self.__was_parse_called: + raise Exception("Each parser object may only be used to parse once.") + self.__was_parse_called = True + + for token in self.tokens: + self._add_token(token) + + while True: + tos = self.stack[-1] + if not tos.dfa.is_final: + # We never broke out -- EOF is too soon -- Unfinished statement. + raise ParserSyntaxError( + message="incomplete input", + encountered=None, + expected=tos.dfa.arcs.keys(), + pos=(len(self.lines), len(self.lines[-1])), + lines=self.lines, + ) + + if len(self.stack) > 1: + self._pop() + else: + return self.convert_nonterminal(tos.nonterminal, tos.nodes) + + def convert_nonterminal( + self, nonterminal: str, children: Sequence[_NodeT] + ) -> _NodeT: + ... + + def convert_terminal(self, token: _TokenT) -> _NodeT: + ... + + def _add_token(self, token: _TokenT) -> None: + """ + This is the only core function for parsing. Here happens basically + everything. Everything is well prepared by the parser generator and we + only apply the necessary steps here. + """ + grammar = self._pgen_grammar + stack = self.stack + # pyre-fixme[6]: Expected `_TokenTypeT` for 2nd param but got `TokenType`. + transition = _token_to_transition(grammar, token.type, token.string) + + while True: + try: + plan = stack[-1].dfa.transitions[transition] + break + except KeyError: + if stack[-1].dfa.is_final: + self._pop() + else: + raise ParserSyntaxError( + message="incomplete input", + encountered=token.string, + expected=stack[-1].dfa.arcs.keys(), + pos=token.start_pos, + lines=self.lines, + ) + except IndexError: + raise ParserSyntaxError( + message="too much input", + encountered=token.string, + expected=None, # EOF + pos=token.start_pos, + lines=self.lines, + ) + + # Logically, `plan` is always defined, but pyre can't reasonably determine that. + # pyre-fixme[18]: Global name `plan` is undefined. + stack[-1].dfa = plan.next_dfa + + for push in plan.dfa_pushes: + stack.append(StackNode(push)) + + leaf = self.convert_terminal(token) + stack[-1].nodes.append(leaf) + + def _pop(self) -> None: + tos = self.stack.pop() + # Unlike parso and lib2to3, we call `convert_nonterminal` unconditionally + # instead of only when we have more than one child. This allows us to create a + # far more consistent and predictable tree. + new_node = self.convert_nonterminal(tos.dfa.from_rule, tos.nodes) + self.stack[-1].nodes.append(new_node) diff --git a/libcst/parser/_conversions/README.md b/libcst/parser/_conversions/README.md new file mode 100644 index 00000000..798e3d18 --- /dev/null +++ b/libcst/parser/_conversions/README.md @@ -0,0 +1,209 @@ +# Parser Conversions Developer Guide + +Parser conversions take grammar productions and convert them to CST nodes, or to some +"partial" value that will later be converted to a CST node. + +The grammar production that parser conversions are associated with is co-located +alongside the conversion function using our `@with_production` decorator. This is +similar to the API that [rply](https://github.com/alex/rply/) uses. + +Grammar productions are collected when the parser is first called, and converted into a +state machine by Parso's pgen2 fork. + +Unlike rply's API, productions are not automatically gathered, because that would be +dependent on implicit import-time side-effects. Instead all conversion functions must be +listed in `_grammar.py`. + +# What's a production? + +A production is a line in our BNF-like grammar definition. A production has a name (the +first argument of `@with_production`), and a sequence of children (the second argument +of `@with_production`). + +Python's full grammar is here: https://docs.python.org/3/reference/grammar.html + +We use Parso's fork of pgen2, and therefore support the same BNF-like syntax that +Python's documentation uses. + +# Why is everything `Any`-typed? Isn't that bad? + +Yes, `Any` types indicate a gap in static type coverage. Unfortunately, this isn't +easily solved. + +The value of `children` given to a conversion function is dependent on textual grammar +representation and pgen2's implementation, which the type system is unaware of. Unless +we extend the type system to support pgen2 (unlikely) or add a layer of +machine-generated code (possible, but we're not there), there's no way for the type +system to validate any annotations on `children`. + +We could add annotations to `children`, but they're usually complicated types (so they +wouldn't be very human-readable), and they wouldn't actually provide any type safety +because the type checker doesn't know about them. + +Similarly, we could annotate return type annotations, but that's just duplicating the +type we're already expressing in our return statement (so it doesn't improve readability +much), and it's not providing any static type safety. + +We do perform runtime type checks inside tests, and we hope that this test coverage will +help compensate for the lack of static type safety. + +# Where's the whitespace? + +The most important differentiation between an Abstract Syntax Tree and a Concrete Syntax +Tree (for our purposes) is that the CST contains enough information to exactly reproduce +the original program. This means that we must somehow capture and store whitespace. + +The grammar does not contain whitespace information, and there are no explicit tokens +for whitespace. If the grammar did contain whitespace information, the grammar likely +wouldn't be LL(1), and while we could use another context free grammar parsing +algorithm, it would add complexity and likely wouldn't be as efficient. + +Instead, we have a hand-written re-entrant recursive-descent parser for whitespace. It's +the responsibility of conversion functions to call into this parser given whitespace +states before and after a token. + +# Token and WhitespaceState Data Structures + +A token is defined as: + +``` +class Token: + type: TokenType + string: str + # The start of where `string` is in the source, not including leading whitespace. + start_pos: Tuple[int, int] + # The end of where `string` is in the source, not including trailing whitespace. + end_pos: Tuple[int, int] + whitespace_before: WhitespaceState + whitespace_after: WhitespaceState +``` + +Or, in the order that these pieces appear lexically in a parsed program: + +``` ++-------------------+--------+-------------------+ +| whitespace_before | string | whitespace_after | +| (WhitespaceState) | (str) | (WhitespaceState) | ++-------------------+--------+-------------------+ +``` + +Tokens are immutable, but only shallowly, because their whitespace fields are mutable +WhitespaceState objects. + +WhitespaceStates are opaque objects that the whitespace parser consumes and mutates. +WhitespaceState nodes are shared across multiple tokens, so `whitespace_after` is the +same object as `whitespace_before` in the next token. + +# Parser Execution Order + +The parser generator we use (`pgen2`) is bottom-up, meaning that children productions +are called before their parents. In contrast, our hand written whitespace parser is +top-down. + +Inside each production, child conversion functions are called from left to right. + +As an example, assume we're given the following simple grammar and program: + +``` +add_expr: NUMBER ['+' add_expr] +``` + +``` +1 + 2 + 3 +``` + +which forms the parse tree: + +``` + [H] add_expr + / | \ +[A] 1 [B] '+' [G] add_expr + / | \ + [C] 2 [D] '+' [F] add_expr + | + [E] 3 +``` + +The conversion functions would be called in the labeled alphabetical order, with `A` +converted first, and `H` converted last. + +# Who owns whitespace? + +There's a lot of holes between you and a correct whitespace representation, but these +can be divided into a few categories of potential mistakes: + +## Forgetting to Parse Whitespace + +Fortunately, the inverse (parsing the same whitespace twice) should not be possible, +because whitespace is "consumed" by the whitespace parser. + +This kind of mistake is easily caught with tests. + +## Assigning Whitespace to the Wrong Owner + +This is probably the easiest mistake to make. The general convention is that the +top-most possible node owns whitespace, but in a bottom-up parser like ours, the +children are parsed before their parents. + +In contrast, the best owner for whitespace in our tree when there's multiple possible +owners is usually the top-most node. + +As an example, assume we have the following grammar and program: + +``` +simple_stmt: (pass_stmt ';')* NEWLINE +``` + +``` +pass; # comment +``` + +Since both `cst.Semicolon` and `cst.SimpleStatement` can both store some whitespace +after themselves, there's some ambiguity about who should own the space character before +the comment. However, since `cst.SimpleStatement` is the parent, the convention is that +it should own it. + +Unfortunately, since nodes are processed bottom-to-top and left-to-right, the semicolon +under `simple_stmt` will get processed before `simple_stmt` is. This means that in a +naive implementation, the semicolon's conversion function would have a chance to consume +the whitespace before `simple_stmt` can. + +To solve this problem, you must "fix" the whitespace in the parent node's conversion +function or grammar. This can be done in a number of ways. In order of preference: + +1. Split the child's grammar production into two separate productions, one that consumes + it's leading or trailing whitespace, and one that doesn't. Depending on the parent, + use the appropriate version of the child. +2. Construct a "partial" node in the child that doesn't consume the whitespace, and then + consume the correct whitespace in the parent. Be careful about what whitespace a + node's siblings consume. +3. "Steal" the whitespace from the child by replacing the child with a new version that + doesn't have the whitespace. + +This mistake is probably hard to catch with tests, because the CST will still reprint +correctly, but it creates ergonomic issues for tools consuming the CST. + +## Consuming Whitespace in the Wrong Order + +This mistake is probably is the hardest to make by accident, but it could still happen, +and may be hard to catch with tests. + +Given the following piece of code: + +``` +pass # trailing +# empty line +pass +``` + +The first statement should own `# trailing` (parsed using `parse_trailing_whitespace`). +The second statement then should `# empty line` (parsed using `parse_empty_lines`). + +However, it's possible that if you somehow called `parse_empty_lines` on the second +statement before calling `parse_trailing_whitespace` on the first statement, +`parse_empty_lines` could accidentally end up consuming the `# trailing` comment, +because `parse_trailing_whitespace` hasn't yet consumed it. + +However, this circumstance is unlikely, because you'd explicitly have to handle the +children out-of-order, and we have assertions inside the whitespace parser to prevent +some possible mistakes, like the one described above. diff --git a/libcst/parser/_conversions/__init__.py b/libcst/parser/_conversions/__init__.py new file mode 100644 index 00000000..62642369 --- /dev/null +++ b/libcst/parser/_conversions/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/libcst/parser/_conversions/dummy.py b/libcst/parser/_conversions/dummy.py new file mode 100644 index 00000000..aa8a610f --- /dev/null +++ b/libcst/parser/_conversions/dummy.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Sequence, Union + +import libcst.nodes as cst +from libcst.nodes._base import CSTNode +from libcst.parser._types.config import ParserConfig +from libcst.parser._types.partials import WithLeadingWhitespace +from libcst.parser._types.token import Token +from libcst.parser._whitespace_parser import parse_parenthesizable_whitespace + + +def make_dummy_node(config: ParserConfig, children: Sequence[Any]) -> Any: + wrapped_children: List[Union[CSTNode, str]] = [] + + for i, child in enumerate(children): + if isinstance(child, Token): + if i > 0: + # Leading whitespace for dummy is owned by the parent, so only + # add raw whitespace if this is isn't the first node. + wrapped_children.append( + parse_parenthesizable_whitespace(config, child.whitespace_before) + ) + # Add ourselves unconditionally. + wrapped_children.append(child.string) + elif isinstance(child, WithLeadingWhitespace): + if i > 0: + # Leading whitespace for dummy is owned by the parent, so only + # add parsed whitespace if this isn't the first node. + wrapped_children.append( + parse_parenthesizable_whitespace(config, child.whitespace_before) + ) + # Add ourselves unconditionally. + wrapped_children.append(child.value) + else: + wrapped_children.append(child) + + if hasattr(children[0], "whitespace_before"): + return WithLeadingWhitespace( + cst.DummyNode(children=wrapped_children), children[0].whitespace_before + ) + else: + return cst.DummyNode(children=wrapped_children) diff --git a/libcst/parser/_conversions/expression.py b/libcst/parser/_conversions/expression.py new file mode 100644 index 00000000..6bc7edce --- /dev/null +++ b/libcst/parser/_conversions/expression.py @@ -0,0 +1,1116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re +from tokenize import ( + Floatnumber as FLOATNUMBER_RE, + Imagnumber as IMAGNUMBER_RE, + Intnumber as INTNUMBER_RE, +) +from typing import Any, Dict, List, Sequence, Type + +import libcst.nodes as cst +from libcst._maybe_sentinel import MaybeSentinel +from libcst.parser._conversions.dummy import make_dummy_node +from libcst.parser._custom_itertools import grouper +from libcst.parser._production_decorator import with_production +from libcst.parser._types.config import ParserConfig +from libcst.parser._types.partials import ( + ArglistPartial, + AttributePartial, + CallPartial, + FormattedStringConversionPartial, + FormattedStringFormatSpecPartial, + SlicePartial, + SubscriptPartial, + WithLeadingWhitespace, +) +from libcst.parser._types.token import Token +from libcst.parser._whitespace_parser import parse_parenthesizable_whitespace + + +BINOP_TOKEN_LUT: Dict[str, Type[cst.BaseBinaryOp]] = { + "*": cst.Multiply, + "@": cst.MatrixMultiply, + "/": cst.Divide, + "%": cst.Modulo, + "//": cst.FloorDivide, + "+": cst.Add, + "-": cst.Subtract, + "<<": cst.LeftShift, + ">>": cst.RightShift, + "&": cst.BitAnd, + "^": cst.BitXor, + "|": cst.BitOr, +} + + +BOOLOP_TOKEN_LUT: Dict[str, Type[cst.BaseBooleanOp]] = {"and": cst.And, "or": cst.Or} + + +COMPOP_TOKEN_LUT: Dict[str, Type[cst.BaseCompOp]] = { + "<": cst.LessThan, + ">": cst.GreaterThan, + "==": cst.Equal, + "<=": cst.LessThanEqual, + ">=": cst.GreaterThanEqual, + "in": cst.In, + "is": cst.Is, +} + + +# N.B. This uses a `testlist | star_expr`, not a `testlist_star_expr` because +# `testlist_star_expr` may not always be representable by a non-partial node, since it's +# only used as part of `expr_stmt`. +@with_production("expression_input", "(testlist | star_expr) ENDMARKER") +def convert_expression_input(config: ParserConfig, children: Sequence[Any]) -> Any: + (child, endmarker) = children + # HACK: UGLY! REMOVE THIS SOON! + # Unwrap WithLeadingWhitespace if it exists. It shouldn't exist by this point, but + # testlist isn't fully implemented, and we currently leak these partial objects. + if isinstance(child, WithLeadingWhitespace): + child = child.value + return child + + +@with_production("testlist_star_expr", "(test|star_expr) (',' (test|star_expr))* [',']") +def convert_testlist_star_expr(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + else: + return make_dummy_node(config, children) + + +@with_production("test", "or_test ['if' or_test 'else' test] | lambdef") +def convert_test(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + else: + (body, if_token, test, else_token, orelse) = children + return WithLeadingWhitespace( + cst.IfExp( + body=body.value, + test=test.value, + orelse=orelse.value, + whitespace_before_if=parse_parenthesizable_whitespace( + config, if_token.whitespace_before + ), + whitespace_after_if=parse_parenthesizable_whitespace( + config, if_token.whitespace_after + ), + whitespace_before_else=parse_parenthesizable_whitespace( + config, else_token.whitespace_before + ), + whitespace_after_else=parse_parenthesizable_whitespace( + config, else_token.whitespace_after + ), + ), + body.whitespace_before, + ) + + +@with_production("test_nocond", "or_test | lambdef_nocond") +def convert_test_nocond(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + return child + + +@with_production("lambdef", "'lambda' [varargslist] ':' test") +@with_production("lambdef_nocond", "'lambda' [varargslist] ':' test_nocond") +def convert_lambda(config: ParserConfig, children: Sequence[Any]) -> Any: + lambdatoken, *params, colontoken, test = children + + # Grab the whitespace around the colon. If there are no params, then + # the colon owns the whitespace before and after it. If there are + # any params, then the last param owns the whitespace before the colon. + # We handle the parameter movement below. + colon = cst.Colon( + whitespace_before=parse_parenthesizable_whitespace( + config, colontoken.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, colontoken.whitespace_after + ), + ) + + # Unpack optional parameters + if len(params) == 0: + parameters = cst.Parameters() + whitespace_after_lambda = MaybeSentinel.DEFAULT + else: + (parameters,) = params + whitespace_after_lambda = parse_parenthesizable_whitespace( + config, lambdatoken.whitespace_after + ) + + # Handle pre-colon whitespace + if parameters.star_kwarg is not None: + if parameters.star_kwarg.comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + star_kwarg=parameters.star_kwarg.with_changes( + whitespace_after_param=colon.whitespace_before + ) + ) + elif parameters.kwonly_params: + if parameters.kwonly_params[-1].comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + kwonly_params=( + *parameters.kwonly_params[:-1], + parameters.kwonly_params[-1].with_changes( + whitespace_after_param=colon.whitespace_before + ), + ) + ) + elif isinstance(parameters.star_arg, cst.Param): + if parameters.star_arg.comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + star_arg=parameters.star_arg.with_changes( + whitespace_after_param=colon.whitespace_before + ) + ) + elif parameters.default_params: + if parameters.default_params[-1].comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + default_params=( + *parameters.default_params[:-1], + parameters.default_params[-1].with_changes( + whitespace_after_param=colon.whitespace_before + ), + ) + ) + elif parameters.params: + if parameters.params[-1].comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + params=( + *parameters.params[:-1], + parameters.params[-1].with_changes( + whitespace_after_param=colon.whitespace_before + ), + ) + ) + + # Colon doesn't own its own pre-whitespace now. + colon = colon.with_changes(whitespace_before=cst.SimpleWhitespace("")) + + # Return a lambda + return WithLeadingWhitespace( + cst.Lambda( + whitespace_after_lambda=whitespace_after_lambda, + params=parameters, + body=test.value, + colon=colon, + ), + lambdatoken.whitespace_before, + ) + + +@with_production("or_test", "and_test ('or' and_test)*") +@with_production("and_test", "not_test ('and' not_test)*") +def convert_boolop(config: ParserConfig, children: Sequence[Any]) -> Any: + leftexpr, *rightexprs = children + if len(rightexprs) == 0: + return leftexpr + + whitespace_before = leftexpr.whitespace_before + leftexpr = leftexpr.value + + # Convert all of the operations that have no precedence in a loop + for op, rightexpr in grouper(rightexprs, 2): + if op.string not in BOOLOP_TOKEN_LUT: + raise Exception(f"Unexpected token '{op.string}'!") + leftexpr = cst.BooleanOperation( + left=leftexpr, + operator=BOOLOP_TOKEN_LUT[op.string]( + whitespace_before=parse_parenthesizable_whitespace( + config, op.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, op.whitespace_after + ), + ), + right=rightexpr.value, + ) + return WithLeadingWhitespace(leftexpr, whitespace_before) + + +@with_production("not_test", "'not' not_test | comparison") +def convert_not_test(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + else: + nottoken, nottest = children + return WithLeadingWhitespace( + cst.UnaryOperation( + operator=cst.Not( + whitespace_after=parse_parenthesizable_whitespace( + config, nottoken.whitespace_after + ) + ), + expression=nottest.value, + ), + nottoken.whitespace_before, + ) + + +@with_production("comparison", "expr (comp_op expr)*") +def convert_comparison(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + + lhs, *rest = children + + comparisons: List[cst.ComparisonTarget] = [] + for operator, comparator in grouper(rest, 2): + comparisons.append( + cst.ComparisonTarget(operator=operator, comparator=comparator.value) + ) + + return WithLeadingWhitespace( + cst.Comparison(left=lhs.value, comparisons=tuple(comparisons)), + lhs.whitespace_before, + ) + + +@with_production( + "comp_op", "('<'|'>'|'=='|'>='|'<='|'<>'|'!='|'in'|'not' 'in'|'is'|'is' 'not')" +) +def convert_comp_op(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (op,) = children + if op.string in COMPOP_TOKEN_LUT: + # A regular comparison containing one token + return COMPOP_TOKEN_LUT[op.string]( + whitespace_before=parse_parenthesizable_whitespace( + config, op.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, op.whitespace_after + ), + ) + elif op.string in ["!=", "<>"]: + # Not equal, which can take two forms in some cases + return cst.NotEqual( + whitespace_before=parse_parenthesizable_whitespace( + config, op.whitespace_before + ), + value=op.string, + whitespace_after=parse_parenthesizable_whitespace( + config, op.whitespace_after + ), + ) + else: + # TODO: Make this a ParserSyntaxError + raise Exception(f"Unexpected token '{op.string}'!") + else: + # A two-token comparison + leftcomp, rightcomp = children + + if leftcomp.string == "not" and rightcomp.string == "in": + return cst.NotIn( + whitespace_before=parse_parenthesizable_whitespace( + config, leftcomp.whitespace_before + ), + whitespace_between=parse_parenthesizable_whitespace( + config, leftcomp.whitespace_after + ), + whitespace_after=parse_parenthesizable_whitespace( + config, rightcomp.whitespace_after + ), + ) + elif leftcomp.string == "is" and rightcomp.string == "not": + return cst.IsNot( + whitespace_before=parse_parenthesizable_whitespace( + config, leftcomp.whitespace_before + ), + whitespace_between=parse_parenthesizable_whitespace( + config, leftcomp.whitespace_after + ), + whitespace_after=parse_parenthesizable_whitespace( + config, rightcomp.whitespace_after + ), + ) + else: + # TODO: Make this a ParserSyntaxError + raise Exception(f"Unexpected token '{leftcomp.string} {rightcomp.string}'!") + + +@with_production("star_expr", "'*' expr") +def convert_star_expr(config: ParserConfig, children: Sequence[Any]) -> Any: + star, expr = children + return WithLeadingWhitespace( + cst.Starred( + expr.value, + whitespace_after_star=parse_parenthesizable_whitespace( + config, star.whitespace_after + ), + ), + star.whitespace_before, + ) + + +@with_production("expr", "xor_expr ('|' xor_expr)*") +@with_production("xor_expr", "and_expr ('^' and_expr)*") +@with_production("and_expr", "shift_expr ('&' shift_expr)*") +@with_production("shift_expr", "arith_expr (('<<'|'>>') arith_expr)*") +@with_production("arith_expr", "term (('+'|'-') term)*") +@with_production("term", "factor (('*'|'@'|'/'|'%'|'//') factor)*") +def convert_binop(config: ParserConfig, children: Sequence[Any]) -> Any: + leftexpr, *rightexprs = children + if len(rightexprs) == 0: + return leftexpr + + whitespace_before = leftexpr.whitespace_before + leftexpr = leftexpr.value + + # Convert all of the operations that have no precedence in a loop + for op, rightexpr in grouper(rightexprs, 2): + if op.string not in BINOP_TOKEN_LUT: + raise Exception(f"Unexpected token '{op.string}'!") + leftexpr = cst.BinaryOperation( + left=leftexpr, + operator=BINOP_TOKEN_LUT[op.string]( + whitespace_before=parse_parenthesizable_whitespace( + config, op.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, op.whitespace_after + ), + ), + right=rightexpr.value, + ) + return WithLeadingWhitespace(leftexpr, whitespace_before) + + +@with_production("factor", "('+'|'-'|'~') factor | power") +def convert_factor(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + + op, factor = children + + # First, tokenize the unary operator + if op.string == "+": + opnode = cst.Plus( + whitespace_after=parse_parenthesizable_whitespace( + config, op.whitespace_after + ) + ) + elif op.string == "-": + opnode = cst.Minus( + whitespace_after=parse_parenthesizable_whitespace( + config, op.whitespace_after + ) + ) + elif op.string == "~": + opnode = cst.BitInvert( + whitespace_after=parse_parenthesizable_whitespace( + config, op.whitespace_after + ) + ) + else: + raise Exception(f"Unexpected token '{op.string}'!") + + # Second, bump the operator into a number node if that's what the + # factor is. Otherwise, return a unary operator node. + if ( + isinstance(factor.value, cst.Number) + and isinstance(opnode, (cst.Plus, cst.Minus)) + and factor.value.operator is None + ): + return WithLeadingWhitespace( + cst.Number(operator=opnode, number=factor.value.number), + op.whitespace_before, + ) + else: + return WithLeadingWhitespace( + cst.UnaryOperation(operator=opnode, expression=factor.value), + op.whitespace_before, + ) + + +@with_production("power", "atom_expr ['**' factor]") +def convert_power(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + + left, power, right = children + return WithLeadingWhitespace( + cst.BinaryOperation( + left=left.value, + operator=cst.Power( + whitespace_before=parse_parenthesizable_whitespace( + config, power.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, power.whitespace_after + ), + ), + right=right.value, + ), + left.whitespace_before, + ) + + +@with_production("atom_expr", "atom_expr_await | atom_expr_trailer") +def convert_atom_expr(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + return child + + +@with_production("atom_expr_await", "'await' atom_expr_trailer") +def convert_atom_expr_await(config: ParserConfig, children: Sequence[Any]) -> Any: + keyword, expr = children + return WithLeadingWhitespace( + cst.Await( + whitespace_after_await=parse_parenthesizable_whitespace( + config, keyword.whitespace_after + ), + expression=expr.value, + ), + keyword.whitespace_before, + ) + + +@with_production("atom_expr_trailer", "atom trailer*") +def convert_atom_expr_trailer(config: ParserConfig, children: Sequence[Any]) -> Any: + atom, *trailers = children + whitespace_before = atom.whitespace_before + atom = atom.value + + # Need to walk through all trailers from left to right and construct + # a series of nodes based on each partial type. We can't do this with + # left recursion due to limits in the parser. + for trailer in trailers: + if isinstance(trailer, SubscriptPartial): + atom = cst.Subscript( + value=atom, + whitespace_after_value=parse_parenthesizable_whitespace( + config, trailer.whitespace_before + ), + lbracket=trailer.lbracket, + slice=trailer.slice, + rbracket=trailer.rbracket, + ) + elif isinstance(trailer, AttributePartial): + atom = cst.Attribute(value=atom, dot=trailer.dot, attr=trailer.attr) + elif isinstance(trailer, CallPartial): + # If the trailing argument doesn't have a comma, then it owns the + # trailing whitespace before the rpar. Otherwise, the comma owns + # it. + if ( + len(trailer.args) > 0 + and trailer.args[-1].comma == MaybeSentinel.DEFAULT + ): + args = ( + *trailer.args[:-1], + trailer.args[-1].with_changes( + whitespace_after_arg=trailer.rpar.whitespace_before + ), + ) + else: + args = trailer.args + atom = cst.Call( + func=atom, + whitespace_after_func=parse_parenthesizable_whitespace( + config, trailer.lpar.whitespace_before + ), + whitespace_before_args=trailer.lpar.value.whitespace_after, + args=tuple(args), + ) + else: + # This is an invalid trailer, so lets give up + raise Exception("Logic error!") + return WithLeadingWhitespace(atom, whitespace_before) + + +@with_production( + "trailer", "trailer_arglist | trailer_subscriptlist | trailer_attribute" +) +def convert_trailer(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + return child + + +@with_production("trailer_arglist", "'(' [arglist] ')'") +def convert_trailer_arglist(config: ParserConfig, children: Sequence[Any]) -> Any: + lpar, *arglist, rpar = children + return CallPartial( + lpar=WithLeadingWhitespace( + cst.LeftParen( + whitespace_after=parse_parenthesizable_whitespace( + config, lpar.whitespace_after + ) + ), + lpar.whitespace_before, + ), + args=() if not arglist else arglist[0].args, + rpar=cst.RightParen( + whitespace_before=parse_parenthesizable_whitespace( + config, rpar.whitespace_before + ) + ), + ) + + +@with_production("trailer_subscriptlist", "'[' subscriptlist ']'") +def convert_trailer_subscriptlist(config: ParserConfig, children: Sequence[Any]) -> Any: + (lbracket, subscriptlist, rbracket) = children + return SubscriptPartial( + lbracket=cst.LeftSquareBracket( + whitespace_after=parse_parenthesizable_whitespace( + config, lbracket.whitespace_after + ) + ), + slice=subscriptlist.value, + rbracket=cst.RightSquareBracket( + whitespace_before=parse_parenthesizable_whitespace( + config, rbracket.whitespace_before + ) + ), + whitespace_before=lbracket.whitespace_before, + ) + + +@with_production("subscriptlist", "subscript (',' subscript)* [',']") +def convert_subscriptlist(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) > 1: + # This is a list of ExtSlice, so construct as such by grouping every + # subscript with an optional comma and adding to a list. + extslices = [] + for slice, comma in grouper(children, 2): + if comma is None: + extslices.append(cst.ExtSlice(slice=slice.value)) + else: + extslices.append( + cst.ExtSlice( + slice=slice.value, + comma=cst.Comma( + whitespace_before=parse_parenthesizable_whitespace( + config, comma.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, comma.whitespace_after + ), + ), + ) + ) + return WithLeadingWhitespace(extslices, children[0].whitespace_before) + else: + # This is an Index or Slice, as parsed in the child. + (index_or_slice,) = children + return index_or_slice + + +@with_production("subscript", "test | [test] ':' [test] [sliceop]") +def convert_subscript(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1 and not isinstance(children[0], Token): + # This is just an index node + (test,) = children + return WithLeadingWhitespace(cst.Index(test.value), test.whitespace_before) + + if isinstance(children[-1], SlicePartial): + # We got a partial slice as the final param. Extract the final + # bits of the full subscript. + *others, sliceop = children + whitespace_before = others[0].whitespace_before + second_colon = sliceop.second_colon + step = sliceop.step + else: + # We can just parse this below, without taking extras from the + # partial child. + others = children + whitespace_before = others[0].whitespace_before + second_colon = MaybeSentinel.DEFAULT + step = None + + # We need to create a partial slice to pass up. So, align so we have + # a list that's always [Optional[Test], Colon, Optional[Test]]. + if isinstance(others[0], Token): + # First token is a colon, so insert an empty test on the LHS. We + # know the RHS is a test since it's not a sliceop. + slicechildren = [None, *others] + else: + # First token is non-colon, so its a test. + slicechildren = [*others] + + if len(slicechildren) < 3: + # Now, we have to fill in the RHS. We know its two long + # at this point if its not already 3. + slicechildren = [*slicechildren, None] + + lower, first_colon, upper = slicechildren + return WithLeadingWhitespace( + cst.Slice( + lower=lower.value if lower is not None else None, + first_colon=cst.Colon( + whitespace_before=parse_parenthesizable_whitespace( + config, first_colon.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, first_colon.whitespace_after + ), + ), + upper=upper.value if upper is not None else None, + second_colon=second_colon, + step=step, + ), + whitespace_before=whitespace_before, + ) + + +@with_production("sliceop", "':' [test]") +def convert_sliceop(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 2: + colon, test = children + step = test.value + else: + (colon,) = children + step = None + return SlicePartial( + second_colon=cst.Colon( + whitespace_before=parse_parenthesizable_whitespace( + config, colon.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, colon.whitespace_after + ), + ), + step=step, + ) + + +@with_production("trailer_attribute", "'.' NAME") +def convert_trailer_attribute(config: ParserConfig, children: Sequence[Any]) -> Any: + dot, name = children + return AttributePartial( + dot=cst.Dot( + whitespace_before=parse_parenthesizable_whitespace( + config, dot.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, dot.whitespace_after + ), + ), + attr=cst.Name(name.string), + ) + + +@with_production( + "atom", + "atom_parens | atom_squarebrackets | atom_curlybrackets | atom_string | atom_fstring | atom_basic | atom_ellipses", +) +def convert_atom(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + return child + + +@with_production("atom_basic", "NAME | NUMBER | 'None' | 'True' | 'False'") +def convert_atom_basic(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + if child.type.name == "NAME": + # This also handles 'None', 'True', and 'False' directly, but we + # keep it in the grammar to be more correct. + return WithLeadingWhitespace(cst.Name(child.string), child.whitespace_before) + elif child.type.name == "NUMBER": + # We must determine what type of number it is since we split node + # types up this way. + if re.fullmatch(INTNUMBER_RE, child.string): + return WithLeadingWhitespace( + cst.Number(cst.Integer(child.string)), child.whitespace_before + ) + elif re.fullmatch(FLOATNUMBER_RE, child.string): + return WithLeadingWhitespace( + cst.Number(cst.Float(child.string)), child.whitespace_before + ) + elif re.fullmatch(IMAGNUMBER_RE, child.string): + return WithLeadingWhitespace( + cst.Number(cst.Imaginary(child.string)), child.whitespace_before + ) + else: + raise Exception("Unparseable number {child.string}") + else: + raise Exception(f"Logic error, unexpected token {child.type.name}") + + +@with_production("atom_squarebrackets", "'[' [testlist_comp] ']'") +def convert_atom_squarebrackets(config: ParserConfig, children: Sequence[Any]) -> Any: + return make_dummy_node(config, children) + + +@with_production("atom_curlybrackets", "'{' [dictorsetmaker] '}'") +def convert_atom_curlybrackets(config: ParserConfig, children: Sequence[Any]) -> Any: + return make_dummy_node(config, children) + + +@with_production("atom_parens", "'(' [yield_expr|testlist_comp] ')'") +def convert_atom_parens(config: ParserConfig, children: Sequence[Any]) -> Any: + lpar, *atoms, rpar = children + + if len(atoms) == 1: + inner_atom = atoms[0].value + # With numbers, we bubble up the parens to the innermost node since + # Number() is just a wrapper to match on any valid number. The only + # instance where we don't do this is in the case that a number has + # a unary operator associated with it. In this case, the outer parens + # are owned by the Number node instead of the inner Integer/Float/Imaginary. + if isinstance(inner_atom, cst.Number) and inner_atom.operator is None: + return WithLeadingWhitespace( + inner_atom.with_changes( + number=inner_atom.number.with_changes( + lpar=( + ( + cst.LeftParen( + whitespace_after=parse_parenthesizable_whitespace( + config, lpar.whitespace_after + ) + ), + ) + + tuple(inner_atom.lpar) + ), + rpar=( + tuple(inner_atom.rpar) + + ( + cst.RightParen( + whitespace_before=parse_parenthesizable_whitespace( + config, rpar.whitespace_before + ) + ), + ) + ), + ) + ), + lpar.whitespace_before, + ) + else: + return WithLeadingWhitespace( + inner_atom.with_changes( + lpar=( + ( + cst.LeftParen( + whitespace_after=parse_parenthesizable_whitespace( + config, lpar.whitespace_after + ) + ), + ) + + tuple(inner_atom.lpar) + ), + rpar=( + tuple(inner_atom.rpar) + + ( + cst.RightParen( + whitespace_before=parse_parenthesizable_whitespace( + config, rpar.whitespace_before + ) + ), + ) + ), + ), + lpar.whitespace_before, + ) + else: + # We don't support tuples yet + return make_dummy_node(config, children) + + +@with_production("atom_ellipses", "'...'") +def convert_atom_ellipses(config: ParserConfig, children: Sequence[Any]) -> Any: + (token,) = children + return WithLeadingWhitespace(cst.Ellipses(), token.whitespace_before) + + +@with_production("atom_string", "STRING atom_string | STRING STRING | STRING") +def convert_atom_string(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + return WithLeadingWhitespace( + cst.SimpleString(children[0].string), children[0].whitespace_before + ) + else: + left, right = children + if isinstance(right, Token): + return WithLeadingWhitespace( + cst.ConcatenatedString( + left=cst.SimpleString(left.string), + whitespace_between=parse_parenthesizable_whitespace( + config, right.whitespace_before + ), + right=cst.SimpleString(right.string), + ), + left.whitespace_before, + ) + else: + return WithLeadingWhitespace( + cst.ConcatenatedString( + left=cst.SimpleString(left.string), + whitespace_between=parse_parenthesizable_whitespace( + config, right.whitespace_before + ), + right=right.value, + ), + left.whitespace_before, + ) + + +@with_production("atom_fstring", "fstring [ atom_fstring | fstring ]") +def convert_atom_fstring(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + # Return the already-parsed f-string object + (child,) = children + return child + else: + left, right = children + # Return a concatenated version of these two f-strings + return WithLeadingWhitespace( + cst.ConcatenatedString( + left=left.value, + whitespace_between=parse_parenthesizable_whitespace( + config, right.whitespace_before + ), + right=right.value, + ), + left.whitespace_before, + ) + + +@with_production("fstring", "FSTRING_START fstring_content* FSTRING_END") +def convert_fstring(config: ParserConfig, children: Sequence[Any]) -> Any: + start, *content, end = children + return WithLeadingWhitespace( + cst.FormattedString(start=start.string, parts=tuple(content), end=end.string), + start.whitespace_before, + ) + + +@with_production("fstring_content", "FSTRING_STRING | fstring_expr") +def convert_fstring_content(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + if isinstance(child, Token): + # Construct and return a raw string portion. + return cst.FormattedStringText(child.string) + else: + # Pass the expression up one production. + return child + + +@with_production("fstring_conversion", "'!' NAME") +def convert_fstring_conversion(config: ParserConfig, children: Sequence[Any]) -> Any: + exclaim, name = children + # There cannot be a space between the two tokens, so no need to preserve this. + return FormattedStringConversionPartial(name.string, exclaim.whitespace_before) + + +@with_production( + "fstring_expr", "'{' testlist [ fstring_conversion ] [ fstring_format_spec ] '}'" +) +def convert_fstring_expr(config: ParserConfig, children: Sequence[Any]) -> Any: + openbrkt, testlist, *conversions, closebrkt = children + + # Extract any optional conversion + if len(conversions) > 0 and isinstance( + conversions[0], FormattedStringConversionPartial + ): + conversion = conversions[0].value + conversions = conversions[1:] + else: + conversion = None + + # Extract any optional format spec + if len(conversions) > 0: + format_spec = conversions[0].values + else: + format_spec = None + + return cst.FormattedStringExpression( + whitespace_before_expression=parse_parenthesizable_whitespace( + config, testlist.whitespace_before + ), + expression=testlist.value, + whitespace_after_expression=parse_parenthesizable_whitespace( + config, children[2].whitespace_before + ), + conversion=conversion, + format_spec=format_spec, + ) + + +@with_production("fstring_format_spec", "':' fstring_content*") +def convert_fstring_format_spec(config: ParserConfig, children: Sequence[Any]) -> Any: + colon, *content = children + return FormattedStringFormatSpecPartial(tuple(content), colon.whitespace_before) + + +@with_production( + "testlist_comp", "(test|star_expr) ( comp_for | (',' (test|star_expr))* [','] )" +) +def convert_testlist_comp(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + else: + return make_dummy_node(config, children) + + +@with_production("testlist", "test (',' test)* [',']") +def convert_testlist(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + else: + return make_dummy_node(config, children) + + +@with_production( + "dictorsetmaker", + ( + "( ((test ':' test | '**' expr)" + + "(comp_for | (',' (test ':' test | '**' expr))* [','])) |" + + "((test | star_expr) " + + "(comp_for | (',' (test | star_expr))* [','])) )" + ), +) +def convert_dictorsetmaker(config: ParserConfig, children: Sequence[Any]) -> Any: + return make_dummy_node(config, children) + + +@with_production("exprlist", "(expr|star_expr) (',' (expr|star_expr))* [',']") +def convert_exprlist(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + else: + return make_dummy_node(config, children) + + +@with_production("arglist", "argument (',' argument)* [',']") +def convert_arglist(config: ParserConfig, children: Sequence[Any]) -> Any: + args = [] + for argument, comma in grouper(children, 2): + if comma is None: + args.append(argument) + else: + args.append( + argument.with_changes( + comma=cst.Comma( + whitespace_before=parse_parenthesizable_whitespace( + config, comma.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, comma.whitespace_after + ), + ) + ) + ) + return ArglistPartial(args) + + +@with_production("argument", "arg_assign_comp_for | star_arg") +def convert_argument(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + return child + + +@with_production("arg_assign_comp_for", "test [comp_for] | test '=' test") +def convert_arg_assign_comp_for(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + # Simple test + (child,) = children + return cst.Arg(value=child.value) + elif len(children) == 2: + # Comprehension, but we don't support comprehensions yet, so + # just set the value to a dummy node. + return cst.Arg(value=make_dummy_node(config, children).value) + else: + # "key = value" assignment argument + lhs, equal, rhs = children + return cst.Arg( + keyword=lhs.value, + equal=cst.AssignEqual( + whitespace_before=parse_parenthesizable_whitespace( + config, equal.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, equal.whitespace_after + ), + ), + value=rhs.value, + ) + + +@with_production("star_arg", "'**' test | '*' test") +def convert_star_arg(config: ParserConfig, children: Sequence[Any]) -> Any: + star, test = children + return cst.Arg( + star=star.string, + whitespace_after_star=parse_parenthesizable_whitespace( + config, star.whitespace_after + ), + value=test.value, + ) + + +@with_production("comp_iter", "comp_for | comp_if") +def convert_comp_iter(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + return child + + +@with_production("sync_comp_for", "'for' exprlist 'in' or_test [comp_iter]") +def convert_sync_comp_for(config: ParserConfig, children: Sequence[Any]) -> Any: + return make_dummy_node(config, children) + + +@with_production("comp_for", "['async'] sync_comp_for") +def convert_comp_for(config: ParserConfig, children: Sequence[Any]) -> Any: + return make_dummy_node(config, children) + + +@with_production("comp_if", "'if' test_nocond [comp_iter]") +def convert_comp_if(config: ParserConfig, children: Sequence[Any]) -> Any: + return make_dummy_node(config, children) + + +@with_production("yield_expr", "'yield' [yield_arg]") +def convert_yield_expr(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + # Yielding implicit none + (yield_token,) = children + yield_node = cst.Yield(value=None) + else: + # Yielding explicit value + (yield_token, yield_arg) = children + yield_node = cst.Yield( + value=yield_arg.value, + whitespace_after_yield=parse_parenthesizable_whitespace( + config, yield_arg.whitespace_before + ), + ) + + return WithLeadingWhitespace(yield_node, yield_token.whitespace_before) + + +@with_production("yield_arg", "'from' test | testlist") +def convert_yield_arg(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + # Just a regular testlist, pass it up + (child,) = children + return child + else: + # Its a yield from + (from_token, test) = children + + return WithLeadingWhitespace( + cst.From( + item=test.value, + whitespace_after_from=parse_parenthesizable_whitespace( + config, test.whitespace_before + ), + ), + from_token.whitespace_before, + ) diff --git a/libcst/parser/_conversions/module.py b/libcst/parser/_conversions/module.py new file mode 100644 index 00000000..18051d11 --- /dev/null +++ b/libcst/parser/_conversions/module.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Sequence + +import libcst.nodes as cst +from libcst.nodes._whitespace import NEWLINE_RE +from libcst.parser._production_decorator import with_production +from libcst.parser._types.config import ParserConfig + + +@with_production("file_input", "(NEWLINE | stmt)* ENDMARKER") +def convert_file_input(config: ParserConfig, children: Sequence[Any]) -> Any: + *body, footer = children + if len(body) == 0: + # If there's no body, the header and footer are ambiguous. The header is more + # important, and should own the EmptyLine nodes instead of the footer. + header = footer + footer = () + if ( + len(config.lines) == 2 + and NEWLINE_RE.fullmatch(config.lines[0]) + and config.lines[1] == "" + ): + # This is an empty file (not even a comment), so special-case this to an + # empty list instead of a single dummy EmptyLine (which is what we'd + # normally parse). + header = () + else: + # Steal the leading lines from the first statement, and move them into the + # header. + first_stmt = body[0] + header = first_stmt.leading_lines + body[0] = first_stmt.with_changes(leading_lines=()) + return cst.Module( + header=header, + body=body, + footer=footer, + encoding=config.encoding, + default_indent=config.default_indent, + default_newline=config.default_newline, + has_trailing_newline=config.has_trailing_newline, + ) diff --git a/libcst/parser/_conversions/params.py b/libcst/parser/_conversions/params.py new file mode 100644 index 00000000..5faa1b9d --- /dev/null +++ b/libcst/parser/_conversions/params.py @@ -0,0 +1,213 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Optional, Sequence, Union + +import libcst.nodes as cst +from libcst._maybe_sentinel import MaybeSentinel +from libcst.parser._custom_itertools import grouper +from libcst.parser._production_decorator import with_production +from libcst.parser._types.config import ParserConfig +from libcst.parser._types.partials import ParamStarPartial +from libcst.parser._whitespace_parser import parse_parenthesizable_whitespace + + +@with_production( + "typedargslist", + ( + "(tfpdef_assign (',' tfpdef_assign)* " + + "[',' [tfpdef_star (',' tfpdef_assign)* [',' [tfpdef_starstar [',']]] | tfpdef_starstar [',']]]" + + "| tfpdef_star (',' tfpdef_assign)* [',' [tfpdef_starstar [',']]] | tfpdef_starstar [','])" + ), +) +@with_production( + "varargslist", + ( + "(vfpdef_assign (',' vfpdef_assign)* " + + "[',' [vfpdef_star (',' vfpdef_assign)* [',' [vfpdef_starstar [',']]] | vfpdef_starstar [',']]]" + + "| vfpdef_star (',' vfpdef_assign)* [',' [vfpdef_starstar [',']]] | vfpdef_starstar [','])" + ), +) +def convert_argslist(config: ParserConfig, children: Sequence[Any]) -> Any: + params: List[cst.Param] = [] + default_params: List[cst.Param] = [] + star_arg: Union[cst.Param, cst.ParamStar, MaybeSentinel] = MaybeSentinel.DEFAULT + kwonly_params: List[cst.Param] = [] + star_kwarg: Optional[cst.Param] = None + + def add_param( + current_param: Optional[List[cst.Param]], param: Union[cst.Param, cst.ParamStar] + ) -> Optional[List[cst.Param]]: + nonlocal star_arg + nonlocal star_kwarg + + if isinstance(param, cst.ParamStar): + # Only can add this if we don't already have a "*" or a "*param". + if current_param in [params, default_params]: + star_arg = param + current_param = kwonly_params + else: + # TODO: We need to inform the user of an invalid syntax here + raise Exception("Syntax error!") + elif isinstance(param.star, str) and param.star == "" and param.default is None: + # Can only add this if we're in the params or kwonly_params section + if current_param is params: + params.append(param) + elif current_param is kwonly_params: + kwonly_params.append(param) + else: + # TODO: We need to inform the user of an invalid syntax here + raise Exception("Syntax error!") + elif ( + isinstance(param.star, str) + and param.star == "" + and param.default is not None + ): + if current_param is params: + current_param = default_params + # Can only add this if we're not yet at star args. + if current_param is default_params: + default_params.append(param) + elif current_param is kwonly_params: + kwonly_params.append(param) + else: + # TODO: We need to inform the user of an invalid syntax here + raise Exception("Syntax error!") + elif ( + isinstance(param.star, str) and param.star == "*" and param.default is None + ): + # Can only add this if we're in params/default_params, since + # we only allow one of "*" or "*param". + if current_param in [params, default_params]: + star_arg = param + current_param = kwonly_params + else: + # TODO: We need to inform the user of an invalid syntax here + raise Exception("Syntax error!") + elif ( + isinstance(param.star, str) and param.star == "**" and param.default is None + ): + # Can add this in all cases where we don't have a star_kwarg + # yet. + if current_param is not None: + star_kwarg = param + current_param = None + else: + # TODO: We need to inform the user of an invalid syntax here + raise Exception("Syntax error!") + else: + # TODO: We need to inform the user of an invalid syntax here + raise Exception("Syntax error!") + + return current_param + + # The parameter list we are adding to + current: Optional[List[cst.Param]] = params + + # We should have every other item in the group as a param or a comma by now, + # so split them up, add commas and then put them in the appropriate group. + for parameter, comma in grouper(children, 2): + if comma is None: + if isinstance(parameter, ParamStarPartial): + # TODO: We need to inform the user of an invalid syntax here + raise Exception("Syntax error!") + else: + current = add_param(current, parameter) + else: + comma = cst.Comma( + whitespace_before=parse_parenthesizable_whitespace( + config, comma.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, comma.whitespace_after + ), + ) + if isinstance(parameter, ParamStarPartial): + current = add_param(current, cst.ParamStar(comma=comma)) + else: + current = add_param(current, parameter.with_changes(comma=comma)) + + return cst.Parameters( + params=tuple(params), + default_params=tuple(default_params), + star_arg=star_arg, + kwonly_params=tuple(kwonly_params), + star_kwarg=star_kwarg, + ) + + +@with_production("tfpdef_star", "'*' [tfpdef]") +@with_production("vfpdef_star", "'*' [vfpdef]") +def convert_fpdef_star(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (star,) = children + return ParamStarPartial() + else: + star, param = children + return param.with_changes( + star=star.string, + whitespace_after_star=parse_parenthesizable_whitespace( + config, star.whitespace_after + ), + ) + + +@with_production("tfpdef_starstar", "'**' tfpdef") +@with_production("vfpdef_starstar", "'**' vfpdef") +def convert_fpdef_starstar(config: ParserConfig, children: Sequence[Any]) -> Any: + starstar, param = children + return param.with_changes( + star=starstar.string, + whitespace_after_star=parse_parenthesizable_whitespace( + config, starstar.whitespace_after + ), + ) + + +@with_production("tfpdef_assign", "tfpdef ['=' test]") +@with_production("vfpdef_assign", "vfpdef ['=' test]") +def convert_fpdef_assign(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (child,) = children + return child + + param, equal, default = children + return param.with_changes( + equal=cst.AssignEqual( + whitespace_before=parse_parenthesizable_whitespace( + config, equal.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, equal.whitespace_after + ), + ), + default=default.value, + ) + + +@with_production("tfpdef", "NAME [':' test]") +@with_production("vfpdef", "NAME") +def convert_fpdef(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + # This is just a parameter + (child,) = children + namenode = cst.Name(child.string) + annotation = None + else: + # This is a parameter with a type hint + name, colon, typehint = children + namenode = cst.Name(name.string) + annotation = cst.Annotation( + whitespace_before_indicator=parse_parenthesizable_whitespace( + config, colon.whitespace_before + ), + indicator=":", + whitespace_after_indicator=parse_parenthesizable_whitespace( + config, colon.whitespace_after + ), + annotation=typehint.value, + ) + + return cst.Param(star="", name=namenode, annotation=annotation, default=None) diff --git a/libcst/parser/_conversions/statement.py b/libcst/parser/_conversions/statement.py new file mode 100644 index 00000000..bf96af1a --- /dev/null +++ b/libcst/parser/_conversions/statement.py @@ -0,0 +1,1248 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type + +import libcst.nodes as cst +from libcst._maybe_sentinel import MaybeSentinel +from libcst.parser._custom_itertools import grouper +from libcst.parser._production_decorator import with_production +from libcst.parser._types.config import ParserConfig +from libcst.parser._types.partials import ( + AnnAssignPartial, + AssignPartial, + AugAssignPartial, + DecoratorPartial, + ExceptClausePartial, + FuncdefPartial, + ImportPartial, + ImportRelativePartial, + SimpleStatementPartial, + WithLeadingWhitespace, +) +from libcst.parser._types.token import Token +from libcst.parser._whitespace_parser import ( + parse_empty_lines, + parse_parenthesizable_whitespace, + parse_simple_whitespace, +) + + +AUGOP_TOKEN_LUT: Dict[str, Type[cst.BaseAugOp]] = { + "+=": cst.AddAssign, + "-=": cst.SubtractAssign, + "*=": cst.MultiplyAssign, + "@=": cst.MatrixMultiplyAssign, + "/=": cst.DivideAssign, + "%=": cst.ModuloAssign, + "&=": cst.BitAndAssign, + "|=": cst.BitOrAssign, + "^=": cst.BitXorAssign, + "<<=": cst.LeftShiftAssign, + ">>=": cst.RightShiftAssign, + "**=": cst.PowerAssign, + "//=": cst.FloorDivideAssign, +} + + +@with_production("stmt_input", "stmt ENDMARKER") +def convert_stmt_input(config: ParserConfig, children: Sequence[Any]) -> Any: + (child, endmarker) = children + return child + + +@with_production("stmt", "simple_stmt_line | compound_stmt") +def convert_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + return child + + +@with_production("simple_stmt_partial", "small_stmt (';' small_stmt)* [';'] NEWLINE") +def convert_simple_stmt_partial(config: ParserConfig, children: Sequence[Any]) -> Any: + *statements, trailing_whitespace = children + + last_stmt = len(statements) / 2 + body = [] + for i, (stmt_body, semi) in enumerate(grouper(statements, 2)): + if semi is not None: + if i == (last_stmt - 1): + # Trailing semicolons only own the whitespace before. + semi = cst.Semicolon( + whitespace_before=parse_simple_whitespace( + config, semi.whitespace_before + ), + whitespace_after=cst.SimpleWhitespace(""), + ) + else: + # Middle semicolons own the whitespace before and after. + semi = cst.Semicolon( + whitespace_before=parse_simple_whitespace( + config, semi.whitespace_before + ), + whitespace_after=parse_simple_whitespace( + config, semi.whitespace_after + ), + ) + else: + semi = MaybeSentinel.DEFAULT + body.append(stmt_body.value.with_changes(semicolon=semi)) + return SimpleStatementPartial( + body, + whitespace_before=statements[0].whitespace_before, + trailing_whitespace=trailing_whitespace, + ) + + +@with_production("simple_stmt_line", "simple_stmt_partial") +def convert_simple_stmt_line(config: ParserConfig, children: Sequence[Any]) -> Any: + """ + This function is similar to convert_simple_stmt_suite, but yields a different type + """ + (partial,) = children + return cst.SimpleStatementLine( + partial.body, + leading_lines=parse_empty_lines(config, partial.whitespace_before), + trailing_whitespace=partial.trailing_whitespace, + ) + + +@with_production("simple_stmt_suite", "simple_stmt_partial") +def convert_simple_stmt_suite(config: ParserConfig, children: Sequence[Any]) -> Any: + """ + This function is similar to convert_simple_stmt_line, but yields a different type + """ + (partial,) = children + return cst.SimpleStatementSuite( + partial.body, + leading_whitespace=parse_simple_whitespace(config, partial.whitespace_before), + trailing_whitespace=partial.trailing_whitespace, + ) + + +@with_production( + "small_stmt", + ( + "expr_stmt | del_stmt | pass_stmt | break_stmt | continue_stmt | return_stmt" + + "| raise_stmt | yield_stmt | import_stmt | global_stmt | nonlocal_stmt" + + "| assert_stmt" + ), +) +def convert_small_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + # Doesn't construct SmallStatement, because we don't know about semicolons yet. + # convert_simple_stmt will construct the SmallStatement nodes. + (small_stmt_body,) = children + return small_stmt_body + + +@with_production("expr_stmt", "testlist_star_expr (annassign | augassign | assign* )") +@with_production("yield_stmt", "yield_expr") +def convert_expr_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + # This is an unassigned expr statement (like a function call) + (test_node,) = children + return WithLeadingWhitespace( + cst.Expr(value=test_node.value), test_node.whitespace_before + ) + elif len(children) == 2: + lhs, rhs = children + if isinstance(rhs, AnnAssignPartial): + return WithLeadingWhitespace( + cst.AnnAssign( + target=lhs.value, + annotation=rhs.annotation, + equal=MaybeSentinel.DEFAULT if rhs.equal is None else rhs.equal, + value=rhs.value, + ), + lhs.whitespace_before, + ) + elif isinstance(rhs, AugAssignPartial): + return WithLeadingWhitespace( + cst.AugAssign(target=lhs.value, operator=rhs.operator, value=rhs.value), + lhs.whitespace_before, + ) + # The only thing it could be at this point is an assign with one or more targets. + # So, walk the children moving the equals ownership back one and constructing a + # list of AssignTargets. + targets = [] + for i in range(len(children) - 1): + target = children[i].value + equal = children[i + 1].equal + + targets.append( + cst.AssignTarget( + target=target, + whitespace_before_equal=equal.whitespace_before, + whitespace_after_equal=equal.whitespace_after, + ) + ) + + return WithLeadingWhitespace( + cst.Assign(targets=tuple(targets), value=children[-1].value), + children[0].whitespace_before, + ) + + +@with_production("annassign", "':' test ['=' test]") +def convert_annassign(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 2: + # Variable annotation only + colon, annotation = children + annotation = annotation.value + equal = None + value = None + elif len(children) == 4: + # Variable annotation and assignment + colon, annotation, equal, value = children + annotation = annotation.value + value = value.value + equal = cst.AssignEqual( + whitespace_before=parse_simple_whitespace(config, equal.whitespace_before), + whitespace_after=parse_simple_whitespace(config, equal.whitespace_after), + ) + else: + raise Exception("Invalid parser state!") + + return AnnAssignPartial( + annotation=cst.Annotation( + whitespace_before_indicator=parse_simple_whitespace( + config, colon.whitespace_before + ), + indicator=colon.string, + whitespace_after_indicator=parse_simple_whitespace( + config, colon.whitespace_after + ), + annotation=annotation, + ), + equal=equal, + value=value, + ) + + +@with_production( + "augassign", + ( + "('+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^=' | '<<=' | " + + "'>>=' | '**=' | '//=') (yield_expr | testlist)" + ), +) +def convert_augassign(config: ParserConfig, children: Sequence[Any]) -> Any: + op, expr = children + if op.string not in AUGOP_TOKEN_LUT: + raise Exception(f"Unexpected token '{op.string}'!") + return AugAssignPartial( + operator=AUGOP_TOKEN_LUT[op.string]( + whitespace_before=parse_simple_whitespace(config, op.whitespace_before), + whitespace_after=parse_simple_whitespace(config, op.whitespace_after), + ), + value=expr.value, + ) + + +@with_production("assign", "'=' (yield_expr|testlist_star_expr)") +def convert_assign(config: ParserConfig, children: Sequence[Any]) -> Any: + equal, expr = children + return AssignPartial( + equal=cst.AssignEqual( + whitespace_before=parse_simple_whitespace(config, equal.whitespace_before), + whitespace_after=parse_simple_whitespace(config, equal.whitespace_after), + ), + value=expr.value, + ) + + +@with_production("pass_stmt", "'pass'") +def convert_pass_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (name,) = children + return WithLeadingWhitespace(cst.Pass(), name.whitespace_before) + + +@with_production("del_stmt", "'del' exprlist") +def convert_del_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (del_name, exprlist) = children + return WithLeadingWhitespace( + cst.Del( + target=exprlist.value, + whitespace_after_del=parse_simple_whitespace( + config, del_name.whitespace_after + ), + ), + del_name.whitespace_before, + ) + + +@with_production("continue_stmt", "'continue'") +def convert_continue_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (name,) = children + return WithLeadingWhitespace(cst.Continue(), name.whitespace_before) + + +@with_production("break_stmt", "'break'") +def convert_break_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (name,) = children + return WithLeadingWhitespace(cst.Break(), name.whitespace_before) + + +@with_production("return_stmt", "'return' [testlist]") +def convert_return_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (keyword,) = children + return WithLeadingWhitespace( + cst.Return(whitespace_after_return=cst.SimpleWhitespace("")), + keyword.whitespace_before, + ) + else: + (keyword, testlist) = children + return WithLeadingWhitespace( + cst.Return( + value=testlist.value, + whitespace_after_return=parse_simple_whitespace( + config, keyword.whitespace_after + ), + ), + keyword.whitespace_before, + ) + + +@with_production("import_stmt", "import_name | import_from") +def convert_import_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (child,) = children + return child + + +@with_production("import_name", "'import' dotted_as_names") +def convert_import_name(config: ParserConfig, children: Sequence[Any]) -> Any: + importtoken, names = children + return WithLeadingWhitespace( + cst.Import( + names=names.names, + whitespace_after_import=parse_simple_whitespace( + config, importtoken.whitespace_after + ), + ), + importtoken.whitespace_before, + ) + + +@with_production("import_relative", "('.' | '...')* dotted_name | ('.' | '...')+") +def convert_import_relative(config: ParserConfig, children: Sequence[Any]) -> Any: + dots = [] + dotted_name = None + for child in children: + if isinstance(child, Token): + # Special case for "...", which is part of the grammar + if child.string == "...": + dots.extend( + [ + cst.Dot(), + cst.Dot(), + cst.Dot( + whitespace_after=parse_simple_whitespace( + config, child.whitespace_after + ) + ), + ] + ) + else: + dots.append( + cst.Dot( + whitespace_after=parse_simple_whitespace( + config, child.whitespace_after + ) + ) + ) + else: + # This should be the dotted name, and we can't get more than + # one, but lets be sure anyway + if dotted_name is not None: + raise Exception("Logic error!") + dotted_name = child + + return ImportRelativePartial(relative=tuple(dots), module=dotted_name) + + +@with_production( + "import_from", + "'from' import_relative 'import' ('*' | '(' import_as_names ')' | import_as_names)", +) +def convert_import_from(config: ParserConfig, children: Sequence[Any]) -> Any: + fromtoken, import_relative, importtoken, *importlist = children + + if len(importlist) == 1: + (possible_star,) = importlist + if isinstance(possible_star, Token): + # Its a "*" import, so we must construct this node. + names = cst.ImportStar() + else: + # Its an import as names partial, grab the names from that. + names = possible_star.names + lpar = None + rpar = None + else: + # Its an import as names partial with parens + lpartoken, namespartial, rpartoken = importlist + lpar = cst.LeftParen( + whitespace_after=parse_parenthesizable_whitespace( + config, lpartoken.whitespace_after + ) + ) + names = namespartial.names + rpar = cst.RightParen( + whitespace_before=parse_parenthesizable_whitespace( + config, rpartoken.whitespace_before + ) + ) + + # If we have a relative-only import, then we need to relocate the space + # after the final dot to be owned by the import token. + if len(import_relative.relative) > 0 and import_relative.module is None: + whitespace_before_import = import_relative.relative[-1].whitespace_after + relative = ( + *import_relative.relative[:-1], + import_relative.relative[-1].with_changes( + whitespace_after=cst.SimpleWhitespace("") + ), + ) + else: + whitespace_before_import = parse_simple_whitespace( + config, importtoken.whitespace_before + ) + relative = import_relative.relative + + return WithLeadingWhitespace( + cst.ImportFrom( + whitespace_after_from=parse_simple_whitespace( + config, fromtoken.whitespace_after + ), + relative=relative, + module=import_relative.module, + whitespace_before_import=whitespace_before_import, + whitespace_after_import=parse_simple_whitespace( + config, importtoken.whitespace_after + ), + lpar=lpar, + names=names, + rpar=rpar, + ), + fromtoken.whitespace_before, + ) + + +@with_production("import_as_name", "NAME ['as' NAME]") +def convert_import_as_name(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (dotted_name,) = children + return cst.ImportAlias(name=cst.Name(dotted_name.string), asname=None) + else: + dotted_name, astoken, name = children + return cst.ImportAlias( + name=cst.Name(dotted_name.string), + asname=cst.AsName( + whitespace_before_as=parse_simple_whitespace( + config, astoken.whitespace_before + ), + whitespace_after_as=parse_simple_whitespace( + config, astoken.whitespace_after + ), + name=cst.Name(name.string), + ), + ) + + +@with_production("dotted_as_name", "dotted_name ['as' NAME]") +def convert_dotted_as_name(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (dotted_name,) = children + return cst.ImportAlias(name=dotted_name, asname=None) + else: + dotted_name, astoken, name = children + return cst.ImportAlias( + name=dotted_name, + asname=cst.AsName( + whitespace_before_as=parse_parenthesizable_whitespace( + config, astoken.whitespace_before + ), + whitespace_after_as=parse_parenthesizable_whitespace( + config, astoken.whitespace_after + ), + name=cst.Name(name.string), + ), + ) + + +@with_production("import_as_names", "import_as_name (',' import_as_name)* [',']") +def convert_import_as_names(config: ParserConfig, children: Sequence[Any]) -> Any: + return _gather_import_names(config, children) + + +@with_production("dotted_as_names", "dotted_as_name (',' dotted_as_name)*") +def convert_dotted_as_names(config: ParserConfig, children: Sequence[Any]) -> Any: + return _gather_import_names(config, children) + + +def _gather_import_names( + config: ParserConfig, children: Sequence[Any] +) -> ImportPartial: + names = [] + for name, comma in grouper(children, 2): + if comma is None: + names.append(name) + else: + names.append( + name.with_changes( + comma=cst.Comma( + whitespace_before=parse_parenthesizable_whitespace( + config, comma.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, comma.whitespace_after + ), + ) + ) + ) + + return ImportPartial(names=names) + + +@with_production("dotted_name", "NAME ('.' NAME)*") +def convert_dotted_name(config: ParserConfig, children: Sequence[Any]) -> Any: + left, *rest = children + node = cst.Name(left.string) + + for dot, right in grouper(rest, 2): + node = cst.Attribute( + value=node, + dot=cst.Dot( + whitespace_before=parse_parenthesizable_whitespace( + config, dot.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, dot.whitespace_after + ), + ), + attr=cst.Name(right.string), + ) + + return node + + +@with_production("raise_stmt", "'raise' [test ['from' test]]") +def convert_raise_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (raise_token,) = children + whitespace_after_raise = MaybeSentinel.DEFAULT + exc = None + cause = None + elif len(children) == 2: + (raise_token, test) = children + whitespace_after_raise = parse_simple_whitespace(config, test.whitespace_before) + exc = test.value + cause = None + elif len(children) == 4: + (raise_token, test, from_token, source) = children + whitespace_after_raise = parse_simple_whitespace(config, test.whitespace_before) + exc = test.value + cause = cst.From( + whitespace_before_from=parse_simple_whitespace( + config, from_token.whitespace_before + ), + whitespace_after_from=parse_simple_whitespace( + config, source.whitespace_before + ), + item=source.value, + ) + else: + raise Exception("Logic error!") + + return WithLeadingWhitespace( + cst.Raise(whitespace_after_raise=whitespace_after_raise, exc=exc, cause=cause), + raise_token.whitespace_before, + ) + + +def _construct_nameitems( + config: ParserConfig, names: Sequence[Any] +) -> List[cst.NameItem]: + nameitems: List[cst.NameItem] = [] + for name, maybe_comma in grouper(names, 2): + if maybe_comma is None: + nameitems.append(cst.NameItem(cst.Name(name.string))) + else: + nameitems.append( + cst.NameItem( + cst.Name(name.string), + comma=cst.Comma( + whitespace_before=parse_simple_whitespace( + config, maybe_comma.whitespace_before + ), + whitespace_after=parse_simple_whitespace( + config, maybe_comma.whitespace_after + ), + ), + ) + ) + return nameitems + + +@with_production("global_stmt", "'global' NAME (',' NAME)*") +def convert_global_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (global_token, *names) = children + return WithLeadingWhitespace( + cst.Global( + names=tuple(_construct_nameitems(config, names)), + whitespace_after_global=parse_simple_whitespace( + config, names[0].whitespace_before + ), + ), + global_token.whitespace_before, + ) + + +@with_production("nonlocal_stmt", "'nonlocal' NAME (',' NAME)*") +def convert_nonlocal_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (nonlocal_token, *names) = children + return WithLeadingWhitespace( + cst.Nonlocal( + names=tuple(_construct_nameitems(config, names)), + whitespace_after_nonlocal=parse_simple_whitespace( + config, names[0].whitespace_before + ), + ), + nonlocal_token.whitespace_before, + ) + + +@with_production("assert_stmt", "'assert' test [',' test]") +def convert_assert_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 2: + (assert_token, test) = children + assert_node = cst.Assert( + whitespace_after_assert=parse_simple_whitespace( + config, test.whitespace_before + ), + test=test.value, + msg=None, + ) + else: + (assert_token, test, comma_token, msg) = children + assert_node = cst.Assert( + whitespace_after_assert=parse_simple_whitespace( + config, test.whitespace_before + ), + test=test.value, + comma=cst.Comma( + whitespace_before=parse_simple_whitespace( + config, comma_token.whitespace_before + ), + whitespace_after=parse_simple_whitespace(config, msg.whitespace_before), + ), + msg=msg.value, + ) + + return WithLeadingWhitespace(assert_node, assert_token.whitespace_before) + + +@with_production( + "compound_stmt", + ("if_stmt | while_stmt | asyncable_stmt | try_stmt | classdef | decorated"), +) +def convert_compound_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (stmt,) = children + return stmt + + +@with_production("if_stmt", "'if' test ':' suite [if_stmt_elif|if_stmt_else]") +def convert_if_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + if_tok, test, colon_tok, suite, *tail = children + + if len(tail) > 0: + (orelse,) = tail + else: + orelse = None + + return cst.If( + leading_lines=parse_empty_lines(config, if_tok.whitespace_before), + whitespace_before_test=parse_simple_whitespace(config, if_tok.whitespace_after), + test=test.value, + whitespace_after_test=parse_simple_whitespace( + config, colon_tok.whitespace_before + ), + body=suite, + orelse=orelse, + ) + + +@with_production("if_stmt_elif", "'elif' test ':' suite [if_stmt_elif|if_stmt_else]") +def convert_if_stmt_elif(config: ParserConfig, children: Sequence[Any]) -> Any: + # this behaves exactly the same as `convert_if_stmt`, except that the leading token + # has a different string value. + return convert_if_stmt(config, children) + + +@with_production("if_stmt_else", "'else' ':' suite") +def convert_if_stmt_else(config: ParserConfig, children: Sequence[Any]) -> Any: + else_tok, colon_tok, suite = children + return cst.Else( + leading_lines=parse_empty_lines(config, else_tok.whitespace_before), + whitespace_before_colon=parse_simple_whitespace( + config, colon_tok.whitespace_before + ), + body=suite, + ) + + +@with_production("while_stmt", "'while' test ':' suite ['else' ':' suite]") +def convert_while_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + while_token, test, while_colon_token, while_suite, *else_block = children + + if len(else_block) > 0: + (else_token, else_colon_token, else_suite) = else_block + orelse = cst.Else( + leading_lines=parse_empty_lines(config, else_token.whitespace_before), + whitespace_before_colon=parse_simple_whitespace( + config, else_colon_token.whitespace_before + ), + body=else_suite, + ) + else: + orelse = None + + return cst.While( + leading_lines=parse_empty_lines(config, while_token.whitespace_before), + whitespace_after_while=parse_simple_whitespace( + config, while_token.whitespace_after + ), + test=test.value, + whitespace_before_colon=parse_simple_whitespace( + config, while_colon_token.whitespace_before + ), + body=while_suite, + orelse=orelse, + ) + + +@with_production( + "for_stmt", "'for' exprlist 'in' testlist ':' suite ['else' ':' suite]" +) +def convert_for_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + ( + for_token, + expr, + in_token, + test, + for_colon_token, + for_suite, + *else_block, + ) = children + + if len(else_block) > 0: + (else_token, else_colon_token, else_suite) = else_block + orelse = cst.Else( + leading_lines=parse_empty_lines(config, else_token.whitespace_before), + whitespace_before_colon=parse_simple_whitespace( + config, else_colon_token.whitespace_before + ), + body=else_suite, + ) + else: + orelse = None + + return WithLeadingWhitespace( + cst.For( + whitespace_after_for=parse_simple_whitespace( + config, for_token.whitespace_after + ), + target=expr.value, + whitespace_before_in=parse_simple_whitespace( + config, in_token.whitespace_before + ), + whitespace_after_in=parse_simple_whitespace( + config, in_token.whitespace_after + ), + iter=test.value, + whitespace_before_colon=parse_simple_whitespace( + config, for_colon_token.whitespace_before + ), + body=for_suite, + orelse=orelse, + ), + for_token.whitespace_before, + ) + + +@with_production( + "try_stmt", + "('try' ':' suite ((except_clause ':' suite)+ ['else' ':' suite] ['finally' ':' suite] | 'finally' ':' suite))", +) +def convert_try_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + trytoken, try_colon_token, try_suite, *rest = children + handlers: List[cst.ExceptHandler] = [] + orelse: Optional[cst.Else] = None + finalbody: Optional[cst.Finally] = None + + for clause, colon_token, suite in grouper(rest, 3): + if isinstance(clause, Token): + if clause.string == "else": + if orelse is not None: + raise Exception("Logic error!") + orelse = cst.Else( + leading_lines=parse_empty_lines(config, clause.whitespace_before), + whitespace_before_colon=parse_simple_whitespace( + config, colon_token.whitespace_before + ), + body=suite, + ) + elif clause.string == "finally": + if finalbody is not None: + raise Exception("Logic error!") + finalbody = cst.Finally( + leading_lines=parse_empty_lines(config, clause.whitespace_before), + whitespace_before_colon=parse_simple_whitespace( + config, colon_token.whitespace_before + ), + body=suite, + ) + else: + raise Exception("Logic error!") + elif isinstance(clause, ExceptClausePartial): + handlers.append( + cst.ExceptHandler( + body=suite, + type=clause.type, + name=clause.name, + leading_lines=clause.leading_lines, + whitespace_after_except=clause.whitespace_after_except, + whitespace_before_colon=parse_simple_whitespace( + config, colon_token.whitespace_before + ), + ) + ) + else: + raise Exception("Logic error!") + + return cst.Try( + leading_lines=parse_empty_lines(config, trytoken.whitespace_before), + whitespace_before_colon=parse_simple_whitespace( + config, try_colon_token.whitespace_before + ), + body=try_suite, + handlers=tuple(handlers), + orelse=orelse, + finalbody=finalbody, + ) + + +@with_production("except_clause", "'except' [test ['as' NAME]]") +def convert_except_clause(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 1: + (except_token,) = children + whitespace_after_except = cst.SimpleWhitespace("") + test = None + name = None + elif len(children) == 2: + (except_token, test_node) = children + whitespace_after_except = parse_simple_whitespace( + config, except_token.whitespace_after + ) + test = test_node.value + name = None + else: + (except_token, test_node, as_token, name_token) = children + whitespace_after_except = parse_simple_whitespace( + config, except_token.whitespace_after + ) + test = test_node.value + name = cst.AsName( + whitespace_before_as=parse_simple_whitespace( + config, as_token.whitespace_before + ), + whitespace_after_as=parse_simple_whitespace( + config, as_token.whitespace_after + ), + name=cst.Name(name_token.string), + ) + + return ExceptClausePartial( + leading_lines=parse_empty_lines(config, except_token.whitespace_before), + whitespace_after_except=whitespace_after_except, + type=test, + name=name, + ) + + +@with_production("with_stmt", "'with' with_item (',' with_item)* ':' suite") +def convert_with_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + (with_token, *items, colon_token, suite) = children + item_nodes: List[cst.WithItem] = [] + + for with_item, maybe_comma in grouper(items, 2): + if maybe_comma is not None: + item_nodes.append( + with_item.with_changes( + comma=cst.Comma( + whitespace_before=parse_parenthesizable_whitespace( + config, maybe_comma.whitespace_before + ), + whitespace_after=parse_parenthesizable_whitespace( + config, maybe_comma.whitespace_after + ), + ) + ) + ) + else: + item_nodes.append(with_item) + + return WithLeadingWhitespace( + cst.With( + whitespace_after_with=parse_simple_whitespace( + config, with_token.whitespace_after + ), + items=tuple(item_nodes), + whitespace_before_colon=parse_simple_whitespace( + config, colon_token.whitespace_before + ), + body=suite, + ), + with_token.whitespace_before, + ) + + +@with_production("with_item", "test ['as' expr]") +def convert_with_item(config: ParserConfig, children: Sequence[Any]) -> Any: + if len(children) == 3: + (test, as_token, expr_node) = children + test_node = test.value + asname = cst.AsName( + whitespace_before_as=parse_simple_whitespace( + config, as_token.whitespace_before + ), + whitespace_after_as=parse_simple_whitespace( + config, as_token.whitespace_after + ), + name=expr_node.value, + ) + else: + (test,) = children + test_node = test.value + asname = None + + return cst.WithItem(item=test_node, asname=asname) + + +def _extract_async( + config: ParserConfig, children: Sequence[Any] +) -> Tuple[List[cst.EmptyLine], Optional[cst.Asynchronous], Any]: + if len(children) == 1: + (stmt,) = children + + whitespace_before = stmt.whitespace_before + asyncnode = None + else: + asynctoken, stmt = children + + whitespace_before = asynctoken.whitespace_before + asyncnode = cst.Asynchronous( + whitespace_after=parse_simple_whitespace( + config, asynctoken.whitespace_after + ) + ) + + return (parse_empty_lines(config, whitespace_before), asyncnode, stmt.value) + + +@with_production("asyncable_funcdef", "['async'] funcdef") +def convert_asyncable_funcdef(config: ParserConfig, children: Sequence[Any]) -> Any: + leading_lines, asyncnode, funcdef = _extract_async(config, children) + + return funcdef.with_changes( + asynchronous=asyncnode, leading_lines=leading_lines, lines_after_decorators=() + ) + + +@with_production("funcdef", "'def' NAME parameters [funcdef_annotation] ':' suite") +def convert_funcdef(config: ParserConfig, children: Sequence[Any]) -> Any: + defnode, namenode, param_partial, *annotation, colon, suite = children + + # If the trailing paremeter doesn't have a comma, then it owns the trailing + # whitespace before the rpar. Otherwise, the comma owns it (and will have + # already parsed it). We don't check/update ParamStar because if it exists + # then we are guaranteed have at least one kwonly_param. + parameters = param_partial.params + if parameters.star_kwarg is not None: + if parameters.star_kwarg.comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + star_kwarg=parameters.star_kwarg.with_changes( + whitespace_after_param=param_partial.rpar.whitespace_before + ) + ) + elif parameters.kwonly_params: + if parameters.kwonly_params[-1].comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + kwonly_params=( + *parameters.kwonly_params[:-1], + parameters.kwonly_params[-1].with_changes( + whitespace_after_param=param_partial.rpar.whitespace_before + ), + ) + ) + elif isinstance(parameters.star_arg, cst.Param): + if parameters.star_arg.comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + star_arg=parameters.star_arg.with_changes( + whitespace_after_param=param_partial.rpar.whitespace_before + ) + ) + elif parameters.default_params: + if parameters.default_params[-1].comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + default_params=( + *parameters.default_params[:-1], + parameters.default_params[-1].with_changes( + whitespace_after_param=param_partial.rpar.whitespace_before + ), + ) + ) + elif parameters.params: + if parameters.params[-1].comma == MaybeSentinel.DEFAULT: + parameters = parameters.with_changes( + params=( + *parameters.params[:-1], + parameters.params[-1].with_changes( + whitespace_after_param=param_partial.rpar.whitespace_before + ), + ) + ) + + return WithLeadingWhitespace( + cst.FunctionDef( + whitespace_after_def=parse_simple_whitespace( + config, defnode.whitespace_after + ), + name=cst.Name(namenode.string), + whitespace_after_name=parse_simple_whitespace( + config, namenode.whitespace_after + ), + whitespace_before_params=param_partial.lpar.whitespace_after, + params=parameters, + returns=None if not annotation else annotation[0], + whitespace_before_colon=parse_simple_whitespace( + config, colon.whitespace_before + ), + body=suite, + ), + defnode.whitespace_before, + ) + + +@with_production("parameters", "'(' [typedargslist] ')'") +def convert_parameters(config: ParserConfig, children: Sequence[Any]) -> Any: + lpar, *paramlist, rpar = children + return FuncdefPartial( + lpar=cst.LeftParen( + whitespace_after=parse_parenthesizable_whitespace( + config, lpar.whitespace_after + ) + ), + params=cst.Parameters() if not paramlist else paramlist[0], + rpar=cst.RightParen( + whitespace_before=parse_parenthesizable_whitespace( + config, rpar.whitespace_before + ) + ), + ) + + +@with_production("funcdef_annotation", "'->' test") +def convert_funcdef_annotation(config: ParserConfig, children: Sequence[Any]) -> Any: + arrow, typehint = children + return cst.Annotation( + whitespace_before_indicator=parse_parenthesizable_whitespace( + config, arrow.whitespace_before + ), + indicator="->", + whitespace_after_indicator=parse_parenthesizable_whitespace( + config, arrow.whitespace_after + ), + annotation=typehint.value, + ) + + +@with_production("classdef", "'class' NAME ['(' [arglist] ')'] ':' suite") +def convert_classdef(config: ParserConfig, children: Sequence[Any]) -> Any: + classdef, name, *arglist, colon, suite = children + + # First, parse out the comments and empty lines before the statement. + leading_lines = parse_empty_lines(config, classdef.whitespace_before) + + # Compute common whitespace and nodes + whitespace_after_class = parse_simple_whitespace(config, classdef.whitespace_after) + namenode = cst.Name(name.string) + whitespace_after_name = parse_simple_whitespace(config, name.whitespace_after) + + # Now, construct the classdef node itself + if not arglist: + # No arglist, so no arguments to this class + return cst.ClassDef( + leading_lines=leading_lines, + lines_after_decorators=(), + whitespace_after_class=whitespace_after_class, + name=namenode, + whitespace_after_name=whitespace_after_name, + body=suite, + ) + else: + # Unwrap arglist partial, because its valid to not have any + lpar, *args, rpar = arglist + args = args[0].args if args else [] + + bases: List[cst.Arg] = [] + keywords: List[cst.Arg] = [] + + current_arg = bases + for arg in args: + if arg.star == "**" or arg.keyword is not None: + current_arg = keywords + # Some quick validation + if current_arg is keywords and ( + arg.star == "*" or (arg.star == "" and arg.keyword is None) + ): + # TODO: Need a real syntax error here + raise Exception("Syntax error!") + current_arg.append(arg) + + return cst.ClassDef( + leading_lines=leading_lines, + lines_after_decorators=(), + whitespace_after_class=whitespace_after_class, + name=namenode, + whitespace_after_name=whitespace_after_name, + lpar=cst.LeftParen( + whitespace_after=parse_parenthesizable_whitespace( + config, lpar.whitespace_after + ) + ), + bases=bases, + keywords=keywords, + rpar=cst.RightParen( + whitespace_before=parse_parenthesizable_whitespace( + config, rpar.whitespace_before + ) + ), + whitespace_before_colon=parse_simple_whitespace( + config, colon.whitespace_before + ), + body=suite, + ) + + +@with_production("decorator", "'@' dotted_name [ '(' [arglist] ')' ] NEWLINE") +def convert_decorator(config: ParserConfig, children: Sequence[Any]) -> Any: + atsign, name, *arglist, newline = children + if not arglist: + # This is either a name or an attribute node, so just extract it. + decoratornode = name + else: + # This needs to be converted into a call node, and we have the + # arglist partial. + lpar, *args, rpar = arglist + args = args[0].args if args else [] + + # If the trailing argument doesn't have a comma, then it owns the + # trailing whitespace before the rpar. Otherwise, the comma owns + # it. + if len(args) > 0 and args[-1].comma == MaybeSentinel.DEFAULT: + args[-1] = args[-1].with_changes( + whitespace_after_arg=parse_parenthesizable_whitespace( + config, rpar.whitespace_before + ) + ) + + decoratornode = cst.Call( + func=name, + whitespace_after_func=parse_simple_whitespace( + config, lpar.whitespace_before + ), + whitespace_before_args=parse_parenthesizable_whitespace( + config, lpar.whitespace_after + ), + args=tuple(args), + ) + + return cst.Decorator( + leading_lines=parse_empty_lines(config, atsign.whitespace_before), + whitespace_after_at=parse_simple_whitespace(config, atsign.whitespace_after), + decorator=decoratornode, + trailing_whitespace=newline, + ) + + +@with_production("decorators", "decorator+") +def convert_decorators(config: ParserConfig, children: Sequence[Any]) -> Any: + return DecoratorPartial(decorators=children) + + +@with_production("decorated", "decorators (classdef | asyncable_funcdef)") +def convert_decorated(config: ParserConfig, children: Sequence[Any]) -> Any: + partial, class_or_func = children + + # First, split up the spacing on the first decorator + leading_lines = partial.decorators[0].leading_lines + + # Now, redistribute ownership of the whitespace + decorators = ( + partial.decorators[0].with_changes(leading_lines=()), + *partial.decorators[1:], + ) + + # Now, modify the original function or class to add the decorators. + return class_or_func.with_changes( + leading_lines=leading_lines, + lines_after_decorators=( + *class_or_func.leading_lines, + *class_or_func.lines_after_decorators, + ), + decorators=decorators, + ) + + +@with_production("asyncable_stmt", "['async'] (funcdef | with_stmt | for_stmt)") +def convert_asyncable_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: + leading_lines, asyncnode, stmtnode = _extract_async(config, children) + if isinstance(stmtnode, cst.FunctionDef): + return stmtnode.with_changes( + asynchronous=asyncnode, + leading_lines=leading_lines, + lines_after_decorators=(), + ) + elif isinstance(stmtnode, cst.With): + return stmtnode.with_changes( + asynchronous=asyncnode, leading_lines=leading_lines + ) + elif isinstance(stmtnode, cst.For): + return stmtnode.with_changes( + asynchronous=asyncnode, leading_lines=leading_lines + ) + else: + raise Exception("Logic error!") + + +@with_production("suite", "simple_stmt_suite | indented_suite") +def convert_suite(config: ParserConfig, children: Sequence[Any]) -> Any: + (suite,) = children + return suite + + +@with_production("indented_suite", "NEWLINE INDENT stmt+ DEDENT") +def convert_indented_suite(config: ParserConfig, children: Sequence[Any]) -> Any: + newline, indent, *stmts, dedent = children + return cst.IndentedBlock( + header=newline, + indent=( + None + if indent.relative_indent == config.default_indent + else indent.relative_indent + ), + body=stmts, + footer=parse_empty_lines(config, dedent.whitespace_after), + ) diff --git a/libcst/parser/_conversions/terminals.py b/libcst/parser/_conversions/terminals.py new file mode 100644 index 00000000..5f743ce0 --- /dev/null +++ b/libcst/parser/_conversions/terminals.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +from libcst.parser._types.config import ParserConfig +from libcst.parser._types.token import Token +from libcst.parser._whitespace_parser import ( + parse_empty_lines, + parse_trailing_whitespace, +) + + +def convert_NAME(config: ParserConfig, token: Token) -> Any: + return token + + +def convert_NUMBER(config: ParserConfig, token: Token) -> Any: + return token + + +def convert_STRING(config: ParserConfig, token: Token) -> Any: + return token + + +def convert_OP(config: ParserConfig, token: Token) -> Any: + return token + + +def convert_NEWLINE(config: ParserConfig, token: Token) -> Any: + # A NEWLINE token is only emitted for semantic newlines, which means that this + # corresponds to a TrailingWhitespace, since that's the only semantic + # newline-containing node. + + # N.B. Because this token is whitespace, and because the whitespace parser doesn't + # try to prevent overflows, `token.whitespace_before` will end up overflowing into + # the value of this newline token, so `parse_trailing_whitespace` will include + # token.string's value. This is expected and desired behavior. + return parse_trailing_whitespace(config, token.whitespace_before) + + +def convert_INDENT(config: ParserConfig, token: Token) -> Any: + return token + + +def convert_DEDENT(config: ParserConfig, token: Token) -> Any: + return token + + +def convert_ENDMARKER(config: ParserConfig, token: Token) -> Any: + return parse_empty_lines(config, token.whitespace_before) + + +def convert_FSTRING_START(config: ParserConfig, token: Token) -> Any: + return token + + +def convert_FSTRING_END(config: ParserConfig, token: Token) -> Any: + return token + + +def convert_FSTRING_STRING(config: ParserConfig, token: Token) -> Any: + return token diff --git a/libcst/parser/_custom_itertools.py b/libcst/parser/_custom_itertools.py new file mode 100644 index 00000000..71541cc9 --- /dev/null +++ b/libcst/parser/_custom_itertools.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from itertools import zip_longest +from typing import Iterable, Iterator, TypeVar + + +_T = TypeVar("_T") + + +# https://docs.python.org/3/library/itertools.html#itertools-recipes +def grouper(iterable: Iterable[_T], n: int, fillvalue: _T = None) -> Iterator[_T]: + "Collect data into fixed-length chunks or blocks" + # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" + args = [iter(iterable)] * n + return zip_longest(*args, fillvalue=fillvalue) diff --git a/libcst/parser/_detect_config.py b/libcst/parser/_detect_config.py new file mode 100644 index 00000000..de9ae4e1 --- /dev/null +++ b/libcst/parser/_detect_config.py @@ -0,0 +1,140 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import itertools +from dataclasses import dataclass +from io import BytesIO +from tokenize import detect_encoding as py_tokenize_detect_encoding +from typing import Iterable, Iterator, Union + +from parso.python.token import PythonTokenTypes, TokenType +from parso.utils import split_lines + +from libcst.nodes._whitespace import NEWLINE_RE +from libcst.parser._types.config import AutoConfig, ParserConfig, PartialParserConfig +from libcst.parser._types.token import Token +from libcst.parser._wrapped_tokenize import tokenize_lines + + +_INDENT: TokenType = PythonTokenTypes.INDENT +_FALLBACK_DEFAULT_NEWLINE = "\n" +_FALLBACK_DEFAULT_INDENT = " " + + +@dataclass(frozen=True) +class ConfigDetectionResult: + # The config is a set of constant values used by the parser. + config: ParserConfig + # The tokens iterator is mutated by the parser. + tokens: Iterator[Token] + + +def _detect_encoding(source: Union[str, bytes]) -> str: + """ + Detects the encoding from the presence of a UTF-8 BOM or an encoding cookie as + specified in PEP 263. + + If given a string (instead of bytes) the encoding is assumed to be utf-8. + """ + + if isinstance(source, str): + return "utf-8" + return py_tokenize_detect_encoding(BytesIO(source).readline)[0] + + +def _detect_default_newline(source_str: str) -> str: + """ + Finds the first newline, and uses that value as the default newline. + """ + # Don't use `NEWLINE_RE` for this, because it might match multiple newlines as a + # single newline. + match = NEWLINE_RE.search(source_str) + return match.group(0) if match is not None else _FALLBACK_DEFAULT_NEWLINE + + +def _detect_indent(tokens: Iterable[Token]) -> str: + """ + Finds the first INDENT token, and uses that as the value of the default indent. + """ + try: + first_indent = next(t for t in tokens if t.type is _INDENT) + except StopIteration: + return _FALLBACK_DEFAULT_INDENT + first_indent_str = first_indent.relative_indent + assert first_indent_str is not None, "INDENT tokens must contain a relative_indent" + return first_indent_str + + +def detect_config( + source: Union[str, bytes], + *, + partial: PartialParserConfig, + detect_trailing_newline: bool, +) -> ConfigDetectionResult: + """ + Computes a ParserConfig given the current source code to be parsed and a partial + config. + """ + + python_version = partial.parsed_python_version + + partial_encoding = partial.encoding + encoding = ( + _detect_encoding(source) + if isinstance(partial_encoding, AutoConfig) + else partial_encoding + ) + + source_str = source if isinstance(source, str) else source.decode(encoding) + + partial_default_newline = partial.default_newline + default_newline = ( + _detect_default_newline(source_str) + if isinstance(partial_default_newline, AutoConfig) + else partial_default_newline + ) + + # HACK: The grammar requires a trailing newline, but python doesn't actually require + # a trailing newline. Add one onto the end to make the parser happy. We'll strip it + # out again during cst.Module's codegen. + # + # I think parso relies on error recovery support to handle this, which we don't + # have. lib2to3 doesn't handle this case at all AFAICT. + has_trailing_newline = detect_trailing_newline and bool( + len(source_str) != 0 and NEWLINE_RE.match(source_str[-1]) + ) + if detect_trailing_newline and not has_trailing_newline: + source_str += default_newline + + lines = split_lines(source_str, keepends=True) + + tokens = tokenize_lines(lines, python_version) + + partial_default_indent = partial.default_indent + if isinstance(partial_default_indent, AutoConfig): + # We need to clone `tokens` before passing it to `_detect_indent`, because + # `_detect_indent` consumes some tokens, mutating `tokens`. + # + # Implementation detail: CPython's `itertools.tee` uses weakrefs to reduce the + # size of its FIFO, so this doesn't retain items (leak memory) for `tokens_dup` + # once `token_dup` is freed at the end of this method (subject to + # GC/refcounting). + tokens, tokens_dup = itertools.tee(tokens) + default_indent = _detect_indent(tokens_dup) + else: + default_indent = partial_default_indent + + return ConfigDetectionResult( + config=ParserConfig( + lines=lines, + encoding=encoding, + default_indent=default_indent, + default_newline=default_newline, + has_trailing_newline=has_trailing_newline, + ), + tokens=tokens, + ) diff --git a/libcst/parser/_entrypoints.py b/libcst/parser/_entrypoints.py new file mode 100644 index 00000000..496afb8d --- /dev/null +++ b/libcst/parser/_entrypoints.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +""" +Parser entrypoints define the way users of our API are allowed to interact with the +parser. A parser entrypoint should take the source code and some configuration +information +""" + +from typing import TypeVar, Union + +import libcst.nodes as cst +from libcst.parser._detect_config import detect_config +from libcst.parser._grammar import get_grammar, validate_grammar +from libcst.parser._python_parser import PythonCSTParser +from libcst.parser._types.config import PartialParserConfig + + +_CSTNodeT = TypeVar("_CSTNodeT", bound=cst.CSTNode) +_DEFAULT_PARTIAL_PARSER_CONFIG: PartialParserConfig = PartialParserConfig() + + +def _parse( + entrypoint: str, + source: Union[str, bytes], + config: PartialParserConfig, + *, + detect_trailing_newline: bool, +) -> cst.CSTNode: + detection_result = detect_config( + source, partial=config, detect_trailing_newline=detect_trailing_newline + ) + validate_grammar() + grammar = get_grammar() + + parser = PythonCSTParser( + tokens=detection_result.tokens, + config=detection_result.config, + pgen_grammar=grammar, + start_nonterminal=entrypoint, + ) + # The parser has an Any return type, we can at least refine it to CSTNode here. + result = parser.parse() + assert isinstance(result, cst.CSTNode) + return result + + +def parse_module( + source: Union[str, bytes], # the only entrypoint that accepts bytes + config: PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG, +) -> cst.Module: + result = _parse("file_input", source, config, detect_trailing_newline=True) + assert isinstance(result, cst.Module) + return result + + +def parse_statement( + source: str, config: PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG +) -> Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]: + """ + Accepts a statement followed by a trailing newline. If a trailing newline is not + provided, one will be added. + + Leading comments and trailing comments (on the same line) are accepted, but + whitespace (or anything else) after the statement's trailing newline is not valid + (there's nowhere to store it). + """ + # use detect_trailing_newline to insert a newline + result = _parse("stmt_input", source, config, detect_trailing_newline=True) + assert isinstance(result, (cst.SimpleStatementLine, cst.BaseCompoundStatement)) + return result + + +def parse_expression( + source: str, config: PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG +) -> cst.BaseExpression: + result = _parse("expression_input", source, config, detect_trailing_newline=False) + assert isinstance(result, cst.BaseExpression) + return result diff --git a/libcst/parser/_grammar.py b/libcst/parser/_grammar.py new file mode 100644 index 00000000..448c8b1a --- /dev/null +++ b/libcst/parser/_grammar.py @@ -0,0 +1,341 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import re +from functools import lru_cache +from typing import Iterator, Mapping, Tuple + +from parso.pgen2.generator import Grammar, generate_grammar +from parso.python.token import PythonTokenTypes, TokenType + +from libcst.parser._conversions.expression import ( + convert_arg_assign_comp_for, + convert_arglist, + convert_argument, + convert_atom, + convert_atom_basic, + convert_atom_curlybrackets, + convert_atom_ellipses, + convert_atom_expr, + convert_atom_expr_await, + convert_atom_expr_trailer, + convert_atom_fstring, + convert_atom_parens, + convert_atom_squarebrackets, + convert_atom_string, + convert_binop, + convert_boolop, + convert_comp_for, + convert_comp_if, + convert_comp_iter, + convert_comp_op, + convert_comparison, + convert_dictorsetmaker, + convert_expression_input, + convert_exprlist, + convert_factor, + convert_fstring, + convert_fstring_content, + convert_fstring_conversion, + convert_fstring_expr, + convert_fstring_format_spec, + convert_lambda, + convert_not_test, + convert_power, + convert_sliceop, + convert_star_arg, + convert_star_expr, + convert_subscript, + convert_subscriptlist, + convert_sync_comp_for, + convert_test, + convert_test_nocond, + convert_testlist, + convert_testlist_comp, + convert_testlist_star_expr, + convert_trailer, + convert_trailer_arglist, + convert_trailer_attribute, + convert_trailer_subscriptlist, + convert_yield_arg, + convert_yield_expr, +) +from libcst.parser._conversions.module import convert_file_input +from libcst.parser._conversions.params import ( + convert_argslist, + convert_fpdef, + convert_fpdef_assign, + convert_fpdef_star, + convert_fpdef_starstar, +) +from libcst.parser._conversions.statement import ( + convert_annassign, + convert_assert_stmt, + convert_assign, + convert_asyncable_funcdef, + convert_asyncable_stmt, + convert_augassign, + convert_break_stmt, + convert_classdef, + convert_compound_stmt, + convert_continue_stmt, + convert_decorated, + convert_decorator, + convert_decorators, + convert_del_stmt, + convert_dotted_as_name, + convert_dotted_as_names, + convert_dotted_name, + convert_except_clause, + convert_expr_stmt, + convert_for_stmt, + convert_funcdef, + convert_funcdef_annotation, + convert_global_stmt, + convert_if_stmt, + convert_if_stmt_elif, + convert_if_stmt_else, + convert_import_as_name, + convert_import_as_names, + convert_import_from, + convert_import_name, + convert_import_relative, + convert_import_stmt, + convert_indented_suite, + convert_nonlocal_stmt, + convert_parameters, + convert_pass_stmt, + convert_raise_stmt, + convert_return_stmt, + convert_simple_stmt_line, + convert_simple_stmt_partial, + convert_simple_stmt_suite, + convert_small_stmt, + convert_stmt, + convert_stmt_input, + convert_suite, + convert_try_stmt, + convert_while_stmt, + convert_with_item, + convert_with_stmt, +) +from libcst.parser._conversions.terminals import ( + convert_DEDENT, + convert_ENDMARKER, + convert_FSTRING_END, + convert_FSTRING_START, + convert_FSTRING_STRING, + convert_INDENT, + convert_NAME, + convert_NEWLINE, + convert_NUMBER, + convert_OP, + convert_STRING, +) +from libcst.parser._production_decorator import get_productions +from libcst.parser._types.conversions import NonterminalConversion, TerminalConversion +from libcst.parser._types.production import Production + + +# Keep this sorted alphabetically +_TERMINAL_CONVERSIONS_SEQUENCE: Tuple[TerminalConversion, ...] = ( + convert_DEDENT, + convert_ENDMARKER, + convert_INDENT, + convert_NAME, + convert_NEWLINE, + convert_NUMBER, + convert_OP, + convert_STRING, + convert_FSTRING_START, + convert_FSTRING_END, + convert_FSTRING_STRING, +) + +# Try to match the order of https://docs.python.org/3/reference/grammar.html +_NONTERMINAL_CONVERSIONS_SEQUENCE: Tuple[NonterminalConversion, ...] = ( + convert_file_input, + convert_stmt_input, # roughly equivalent to single_input + convert_expression_input, # roughly equivalent to eval_input + convert_stmt, + convert_simple_stmt_partial, + convert_simple_stmt_line, + convert_simple_stmt_suite, + convert_small_stmt, + convert_expr_stmt, + convert_annassign, + convert_augassign, + convert_assign, + convert_pass_stmt, + convert_continue_stmt, + convert_break_stmt, + convert_del_stmt, + convert_import_stmt, + convert_import_name, + convert_import_relative, + convert_import_from, + convert_import_as_name, + convert_dotted_as_name, + convert_import_as_names, + convert_dotted_as_names, + convert_dotted_name, + convert_return_stmt, + convert_raise_stmt, + convert_global_stmt, + convert_nonlocal_stmt, + convert_assert_stmt, + convert_compound_stmt, + convert_if_stmt, + convert_if_stmt_elif, + convert_if_stmt_else, + convert_while_stmt, + convert_for_stmt, + convert_try_stmt, + convert_except_clause, + convert_with_stmt, + convert_with_item, + convert_asyncable_funcdef, + convert_funcdef, + convert_classdef, + convert_decorator, + convert_decorators, + convert_decorated, + convert_asyncable_stmt, + convert_parameters, + convert_argslist, + convert_fpdef_star, + convert_fpdef_starstar, + convert_fpdef_assign, + convert_fpdef, + convert_funcdef_annotation, + convert_suite, + convert_indented_suite, + convert_testlist_star_expr, + convert_test, + convert_test_nocond, + convert_lambda, + convert_boolop, + convert_not_test, + convert_comparison, + convert_comp_op, + convert_star_expr, + convert_binop, + convert_factor, + convert_power, + convert_atom_expr, + convert_atom_expr_await, + convert_atom_expr_trailer, + convert_trailer, + convert_trailer_attribute, + convert_trailer_subscriptlist, + convert_subscriptlist, + convert_subscript, + convert_sliceop, + convert_trailer_arglist, + convert_atom, + convert_atom_basic, + convert_atom_parens, + convert_atom_squarebrackets, + convert_atom_curlybrackets, + convert_atom_string, + convert_atom_fstring, + convert_fstring, + convert_fstring_content, + convert_fstring_conversion, + convert_fstring_expr, + convert_fstring_format_spec, + convert_atom_ellipses, + convert_testlist_comp, + convert_testlist, + convert_dictorsetmaker, + convert_exprlist, + convert_arglist, + convert_argument, + convert_arg_assign_comp_for, + convert_star_arg, + convert_comp_iter, + convert_sync_comp_for, + convert_comp_for, + convert_comp_if, + convert_yield_expr, + convert_yield_arg, +) + + +def get_grammar_str() -> str: + """ + Returns an BNF-like grammar text that `parso.pgen2.generator.generate_grammar` can + handle. + + While you should generally use `get_grammar` instead, this can be useful for + debugging the grammar. + """ + lines = [] + for p in get_nonterminal_productions(): + lines.append(str(p)) + return "\n".join(lines) + "\n" + + +# TODO: We should probably provide an on-disk cache like parso and lib2to3 do. Because +# of how we're defining our grammar, efficient cache invalidation is harder, though not +# impossible. +@lru_cache() +def get_grammar() -> "Grammar[TokenType]": + return generate_grammar(get_grammar_str(), PythonTokenTypes) + + +@lru_cache() +def get_terminal_conversions() -> Mapping[str, TerminalConversion]: + """ + Returns a mapping from terminal type name to the conversion function that should be + called by the parser. + """ + return { + # pyre-fixme[16]: Optional type has no attribute `group`. + re.match("convert_(.*)", fn.__name__).group(1): fn + for fn in _TERMINAL_CONVERSIONS_SEQUENCE + } + + +@lru_cache() +def validate_grammar() -> None: + for fn in _NONTERMINAL_CONVERSIONS_SEQUENCE: + fn_productions = get_productions(fn) + if all(p.name == fn_productions[0].name for p in fn_productions): + # all the production names are the same, ensure that the `convert_` function + # is named correctly + production_name = fn_productions[0].name + expected_name = f"convert_{production_name}" + if fn.__name__ != expected_name: + raise Exception( + f"The conversion function for '{production_name}' " + + f"must be called '{expected_name}', not '{fn.__name__}'." + ) + + +def get_nonterminal_productions() -> Iterator[Production]: + for conversion in _NONTERMINAL_CONVERSIONS_SEQUENCE: + # TODO: Filter out productions used by other python versions here + yield from get_productions(conversion) + + +@lru_cache() +def get_nonterminal_conversions() -> Mapping[str, NonterminalConversion]: + """ + Returns a mapping from nonterminal production name to the conversion function that + should be called by the parser. + """ + conversions = {} + for fn in _NONTERMINAL_CONVERSIONS_SEQUENCE: + for fn_production in get_productions(fn): + # TODO: Filter out productions used by other python versions here + if fn_production.name in conversions: + raise Exception( + f"Found duplicate '{fn_production.name}' production in grammar" + ) + conversions[fn_production.name] = fn + + return conversions diff --git a/libcst/parser/_production_decorator.py b/libcst/parser/_production_decorator.py new file mode 100644 index 00000000..b0d81382 --- /dev/null +++ b/libcst/parser/_production_decorator.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable, Iterable, TypeVar + +from libcst.parser._types.conversions import NonterminalConversion +from libcst.parser._types.production import Production + + +_NonterminalConversionT = TypeVar( + "_NonterminalConversionT", bound=NonterminalConversion +) + + +# We could version our grammar at a later point by adding a version metadata kwarg to +# this decorator. +def with_production( + production_name: str, children: str +) -> Callable[[_NonterminalConversionT], _NonterminalConversionT]: + """ + Attaches a bit of grammar to a conversion function. The parser extracts all of these + production strings, and uses it to form the language's full grammar. + + If you need to attach multiple productions to the same conversion function + """ + + def inner(fn: _NonterminalConversionT) -> _NonterminalConversionT: + if not hasattr(fn, "productions"): + fn.productions = [] + # pyre-fixme[16]: `Callable[[ParserConfig, Sequence[Any]], Any]` has no attri... + fn_name = fn.__name__ + if not fn_name.startswith("convert_"): + raise Exception( + "A function with a production must be named 'convert_X', not " + + f"'{fn_name}'." + ) + # pyre-fixme[16]: Pyre doesn't know about this magic field we added + fn.productions.append(Production(production_name, children)) + return fn + + return inner + + +def get_productions(fn: NonterminalConversion) -> Iterable[Production]: + return fn.productions diff --git a/libcst/parser/_python_parser.py b/libcst/parser/_python_parser.py new file mode 100644 index 00000000..daca5800 --- /dev/null +++ b/libcst/parser/_python_parser.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Iterable, Mapping, Sequence + +from parso.pgen2.generator import Grammar +from parso.python.token import TokenType + +from libcst.parser._base_parser import BaseParser +from libcst.parser._grammar import get_nonterminal_conversions, get_terminal_conversions +from libcst.parser._types.config import ParserConfig +from libcst.parser._types.conversions import NonterminalConversion, TerminalConversion +from libcst.parser._types.token import Token + + +class PythonCSTParser(BaseParser[Token, TokenType, Any]): + config: ParserConfig + terminal_conversions: Mapping[str, TerminalConversion] + nonterminal_conversions: Mapping[str, NonterminalConversion] + + def __init__( + self, + *, + tokens: Iterable[Token], + config: ParserConfig, + pgen_grammar: "Grammar[TokenType]", + start_nonterminal: str = "file_input", + ) -> None: + super().__init__( + tokens=tokens, + lines=config.lines, + pgen_grammar=pgen_grammar, + start_nonterminal=start_nonterminal, + ) + self.config = config + self.terminal_conversions = get_terminal_conversions() + self.nonterminal_conversions = get_nonterminal_conversions() + + def convert_nonterminal(self, nonterminal: str, children: Sequence[Any]) -> Any: + return self.nonterminal_conversions[nonterminal](self.config, children) + + def convert_terminal(self, token: Token) -> Any: + return self.terminal_conversions[token.type.name](self.config, token) diff --git a/libcst/parser/_types/__init__.py b/libcst/parser/_types/__init__.py new file mode 100644 index 00000000..62642369 --- /dev/null +++ b/libcst/parser/_types/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/libcst/parser/_types/config.py b/libcst/parser/_types/config.py new file mode 100644 index 00000000..1de9ab3a --- /dev/null +++ b/libcst/parser/_types/config.py @@ -0,0 +1,114 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import abc +import codecs +import re +from dataclasses import dataclass, field +from enum import Enum +from typing import Pattern, Sequence, Union + +from parso.utils import PythonVersionInfo, parse_version_string + +from libcst._add_slots import add_slots +from libcst.nodes._whitespace import NEWLINE_RE + + +_INDENT_RE: Pattern[str] = re.compile(r"[ \t]+") + + +class BaseWhitespaceParserConfig(abc.ABC): + """ + Represents the subset of ParserConfig that the whitespace parser requires. This + makes calling the whitespace parser in tests with a mocked configuration easier. + """ + + # pyre-fixme[13]: Uninitialized attribute + lines: Sequence[str] + # pyre-fixme[13]: Uninitialized attribute + default_newline: str + + +@add_slots # We'll access these properties frequently, so use slots +@dataclass(frozen=True) +class ParserConfig(BaseWhitespaceParserConfig): + """ + An internal configuration object that the python parser passes around. These values + are global to the parsed code and should not change during the lifetime of the + parser object. + """ + + lines: Sequence[str] + encoding: str + default_indent: str + default_newline: str + has_trailing_newline: bool + + +class AutoConfig(Enum): + """ + A sentinel value used in PartialParserConfig + """ + + token: int = 0 + + +@dataclass(frozen=True) +class PartialParserConfig: + """ + An optional object that can be supplied to the parser entrypoints (e.g. + `parse_module`) to configure the parser. + + Unspecified fields will be inferred from the input source code or from the execution + environment (the current Python version). + """ + + # `python_version` only configures the tokenization/lexer right now. The grammar + # isn't currently versioned. + python_version: Union[str, AutoConfig] = AutoConfig.token + # parsed_python_version is derived from python_version in __post_init__ + parsed_python_version: PythonVersionInfo = field(init=False) + encoding: Union[str, AutoConfig] = AutoConfig.token + default_indent: Union[str, AutoConfig] = AutoConfig.token + default_newline: Union[str, AutoConfig] = AutoConfig.token + + def __post_init__(self) -> None: + raw_python_version = self.python_version + # `parse_version_string` will raise a ValueError if the version is invalid. + # + # We use object.__setattr__ because the dataclass is frozen. See: + # https://docs.python.org/3/library/dataclasses.html#frozen-instances + # This should be safe behavior inside of `__post_init__`. + object.__setattr__( + self, + "parsed_python_version", + parse_version_string( + None # parso will derive the version from `sys.version_info` + if isinstance(raw_python_version, AutoConfig) + else raw_python_version + ), + ) + + encoding = self.encoding + if not isinstance(encoding, AutoConfig): + try: + codecs.lookup(encoding) + except LookupError: + raise ValueError(f"{repr(encoding)} is not a supported encoding") + + newline = self.default_newline + if ( + not isinstance(newline, AutoConfig) + and NEWLINE_RE.fullmatch(newline) is None + ): + raise ValueError( + f"Got an invalid value for default_newline: {repr(newline)}" + ) + + indent = self.default_indent + if not isinstance(indent, AutoConfig) and _INDENT_RE.fullmatch(indent) is None: + raise ValueError(f"Got an invalid value for default_indent: {repr(indent)}") diff --git a/libcst/parser/_types/conversions.py b/libcst/parser/_types/conversions.py new file mode 100644 index 00000000..82b4a580 --- /dev/null +++ b/libcst/parser/_types/conversions.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Sequence + +from libcst.parser._types.config import ParserConfig +from libcst.parser._types.token import Token + + +NonterminalConversion = Callable[[ParserConfig, Sequence[Any]], Any] +TerminalConversion = Callable[[ParserConfig, Token], Any] diff --git a/libcst/parser/_types/partials.py b/libcst/parser/_types/partials.py new file mode 100644 index 00000000..874332c8 --- /dev/null +++ b/libcst/parser/_types/partials.py @@ -0,0 +1,146 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass +from enum import Enum +from typing import Generic, Optional, Sequence, TypeVar, Union + +import libcst.nodes as cst +from libcst._add_slots import add_slots +from libcst.parser._types.whitespace_state import WhitespaceState + + +_T = TypeVar("_T") +_V = TypeVar("_V") + + +@add_slots +@dataclass(frozen=True) +class WithLeadingWhitespace(Generic[_T]): + value: _T + whitespace_before: WhitespaceState + + +@add_slots +@dataclass(frozen=True) +class SimpleStatementPartial: + body: Sequence[cst.BaseSmallStatement] + whitespace_before: WhitespaceState + trailing_whitespace: cst.TrailingWhitespace + + +@add_slots +@dataclass(frozen=True) +class SlicePartial: + second_colon: cst.Colon + step: Optional[cst.BaseExpression] + + +@add_slots +@dataclass(frozen=True) +class AttributePartial: + dot: cst.Dot + attr: cst.Name + + +@add_slots +@dataclass(frozen=True) +class ArglistPartial: + args: Sequence[cst.Arg] + + +@add_slots +@dataclass(frozen=True) +class CallPartial: + lpar: WithLeadingWhitespace[cst.LeftParen] + args: Sequence[cst.Arg] + rpar: cst.RightParen + + +@add_slots +@dataclass(frozen=True) +class SubscriptPartial: + slice: Union[cst.Index, cst.Slice, Sequence[cst.ExtSlice]] + lbracket: cst.LeftSquareBracket + rbracket: cst.RightSquareBracket + whitespace_before: WhitespaceState + + +@add_slots +@dataclass(frozen=True) +class AnnAssignPartial: + annotation: cst.Annotation + equal: Optional[cst.AssignEqual] + value: Optional[cst.BaseExpression] + + +@add_slots +@dataclass(frozen=True) +class AugAssignPartial: + operator: cst.BaseAugOp + value: cst.BaseExpression + + +@add_slots +@dataclass(frozen=True) +class AssignPartial: + equal: cst.AssignEqual + value: cst.BaseExpression + + +class ParamStarPartial: + pass + + +@add_slots +@dataclass(frozen=True) +class FuncdefPartial: + lpar: cst.LeftParen + params: cst.Parameters + rpar: cst.RightParen + + +@add_slots +@dataclass(frozen=True) +class DecoratorPartial: + decorators: Sequence[cst.Decorator] + + +@add_slots +@dataclass(frozen=True) +class ImportPartial: + names: Sequence[cst.ImportAlias] + + +@add_slots +@dataclass(frozen=True) +class ImportRelativePartial: + relative: Sequence[cst.Dot] + module: Optional[Union[cst.Attribute, cst.Name]] + + +@add_slots +@dataclass(frozen=True) +class FormattedStringConversionPartial: + value: str + whitespace_before: WhitespaceState + + +@add_slots +@dataclass(frozen=True) +class FormattedStringFormatSpecPartial: + values: Sequence[cst.BaseFormattedStringContent] + whitespace_before: WhitespaceState + + +@add_slots +@dataclass(frozen=True) +class ExceptClausePartial: + leading_lines: Sequence[cst.EmptyLine] + whitespace_after_except: cst.SimpleWhitespace + type: Optional[cst.BaseExpression] = None + name: Optional[cst.AsName] = None diff --git a/libcst/parser/_types/production.py b/libcst/parser/_types/production.py new file mode 100644 index 00000000..6a4ba2f8 --- /dev/null +++ b/libcst/parser/_types/production.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Production: + name: str + children: str + + def __str__(self) -> str: + return f"{self.name}: {self.children}" diff --git a/libcst/parser/_types/tests/test_config.py b/libcst/parser/_types/tests/test_config.py new file mode 100644 index 00000000..b108e08b --- /dev/null +++ b/libcst/parser/_types/tests/test_config.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Callable + +from libcst.parser._types.config import PartialParserConfig +from libcst.testing.utils import UnitTest, data_provider + + +class TestConfig(UnitTest): + @data_provider( + { + "empty": (lambda: PartialParserConfig(),), + "python_version_a": (lambda: PartialParserConfig(python_version="3"),), + "python_version_b": (lambda: PartialParserConfig(python_version="3.2"),), + "python_version_c": (lambda: PartialParserConfig(python_version="3.2.1"),), + "encoding": (lambda: PartialParserConfig(encoding="latin-1"),), + "default_indent": (lambda: PartialParserConfig(default_indent="\t "),), + "default_newline": (lambda: PartialParserConfig(default_newline="\r\n"),), + } + ) + def test_valid_partial_parser_config( + self, factory: Callable[[], PartialParserConfig] + ) -> None: + self.assertIsInstance(factory(), PartialParserConfig) + + @data_provider( + { + "python_version": (lambda: PartialParserConfig(python_version="3.2.1.0"),), + "encoding": (lambda: PartialParserConfig(encoding="utf-42"),), + "default_indent": (lambda: PartialParserConfig(default_indent="badinput"),), + "default_newline": (lambda: PartialParserConfig(default_newline="\n\r"),), + } + ) + def test_invalid_partial_parser_config( + self, factory: Callable[[], PartialParserConfig] + ) -> None: + with self.assertRaises(ValueError): + factory() diff --git a/libcst/parser/_types/token.py b/libcst/parser/_types/token.py new file mode 100644 index 00000000..da183db6 --- /dev/null +++ b/libcst/parser/_types/token.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass +from typing import Optional, Tuple + +from parso.python.token import TokenType + +from libcst._add_slots import add_slots +from libcst.parser._types.whitespace_state import WhitespaceState + + +@add_slots +@dataclass(frozen=True) +class Token: + type: TokenType + string: str + # The start of where `string` is in the source, not including leading whitespace. + start_pos: Tuple[int, int] + # The end of where `string` is in the source, not including trailing whitespace. + end_pos: Tuple[int, int] + whitespace_before: WhitespaceState + whitespace_after: WhitespaceState + # The relative indent this token adds. + relative_indent: Optional[str] diff --git a/libcst/parser/_types/whitespace_state.py b/libcst/parser/_types/whitespace_state.py new file mode 100644 index 00000000..4df9a111 --- /dev/null +++ b/libcst/parser/_types/whitespace_state.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +""" +Defines the state object used by the whitespace parser. +""" + +from dataclasses import dataclass + +from libcst._add_slots import add_slots + + +@add_slots +@dataclass(frozen=False) +class WhitespaceState: + """ + A frequently mutated store of the whitespace parser's current state. This object + must be cloned prior to speculative parsing. + + This is in contrast to the `config` object each whitespace parser function takes, + which is frozen and never mutated. + + Whitespace parsing works by mutating this state object. By encapsulating saving, and + re-using state objects inside the top-level python parser, the whitespace parser is + able to be reentrant. One 'convert' function can consume part of the whitespace, and + another 'convert' function can consume the rest, depending on who owns what + whitespace. + + This is similar to the approach you might take to parse nested languages (e.g. + JavaScript inside of HTML). We're treating whitespace as a separate language and + grammar from the rest of Python's grammar. + """ + + line: int # one-indexed (to match parso's behavior) + column: int # zero-indexed (to match parso's behavior) + # What to look for when executing `_parse_indent`. + absolute_indent: str + is_parenthesized: bool diff --git a/libcst/parser/_whitespace_parser.py b/libcst/parser/_whitespace_parser.py new file mode 100644 index 00000000..05ebcaf7 --- /dev/null +++ b/libcst/parser/_whitespace_parser.py @@ -0,0 +1,208 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +""" +Parso doesn't attempt to parse (or even emit tokens for) whitespace or comments that +isn't syntatically important. Instead, we're just given the whitespace as a "prefix" of +the token. + +However, in our CST, whitespace is gathered into far more detailed objects than a simple +str. + +Fortunately this isn't hard for us to parse ourselves, so we just use our own +hand-rolled recursive descent parser. +""" + +from typing import Optional, Sequence, Union + +import libcst.nodes as cst +from libcst.nodes._whitespace import COMMENT_RE, NEWLINE_RE, SIMPLE_WHITESPACE_RE +from libcst.parser._types.config import BaseWhitespaceParserConfig +from libcst.parser._types.whitespace_state import WhitespaceState as State + + +# BEGIN PARSER ENTRYPOINTS + + +def parse_simple_whitespace( + config: BaseWhitespaceParserConfig, state: State +) -> cst.SimpleWhitespace: + # The match never fails because the pattern can match an empty string + lines = config.lines + # pyre-fixme[16]: Optional type has no attribute `group`. + ws_line = SIMPLE_WHITESPACE_RE.match(lines[state.line - 1], state.column).group(0) + ws_line_list = [ws_line] + while "\\" in ws_line: + # continuation character + state.line += 1 + state.column = 0 + # pyre-fixme[16]: Optional type has no attribute `group`. + ws_line = SIMPLE_WHITESPACE_RE.match(lines[state.line - 1], state.column).group( + 0 + ) + ws_line_list.append(ws_line) + + # TODO: we could special-case the common case where there's no continuation + # character to avoid list construction and joining. + + # once we've finished collecting continuation characters + state.column += len(ws_line) + return cst.SimpleWhitespace("".join(ws_line_list)) + + +def parse_empty_lines( + config: BaseWhitespaceParserConfig, state: State +) -> Sequence[cst.EmptyLine]: + result = [] + el = _parse_empty_line(config, state) + while el is not None: + result.append(el) + el = _parse_empty_line(config, state) + return result + + +def parse_trailing_whitespace( + config: BaseWhitespaceParserConfig, state: State +) -> cst.TrailingWhitespace: + trailing_whitespace = _parse_trailing_whitespace(config, state) + if trailing_whitespace is None: + raise Exception( + "Internal Error: Failed to parse TrailingWhitespace. This should never " + + "happen because a TrailingWhitespace is never optional in the grammar, " + + "so this error should've been caught by parso first." + ) + return trailing_whitespace + + +def parse_parenthesizable_whitespace( + config: BaseWhitespaceParserConfig, state: State +) -> Union[cst.SimpleWhitespace, cst.ParenthesizedWhitespace]: + if state.is_parenthesized: + # First, try parenthesized (don't need speculation because it either + # parses or doesn't modify state). + parenthesized_whitespace = _parse_parenthesized_whitespace(config, state) + if parenthesized_whitespace is not None: + return parenthesized_whitespace + # Now, just parse and return a simple whitespace + return parse_simple_whitespace(config, state) + + +# END PARSER ENTRYPOINTS +# BEGIN PARSER INTERNAL PRODUCTIONS + + +def _parse_empty_line( + config: BaseWhitespaceParserConfig, state: State +) -> Optional[cst.EmptyLine]: + # begin speculative parsing + speculative_state = State( + state.line, state.column, state.absolute_indent, state.is_parenthesized + ) + indent = _parse_indent(config, speculative_state) + whitespace = parse_simple_whitespace(config, speculative_state) + comment = _parse_comment(config, speculative_state) + newline = _parse_newline(config, speculative_state) + if newline is None: + # speculative parsing failed + return None + # speculative parsing succeeded + state.line = speculative_state.line + state.column = speculative_state.column + # don't need to copy absolute_indent/is_parenthesized because they don't change. + return cst.EmptyLine(indent, whitespace, comment, newline) + + +def _parse_indent(config: BaseWhitespaceParserConfig, state: State) -> bool: + """ + Returns True if indentation was found, otherwise False. + """ + absolute_indent = state.absolute_indent + line_str = config.lines[state.line - 1] + if state.column != 0: + if state.column == len(line_str) and state.line == len(config.lines): + # We're at EOF, treat this as a failed speculative parse + return False + raise Exception("Internal Error: Column should be 0 when parsing an indent.") + if line_str.startswith(absolute_indent, state.column): + state.column += len(absolute_indent) + return True + return False + + +def _parse_comment( + config: BaseWhitespaceParserConfig, state: State +) -> Optional[cst.Comment]: + comment_match = COMMENT_RE.match(config.lines[state.line - 1], state.column) + if comment_match is None: + return None + comment = comment_match.group(0) + state.column += len(comment) + return cst.Comment(comment) + + +def _parse_newline( + config: BaseWhitespaceParserConfig, state: State +) -> Optional[cst.Newline]: + # begin speculative parsing + line_str = config.lines[state.line - 1] + newline_match = NEWLINE_RE.match(line_str, state.column) + if newline_match is not None: + # speculative parsing succeeded + newline_str = newline_match.group(0) + state.column += len(newline_str) + if state.column != len(line_str): + raise Exception("Internal Error: Found a newline, but it wasn't the EOL.") + if state.line < len(config.lines): + # this newline was the end of a line, and there's another line, + # therefore we should move to the next line + state.line += 1 + state.column = 0 + if newline_str == config.default_newline: + # Just inherit it from the Module instead of explicitly setting it. + return cst.Newline() + else: + return cst.Newline(newline_str) + else: # no newline was found, speculative parsing failed + return None + + +def _parse_trailing_whitespace( + config: BaseWhitespaceParserConfig, state: State +) -> Optional[cst.TrailingWhitespace]: + # Begin speculative parsing + speculative_state = State( + state.line, state.column, state.absolute_indent, state.is_parenthesized + ) + whitespace = parse_simple_whitespace(config, speculative_state) + comment = _parse_comment(config, speculative_state) + newline = _parse_newline(config, speculative_state) + if newline is None: + # Speculative parsing failed + return None + # Speculative parsing succeeded + state.line = speculative_state.line + state.column = speculative_state.column + # don't need to copy absolute_indent/is_parenthesized because they don't change. + return cst.TrailingWhitespace(whitespace, comment, newline) + + +def _parse_parenthesized_whitespace( + config: BaseWhitespaceParserConfig, state: State +) -> Optional[cst.ParenthesizedWhitespace]: + first_line = _parse_trailing_whitespace(config, state) + if first_line is None: + # Speculative parsing failed + return None + empty_lines = () + while True: + empty_line = _parse_empty_line(config, state) + if empty_line is None: + # This isn't an empty line, so parse it below + break + empty_lines = empty_lines + (empty_line,) + indent = _parse_indent(config, state) + last_line = parse_simple_whitespace(config, state) + return cst.ParenthesizedWhitespace(first_line, empty_lines, indent, last_line) diff --git a/libcst/parser/_wrapped_tokenize.py b/libcst/parser/_wrapped_tokenize.py new file mode 100644 index 00000000..dfbe528b --- /dev/null +++ b/libcst/parser/_wrapped_tokenize.py @@ -0,0 +1,198 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +Parso's tokenize doesn't give us tokens in the format that we'd ideally like, so this +performs a small number of transformations to the token stream: + +- `end_pos` is precomputed as a property, instead of lazily as a method, for more + efficient access. +- `whitespace_before` and `whitespace_after` have been added. These include the correct + indentation information. +- `prefix` is removed, since we don't use it anywhere. +- `ERRORTOKEN` and `ERROR_DEDENT` have been removed, because we don't intend to support + error recovery. If we encounter token errors, we'll raise a ParserSyntaxError instead. + +If performance becomes a concern, we can rewrite this later as a fork of the original +tokenize module, instead of as a wrapper. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Generator, List, Optional, Sequence + +from parso.python.token import PythonTokenTypes, TokenType +from parso.python.tokenize import ( + Token as OrigToken, + tokenize_lines as orig_tokenize_lines, +) +from parso.utils import PythonVersionInfo, split_lines + +from libcst._add_slots import add_slots +from libcst.exceptions import ParserSyntaxError +from libcst.parser._types.token import Token +from libcst.parser._types.whitespace_state import WhitespaceState + + +_ERRORTOKEN: TokenType = PythonTokenTypes.ERRORTOKEN +_ERROR_DEDENT: TokenType = PythonTokenTypes.ERROR_DEDENT + +_INDENT: TokenType = PythonTokenTypes.INDENT +_DEDENT: TokenType = PythonTokenTypes.DEDENT +_ENDMARKER: TokenType = PythonTokenTypes.ENDMARKER + +_FSTRING_START: TokenType = PythonTokenTypes.FSTRING_START +_FSTRING_END: TokenType = PythonTokenTypes.FSTRING_END + +_OP: TokenType = PythonTokenTypes.OP + + +class _ParenthesisOrFStringStackEntry(Enum): + PARENTHESIS = 0 + FSTRING = 0 + + +_PARENTHESIS_STACK_ENTRY: _ParenthesisOrFStringStackEntry = ( + _ParenthesisOrFStringStackEntry.PARENTHESIS +) +_FSTRING_STACK_ENTRY: _ParenthesisOrFStringStackEntry = ( + _ParenthesisOrFStringStackEntry.FSTRING +) + + +@add_slots +@dataclass(frozen=False) +class _TokenizeState: + lines: Sequence[str] + previous_whitespace_state: WhitespaceState = field( + default_factory=lambda: WhitespaceState( + line=1, column=0, absolute_indent="", is_parenthesized=False + ) + ) + indents: List[str] = field(default_factory=lambda: [""]) + parenthesis_or_fstring_stack: List[_ParenthesisOrFStringStackEntry] = field( + default_factory=list + ) + + +def tokenize( + code: str, version_info: PythonVersionInfo +) -> Generator[Token, None, None]: + lines = split_lines(code, keepends=True) + return tokenize_lines(lines, version_info) + + +def tokenize_lines( + lines: Sequence[str], version_info: PythonVersionInfo +) -> Generator[Token, None, None]: + state = _TokenizeState(lines) + orig_tokens_iter = iter(orig_tokenize_lines(lines, version_info)) + + # Iterate over the tokens and pass them to _convert_token, providing a one-token + # lookahead, to enable proper indent handling. + try: + curr_token = next(orig_tokens_iter) + except StopIteration: + pass # empty file + else: + for next_token in orig_tokens_iter: + yield _convert_token(state, curr_token, next_token) + curr_token = next_token + yield _convert_token(state, curr_token, None) + + +def _convert_token( + state: _TokenizeState, curr_token: OrigToken, next_token: Optional[OrigToken] +) -> Token: + ct_type = curr_token.type + ct_string = curr_token.string + ct_start_pos = curr_token.start_pos + if ct_type is _ERRORTOKEN: + raise ParserSyntaxError( + message="invalid token", + encountered=ct_string, + expected=["TOKEN"], + pos=curr_token.start_pos, + lines=state.lines, + ) + if ct_type is _ERROR_DEDENT: + raise ParserSyntaxError( + message="inconsistent indentation", + encountered=next_token.string if next_token is not None else None, + expected=["DEDENT"], + pos=curr_token.start_pos, + lines=state.lines, + ) + + # Compute relative indent changes for indent/dedent nodes + relative_indent: Optional[str] = None + if ct_type is _INDENT: + old_indent = "" if len(state.indents) < 2 else state.indents[-2] + new_indent = state.indents[-1] + relative_indent = new_indent[len(old_indent) :] + + if next_token is not None: + nt_type = next_token.type + if nt_type is _INDENT: + nt_line, nt_column = next_token.start_pos + state.indents.append(state.lines[nt_line - 1][:nt_column]) + elif nt_type is _DEDENT: + state.indents.pop() + + whitespace_before = state.previous_whitespace_state + + if ct_type is _INDENT or ct_type is _DEDENT or ct_type is _ENDMARKER: + # Don't update whitespace state for these dummy tokens. + whitespace_after = whitespace_before + ct_end_pos = ct_start_pos + else: + # Not a dummy token, so update the whitespace state. + + # Compute our own end_pos, since parso's end_pos is wrong for triple-strings. + lines = split_lines(ct_string) + if len(lines) > 1: + ct_end_pos = ct_start_pos[0] + len(lines) - 1, len(lines[-1]) + else: + ct_end_pos = (ct_start_pos[0], ct_start_pos[1] + len(ct_string)) + + # Figure out what mode the whitespace parser should use. If we're inside + # parentheses, certain whitespace (e.g. newlines) are allowed where they would + # otherwise not be. f-strings override and disable this behavior, however. + # + # Parso's tokenizer tracks this internally, but doesn't expose it, so we have to + # duplicate that logic here. + pof_stack = state.parenthesis_or_fstring_stack + if ct_type is _FSTRING_START: + pof_stack.append(_FSTRING_STACK_ENTRY) + elif ct_type is _FSTRING_END: + pof_stack.pop() + elif ct_type is _OP: + if ct_string in "([{": + pof_stack.append(_PARENTHESIS_STACK_ENTRY) + elif ct_string in ")]}": + pof_stack.pop() + is_parenthesized = ( + len(pof_stack) > 0 and pof_stack[-1] == _PARENTHESIS_STACK_ENTRY + ) + + whitespace_after = WhitespaceState( + ct_end_pos[0], ct_end_pos[1], state.indents[-1], is_parenthesized + ) + + # Hold onto whitespace_after, so we can use it as whitespace_before in the next + # node. + state.previous_whitespace_state = whitespace_after + + return Token( + ct_type, + ct_string, + ct_start_pos, + ct_end_pos, + whitespace_before, + whitespace_after, + relative_indent, + ) diff --git a/libcst/parser/tests/test_detect_config.py b/libcst/parser/tests/test_detect_config.py new file mode 100644 index 00000000..2faefe8a --- /dev/null +++ b/libcst/parser/tests/test_detect_config.py @@ -0,0 +1,167 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Union + +from libcst.parser._detect_config import detect_config +from libcst.parser._types.config import ParserConfig, PartialParserConfig +from libcst.testing.utils import UnitTest, data_provider + + +class TestDetectConfig(UnitTest): + @data_provider( + { + "empty_input": { + "source": b"", + "partial": PartialParserConfig(), + "detect_trailing_newline": True, + "expected_config": ParserConfig( + lines=["\n", ""], + encoding="utf-8", + default_indent=" ", + default_newline="\n", + has_trailing_newline=False, + ), + }, + "detect_trailing_newline_disabled": { + "source": b"", + "partial": PartialParserConfig(), + "detect_trailing_newline": False, + "expected_config": ParserConfig( + lines=[""], # the trailing newline isn't inserted + encoding="utf-8", + default_indent=" ", + default_newline="\n", + has_trailing_newline=False, + ), + }, + "newline_inferred": { + "source": b"first_line\r\n\nsomething\n", + "partial": PartialParserConfig(), + "detect_trailing_newline": True, + "expected_config": ParserConfig( + lines=["first_line\r\n", "\n", "something\n", ""], + encoding="utf-8", + default_indent=" ", + default_newline="\r\n", + has_trailing_newline=True, + ), + }, + "newline_partial_given": { + "source": b"first_line\r\nsecond_line\r\n", + "partial": PartialParserConfig(default_newline="\n"), + "detect_trailing_newline": True, + "expected_config": ParserConfig( + lines=["first_line\r\n", "second_line\r\n", ""], + encoding="utf-8", + default_indent=" ", + default_newline="\n", # The given partial disables inference + has_trailing_newline=True, + ), + }, + "indent_inferred": { + "source": b"if test:\n\t something\n", + "partial": PartialParserConfig(), + "detect_trailing_newline": True, + "expected_config": ParserConfig( + lines=["if test:\n", "\t something\n", ""], + encoding="utf-8", + default_indent="\t ", + default_newline="\n", + has_trailing_newline=True, + ), + }, + "indent_partial_given": { + "source": b"if test:\n\t something\n", + "partial": PartialParserConfig(default_indent=" "), + "detect_trailing_newline": True, + "expected_config": ParserConfig( + lines=["if test:\n", "\t something\n", ""], + encoding="utf-8", + default_indent=" ", + default_newline="\n", + has_trailing_newline=True, + ), + }, + "encoding_inferred": { + "source": b"#!/usr/bin/python3\n# -*- coding: latin-1 -*-\npass\n", + "partial": PartialParserConfig(), + "detect_trailing_newline": True, + "expected_config": ParserConfig( + lines=[ + "#!/usr/bin/python3\n", + "# -*- coding: latin-1 -*-\n", + "pass\n", + "", + ], + encoding="iso-8859-1", # this is an alias for latin-1 + default_indent=" ", + default_newline="\n", + has_trailing_newline=True, + ), + }, + "encoding_partial_given": { + "source": b"#!/usr/bin/python3\n# -*- coding: latin-1 -*-\npass\n", + "partial": PartialParserConfig(encoding="us-ascii"), + "detect_trailing_newline": True, + "expected_config": ParserConfig( + lines=[ + "#!/usr/bin/python3\n", + "# -*- coding: latin-1 -*-\n", + "pass\n", + "", + ], + encoding="us-ascii", + default_indent=" ", + default_newline="\n", + has_trailing_newline=True, + ), + }, + "encoding_str_not_bytes_disables_inference": { + "source": "#!/usr/bin/python3\n# -*- coding: latin-1 -*-\npass\n", + "partial": PartialParserConfig(), + "detect_trailing_newline": True, + "expected_config": ParserConfig( + lines=[ + "#!/usr/bin/python3\n", + "# -*- coding: latin-1 -*-\n", + "pass\n", + "", + ], + encoding="utf-8", # because source is a str, don't infer latin-1 + default_indent=" ", + default_newline="\n", + has_trailing_newline=True, + ), + }, + "encoding_non_ascii_compatible_utf_16_with_bom": { + "source": b"\xff\xfet\x00e\x00s\x00t\x00", + "partial": PartialParserConfig(encoding="utf-16"), + "detect_trailing_newline": True, + "expected_config": ParserConfig( + lines=["test\n", ""], + encoding="utf-16", + default_indent=" ", + default_newline="\n", + has_trailing_newline=False, + ), + }, + } + ) + def test_detect_module_config( + self, + *, + source: Union[str, bytes], + partial: PartialParserConfig, + detect_trailing_newline: bool, + expected_config: ParserConfig, + ) -> None: + self.assertEqual( + detect_config( + source, partial=partial, detect_trailing_newline=detect_trailing_newline + ).config, + expected_config, + ) diff --git a/libcst/parser/tests/test_whitespace_parser.py b/libcst/parser/tests/test_whitespace_parser.py new file mode 100644 index 00000000..34e0de88 --- /dev/null +++ b/libcst/parser/tests/test_whitespace_parser.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from dataclasses import dataclass +from typing import Callable, Sequence, TypeVar + +import libcst.nodes as cst +from libcst.nodes._deep_equals import deep_equals +from libcst.parser._types.config import BaseWhitespaceParserConfig +from libcst.parser._types.whitespace_state import WhitespaceState as State +from libcst.parser._whitespace_parser import ( + parse_empty_lines, + parse_simple_whitespace, + parse_trailing_whitespace, +) +from libcst.testing.utils import UnitTest, data_provider + + +_T = TypeVar("_T") + + +@dataclass(frozen=True) +class Config(BaseWhitespaceParserConfig): + lines: Sequence[str] + default_newline: str + + +class WhitespaceParserTest(UnitTest): + @data_provider( + { + "simple_whitespace_empty": { + "parser": parse_simple_whitespace, + "config": Config( + lines=["not whitespace\n", " another line\n"], default_newline="\n" + ), + "start_state": State( + line=1, column=0, absolute_indent="", is_parenthesized=False + ), + "end_state": State( + line=1, column=0, absolute_indent="", is_parenthesized=False + ), + "expected_node": cst.SimpleWhitespace(""), + }, + "simple_whitespace_start_of_line": { + "parser": parse_simple_whitespace, + "config": Config( + lines=["\t <-- There's some whitespace there\n"], + default_newline="\n", + ), + "start_state": State( + line=1, column=0, absolute_indent="", is_parenthesized=False + ), + "end_state": State( + line=1, column=3, absolute_indent="", is_parenthesized=False + ), + "expected_node": cst.SimpleWhitespace("\t "), + }, + "simple_whitespace_end_of_line": { + "parser": parse_simple_whitespace, + "config": Config(lines=["prefix "], default_newline="\n"), + "start_state": State( + line=1, column=6, absolute_indent="", is_parenthesized=False + ), + "end_state": State( + line=1, column=9, absolute_indent="", is_parenthesized=False + ), + "expected_node": cst.SimpleWhitespace(" "), + }, + "simple_whitespace_line_continuation": { + "parser": parse_simple_whitespace, + "config": Config( + lines=["prefix \\\n", " \\\n", " # suffix\n"], + default_newline="\n", + ), + "start_state": State( + line=1, column=6, absolute_indent="", is_parenthesized=False + ), + "end_state": State( + line=3, column=4, absolute_indent="", is_parenthesized=False + ), + "expected_node": cst.SimpleWhitespace(" \\\n \\\n "), + }, + "empty_lines_empty_list": { + "parser": parse_empty_lines, + "config": Config( + lines=["this is not an empty line"], default_newline="\n" + ), + "start_state": State( + line=1, column=0, absolute_indent="", is_parenthesized=False + ), + "end_state": State( + line=1, column=0, absolute_indent="", is_parenthesized=False + ), + "expected_node": [], + }, + "empty_lines_single_line": { + "parser": parse_empty_lines, + "config": Config( + lines=[" # comment\n", "this is not an empty line\n"], + default_newline="\n", + ), + "start_state": State( + line=1, column=0, absolute_indent=" ", is_parenthesized=False + ), + "end_state": State( + line=2, column=0, absolute_indent=" ", is_parenthesized=False + ), + "expected_node": [ + cst.EmptyLine( + indent=True, + whitespace=cst.SimpleWhitespace(""), + comment=cst.Comment("# comment"), + newline=cst.Newline(), + ) + ], + }, + "empty_lines_multiple": { + "parser": parse_empty_lines, + "config": Config( + lines=[ + "\n", + " \n", + " # comment with indent and whitespace\n", + "# comment without indent\n", + " # comment with no indent but some whitespace\n", + ], + default_newline="\n", + ), + "start_state": State( + line=1, column=0, absolute_indent=" ", is_parenthesized=False + ), + "end_state": State( + line=5, column=47, absolute_indent=" ", is_parenthesized=False + ), + "expected_node": [ + cst.EmptyLine( + indent=False, + whitespace=cst.SimpleWhitespace(""), + comment=None, + newline=cst.Newline(), + ), + cst.EmptyLine( + indent=True, + whitespace=cst.SimpleWhitespace(""), + comment=None, + newline=cst.Newline(), + ), + cst.EmptyLine( + indent=True, + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# comment with indent and whitespace"), + newline=cst.Newline(), + ), + cst.EmptyLine( + indent=False, + whitespace=cst.SimpleWhitespace(""), + comment=cst.Comment("# comment without indent"), + newline=cst.Newline(), + ), + cst.EmptyLine( + indent=False, + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment( + "# comment with no indent but some whitespace" + ), + newline=cst.Newline(), + ), + ], + }, + "empty_lines_non_default_newline": { + "parser": parse_empty_lines, + "config": Config(lines=["\n", "\r\n", "\r"], default_newline="\n"), + "start_state": State( + line=1, column=0, absolute_indent="", is_parenthesized=False + ), + "end_state": State( + line=3, column=1, absolute_indent="", is_parenthesized=False + ), + "expected_node": [ + cst.EmptyLine( + indent=True, + whitespace=cst.SimpleWhitespace(""), + comment=None, + newline=cst.Newline(None), # default newline + ), + cst.EmptyLine( + indent=True, + whitespace=cst.SimpleWhitespace(""), + comment=None, + newline=cst.Newline("\r\n"), # non-default + ), + cst.EmptyLine( + indent=True, + whitespace=cst.SimpleWhitespace(""), + comment=None, + newline=cst.Newline("\r"), # non-default + ), + ], + }, + "trailing_whitespace": { + "parser": parse_trailing_whitespace, + "config": Config( + lines=["some code # comment\n"], default_newline="\n" + ), + "start_state": State( + line=1, column=9, absolute_indent="", is_parenthesized=False + ), + "end_state": State( + line=1, column=21, absolute_indent="", is_parenthesized=False + ), + "expected_node": cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment("# comment"), + newline=cst.Newline(), + ), + }, + } + ) + def test_parsers( + self, + parser: Callable[[Config, State], _T], + config: Config, + start_state: State, + end_state: State, + expected_node: _T, + ) -> None: + # Uses internal `deep_equals` function instead of `CSTNode.deep_equals`, because + # we need to compare sequences of nodes, and this is the easiest way. :/ + parsed_node = parser(config, start_state) + self.assertTrue( + deep_equals(parsed_node, expected_node), + msg=f"\n{parsed_node!r}\nis not deeply equal to \n{expected_node!r}", + ) + self.assertEqual(start_state, end_state) diff --git a/libcst/parser/tests/test_wrapped_tokenize.py b/libcst/parser/tests/test_wrapped_tokenize.py new file mode 100644 index 00000000..75c6c411 --- /dev/null +++ b/libcst/parser/tests/test_wrapped_tokenize.py @@ -0,0 +1,241 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Sequence + +from parso.python.token import PythonTokenTypes +from parso.utils import parse_version_string + +from libcst.exceptions import ParserSyntaxError +from libcst.parser._types.whitespace_state import WhitespaceState +from libcst.parser._wrapped_tokenize import Token, tokenize +from libcst.testing.utils import UnitTest, data_provider + + +_PY38 = parse_version_string("3.8.0") + + +class WrappedTokenizeTest(UnitTest): + maxDiff = 10000 + + @data_provider( + { + "simple": ( + "pass;\n", + ( + Token( + type=PythonTokenTypes.NAME, + string="pass", + start_pos=(1, 0), + end_pos=(1, 4), + whitespace_before=WhitespaceState( + line=1, column=0, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=1, column=4, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.OP, + string=";", + start_pos=(1, 4), + end_pos=(1, 5), + whitespace_before=WhitespaceState( + line=1, column=4, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=1, column=5, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.NEWLINE, + string="\n", + start_pos=(1, 5), + end_pos=(2, 0), + whitespace_before=WhitespaceState( + line=1, column=5, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=2, column=0, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.ENDMARKER, + string="", + start_pos=(2, 0), + end_pos=(2, 0), + whitespace_before=WhitespaceState( + line=2, column=0, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=2, column=0, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + ), + ), + "with_indent": ( + "if foo:\n bar\n", + ( + Token( + type=PythonTokenTypes.NAME, + string="if", + start_pos=(1, 0), + end_pos=(1, 2), + whitespace_before=WhitespaceState( + line=1, column=0, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=1, column=2, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.NAME, + string="foo", + start_pos=(1, 3), + end_pos=(1, 6), + whitespace_before=WhitespaceState( + line=1, column=2, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=1, column=6, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.OP, + string=":", + start_pos=(1, 6), + end_pos=(1, 7), + whitespace_before=WhitespaceState( + line=1, column=6, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=1, column=7, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.NEWLINE, + string="\n", + start_pos=(1, 7), + end_pos=(2, 0), + whitespace_before=WhitespaceState( + line=1, column=7, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=2, + column=0, + absolute_indent=" ", + is_parenthesized=False, + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.INDENT, + string="", + start_pos=(2, 4), + end_pos=(2, 4), + whitespace_before=WhitespaceState( + line=2, + column=0, + absolute_indent=" ", + is_parenthesized=False, + ), + whitespace_after=WhitespaceState( + line=2, + column=0, + absolute_indent=" ", + is_parenthesized=False, + ), + relative_indent=" ", + ), + Token( + type=PythonTokenTypes.NAME, + string="bar", + start_pos=(2, 4), + end_pos=(2, 7), + whitespace_before=WhitespaceState( + line=2, + column=0, + absolute_indent=" ", + is_parenthesized=False, + ), + whitespace_after=WhitespaceState( + line=2, + column=7, + absolute_indent=" ", + is_parenthesized=False, + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.NEWLINE, + string="\n", + start_pos=(2, 7), + end_pos=(3, 0), + whitespace_before=WhitespaceState( + line=2, + column=7, + absolute_indent=" ", + is_parenthesized=False, + ), + whitespace_after=WhitespaceState( + line=3, column=0, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.DEDENT, + string="", + start_pos=(3, 0), + end_pos=(3, 0), + whitespace_before=WhitespaceState( + line=3, column=0, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=3, column=0, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + Token( + type=PythonTokenTypes.ENDMARKER, + string="", + start_pos=(3, 0), + end_pos=(3, 0), + whitespace_before=WhitespaceState( + line=3, column=0, absolute_indent="", is_parenthesized=False + ), + whitespace_after=WhitespaceState( + line=3, column=0, absolute_indent="", is_parenthesized=False + ), + relative_indent=None, + ), + ), + ), + } + ) + def test_tokenize(self, code: str, expected: Sequence[Token]) -> None: + tokens = tuple(tokenize(code, _PY38)) + self.assertSequenceEqual(tokens, expected) + for a, b in zip(tokens, tokens[1:]): + # These must be the same object, so if whitespace gets consumed (mutated) at + # the end of token a, it shows up at the beginning of token b. + self.assertIs(a.whitespace_after, b.whitespace_before) + + def test_errortoken(self) -> None: + with self.assertRaisesRegex(ParserSyntaxError, "invalid token"): + # use tuple() to read everything + # The copyright symbol isn't a valid token + tuple(tokenize("\u00a9", _PY38)) + + def test_error_dedent(self) -> None: + with self.assertRaisesRegex(ParserSyntaxError, "inconsistent indentation"): + # create some inconsistent indents to generate an ERROR_DEDENT token + tuple(tokenize(" a\n b", _PY38)) diff --git a/libcst/testing/__init__.py b/libcst/testing/__init__.py new file mode 100644 index 00000000..62642369 --- /dev/null +++ b/libcst/testing/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/libcst/testing/utils.py b/libcst/testing/utils.py new file mode 100644 index 00000000..7ed8122b --- /dev/null +++ b/libcst/testing/utils.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import re +from functools import wraps +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) +from unittest import TestCase + + +DATA_PROVIDER_DATA_ATTR_NAME = "__data_provider_data" +DATA_PROVIDER_DESCRIPTION_PREFIX = "_data_provider_" +PROVIDER_TEST_LIMIT_ATTR_NAME = "__provider_test_limit" +DEFAULT_TEST_LIMIT = 256 + + +T = TypeVar("T") + + +def none_throws(value: Optional[T], message: str = "Unexpected None value") -> T: + assert value is not None, message + return value + + +def update_test_limit(test_method: Any, test_limit: int) -> None: + # Store the maximum number of generated tests on the test_method. Since + # contextmanager_provider can be specified multiple times, we need to + # take the maximum of the existing attribute and the current value + existing_test_limit = getattr( + test_method, PROVIDER_TEST_LIMIT_ATTR_NAME, test_limit + ) + setattr( + test_method, PROVIDER_TEST_LIMIT_ATTR_NAME, max(existing_test_limit, test_limit) + ) + + +def try_get_provider_attr( + member_name: str, member: Any, attr_name: str +) -> Optional[Any]: + if inspect.isfunction(member) and member_name.startswith("test"): + return getattr(member, attr_name, None) + return None + + +def populate_data_provider_tests(dct: Dict[str, Any]) -> None: + test_methods_to_add: Dict[str, Callable] = {} + test_methods_to_remove: List[str] = [] + for member_name, member in dct.items(): + provider_data = try_get_provider_attr( + member_name, member, DATA_PROVIDER_DATA_ATTR_NAME + ) + if provider_data is not None: + + for description, data in ( + provider_data.items() + if isinstance(provider_data, dict) + else enumerate(provider_data) + ): + if isinstance(provider_data, dict): + description = f"{DATA_PROVIDER_DESCRIPTION_PREFIX}{description}" + + assert re.fullmatch( + r"[a-zA-Z0-9_]+", str(description) + ), f"Testcase description must be a valid python identifier: '{description}'" + + @wraps(member) + def new_test( + self: object, + data: Iterable[object] = data, + member: Callable[..., object] = member, + ) -> object: + if isinstance(data, dict): + return member(self, **data) + else: + return member(self, *data) + + name = f"{member_name}_{description}" + new_test.__name__ = name + test_methods_to_add[name] = new_test + if not test_methods_to_add: + raise ValueError( + f"No data_provider tests were created for {member_name}! Please double check your data." + ) + test_methods_to_remove.append(member_name) + dct.update(test_methods_to_add) + + # Remove all old methods + for test_name in test_methods_to_remove: + del dct[test_name] + + +def validate_provider_tests(dct: Dict[str, Any]) -> None: + members_to_replace = {} + + for member_name, member in dct.items(): + test_limit = try_get_provider_attr( + member_name, member, PROVIDER_TEST_LIMIT_ATTR_NAME + ) + if test_limit is not None: + data = try_get_provider_attr( + member_name, member, DATA_PROVIDER_DATA_ATTR_NAME + ) + num_tests = len(data) if data else 1 + + if num_tests > test_limit: + # We don't use wraps() here so that the test isn't expanded + # as it normally would be by whichever provider it uses + def test_replacement( + self: Any, + member_name: Any = member_name, + num_tests: Any = num_tests, + test_limit: Any = test_limit, + ) -> None: + raise AssertionError( + f"{member_name} generated {num_tests} tests but the limit is " + + f"{test_limit}. You can increase the number of " + + "allowed tests by specifying test_limit, but please " + + "consider whether you really need to test all of " + + "these combinations." + ) + + test_replacement.__name__ = member_name + members_to_replace[member_name] = test_replacement + + for member_name, new_member in members_to_replace.items(): + dct[member_name] = new_member + + +TestCaseType = Union[Sequence[object], Mapping[str, object]] +# Can't use Sequence[TestCaseType] here as some clients may pass in a Generator[TestCaseType] +StaticDataType = Union[Iterable[TestCaseType], Mapping[str, TestCaseType]] + + +def data_provider( + static_data: StaticDataType, *, test_limit: int = DEFAULT_TEST_LIMIT +) -> Callable[[Callable], Callable]: + # We need to be able to iterate over static_data more than once + # (for validation), so if we weren't passed in a dict, list, or tuple + # then we'll just create a list from the data + if not isinstance(static_data, (dict, list, tuple)): + static_data = list(static_data) + + def test_decorator(test_method: Callable) -> Callable: + update_test_limit(test_method, test_limit) + + setattr(test_method, DATA_PROVIDER_DATA_ATTR_NAME, static_data) + return test_method + + return test_decorator + + +class BaseTestMeta(type): + def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> object: + validate_provider_tests(dct) + populate_data_provider_tests(dct) + return super().__new__(mcs, name, bases, dict(dct)) + + +class UnitTest(TestCase, metaclass=BaseTestMeta): + pass diff --git a/libcst/tests/test_exceptions.py b/libcst/tests/test_exceptions.py new file mode 100644 index 00000000..e1e13c3c --- /dev/null +++ b/libcst/tests/test_exceptions.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from textwrap import dedent + +from libcst.exceptions import ParserSyntaxError +from libcst.testing.utils import UnitTest, data_provider + + +class ExceptionsTest(UnitTest): + @data_provider( + [ + ( + ParserSyntaxError( + message="some message", + encountered=None, # EOF + expected=None, # EOF + pos=(1, 0), + lines=["abcd"], + ), + dedent( + """ + Syntax Error: some message @ 1:1. + Encountered end of file (EOF), but expected end of file (EOF). + + abcd + ^ + """ + ).strip(), + ), + ( + ParserSyntaxError( + message="some message", + encountered="encountered_value", + expected=["expected_value"], + pos=(1, 2), + lines=["\tabcd\r\n"], + ), + dedent( + """ + Syntax Error: some message @ 1:10. + Encountered 'encountered_value', but expected one of ['expected_value']. + + abcd + ^ + """ + ).strip(), + ), + ] + ) + def test_parser_syntax_error_str( + self, err: ParserSyntaxError, expected: str + ) -> None: + self.assertEqual(str(err), expected) diff --git a/libcst/tests/test_tabs.py b/libcst/tests/test_tabs.py new file mode 100644 index 00000000..1be5c421 --- /dev/null +++ b/libcst/tests/test_tabs.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from libcst._tabs import expand_tabs +from libcst.testing.utils import UnitTest, data_provider + + +class ExpandTabsTest(UnitTest): + @data_provider( + [ + ("\t", " " * 8), + ("\t\t", " " * 16), + (" \t", " " * 8), + ("\t ", " " * 12), + ("abcd\t", "abcd "), + ("abcdefg\t", "abcdefg "), + ("abcdefgh\t", "abcdefgh "), + ("\tsuffix", " suffix"), + ] + ) + def test_expand_tabs(self, input, output) -> None: + self.assertEqual(expand_tabs(input), output) diff --git a/libcst/tool.py b/libcst/tool.py new file mode 100644 index 00000000..652f1021 --- /dev/null +++ b/libcst/tool.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: +# +# python -m libcst.tool --help +# python -m libcst.tool print python_file.py + +import argparse +import sys +from typing import List + +from libcst.parser import parse_module + + +def print_tree(args: argparse.Namespace) -> int: + infile = args.infile + + # Grab input file + if infile == "-": + code = sys.stdin.read() + else: + with open(infile, "rb") as fp: + code = fp.read() + + tree = parse_module(code) + print(tree) + return 0 + + +def main(cli_args: List[str]) -> int: + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(title="commands", description="valid commands") + + print_parser = subparsers.add_parser( + "print", help="Print LibCST tree for a python file." + ) + print_parser.set_defaults(func=print_tree) + print_parser.add_argument( + "infile", + metavar="INFILE", + help='File to print tree for. Use "-" for stdin', + type=str, + ) + args = parser.parse_args(cli_args) + if "func" in args: + return args.func(args) + else: + print("Please specify a command!\n", file=sys.stderr) + parser.print_help(sys.stderr) + return 1 + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..927fefd6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +parso +typing_extensions diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..f3f4f567 --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from os import path + +# pyre-ignore Pyre doesn't know about setuptools. +import setuptools + + +this_directory = path.abspath(path.dirname(__file__)) +with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +setuptools.setup( + name="libcst", + description="A concrete syntax tree with AST-like properties for Python 3.7.", + long_description=long_description, + long_description_content_type="text/markdown", + version="0.1.dev", + packages=setuptools.find_packages(), +) diff --git a/stubs/parso/__init__.pyi b/stubs/parso/__init__.pyi new file mode 100644 index 00000000..bf4f2554 --- /dev/null +++ b/stubs/parso/__init__.pyi @@ -0,0 +1,6 @@ +from parso.grammar import Grammar, load_grammar +from parso.parser import ParserSyntaxError +from parso.utils import python_bytes_to_unicode, split_lines + + +__version__: str = ... diff --git a/stubs/parso/grammar.pyi b/stubs/parso/grammar.pyi new file mode 100644 index 00000000..03386406 --- /dev/null +++ b/stubs/parso/grammar.pyi @@ -0,0 +1,40 @@ +from typing import Any, Callable, Generic, Sequence, TypeVar + +from parso.utils import PythonVersionInfo + + +_Token = Any +_NodeT = TypeVar("_NodeT") + +class Grammar(Generic[_NodeT]): + _default_normalizer_config: Optional[Any] = ... + _error_normalizer_config: Optional[Any] = None + _start_nonterminal: str = ... + _token_namespace: Optional[str] = None + def __init__( + self, + text: str, + tokenizer: Callable[[Sequence[str], int], Sequence[_Token]], + parser: Any = ..., + diff_parser: Any = None, + ) -> None: ... + def parse( + self, + code: Union[str, bytes] = None, + error_recovery: bool = True, + path: Optional[str] = None, + start_symbol: Optional[str] = None, + cache: bool = False, + diff_cache: bool = False, + cache_path: Optional[str] = None, + ) -> _NodeT: ... + +class PythonGrammar(Grammar): + version_info: PythonVersionInfo + def __init__(self, bnf_text: str, version_info: PythonVersionInfo) -> None: ... + +# Realistically, this should be `language: Literal["python"]` since only python is +# supported, but pyre doesn't support literal types yet. +def load_grammar( + language: str = "python", version: Optional[str] = None, path: str = None +) -> Grammar: ... diff --git a/stubs/parso/pgen2/__init__.pyi b/stubs/parso/pgen2/__init__.pyi new file mode 100644 index 00000000..c62c856b --- /dev/null +++ b/stubs/parso/pgen2/__init__.pyi @@ -0,0 +1 @@ +from parso.pgen2.generator import generate_grammar diff --git a/stubs/parso/pgen2/generator.pyi b/stubs/parso/pgen2/generator.pyi new file mode 100644 index 00000000..9ad3c43d --- /dev/null +++ b/stubs/parso/pgen2/generator.pyi @@ -0,0 +1,39 @@ +from typing import Any, Generic, Mapping, Sequence, TypeVar, Union + +from parso.pgen2.grammar_parser import NFAState + + +_TokenTypeT = TypeVar("_TokenTypeT") + +class Grammar(Generic[_TokenTypeT]): + nonterminal_to_dfas: Mapping[str, Sequence[DFAState[_TokenTypeT]]] + reserved_syntax_strings: Mapping[str, ReservedString] + start_nonterminal: str + def __init__( + self, + start_nonterminal: str, + rule_to_dfas: Mapping[str, Sequence[DFAState]], + reserved_syntax_strings: Mapping[str, ReservedString], + ) -> None: ... + +class DFAPlan: + next_dfa: DFAState + dfa_pushes: Sequence[DFAState] + +class DFAState(Generic[_TokenTypeT]): + from_rule: str + nfa_set: Set[NFAState] + is_final: bool + arcs: Mapping[str, DFAState] # map from all terminals/nonterminals to DFAState + nonterminal_arcs: Mapping[str, DFAState] + transitions: Mapping[Union[_TokenTypeT, ReservedString], DFAPlan] + def __init__( + self, from_rule: str, nfa_set: Set[NFAState], final: NFAState + ) -> None: ... + +class ReservedString: + value: str + def __init__(self, value: str) -> None: ... + def __repr__(self) -> str: ... + +def generate_grammar(bnf_grammar: str, token_namespace: Any) -> Grammar[Any]: ... diff --git a/stubs/parso/pgen2/grammar_parser.pyi b/stubs/parso/pgen2/grammar_parser.pyi new file mode 100644 index 00000000..0adb90cd --- /dev/null +++ b/stubs/parso/pgen2/grammar_parser.pyi @@ -0,0 +1,21 @@ +from typing import Generator, List, Optional, Tuple + +from parso.python.token import PythonToken + + +class GrammarParser: + generator: Generator[PythonToken, None, None] + def __init__(self, bnf_grammar: str) -> None: ... + def parse(self) -> Generator[Tuple[NFAState, NFAState], None, None]: ... + +class NFAArc: + next: NFAState + nonterminal_or_string: Optional[str] + def __init__( + self, next_: NFAState, nonterminal_or_string: Optional[str] + ) -> None: ... + +class NFAState: + from_rule: str + arcs: List[NFAArc] + def __init__(self, from_rule: str) -> None: ... diff --git a/stubs/parso/python/token.pyi b/stubs/parso/python/token.pyi new file mode 100644 index 00000000..d533fd77 --- /dev/null +++ b/stubs/parso/python/token.pyi @@ -0,0 +1,31 @@ +from typing import Container, Iterable + + +class TokenType: + name: str + contains_syntax: bool + def __init__(self, name: str, contains_syntax: bool) -> None: ... + +class TokenTypes: + def __init__( + self, names: Iterable[str], contains_syntax: Container[str] + ) -> None: ... + +# not an actual class in the source code, but we need this class to type the fields of +# PythonTokenTypes +class _FakePythonTokenTypesClass(TokenTypes): + STRING: TokenType + NUMBER: TokenType + NAME: TokenType + ERRORTOKEN: TokenType + NEWLINE: TokenType + INDENT: TokenType + DEDENT: TokenType + ERROR_DEDENT: TokenType + FSTRING_STRING: TokenType + FSTRING_START: TokenType + FSTRING_END: TokenType + OP: TokenType + ENDMARKER: TokenType + +PythonTokenTypes: _FakePythonTokenTypesClass = ... diff --git a/stubs/parso/python/tokenize.pyi b/stubs/parso/python/tokenize.pyi new file mode 100644 index 00000000..ad7fc60a --- /dev/null +++ b/stubs/parso/python/tokenize.pyi @@ -0,0 +1,25 @@ +from typing import Generator, Iterable, NamedTuple, Tuple + +from parso.python.token import TokenType +from parso.utils import PythonVersionInfo + + +class Token(NamedTuple): + type: TokenType + string: str + start_pos: Tuple[int, int] + prefix: str + @property + def end_pos(self) -> Tuple[int, int]: ... + +class PythonToken(Token): + def __repr__(self) -> str: ... + +def tokenize( + code: str, version_info: PythonVersionInfo, start_pos: Tuple[int, int] = (1, 0) +) -> Generator[PythonToken, None, None]: ... +def tokenize_lines( + lines: Iterable[str], + version_info: PythonVersionInfo, + start_pos: Tuple[int, int] = (1, 0), +) -> Generator[PythonToken, None, None]: ... diff --git a/stubs/parso/utils.pyi b/stubs/parso/utils.pyi new file mode 100644 index 00000000..8e0eeb98 --- /dev/null +++ b/stubs/parso/utils.pyi @@ -0,0 +1,30 @@ +from typing import NamedTuple, Optional, Sequence, Union + + +class Version(NamedTuple): + major: int + minor: int + micro: int + +def split_lines(string: str, keepends: bool = False) -> Sequence[str]: ... +def python_bytes_to_unicode( + source: Union[str, bytes], encoding: str = "utf-8", errors: str = "strict" +) -> str: ... +def version_info() -> Version: + """ + Returns a namedtuple of parso's version, similar to Python's + ``sys.version_info``. + """ + ... + +class PythonVersionInfo(NamedTuple): + major: int + minor: int + +def parse_version_string(version: Optional[str]) -> PythonVersionInfo: + """ + Checks for a valid version number (e.g. `3.2` or `2.7.1` or `3`) and + returns a corresponding version info that is always two characters long in + decimal. + """ + ... diff --git a/stubs/tokenize.pyi b/stubs/tokenize.pyi new file mode 100644 index 00000000..38f9a5c2 --- /dev/null +++ b/stubs/tokenize.pyi @@ -0,0 +1,96 @@ +from token import ( + AMPER, + AMPEREQUAL, + AT, + ATEQUAL, + CIRCUMFLEX, + CIRCUMFLEXEQUAL, + COLON, + COLONEQUAL, + COMMA, + COMMENT, + DEDENT, + DOT, + DOUBLESLASH, + DOUBLESLASHEQUAL, + DOUBLESTAR, + DOUBLESTAREQUAL, + ELLIPSIS, + ENCODING, + ENDMARKER, + EQEQUAL, + EQUAL, + ERRORTOKEN, + EXACT_TOKEN_TYPES, + GREATER, + GREATEREQUAL, + INDENT, + LBRACE, + LEFTSHIFT, + LEFTSHIFTEQUAL, + LESS, + LESSEQUAL, + LPAR, + LSQB, + MINEQUAL, + MINUS, + N_TOKENS, + NAME, + NEWLINE, + NL, + NOTEQUAL, + NT_OFFSET, + NUMBER, + OP, + PERCENT, + PERCENTEQUAL, + PLUS, + PLUSEQUAL, + RARROW, + RBRACE, + RIGHTSHIFT, + RIGHTSHIFTEQUAL, + RPAR, + RSQB, + SEMI, + SLASH, + SLASHEQUAL, + STAR, + STAREQUAL, + STRING, + TILDE, + TYPE_COMMENT, + TYPE_IGNORE, + VBAR, + VBAREQUAL, +) +from typing import Callable, Tuple + + +Hexnumber: str = ... +Binnumber: str = ... +Octnumber: str = ... +Decnumber: str = ... +Intnumber: str = ... +Exponent: str = ... +Pointfloat: str = ... +Expfloat: str = ... +Floatnumber: str = ... +Imagnumber: str = ... +Number: str = ... +Whitespace: str = ... +Comment: str = ... +Ignore: str = ... +Name: str = ... + +class TokenInfo(Tuple[int, str, Tuple[int, int], Tuple[int, int], int]): + exact_type: int = ... + type: int = ... + string: str = ... + start: Tuple[int, int] = ... + end: Tuple[int, int] = ... + line: int = ... + def __repr__(self) -> str: ... + +def detect_encoding(readline: Callable[[], bytes]) -> Tuple[str, Sequence[bytes]]: ... +def tokenize(Callable) -> TokenInfo: ...