Make the codegen enter/leave tracking more generic

I need to do some additional work on visit/leave to make codegen
re-entrant, so this makes it more generic.

This should have an additional small positive effect of creating less
throwaway objects when we're doing codegen without position calculation.
This commit is contained in:
Benjamin Woodruff 2019-10-22 15:39:40 -07:00 committed by jimmylai
parent 522f7c9a67
commit 3bfbb4b2dd
4 changed files with 23 additions and 17 deletions

View file

@ -19,7 +19,6 @@ from typing import (
)
from libcst._nodes.internal import CodegenState
from libcst._position import CodePosition, CodeRange
from libcst._removal_sentinel import RemovalSentinel
from libcst._type_enforce import is_value_of_type
from libcst._types import CSTNodeT
@ -310,10 +309,9 @@ class CSTNode(ABC):
...
def _codegen(self, state: CodegenState, **kwargs: Any) -> None:
start = CodePosition(state.line, state.column)
state.before_visit(self)
self._codegen_impl(state, **kwargs)
end = CodePosition(state.line, state.column)
state.record_position(self, CodeRange(start, end))
state.after_leave(self)
def with_changes(self: _CSTNodeSelfT, **changes: Any) -> _CSTNodeSelfT:
"""

View file

@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Iterable, Iterator, List, Optional, Sequence,
from libcst._add_slots import add_slots
from libcst._maybe_sentinel import MaybeSentinel
from libcst._position import CodeRange
from libcst._removal_sentinel import RemovalSentinel
from libcst._types import CSTNodeT
@ -48,7 +47,10 @@ class CodegenState:
def add_token(self, value: str) -> None:
self.tokens.append(value)
def record_position(self, node: "CSTNode", position: CodeRange) -> None:
def before_visit(self, node: "CSTNode") -> None:
pass
def after_leave(self, node: "CSTNode") -> None:
pass
@contextmanager

View file

@ -7,8 +7,8 @@
import re
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Iterator, Optional, Pattern
from dataclasses import dataclass, field
from typing import Iterator, List, Optional, Pattern
from libcst._add_slots import add_slots
from libcst._nodes.base import CSTNode
@ -25,6 +25,7 @@ NEWLINE_RE: Pattern[str] = re.compile(r"\r\n?|\n")
@dataclass(frozen=False)
class WhitespaceInclusivePositionProvidingCodegenState(CodegenState):
provider: BaseMetadataProvider[CodeRange]
_stack: List[CodePosition] = field(init=False, default_factory=list)
def add_indent_tokens(self) -> None:
self.tokens.extend(self.indent_tokens)
@ -48,11 +49,19 @@ class WhitespaceInclusivePositionProvidingCodegenState(CodegenState):
# newline resets column back to 0, but a trailing token may shift column
self.column = len(segments[-1])
def record_position(self, node: CSTNode, position: CodeRange) -> None:
def before_visit(self, node: "CSTNode") -> None:
self._stack.append(CodePosition(self.line, self.column))
def after_leave(self, node: "CSTNode") -> None:
# we must unconditionally pop the stack, else we could end up in a broken state
start_pos = self._stack.pop()
# Don't overwrite existing position information
# (i.e. semantic position has already been recorded)
if node not in self.provider._computed:
self.provider._computed[node] = position
end_pos = CodePosition(self.line, self.column)
node_range = CodeRange(start_pos, end_pos)
self.provider._computed[node] = node_range
class WhitespaceInclusivePositionProvider(BaseMetadataProvider[CodeRange]):

View file

@ -13,7 +13,6 @@ from libcst._batched_visitor import BatchableCSTVisitor
from libcst._nodes.internal import CodegenState
from libcst._visitors import CSTTransformer
from libcst.metadata import (
CodePosition,
CodeRange,
MetadataWrapper,
PositionProvider,
@ -121,13 +120,12 @@ class PositionProvidingCodegenStateTest(UnitTest):
state = WhitespaceInclusivePositionProvidingCodegenState(
" " * 4, "\n", WhitespaceInclusivePositionProvider()
)
start = CodePosition(state.line, state.column)
state.before_visit(node)
state.add_token(" ")
with state.record_syntactic_position(node):
state.add_token("pass")
state.add_token(" ")
end = CodePosition(state.line, state.column)
state.record_position(node, CodeRange(start, end))
state.after_leave(node)
# check whitespace is correctly recorded
self.assertEqual(state.provider._computed[node], CodeRange((1, 0), (1, 6)))
@ -139,13 +137,12 @@ class PositionProvidingCodegenStateTest(UnitTest):
# simulate codegen behavior for the dummy node
# generates the code " pass "
state = PositionProvidingCodegenState(" " * 4, "\n", PositionProvider())
start = CodePosition(state.line, state.column)
state.before_visit(node)
state.add_token(" ")
with state.record_syntactic_position(node):
state.add_token("pass")
state.add_token(" ")
end = CodePosition(state.line, state.column)
state.record_position(node, CodeRange(start, end))
state.after_leave(node)
# check syntactic position ignores whitespace
self.assertEqual(state.provider._computed[node], CodeRange((1, 1), (1, 5)))