Define a context manager that can be used to track "semantic" start and end positions for nodes

Created a context manager `record_semantic_position()` in `CodegenState`
to be used for tracking semantic positions of nodes in the CodegenState.
This commit is contained in:
Ray Zeng 2019-06-27 17:33:43 -07:00 committed by Benjamin Woodruff
parent d3544824fc
commit f2f2510297
3 changed files with 39 additions and 5 deletions

View file

@ -188,9 +188,9 @@ class CSTNode(ABC):
...
def _codegen(self, state: CodegenState, **kwargs: Any) -> None:
start = state.line, state.column
start = (state.line, state.column)
self._codegen_impl(state, **kwargs)
end = state.line, state.column
end = (state.line, state.column)
state.update_position(self, CodePosition(start, end))
def with_changes(self: _CSTNodeSelfT, **changes: Any) -> _CSTNodeSelfT:

View file

@ -6,10 +6,12 @@
# pyre-strict
import re
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Iterable,
Iterator,
List,
MutableMapping,
Optional,
@ -57,8 +59,10 @@ class CodegenState:
line: int = 1 # one-indexed
column: int = 0 # zero-indexed
positions: MutableMapping["CSTNode", CodePosition] = field(
default_factory=lambda: {}
positions: MutableMapping["CSTNode", CodePosition] = field(default_factory=dict)
semantic_positions: MutableMapping["CSTNode", CodePosition] = field(
default_factory=dict
)
def increase_indent(self, value: str) -> None:
@ -92,6 +96,15 @@ class CodegenState:
def update_position(self, node: _CSTNodeT, position: CodePosition) -> None:
self.positions[node] = position
@contextmanager
def record_semantic_position(self, node: _CSTNodeT) -> Iterator[None]:
start = (self.line, self.column)
try:
yield
finally:
end = (self.line, self.column)
self.semantic_positions[node] = CodePosition(start, end)
def visit_required(fieldname: str, node: _CSTNodeT, visitor: "CSTVisitor") -> _CSTNodeT:
"""

View file

@ -6,7 +6,8 @@
# pyre-strict
from typing import Tuple
from libcst.nodes._internal import CodegenState
import libcst.nodes as cst
from libcst.nodes._internal import CodegenState, CodePosition
from libcst.testing.utils import UnitTest
@ -48,3 +49,23 @@ class InternalTest(UnitTest):
state.decrease_indent()
state.add_indent_tokens()
self.assertEqual(position(state), (1, 8))
def test_context_manager(self) -> None:
# create a dummy node
node = cst.Pass()
# simulate codegen behavior for the dummy node
# generates the code " pass "
state = CodegenState(" " * 4, "\n")
start = (state.line, state.column)
state.add_token(" ")
with state.record_semantic_position(node):
state.add_token("pass")
state.add_token(" ")
end = (state.line, state.column)
state.update_position(node, CodePosition(start, end))
# check syntactic whitespace is correctly recorded (includes whitespace)
self.assertEqual(state.positions[node], CodePosition((1, 0), (1, 6)))
# check semantic whitespace is correctly recorded (ignoring whitespace)
self.assertEqual(state.semantic_positions[node], CodePosition((1, 1), (1, 5)))