Wrap _codegen methods in a helper function to track where nodes start and end.

Converts `_codegen` methods into `_codegen_impl` to wrap implementations
to calls to update the position of each node in the `CodegenState`. The
stored position is the syntactic position of a node (that includes any
whitespace attached to that particular node).

Also updates implementation of tool and `CSTNode.__repr__` to not print
fields of `CSTNode` objects prefixed with "_".
This commit is contained in:
Ray Zeng 2019-06-27 17:19:03 -07:00 committed by Benjamin Woodruff
parent 65cea1ce21
commit d3544824fc
11 changed files with 247 additions and 104 deletions

View file

@ -10,7 +10,7 @@ from typing import Any, List, Sequence, TypeVar, Union, cast
from libcst._base_visitor import CSTVisitor
from libcst._removal_sentinel import RemovalSentinel
from libcst.nodes._internal import CodegenState
from libcst.nodes._internal import CodegenState, CodePosition
_CSTNodeSelfT = TypeVar("_CSTNodeSelfT", bound="CSTNode")
@ -184,9 +184,15 @@ class CSTNode(ABC):
...
@abstractmethod
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
...
def _codegen(self, state: CodegenState, **kwargs: Any) -> None:
start = state.line, state.column
self._codegen_impl(state, **kwargs)
end = state.line, state.column
state.update_position(self, CodePosition(start, end))
def with_changes(self: _CSTNodeSelfT, **changes: Any) -> _CSTNodeSelfT:
"""
A convenience method for performing mutation-like operations on immutable nodes.
@ -246,8 +252,9 @@ class CSTNode(ABC):
lines.append(f"{type(self).__name__}(")
for field in fields(self):
key = field.name
value = getattr(self, key)
lines.append(_indent(f"{key}={_pretty_repr(value)},"))
if key[0] != "_":
value = getattr(self, key)
lines.append(_indent(f"{key}={_pretty_repr(value)},"))
lines.append(")")
return "\n".join(lines)
@ -274,7 +281,7 @@ class BaseValueToken(BaseLeaf, ABC):
value: str
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token(self.value)

View file

@ -57,7 +57,7 @@ class DummyNode(CSTNode):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState, **kwargs: Any) -> None:
def _codegen_impl(self, state: CodegenState, **kwargs: Any) -> None:
for lpar in self.lpar:
lpar._codegen(state)
for child in self.children:

View file

@ -72,7 +72,7 @@ class LeftSquareBracket(CSTNode):
)
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token("[")
self.whitespace_after._codegen(state)
@ -94,7 +94,7 @@ class RightSquareBracket(CSTNode):
)
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.whitespace_before._codegen(state)
state.add_token("]")
@ -116,7 +116,7 @@ class LeftParen(CSTNode):
)
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token("(")
self.whitespace_after._codegen(state)
@ -138,7 +138,7 @@ class RightParen(CSTNode):
)
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.whitespace_before._codegen(state)
state.add_token(")")
@ -262,7 +262,7 @@ class Name(BaseAssignTargetExpression, BaseDelTargetExpression, BaseAtom):
if not self.value.isidentifier():
raise CSTValidationError("Name is not a valid identifier.")
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token(self.value)
@ -286,7 +286,7 @@ class Ellipses(BaseAtom):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token("...")
@ -314,7 +314,7 @@ class Integer(_BaseParenthesizedNode):
if not re.fullmatch(INTNUMBER_RE, self.value):
raise CSTValidationError("Number is not a valid integer.")
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token(self.value)
@ -342,7 +342,7 @@ class Float(_BaseParenthesizedNode):
if not re.fullmatch(FLOATNUMBER_RE, self.value):
raise CSTValidationError("Number is not a valid float.")
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token(self.value)
@ -370,7 +370,7 @@ class Imaginary(_BaseParenthesizedNode):
if not re.fullmatch(IMAGNUMBER_RE, self.value):
raise CSTValidationError("Number is not a valid imaginary.")
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token(self.value)
@ -408,7 +408,7 @@ class Number(BaseAtom):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
operator = self.operator
if operator is not None:
@ -483,7 +483,7 @@ class SimpleString(BaseString):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token(self.value)
@ -505,7 +505,7 @@ class FormattedStringText(BaseFormattedStringContent):
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "FormattedStringText":
return FormattedStringText(value=self.value)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token(self.value)
@ -551,7 +551,7 @@ class FormattedStringExpression(BaseFormattedStringContent):
),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token("{")
self.whitespace_before_expression._codegen(state)
self.expression._codegen(state)
@ -620,7 +620,7 @@ class FormattedString(BaseString):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token(self.start)
for part in self.parts:
@ -679,7 +679,7 @@ class ConcatenatedString(BaseString):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.left._codegen(state)
self.whitespace_between._codegen(state)
@ -718,7 +718,7 @@ class ComparisonTarget(CSTNode):
comparator=visit_required("comparator", self.comparator, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.operator._codegen(state)
self.comparator._codegen(state)
@ -768,7 +768,7 @@ class Comparison(BaseExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.left._codegen(state)
for comp in self.comparisons:
@ -817,7 +817,7 @@ class UnaryOperation(BaseExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.operator._codegen(state)
self.expression._codegen(state)
@ -854,7 +854,7 @@ class BinaryOperation(BaseExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.left._codegen(state)
self.operator._codegen(state)
@ -911,7 +911,7 @@ class BooleanOperation(BaseExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.left._codegen(state)
self.operator._codegen(state)
@ -951,7 +951,7 @@ class Attribute(BaseAssignTargetExpression, BaseDelTargetExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.value._codegen(state)
self.dot._codegen(state)
@ -971,7 +971,7 @@ class Index(CSTNode):
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Index":
return Index(value=visit_required("value", self.value, visitor))
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.value._codegen(state)
@ -1008,7 +1008,7 @@ class Slice(CSTNode):
step=visit_optional("step", self.step, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
lower = self.lower
if lower is not None:
lower._codegen(state)
@ -1046,7 +1046,7 @@ class ExtSlice(CSTNode):
comma=visit_sentinel("comma", self.comma, visitor),
)
def _codegen(self, state: CodegenState, default_comma: bool = False) -> None:
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
self.slice._codegen(state)
comma = self.comma
if comma is MaybeSentinel.DEFAULT and default_comma:
@ -1105,7 +1105,7 @@ class Subscript(BaseAssignTargetExpression, BaseDelTargetExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.value._codegen(state)
self.whitespace_after_value._codegen(state)
@ -1160,7 +1160,7 @@ class Annotation(CSTNode):
annotation=visit_required("annotation", self.annotation, visitor),
)
def _codegen(
def _codegen_impl(
self, state: CodegenState, default_indicator: Optional[str] = None
) -> None:
# First, figure out the indicator which tells us default whitespace.
@ -1201,7 +1201,7 @@ class ParamStar(CSTNode):
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "ParamStar":
return ParamStar(comma=visit_required("comma", self.comma, visitor))
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token("*")
self.comma._codegen(state)
@ -1266,7 +1266,7 @@ class Param(CSTNode):
),
)
def _codegen(
def _codegen_impl(
self,
state: CodegenState,
default_star: Optional[str] = None,
@ -1401,7 +1401,7 @@ class Parameters(CSTNode):
star_kwarg=visit_optional("star_kwarg", self.star_kwarg, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
# Compute the star existence first so we can ask about whether
# each element is the last in the list or not.
star_arg = self.star_arg
@ -1520,7 +1520,7 @@ class Lambda(BaseExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token("lambda")
whitespace_after_lambda = self.whitespace_after_lambda
@ -1593,7 +1593,7 @@ class Arg(CSTNode):
),
)
def _codegen(self, state: CodegenState, default_comma: bool = False) -> None:
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
state.add_token(self.star)
self.whitespace_after_star._codegen(state)
keyword = self.keyword
@ -1757,7 +1757,7 @@ class Call(_BaseExpressionWithArgs):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.func._codegen(state)
self.whitespace_after_func._codegen(state)
@ -1801,7 +1801,7 @@ class Await(BaseExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token("await")
self.whitespace_after_await._codegen(state)
@ -1890,7 +1890,7 @@ class IfExp(BaseExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.body._codegen(state)
self.whitespace_before_if._codegen(state)
@ -1938,7 +1938,7 @@ class From(CSTNode):
),
)
def _codegen(self, state: CodegenState, default_space: str = "") -> None:
def _codegen_impl(self, state: CodegenState, default_space: str = "") -> None:
whitespace_before_from = self.whitespace_before_from
if isinstance(whitespace_before_from, BaseParenthesizableWhitespace):
whitespace_before_from._codegen(state)
@ -2001,7 +2001,7 @@ class Yield(BaseExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token("yield")
whitespace_after_yield = self.whitespace_after_yield
@ -2032,7 +2032,7 @@ class BaseElement(CSTNode, ABC):
whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace("")
@abstractmethod
def _codegen(
def _codegen_impl(
self,
state: CodegenState,
default_comma: bool = False,
@ -2061,7 +2061,7 @@ class Element(BaseElement):
),
)
def _codegen(
def _codegen_impl(
self,
state: CodegenState,
default_comma: bool = False,
@ -2110,7 +2110,7 @@ class StarredElement(BaseElement, _BaseParenthesizedNode):
comma=visit_sentinel("comma", self.comma, visitor),
)
def _codegen(
def _codegen_impl(
self,
state: CodegenState,
default_comma: bool = False,
@ -2186,7 +2186,7 @@ class Tuple(BaseExpression):
rpar=visit_sequence("rpar", self.rpar, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
elements = self.elements
if len(elements) == 1:

View file

@ -11,9 +11,11 @@ from typing import (
TYPE_CHECKING,
Iterable,
List,
MutableMapping,
Optional,
Pattern,
Sequence,
Tuple,
TypeVar,
Union,
)
@ -35,6 +37,13 @@ _CSTNodeT = TypeVar("_CSTNodeT", bound="CSTNode")
NEWLINE_RE: Pattern[str] = re.compile(r"\r\n?|\n")
@dataclass(frozen=True)
class CodePosition:
# start and end are each a tuple of (line, column) numbers
start: Tuple[int, int]
end: Tuple[int, int]
@add_slots
@dataclass(frozen=False)
class CodegenState:
@ -48,6 +57,10 @@ class CodegenState:
line: int = 1 # one-indexed
column: int = 0 # zero-indexed
positions: MutableMapping["CSTNode", CodePosition] = field(
default_factory=lambda: {}
)
def increase_indent(self, value: str) -> None:
self.indent_tokens.append(value)
@ -76,6 +89,9 @@ class CodegenState:
# newline resets column back to 0, but a trailing token may shift column
self.column = len(segments[-1])
def update_position(self, node: _CSTNodeT, position: CodePosition) -> None:
self.positions[node] = position
def visit_required(fieldname: str, node: _CSTNodeT, visitor: "CSTVisitor") -> _CSTNodeT:
"""

