diff --git a/libcst/nodes/_base.py b/libcst/nodes/_base.py index b6f95937..36b992b6 100644 --- a/libcst/nodes/_base.py +++ b/libcst/nodes/_base.py @@ -10,7 +10,7 @@ from typing import Any, List, Sequence, TypeVar, Union, cast from libcst._base_visitor import CSTVisitor from libcst._removal_sentinel import RemovalSentinel -from libcst.nodes._internal import CodegenState +from libcst.nodes._internal import CodegenState, CodePosition _CSTNodeSelfT = TypeVar("_CSTNodeSelfT", bound="CSTNode") @@ -184,9 +184,15 @@ class CSTNode(ABC): ... @abstractmethod - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: ... + def _codegen(self, state: CodegenState, **kwargs: Any) -> None: + start = state.line, state.column + self._codegen_impl(state, **kwargs) + end = state.line, state.column + state.update_position(self, CodePosition(start, end)) + def with_changes(self: _CSTNodeSelfT, **changes: Any) -> _CSTNodeSelfT: """ A convenience method for performing mutation-like operations on immutable nodes. @@ -246,8 +252,9 @@ class CSTNode(ABC): lines.append(f"{type(self).__name__}(") for field in fields(self): key = field.name - value = getattr(self, key) - lines.append(_indent(f"{key}={_pretty_repr(value)},")) + if key[0] != "_": + value = getattr(self, key) + lines.append(_indent(f"{key}={_pretty_repr(value)},")) lines.append(")") return "\n".join(lines) @@ -274,7 +281,7 @@ class BaseValueToken(BaseLeaf, ABC): value: str - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token(self.value) diff --git a/libcst/nodes/_dummy.py b/libcst/nodes/_dummy.py index ca36496d..d22f1576 100644 --- a/libcst/nodes/_dummy.py +++ b/libcst/nodes/_dummy.py @@ -57,7 +57,7 @@ class DummyNode(CSTNode): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState, **kwargs: Any) -> None: + def _codegen_impl(self, state: CodegenState, **kwargs: Any) -> None: for lpar in self.lpar: lpar._codegen(state) for child in self.children: diff --git a/libcst/nodes/_expression.py b/libcst/nodes/_expression.py index 83a45b54..69ba145f 100644 --- a/libcst/nodes/_expression.py +++ b/libcst/nodes/_expression.py @@ -72,7 +72,7 @@ class LeftSquareBracket(CSTNode): ) ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token("[") self.whitespace_after._codegen(state) @@ -94,7 +94,7 @@ class RightSquareBracket(CSTNode): ) ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.whitespace_before._codegen(state) state.add_token("]") @@ -116,7 +116,7 @@ class LeftParen(CSTNode): ) ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token("(") self.whitespace_after._codegen(state) @@ -138,7 +138,7 @@ class RightParen(CSTNode): ) ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.whitespace_before._codegen(state) state.add_token(")") @@ -262,7 +262,7 @@ class Name(BaseAssignTargetExpression, BaseDelTargetExpression, BaseAtom): if not self.value.isidentifier(): raise CSTValidationError("Name is not a valid identifier.") - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token(self.value) @@ -286,7 +286,7 @@ class Ellipses(BaseAtom): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token("...") @@ -314,7 +314,7 @@ class Integer(_BaseParenthesizedNode): if not re.fullmatch(INTNUMBER_RE, self.value): raise CSTValidationError("Number is not a valid integer.") - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token(self.value) @@ -342,7 +342,7 @@ class Float(_BaseParenthesizedNode): if not re.fullmatch(FLOATNUMBER_RE, self.value): raise CSTValidationError("Number is not a valid float.") - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token(self.value) @@ -370,7 +370,7 @@ class Imaginary(_BaseParenthesizedNode): if not re.fullmatch(IMAGNUMBER_RE, self.value): raise CSTValidationError("Number is not a valid imaginary.") - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token(self.value) @@ -408,7 +408,7 @@ class Number(BaseAtom): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): operator = self.operator if operator is not None: @@ -483,7 +483,7 @@ class SimpleString(BaseString): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token(self.value) @@ -505,7 +505,7 @@ class FormattedStringText(BaseFormattedStringContent): def _visit_and_replace_children(self, visitor: CSTVisitor) -> "FormattedStringText": return FormattedStringText(value=self.value) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token(self.value) @@ -551,7 +551,7 @@ class FormattedStringExpression(BaseFormattedStringContent): ), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token("{") self.whitespace_before_expression._codegen(state) self.expression._codegen(state) @@ -620,7 +620,7 @@ class FormattedString(BaseString): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token(self.start) for part in self.parts: @@ -679,7 +679,7 @@ class ConcatenatedString(BaseString): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.left._codegen(state) self.whitespace_between._codegen(state) @@ -718,7 +718,7 @@ class ComparisonTarget(CSTNode): comparator=visit_required("comparator", self.comparator, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.operator._codegen(state) self.comparator._codegen(state) @@ -768,7 +768,7 @@ class Comparison(BaseExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.left._codegen(state) for comp in self.comparisons: @@ -817,7 +817,7 @@ class UnaryOperation(BaseExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.operator._codegen(state) self.expression._codegen(state) @@ -854,7 +854,7 @@ class BinaryOperation(BaseExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.left._codegen(state) self.operator._codegen(state) @@ -911,7 +911,7 @@ class BooleanOperation(BaseExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.left._codegen(state) self.operator._codegen(state) @@ -951,7 +951,7 @@ class Attribute(BaseAssignTargetExpression, BaseDelTargetExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.value._codegen(state) self.dot._codegen(state) @@ -971,7 +971,7 @@ class Index(CSTNode): def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Index": return Index(value=visit_required("value", self.value, visitor)) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.value._codegen(state) @@ -1008,7 +1008,7 @@ class Slice(CSTNode): step=visit_optional("step", self.step, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: lower = self.lower if lower is not None: lower._codegen(state) @@ -1046,7 +1046,7 @@ class ExtSlice(CSTNode): comma=visit_sentinel("comma", self.comma, visitor), ) - def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None: self.slice._codegen(state) comma = self.comma if comma is MaybeSentinel.DEFAULT and default_comma: @@ -1105,7 +1105,7 @@ class Subscript(BaseAssignTargetExpression, BaseDelTargetExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.value._codegen(state) self.whitespace_after_value._codegen(state) @@ -1160,7 +1160,7 @@ class Annotation(CSTNode): annotation=visit_required("annotation", self.annotation, visitor), ) - def _codegen( + def _codegen_impl( self, state: CodegenState, default_indicator: Optional[str] = None ) -> None: # First, figure out the indicator which tells us default whitespace. @@ -1201,7 +1201,7 @@ class ParamStar(CSTNode): def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ParamStar": return ParamStar(comma=visit_required("comma", self.comma, visitor)) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token("*") self.comma._codegen(state) @@ -1266,7 +1266,7 @@ class Param(CSTNode): ), ) - def _codegen( + def _codegen_impl( self, state: CodegenState, default_star: Optional[str] = None, @@ -1401,7 +1401,7 @@ class Parameters(CSTNode): star_kwarg=visit_optional("star_kwarg", self.star_kwarg, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: # Compute the star existence first so we can ask about whether # each element is the last in the list or not. star_arg = self.star_arg @@ -1520,7 +1520,7 @@ class Lambda(BaseExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token("lambda") whitespace_after_lambda = self.whitespace_after_lambda @@ -1593,7 +1593,7 @@ class Arg(CSTNode): ), ) - def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None: state.add_token(self.star) self.whitespace_after_star._codegen(state) keyword = self.keyword @@ -1757,7 +1757,7 @@ class Call(_BaseExpressionWithArgs): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.func._codegen(state) self.whitespace_after_func._codegen(state) @@ -1801,7 +1801,7 @@ class Await(BaseExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token("await") self.whitespace_after_await._codegen(state) @@ -1890,7 +1890,7 @@ class IfExp(BaseExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): self.body._codegen(state) self.whitespace_before_if._codegen(state) @@ -1938,7 +1938,7 @@ class From(CSTNode): ), ) - def _codegen(self, state: CodegenState, default_space: str = "") -> None: + def _codegen_impl(self, state: CodegenState, default_space: str = "") -> None: whitespace_before_from = self.whitespace_before_from if isinstance(whitespace_before_from, BaseParenthesizableWhitespace): whitespace_before_from._codegen(state) @@ -2001,7 +2001,7 @@ class Yield(BaseExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): state.add_token("yield") whitespace_after_yield = self.whitespace_after_yield @@ -2032,7 +2032,7 @@ class BaseElement(CSTNode, ABC): whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("") @abstractmethod - def _codegen( + def _codegen_impl( self, state: CodegenState, default_comma: bool = False, @@ -2061,7 +2061,7 @@ class Element(BaseElement): ), ) - def _codegen( + def _codegen_impl( self, state: CodegenState, default_comma: bool = False, @@ -2110,7 +2110,7 @@ class StarredElement(BaseElement, _BaseParenthesizedNode): comma=visit_sentinel("comma", self.comma, visitor), ) - def _codegen( + def _codegen_impl( self, state: CodegenState, default_comma: bool = False, @@ -2186,7 +2186,7 @@ class Tuple(BaseExpression): rpar=visit_sequence("rpar", self.rpar, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: with self._parenthesize(state): elements = self.elements if len(elements) == 1: diff --git a/libcst/nodes/_internal.py b/libcst/nodes/_internal.py index dc532ef6..500e7711 100644 --- a/libcst/nodes/_internal.py +++ b/libcst/nodes/_internal.py @@ -11,9 +11,11 @@ from typing import ( TYPE_CHECKING, Iterable, List, + MutableMapping, Optional, Pattern, Sequence, + Tuple, TypeVar, Union, ) @@ -35,6 +37,13 @@ _CSTNodeT = TypeVar("_CSTNodeT", bound="CSTNode") NEWLINE_RE: Pattern[str] = re.compile(r"\r\n?|\n") +@dataclass(frozen=True) +class CodePosition: + # start and end are each a tuple of (line, column) numbers + start: Tuple[int, int] + end: Tuple[int, int] + + @add_slots @dataclass(frozen=False) class CodegenState: @@ -48,6 +57,10 @@ class CodegenState: line: int = 1 # one-indexed column: int = 0 # zero-indexed + positions: MutableMapping["CSTNode", CodePosition] = field( + default_factory=lambda: {} + ) + def increase_indent(self, value: str) -> None: self.indent_tokens.append(value) @@ -76,6 +89,9 @@ class CodegenState: # newline resets column back to 0, but a trailing token may shift column self.column = len(segments[-1]) + def update_position(self, node: _CSTNodeT, position: CodePosition) -> None: + self.positions[node] = position + def visit_required(fieldname: str, node: _CSTNodeT, visitor: "CSTVisitor") -> _CSTNodeT: """ diff --git a/libcst/nodes/_module.py b/libcst/nodes/_module.py index 15163ee9..db2b6895 100644 --- a/libcst/nodes/_module.py +++ b/libcst/nodes/_module.py @@ -4,13 +4,13 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Sequence, TypeVar, Union +from typing import MutableMapping, Sequence, TypeVar, Union from libcst._add_slots import add_slots from libcst._base_visitor import CSTVisitor from libcst._removal_sentinel import RemovalSentinel from libcst.nodes._base import CSTNode -from libcst.nodes._internal import CodegenState, visit_sequence +from libcst.nodes._internal import CodegenState, CodePosition, visit_sequence from libcst.nodes._statement import BaseCompoundStatement, SimpleStatementLine from libcst.nodes._whitespace import EmptyLine @@ -22,7 +22,7 @@ builtin_bytes = bytes @add_slots -@dataclass(frozen=True) +@dataclass(frozen=False) class Module(CSTNode): """ Contains some top-level information inferred from the file letting us set correct @@ -41,6 +41,8 @@ class Module(CSTNode): default_newline: str = "\n" has_trailing_newline: bool = True + _positions: MutableMapping["CSTNode", CodePosition] = None + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Module": return Module( header=visit_sequence("header", self.header, visitor), @@ -50,6 +52,7 @@ class Module(CSTNode): default_indent=self.default_indent, default_newline=self.default_newline, has_trailing_newline=self.has_trailing_newline, + _positions=self._positions, ) def visit(self: _ModuleSelfT, visitor: CSTVisitor) -> _ModuleSelfT: @@ -59,7 +62,7 @@ class Module(CSTNode): else: # is a Module return result - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for h in self.header: h._codegen(state) for stmt in self.body: @@ -95,4 +98,5 @@ class Module(CSTNode): default_indent=self.default_indent, default_newline=self.default_newline ) node._codegen(state) + self._positions = state.positions return "".join(state.tokens) diff --git a/libcst/nodes/_op.py b/libcst/nodes/_op.py index 5b016991..5591c543 100644 --- a/libcst/nodes/_op.py +++ b/libcst/nodes/_op.py @@ -32,7 +32,7 @@ class _BaseOneTokenOp(CSTNode, ABC): ), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.whitespace_before._codegen(state) state.add_token(self._get_token()) self.whitespace_after._codegen(state) @@ -69,7 +69,7 @@ class _BaseTwoTokenOp(CSTNode, ABC): ), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.whitespace_before._codegen(state) state.add_token(self._get_tokens()[0]) self.whitespace_between._codegen(state) @@ -95,7 +95,7 @@ class BaseUnaryOp(CSTNode, ABC): ) ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token(self._get_token()) self.whitespace_after._codegen(state) @@ -204,7 +204,7 @@ class ImportStar(BaseLeaf): Used by ImportFrom to denote a star import. """ - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token("*") @@ -463,7 +463,7 @@ class NotEqual(BaseCompOp): ), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.whitespace_before._codegen(state) state.add_token(self.value) self.whitespace_after._codegen(state) diff --git a/libcst/nodes/_statement.py b/libcst/nodes/_statement.py index f23b8d2d..a9eb75ca 100644 --- a/libcst/nodes/_statement.py +++ b/libcst/nodes/_statement.py @@ -80,7 +80,9 @@ class BaseSmallStatement(CSTNode, ABC): semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT @abstractmethod - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: ... @@ -113,7 +115,9 @@ class Del(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: state.add_token("del") self.whitespace_after_del._codegen(state) self.target._codegen(state) @@ -135,7 +139,9 @@ class Pass(BaseSmallStatement): def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Pass": return Pass(semicolon=visit_sentinel("semicolon", self.semicolon, visitor)) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: state.add_token("pass") semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): @@ -155,7 +161,9 @@ class Break(BaseSmallStatement): def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Break": return Break(semicolon=visit_sentinel("semicolon", self.semicolon, visitor)) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: state.add_token("break") semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): @@ -175,7 +183,9 @@ class Continue(BaseSmallStatement): def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Continue": return Continue(semicolon=visit_sentinel("semicolon", self.semicolon, visitor)) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: state.add_token("continue") semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): @@ -219,7 +229,9 @@ class Return(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: value = self.value state.add_token("return") @@ -261,7 +273,9 @@ class Expr(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: self.value._codegen(state) semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): @@ -299,7 +313,7 @@ class _BaseSimpleStatement(CSTNode, ABC): + "on the same line." ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: body = self.body laststmt = len(body) - 1 for idx, stmt in enumerate(body): @@ -335,11 +349,11 @@ class SimpleStatementLine(_BaseSimpleStatement): ), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() - _BaseSimpleStatement._codegen(self, state) + _BaseSimpleStatement._codegen_impl(self, state) @add_slots @@ -376,9 +390,9 @@ class SimpleStatementSuite(_BaseSimpleStatement, BaseSuite): ), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.leading_whitespace._codegen(state) - _BaseSimpleStatement._codegen(self, state) + _BaseSimpleStatement._codegen_impl(self, state) @add_slots @@ -405,7 +419,7 @@ class Else(CSTNode): body=visit_required("body", self.body, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() @@ -466,7 +480,7 @@ class If(BaseCompoundStatement): orelse=visit_optional("orelse", self.orelse, visitor), ) - def _codegen(self, state: CodegenState, is_elif: bool = False) -> None: + def _codegen_impl(self, state: CodegenState, is_elif: bool = False) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() @@ -547,7 +561,7 @@ class IndentedBlock(BaseSuite): footer=visit_sequence("footer", self.footer, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.header._codegen(state) indent = self.indent @@ -598,7 +612,7 @@ class AsName(CSTNode): ), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.whitespace_before_as._codegen(state) state.add_token("as") self.whitespace_after_as._codegen(state) @@ -653,7 +667,7 @@ class ExceptHandler(CSTNode): body=visit_required("body", self.body, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() @@ -690,7 +704,7 @@ class Finally(CSTNode): body=visit_required("body", self.body, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() @@ -746,7 +760,7 @@ class Try(BaseCompoundStatement): finalbody=visit_optional("finalbody", self.finalbody, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() @@ -794,7 +808,7 @@ class ImportAlias(CSTNode): comma=visit_sentinel("comma", self.comma, visitor), ) - def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None: self.name._codegen(state) asname = self.asname if asname is not None: @@ -842,7 +856,9 @@ class Import(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: state.add_token("import") self.whitespace_after_import._codegen(state) lastname = len(self.names) - 1 @@ -947,7 +963,9 @@ class ImportFrom(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: state.add_token("from") self.whitespace_after_from._codegen(state) for dot in self.relative: @@ -1002,7 +1020,7 @@ class AssignTarget(CSTNode): ), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.target._codegen(state) self.whitespace_before_equal._codegen(state) state.add_token("=") @@ -1037,7 +1055,9 @@ class Assign(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: for target in self.targets: target._codegen(state) self.value._codegen(state) @@ -1090,7 +1110,9 @@ class AnnAssign(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: self.target._codegen(state) self.annotation._codegen(state, default_indicator=":") equal = self.equal @@ -1135,7 +1157,9 @@ class AugAssign(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: self.target._codegen(state) self.operator._codegen(state) self.value._codegen(state) @@ -1167,7 +1191,7 @@ class Asynchronous(CSTNode): ) ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token("async") self.whitespace_after._codegen(state) @@ -1218,7 +1242,7 @@ class Decorator(CSTNode): ), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() @@ -1310,7 +1334,7 @@ class FunctionDef(BaseCompoundStatement): body=visit_required("body", self.body, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) for decorator in self.decorators: @@ -1436,7 +1460,7 @@ class ClassDef(BaseCompoundStatement): body=visit_required("body", self.body, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) for decorator in self.decorators: @@ -1492,7 +1516,7 @@ class WithItem(CSTNode): comma=visit_sentinel("comma", self.comma, visitor), ) - def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None: self.item._codegen(state) asname = self.asname if asname is not None: @@ -1554,7 +1578,7 @@ class With(BaseCompoundStatement): body=visit_required("body", self.body, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() @@ -1645,7 +1669,7 @@ class For(BaseCompoundStatement): orelse=visit_optional("orelse", self.orelse, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() @@ -1711,7 +1735,7 @@ class While(BaseCompoundStatement): orelse=visit_optional("orelse", self.orelse, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() @@ -1783,7 +1807,9 @@ class Raise(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: exc = self.exc cause = self.cause @@ -1854,7 +1880,9 @@ class Assert(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: state.add_token("assert") self.whitespace_after_assert._codegen(state) self.test._codegen(state) @@ -1901,7 +1929,7 @@ class NameItem(CSTNode): comma=visit_sentinel("comma", self.comma, visitor), ) - def _codegen(self, state: CodegenState, default_comma: bool = False) -> None: + def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None: self.name._codegen(state) comma = self.comma if comma is MaybeSentinel.DEFAULT and default_comma: @@ -1949,7 +1977,9 @@ class Global(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: state.add_token("global") self.whitespace_after_global._codegen(state) last_name = len(self.names) - 1 @@ -2003,7 +2033,9 @@ class Nonlocal(BaseSmallStatement): semicolon=visit_sentinel("semicolon", self.semicolon, visitor), ) - def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None: + def _codegen_impl( + self, state: CodegenState, default_semicolon: bool = False + ) -> None: state.add_token("nonlocal") self.whitespace_after_nonlocal._codegen(state) last_name = len(self.names) - 1 diff --git a/libcst/nodes/_whitespace.py b/libcst/nodes/_whitespace.py index 7b45fa93..4246e82a 100644 --- a/libcst/nodes/_whitespace.py +++ b/libcst/nodes/_whitespace.py @@ -98,7 +98,7 @@ class Newline(BaseLeaf): f"Got an invalid value for newline node: {repr(self.value)}" ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.add_token(state.default_newline if self.value is None else self.value) @@ -146,7 +146,7 @@ class TrailingWhitespace(CSTNode): newline=visit_required("newline", self.newline, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.whitespace._codegen(state) if self.comment is not None: self.comment._codegen(state) @@ -178,7 +178,7 @@ class EmptyLine(CSTNode): newline=visit_required("newline", self.newline, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: if self.indent: state.add_indent_tokens() self.whitespace._codegen(state) @@ -214,7 +214,7 @@ class ParenthesizedWhitespace(BaseParenthesizableWhitespace): last_line=visit_required("last_line", self.last_line, visitor), ) - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: self.first_line._codegen(state) for line in self.empty_lines: line._codegen(state) diff --git a/libcst/nodes/tests/base.py b/libcst/nodes/tests/base.py index 7643fbec..d65f62e3 100644 --- a/libcst/nodes/tests/base.py +++ b/libcst/nodes/tests/base.py @@ -6,7 +6,7 @@ import dataclasses from contextlib import ExitStack from dataclasses import dataclass -from typing import Callable, Iterable, List, Optional, Sequence, Type, TypeVar +from typing import Any, Callable, Iterable, List, Optional, Sequence, Type, TypeVar from unittest.mock import patch import libcst.nodes as cst @@ -29,7 +29,9 @@ class _NOOPVisitor(CSTVisitor): pass -def _cst_node_equality_func(a: cst.CSTNode, b: cst.CSTNode, msg=None) -> None: +def _cst_node_equality_func( + a: cst.CSTNode, b: cst.CSTNode, msg: Optional[str] = None +) -> None: """ For use with addTypeEqualityFunc. """ @@ -121,8 +123,10 @@ class CSTNodeTest(UnitTest): children: List[cst.CSTNode] = [] codegen_stack: List[cst.CSTNode] = [] - def _get_codegen_override(target: _CSTCodegenPatchTarget): - def _codegen(self, *args, **kwargs) -> None: + def _get_codegen_override( + target: _CSTCodegenPatchTarget + ) -> Callable[..., None]: + def _codegen_impl(self: _CSTNodeT, *args: Any, **kwargs: Any) -> None: should_pop = False # Don't stick duplicates in the stack. This is needed so that we don't # track calls to `super()._codegen()`. @@ -138,7 +142,7 @@ class CSTNodeTest(UnitTest): if should_pop: codegen_stack.pop() - return _codegen + return _codegen_impl with ExitStack() as patch_stack: for t in patch_targets: @@ -202,7 +206,7 @@ class DummyIndentedBlock(cst.CSTNode): value: str child: cst.CSTNode - def _codegen(self, state: CodegenState) -> None: + def _codegen_impl(self, state: CodegenState) -> None: state.increase_indent(self.value) self.child._codegen(state) state.decrease_indent() diff --git a/libcst/nodes/tests/test_module.py b/libcst/nodes/tests/test_module.py index 0fd4d776..a825df3d 100644 --- a/libcst/nodes/tests/test_module.py +++ b/libcst/nodes/tests/test_module.py @@ -4,7 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +from typing import Tuple, cast + import libcst.nodes as cst +from libcst.nodes._internal import CodePosition from libcst.nodes.tests.base import CSTNodeTest from libcst.parser import parse_module from libcst.testing.utils import data_provider @@ -112,3 +115,77 @@ class ModuleTest(CSTNodeTest): ) def test_parser(self, *, code: str, expected: cst.Module) -> None: self.assertEqual(parse_module(code), expected) + + @data_provider( + { + "empty": {"code": "", "expected": CodePosition((1, 0), (1, 0))}, + "empty_with_newline": { + "code": "\n", + "expected": CodePosition((1, 0), (2, 0)), + }, + "empty_program_with_comments": { + "code": "# 2345", + "expected": CodePosition((1, 0), (2, 0)), + }, + "simple_pass": {"code": "pass\n", "expected": CodePosition((1, 0), (2, 0))}, + "simple_pass_with_header_footer": { + "code": "# header\npass # trailing\n# footer\n", + "expected": CodePosition((1, 0), (4, 0)), + }, + } + ) + def test_module_position(self, *, code: str, expected: CodePosition) -> None: + module = parse_module(code) + module.code + + self.assertEqual(module._positions[module], expected) + + def cmp_position( + self, + module: cst.Module, + node: cst.CSTNode, + start: Tuple[int, int], + end: Tuple[int, int], + ) -> None: + self.assertEqual(module._positions[node], CodePosition(start, end)) + + def test_function_position(self) -> None: + module = parse_module("def foo():\n pass") + module.code + + fn = cast(cst.FunctionDef, module.body[0]) + stmt = cast(cst.SimpleStatementLine, fn.body.body[0]) + pass_stmt = cast(cst.Pass, stmt.body[0]) + self.cmp_position(module, stmt, (2, 0), (3, 0)) + self.cmp_position(module, pass_stmt, (2, 4), (2, 8)) + + def test_nested_indent_position(self) -> None: + module = parse_module( + "if True:\n if False:\n x = 1\nelse:\n return" + ) + module.code + + outer_if = cast(cst.If, module.body[0]) + inner_if = cast(cst.If, outer_if.body.body[0]) + assign = cast(cst.SimpleStatementLine, inner_if.body.body[0]).body[0] + + outer_else = cast(cst.Else, outer_if.orelse) + return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0] + + self.cmp_position(module, outer_if, (1, 0), (6, 0)) + self.cmp_position(module, inner_if, (2, 0), (4, 0)) + self.cmp_position(module, assign, (3, 8), (3, 13)) + self.cmp_position(module, outer_else, (4, 0), (6, 0)) + self.cmp_position(module, return_stmt, (5, 4), (5, 10)) + + def test_multiline_string_position(self) -> None: + module = parse_module('"abc"\\\n"def"') + module.code + + stmt = cast(cst.SimpleStatementLine, module.body[0]) + expr = cast(cst.Expr, stmt.body[0]) + string = expr.value + + self.cmp_position(module, stmt, (1, 0), (3, 0)) + self.cmp_position(module, expr, (1, 0), (2, 5)) + self.cmp_position(module, string, (1, 0), (2, 5)) diff --git a/libcst/tool.py b/libcst/tool.py index 224b18ae..b12b17ab 100644 --- a/libcst/tool.py +++ b/libcst/tool.py @@ -34,6 +34,9 @@ def _node_repr_recursive( tokens: List[str] = [node.__class__.__name__] fields: Sequence[dataclasses.Field] = dataclasses.fields(node) + # Hide all fields prefixed with "_" + fields = [f for f in fields if f.name[0] != "_"] + # Filter whitespace nodes if needed if not show_whitespace: