diff --git a/libcst/_nodes/base.py b/libcst/_nodes/base.py index d1a3ed7b..421a4281 100644 --- a/libcst/_nodes/base.py +++ b/libcst/_nodes/base.py @@ -19,7 +19,6 @@ from typing import ( ) from libcst._nodes.internal import CodegenState -from libcst._position import CodePosition, CodeRange from libcst._removal_sentinel import RemovalSentinel from libcst._type_enforce import is_value_of_type from libcst._types import CSTNodeT @@ -310,10 +309,9 @@ class CSTNode(ABC): ... def _codegen(self, state: CodegenState, **kwargs: Any) -> None: - start = CodePosition(state.line, state.column) + state.before_visit(self) self._codegen_impl(state, **kwargs) - end = CodePosition(state.line, state.column) - state.record_position(self, CodeRange(start, end)) + state.after_leave(self) def with_changes(self: _CSTNodeSelfT, **changes: Any) -> _CSTNodeSelfT: """ diff --git a/libcst/_nodes/internal.py b/libcst/_nodes/internal.py index 4ded3541..a311e80a 100644 --- a/libcst/_nodes/internal.py +++ b/libcst/_nodes/internal.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Iterable, Iterator, List, Optional, Sequence, from libcst._add_slots import add_slots from libcst._maybe_sentinel import MaybeSentinel -from libcst._position import CodeRange from libcst._removal_sentinel import RemovalSentinel from libcst._types import CSTNodeT @@ -48,7 +47,10 @@ class CodegenState: def add_token(self, value: str) -> None: self.tokens.append(value) - def record_position(self, node: "CSTNode", position: CodeRange) -> None: + def before_visit(self, node: "CSTNode") -> None: + pass + + def after_leave(self, node: "CSTNode") -> None: pass @contextmanager diff --git a/libcst/metadata/position_provider.py b/libcst/metadata/position_provider.py index e52b3272..9716b263 100644 --- a/libcst/metadata/position_provider.py +++ b/libcst/metadata/position_provider.py @@ -7,8 +7,8 @@ import re from contextlib import contextmanager -from dataclasses import dataclass -from typing import Iterator, Optional, Pattern +from dataclasses import dataclass, field +from typing import Iterator, List, Optional, Pattern from libcst._add_slots import add_slots from libcst._nodes.base import CSTNode @@ -25,6 +25,7 @@ NEWLINE_RE: Pattern[str] = re.compile(r"\r\n?|\n") @dataclass(frozen=False) class WhitespaceInclusivePositionProvidingCodegenState(CodegenState): provider: BaseMetadataProvider[CodeRange] + _stack: List[CodePosition] = field(init=False, default_factory=list) def add_indent_tokens(self) -> None: self.tokens.extend(self.indent_tokens) @@ -48,11 +49,19 @@ class WhitespaceInclusivePositionProvidingCodegenState(CodegenState): # newline resets column back to 0, but a trailing token may shift column self.column = len(segments[-1]) - def record_position(self, node: CSTNode, position: CodeRange) -> None: + def before_visit(self, node: "CSTNode") -> None: + self._stack.append(CodePosition(self.line, self.column)) + + def after_leave(self, node: "CSTNode") -> None: + # we must unconditionally pop the stack, else we could end up in a broken state + start_pos = self._stack.pop() + # Don't overwrite existing position information # (i.e. semantic position has already been recorded) if node not in self.provider._computed: - self.provider._computed[node] = position + end_pos = CodePosition(self.line, self.column) + node_range = CodeRange(start_pos, end_pos) + self.provider._computed[node] = node_range class WhitespaceInclusivePositionProvider(BaseMetadataProvider[CodeRange]): diff --git a/libcst/metadata/tests/test_position_provider.py b/libcst/metadata/tests/test_position_provider.py index 12b8e44e..45f2b859 100644 --- a/libcst/metadata/tests/test_position_provider.py +++ b/libcst/metadata/tests/test_position_provider.py @@ -13,7 +13,6 @@ from libcst._batched_visitor import BatchableCSTVisitor from libcst._nodes.internal import CodegenState from libcst._visitors import CSTTransformer from libcst.metadata import ( - CodePosition, CodeRange, MetadataWrapper, PositionProvider, @@ -121,13 +120,12 @@ class PositionProvidingCodegenStateTest(UnitTest): state = WhitespaceInclusivePositionProvidingCodegenState( " " * 4, "\n", WhitespaceInclusivePositionProvider() ) - start = CodePosition(state.line, state.column) + state.before_visit(node) state.add_token(" ") with state.record_syntactic_position(node): state.add_token("pass") state.add_token(" ") - end = CodePosition(state.line, state.column) - state.record_position(node, CodeRange(start, end)) + state.after_leave(node) # check whitespace is correctly recorded self.assertEqual(state.provider._computed[node], CodeRange((1, 0), (1, 6))) @@ -139,13 +137,12 @@ class PositionProvidingCodegenStateTest(UnitTest): # simulate codegen behavior for the dummy node # generates the code " pass " state = PositionProvidingCodegenState(" " * 4, "\n", PositionProvider()) - start = CodePosition(state.line, state.column) + state.before_visit(node) state.add_token(" ") with state.record_syntactic_position(node): state.add_token("pass") state.add_token(" ") - end = CodePosition(state.line, state.column) - state.record_position(node, CodeRange(start, end)) + state.after_leave(node) # check syntactic position ignores whitespace self.assertEqual(state.provider._computed[node], CodeRange((1, 1), (1, 5)))