View file

@ -4,13 +4,13 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Sequence, TypeVar, Union
from typing import MutableMapping, Sequence, TypeVar, Union
from libcst._add_slots import add_slots
from libcst._base_visitor import CSTVisitor
from libcst._removal_sentinel import RemovalSentinel
from libcst.nodes._base import CSTNode
from libcst.nodes._internal import CodegenState, visit_sequence
from libcst.nodes._internal import CodegenState, CodePosition, visit_sequence
from libcst.nodes._statement import BaseCompoundStatement, SimpleStatementLine
from libcst.nodes._whitespace import EmptyLine
@ -22,7 +22,7 @@ builtin_bytes = bytes
@add_slots
@dataclass(frozen=True)
@dataclass(frozen=False)
class Module(CSTNode):
"""
Contains some top-level information inferred from the file letting us set correct
@ -41,6 +41,8 @@ class Module(CSTNode):
default_newline: str = "\n"
has_trailing_newline: bool = True
_positions: MutableMapping["CSTNode", CodePosition] = None
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Module":
return Module(
header=visit_sequence("header", self.header, visitor),
@ -50,6 +52,7 @@ class Module(CSTNode):
default_indent=self.default_indent,
default_newline=self.default_newline,
has_trailing_newline=self.has_trailing_newline,
_positions=self._positions,
)
def visit(self: _ModuleSelfT, visitor: CSTVisitor) -> _ModuleSelfT:
@ -59,7 +62,7 @@ class Module(CSTNode):
else: # is a Module
return result
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for h in self.header:
h._codegen(state)
for stmt in self.body:
@ -95,4 +98,5 @@ class Module(CSTNode):
default_indent=self.default_indent, default_newline=self.default_newline
)
node._codegen(state)
self._positions = state.positions
return "".join(state.tokens)

View file

@ -32,7 +32,7 @@ class _BaseOneTokenOp(CSTNode, ABC):
),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.whitespace_before._codegen(state)
state.add_token(self._get_token())
self.whitespace_after._codegen(state)
@ -69,7 +69,7 @@ class _BaseTwoTokenOp(CSTNode, ABC):
),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.whitespace_before._codegen(state)
state.add_token(self._get_tokens()[0])
self.whitespace_between._codegen(state)
@ -95,7 +95,7 @@ class BaseUnaryOp(CSTNode, ABC):
)
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token(self._get_token())
self.whitespace_after._codegen(state)
@ -204,7 +204,7 @@ class ImportStar(BaseLeaf):
Used by ImportFrom to denote a star import.
"""
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token("*")
@ -463,7 +463,7 @@ class NotEqual(BaseCompOp):
),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.whitespace_before._codegen(state)
state.add_token(self.value)
self.whitespace_after._codegen(state)

View file

@ -80,7 +80,9 @@ class BaseSmallStatement(CSTNode, ABC):
semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT
@abstractmethod
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
...
@ -113,7 +115,9 @@ class Del(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("del")
self.whitespace_after_del._codegen(state)
self.target._codegen(state)
@ -135,7 +139,9 @@ class Pass(BaseSmallStatement):
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Pass":
return Pass(semicolon=visit_sentinel("semicolon", self.semicolon, visitor))
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("pass")
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
@ -155,7 +161,9 @@ class Break(BaseSmallStatement):
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Break":
return Break(semicolon=visit_sentinel("semicolon", self.semicolon, visitor))
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("break")
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
@ -175,7 +183,9 @@ class Continue(BaseSmallStatement):
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Continue":
return Continue(semicolon=visit_sentinel("semicolon", self.semicolon, visitor))
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("continue")
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
@ -219,7 +229,9 @@ class Return(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
value = self.value
state.add_token("return")
@ -261,7 +273,9 @@ class Expr(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
self.value._codegen(state)
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
@ -299,7 +313,7 @@ class _BaseSimpleStatement(CSTNode, ABC):
+ "on the same line."
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
body = self.body
laststmt = len(body) - 1
for idx, stmt in enumerate(body):
@ -335,11 +349,11 @@ class SimpleStatementLine(_BaseSimpleStatement):
),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
_BaseSimpleStatement._codegen(self, state)
_BaseSimpleStatement._codegen_impl(self, state)
@add_slots
@ -376,9 +390,9 @@ class SimpleStatementSuite(_BaseSimpleStatement, BaseSuite):
),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.leading_whitespace._codegen(state)
_BaseSimpleStatement._codegen(self, state)
_BaseSimpleStatement._codegen_impl(self, state)
@add_slots
@ -405,7 +419,7 @@ class Else(CSTNode):
body=visit_required("body", self.body, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
@ -466,7 +480,7 @@ class If(BaseCompoundStatement):
orelse=visit_optional("orelse", self.orelse, visitor),
)
def _codegen(self, state: CodegenState, is_elif: bool = False) -> None:
def _codegen_impl(self, state: CodegenState, is_elif: bool = False) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
@ -547,7 +561,7 @@ class IndentedBlock(BaseSuite):
footer=visit_sequence("footer", self.footer, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.header._codegen(state)
indent = self.indent
@ -598,7 +612,7 @@ class AsName(CSTNode):
),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.whitespace_before_as._codegen(state)
state.add_token("as")
self.whitespace_after_as._codegen(state)
@ -653,7 +667,7 @@ class ExceptHandler(CSTNode):
body=visit_required("body", self.body, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
@ -690,7 +704,7 @@ class Finally(CSTNode):
body=visit_required("body", self.body, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
@ -746,7 +760,7 @@ class Try(BaseCompoundStatement):
finalbody=visit_optional("finalbody", self.finalbody, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
@ -794,7 +808,7 @@ class ImportAlias(CSTNode):
comma=visit_sentinel("comma", self.comma, visitor),
)
def _codegen(self, state: CodegenState, default_comma: bool = False) -> None:
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
self.name._codegen(state)
asname = self.asname
if asname is not None:
@ -842,7 +856,9 @@ class Import(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("import")
self.whitespace_after_import._codegen(state)
lastname = len(self.names) - 1
@ -947,7 +963,9 @@ class ImportFrom(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("from")
self.whitespace_after_from._codegen(state)
for dot in self.relative:
@ -1002,7 +1020,7 @@ class AssignTarget(CSTNode):
),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.target._codegen(state)
self.whitespace_before_equal._codegen(state)
state.add_token("=")
@ -1037,7 +1055,9 @@ class Assign(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
for target in self.targets:
target._codegen(state)
self.value._codegen(state)
@ -1090,7 +1110,9 @@ class AnnAssign(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
self.target._codegen(state)
self.annotation._codegen(state, default_indicator=":")
equal = self.equal
@ -1135,7 +1157,9 @@ class AugAssign(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
self.target._codegen(state)
self.operator._codegen(state)
self.value._codegen(state)
@ -1167,7 +1191,7 @@ class Asynchronous(CSTNode):
)
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token("async")
self.whitespace_after._codegen(state)
@ -1218,7 +1242,7 @@ class Decorator(CSTNode):
),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
@ -1310,7 +1334,7 @@ class FunctionDef(BaseCompoundStatement):
body=visit_required("body", self.body, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
for decorator in self.decorators:
@ -1436,7 +1460,7 @@ class ClassDef(BaseCompoundStatement):
body=visit_required("body", self.body, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
for decorator in self.decorators:
@ -1492,7 +1516,7 @@ class WithItem(CSTNode):
comma=visit_sentinel("comma", self.comma, visitor),
)
def _codegen(self, state: CodegenState, default_comma: bool = False) -> None:
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
self.item._codegen(state)
asname = self.asname
if asname is not None:
@ -1554,7 +1578,7 @@ class With(BaseCompoundStatement):
body=visit_required("body", self.body, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
@ -1645,7 +1669,7 @@ class For(BaseCompoundStatement):
orelse=visit_optional("orelse", self.orelse, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
@ -1711,7 +1735,7 @@ class While(BaseCompoundStatement):
orelse=visit_optional("orelse", self.orelse, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
@ -1783,7 +1807,9 @@ class Raise(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
exc = self.exc
cause = self.cause
@ -1854,7 +1880,9 @@ class Assert(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("assert")
self.whitespace_after_assert._codegen(state)
self.test._codegen(state)
@ -1901,7 +1929,7 @@ class NameItem(CSTNode):
comma=visit_sentinel("comma", self.comma, visitor),
)
def _codegen(self, state: CodegenState, default_comma: bool = False) -> None:
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
self.name._codegen(state)
comma = self.comma
if comma is MaybeSentinel.DEFAULT and default_comma:
@ -1949,7 +1977,9 @@ class Global(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("global")
self.whitespace_after_global._codegen(state)
last_name = len(self.names) - 1
@ -2003,7 +2033,9 @@ class Nonlocal(BaseSmallStatement):
semicolon=visit_sentinel("semicolon", self.semicolon, visitor),
)
def _codegen(self, state: CodegenState, default_semicolon: bool = False) -> None:
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("nonlocal")
self.whitespace_after_nonlocal._codegen(state)
last_name = len(self.names) - 1

View file

@ -98,7 +98,7 @@ class Newline(BaseLeaf):
f"Got an invalid value for newline node: {repr(self.value)}"
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token(state.default_newline if self.value is None else self.value)
@ -146,7 +146,7 @@ class TrailingWhitespace(CSTNode):
newline=visit_required("newline", self.newline, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.whitespace._codegen(state)
if self.comment is not None:
self.comment._codegen(state)
@ -178,7 +178,7 @@ class EmptyLine(CSTNode):
newline=visit_required("newline", self.newline, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
if self.indent:
state.add_indent_tokens()
self.whitespace._codegen(state)
@ -214,7 +214,7 @@ class ParenthesizedWhitespace(BaseParenthesizableWhitespace):
last_line=visit_required("last_line", self.last_line, visitor),
)
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
self.first_line._codegen(state)
for line in self.empty_lines:
line._codegen(state)

View file

@ -6,7 +6,7 @@
import dataclasses
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Callable, Iterable, List, Optional, Sequence, Type, TypeVar
from typing import Any, Callable, Iterable, List, Optional, Sequence, Type, TypeVar
from unittest.mock import patch
import libcst.nodes as cst
@ -29,7 +29,9 @@ class _NOOPVisitor(CSTVisitor):
pass
def _cst_node_equality_func(a: cst.CSTNode, b: cst.CSTNode, msg=None) -> None:
def _cst_node_equality_func(
a: cst.CSTNode, b: cst.CSTNode, msg: Optional[str] = None
) -> None:
"""
For use with addTypeEqualityFunc.
"""
@ -121,8 +123,10 @@ class CSTNodeTest(UnitTest):
children: List[cst.CSTNode] = []
codegen_stack: List[cst.CSTNode] = []
def _get_codegen_override(target: _CSTCodegenPatchTarget):
def _codegen(self, *args, **kwargs) -> None:
def _get_codegen_override(
target: _CSTCodegenPatchTarget
) -> Callable[..., None]:
def _codegen_impl(self: _CSTNodeT, *args: Any, **kwargs: Any) -> None:
should_pop = False
# Don't stick duplicates in the stack. This is needed so that we don't
# track calls to `super()._codegen()`.
@ -138,7 +142,7 @@ class CSTNodeTest(UnitTest):
if should_pop:
codegen_stack.pop()
return _codegen
return _codegen_impl
with ExitStack() as patch_stack:
for t in patch_targets:
@ -202,7 +206,7 @@ class DummyIndentedBlock(cst.CSTNode):
value: str
child: cst.CSTNode
def _codegen(self, state: CodegenState) -> None:
def _codegen_impl(self, state: CodegenState) -> None:
state.increase_indent(self.value)
self.child._codegen(state)
state.decrease_indent()

View file

@ -4,7 +4,10 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Tuple, cast
import libcst.nodes as cst
from libcst.nodes._internal import CodePosition
from libcst.nodes.tests.base import CSTNodeTest
from libcst.parser import parse_module
from libcst.testing.utils import data_provider
@ -112,3 +115,77 @@ class ModuleTest(CSTNodeTest):
)
def test_parser(self, *, code: str, expected: cst.Module) -> None:
self.assertEqual(parse_module(code), expected)
@data_provider(
{
"empty": {"code": "", "expected": CodePosition((1, 0), (1, 0))},
"empty_with_newline": {
"code": "\n",
"expected": CodePosition((1, 0), (2, 0)),
},
"empty_program_with_comments": {
"code": "# 2345",
"expected": CodePosition((1, 0), (2, 0)),
},
"simple_pass": {"code": "pass\n", "expected": CodePosition((1, 0), (2, 0))},
"simple_pass_with_header_footer": {
"code": "# header\npass # trailing\n# footer\n",
"expected": CodePosition((1, 0), (4, 0)),
},
}
)
def test_module_position(self, *, code: str, expected: CodePosition) -> None:
module = parse_module(code)
module.code
self.assertEqual(module._positions[module], expected)
def cmp_position(
self,
module: cst.Module,
node: cst.CSTNode,
start: Tuple[int, int],
end: Tuple[int, int],
) -> None:
self.assertEqual(module._positions[node], CodePosition(start, end))
def test_function_position(self) -> None:
module = parse_module("def foo():\n pass")
module.code
fn = cast(cst.FunctionDef, module.body[0])
stmt = cast(cst.SimpleStatementLine, fn.body.body[0])
pass_stmt = cast(cst.Pass, stmt.body[0])
self.cmp_position(module, stmt, (2, 0), (3, 0))
self.cmp_position(module, pass_stmt, (2, 4), (2, 8))
def test_nested_indent_position(self) -> None:
module = parse_module(
"if True:\n if False:\n x = 1\nelse:\n return"
)
module.code
outer_if = cast(cst.If, module.body[0])
inner_if = cast(cst.If, outer_if.body.body[0])
assign = cast(cst.SimpleStatementLine, inner_if.body.body[0]).body[0]
outer_else = cast(cst.Else, outer_if.orelse)
return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0]
self.cmp_position(module, outer_if, (1, 0), (6, 0))
self.cmp_position(module, inner_if, (2, 0), (4, 0))
self.cmp_position(module, assign, (3, 8), (3, 13))
self.cmp_position(module, outer_else, (4, 0), (6, 0))
self.cmp_position(module, return_stmt, (5, 4), (5, 10))
def test_multiline_string_position(self) -> None:
module = parse_module('"abc"\\\n"def"')
module.code
stmt = cast(cst.SimpleStatementLine, module.body[0])
expr = cast(cst.Expr, stmt.body[0])
string = expr.value
self.cmp_position(module, stmt, (1, 0), (3, 0))
self.cmp_position(module, expr, (1, 0), (2, 5))
self.cmp_position(module, string, (1, 0), (2, 5))

View file

@ -34,6 +34,9 @@ def _node_repr_recursive(
tokens: List[str] = [node.__class__.__name__]
fields: Sequence[dataclasses.Field] = dataclasses.fields(node)
# Hide all fields prefixed with "_"
fields = [f for f in fields if f.name[0] != "_"]
# Filter whitespace nodes if needed
if not show_whitespace: