Port line and column numbers to metadata framework

Refactors codegen to write position information to the `__metadata__`
fields in nodes keyed by `BasicPositionProvider` and
`SyntacticPositionProvider` as defined in position_metadata.py.

This commit also updates `deep_equals` to ignore dataclass fields that
are marked `compare=False` to avoid comparing metadata when doing
equality checks.
This commit is contained in:
Ray Zeng 2019-07-03 16:24:39 -07:00 committed by Benjamin Woodruff
parent 8fba418f2c
commit d0ecb018e1
11 changed files with 167 additions and 74 deletions

View file

@ -5,7 +5,7 @@
# pyre-strict
from abc import ABC
from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast
from typing import TYPE_CHECKING, Type, TypeVar, Union
from libcst._removal_sentinel import RemovalSentinel
@ -60,5 +60,4 @@ class CSTVisitor(ABC):
"""
# TODO: runtime checks that metadata is available in this visitor
# pyre-fixme[33]: Given annotation cannot be `Any`.
return cast(Any, node).__metadata__[key]
return node.__metadata__[key]

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Any, Generic, TypeVar, cast
from typing import Generic, TypeVar
import libcst.nodes as cst
from libcst._base_visitor import CSTVisitor
@ -37,5 +37,4 @@ class BaseMetadataProvider(CSTVisitor, Generic[_T]):
@classmethod
def set_metadata(cls, node: cst.CSTNode, value: _T) -> None:
# pyre-fixme[33]: Explicit annotation for `typing.cast` cannot be `Any`.
cast(Any, node).__metadata__[cls] = value
node.__metadata__[cls] = value

View file

@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import libcst.nodes as cst
from libcst.metadata.base_provider import BaseMetadataProvider
from libcst.nodes._internal import CodePosition
class BasicPositionProvider(BaseMetadataProvider[CodePosition]):
"""
Generates basic line and column metadata. Basic position is
defined by the start and ending bounds of a node including all whitespace
owned by that node.
"""
def generate(self, module: cst.Module) -> None:
"""
Override default generate behavior as position information is
calculated through codegen instead of a standard visitor.
"""
module.code_for_node(module, provider=self.__class__)
class SyntacticPositionProvider(BasicPositionProvider):
"""
Generates Syntactic line and column metadata. Syntactic position is
defined by the start and ending bounds of a node ignoring most instances
of leading and trailing whitespace when it is not syntactically significant.
"""

View file

@ -254,7 +254,7 @@ class CSTNode(ABC):
start = (state.line, state.column)
self._codegen_impl(state, **kwargs)
end = (state.line, state.column)
state.update_position(self, CodePosition(start, end))
state.record_position(self, CodePosition(start, end))
def with_changes(self: _CSTNodeSelfT, **changes: Any) -> _CSTNodeSelfT:
"""

View file

@ -48,7 +48,8 @@ def _deep_equals_cst_node(a: "CSTNode", b: "CSTNode") -> bool:
return False
if a is b: # short-circuit
return True
for field in fields(a):
# Ignore metadata and other hidden fields
for field in (f for f in fields(a) if f.compare is True):
a_value = getattr(a, field.name)
b_value = getattr(b, field.name)
if not deep_equals(a_value, b_value):

View file

@ -169,7 +169,7 @@ class _BaseParenthesizedNode(CSTNode, ABC):
def _parenthesize(self, state: CodegenState) -> Generator[None, None, None]:
for lpar in self.lpar:
lpar._codegen(state)
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
yield
for rpar in self.rpar:
rpar._codegen(state)
@ -556,7 +556,7 @@ class FormattedStringExpression(BaseFormattedStringContent):
)
def _codegen_impl(self, state: CodegenState) -> None:
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
state.add_token("{")
self.whitespace_before_expression._codegen(state)
self.expression._codegen(state)
@ -977,7 +977,7 @@ class Index(CSTNode):
return Index(value=visit_required("value", self.value, visitor))
def _codegen_impl(self, state: CodegenState) -> None:
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
self.value._codegen(state)
@ -1015,7 +1015,7 @@ class Slice(CSTNode):
)
def _codegen_impl(self, state: CodegenState) -> None:
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
lower = self.lower
if lower is not None:
lower._codegen(state)
@ -1055,7 +1055,7 @@ class ExtSlice(CSTNode):
)
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
self.slice._codegen(state)
comma = self.comma
@ -1198,7 +1198,7 @@ class Annotation(CSTNode):
state.add_token(indicator)
self.whitespace_after_indicator._codegen(state)
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
self.annotation._codegen(state)
@ -1287,7 +1287,7 @@ class Param(CSTNode):
default_star: Optional[str] = None,
default_comma: bool = False,
) -> None:
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
star = self.star
if isinstance(star, MaybeSentinel):
if default_star is None:
@ -1421,7 +1421,7 @@ class Parameters(CSTNode):
def _codegen_impl(self, state: CodegenState) -> None:
# TODO: remove this when fallback to syntactic whitespace becomes available
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
# Compute the star existence first so we can ask about whether
# each element is the last in the list or not.
star_arg = self.star_arg
@ -1616,7 +1616,7 @@ class Arg(CSTNode):
)
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
state.add_token(self.star)
self.whitespace_after_star._codegen(state)
keyword = self.keyword
@ -1783,7 +1783,7 @@ class Call(_BaseExpressionWithArgs):
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
self.func._codegen(state)
self.whitespace_after_func._codegen(state)
state.add_token("(")
@ -1971,7 +1971,7 @@ class From(CSTNode):
else:
state.add_token(default_space)
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
state.add_token("from")
self.whitespace_after_from._codegen(state)
self.item._codegen(state)
@ -2095,7 +2095,7 @@ class Element(BaseElement):
default_comma: bool = False,
default_comma_whitespace: bool = False,
) -> None:
with state.record_semantic_position(self):
with state.record_syntactic_position(self):
self.value._codegen(state)
self.whitespace_after._codegen(state)

View file

@ -13,11 +13,11 @@ from typing import (
Iterable,
Iterator,
List,
MutableMapping,
Optional,
Pattern,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
@ -31,14 +31,20 @@ if TYPE_CHECKING:
# These are circular dependencies only used for typing purposes
from libcst.nodes._base import CSTNode
from libcst._base_visitor import CSTVisitor
from libcst.metadata.position_provider import (
BasicPositionProvider,
SyntacticPositionProvider,
)
_CSTNodeT = TypeVar("_CSTNodeT", bound="CSTNode")
_ProviderT = Union[Type["BasicPositionProvider"], Type["SyntacticPositionProvider"]]
NEWLINE_RE: Pattern[str] = re.compile(r"\r\n?|\n")
@add_slots
@dataclass(frozen=True)
class CodePosition:
# start and end are each a tuple of (line, column) numbers
@ -53,17 +59,18 @@ class CodegenState:
default_indent: str
default_newline: str
provider: _ProviderT = field(init=False)
indent_tokens: List[str] = field(default_factory=list)
tokens: List[str] = field(default_factory=list)
line: int = 1 # one-indexed
column: int = 0 # zero-indexed
positions: MutableMapping["CSTNode", CodePosition] = field(default_factory=dict)
def __post_init__(self) -> None:
from libcst.metadata.position_provider import BasicPositionProvider
semantic_positions: MutableMapping["CSTNode", CodePosition] = field(
default_factory=dict
)
self.provider = BasicPositionProvider
def increase_indent(self, value: str) -> None:
self.indent_tokens.append(value)
@ -93,17 +100,35 @@ class CodegenState:
# newline resets column back to 0, but a trailing token may shift column
self.column = len(segments[-1])
def update_position(self, node: _CSTNodeT, position: CodePosition) -> None:
self.positions[node] = position
def record_position(self, node: _CSTNodeT, position: CodePosition) -> None:
# Don't overwrite existing position information
# (i.e. semantic position has already been recorded)
if self.provider not in node.__metadata__:
node.__metadata__[self.provider] = position
@contextmanager
def record_semantic_position(self, node: _CSTNodeT) -> Iterator[None]:
def record_syntactic_position(self, node: _CSTNodeT) -> Iterator[None]:
yield
class SyntacticCodegenState(CodegenState):
"""
Pass to codegen to record the syntatic position of nodes.
"""
def __post_init__(self) -> None:
from libcst.metadata.position_provider import SyntacticPositionProvider
self.provider = SyntacticPositionProvider
@contextmanager
def record_syntactic_position(self, node: _CSTNodeT) -> Iterator[None]:
start = (self.line, self.column)
try:
yield
finally:
end = (self.line, self.column)
self.semantic_positions[node] = CodePosition(start, end)
node.__metadata__[self.provider] = CodePosition(start, end)
def visit_required(fieldname: str, node: _CSTNodeT, visitor: "CSTVisitor") -> _CSTNodeT:

View file

@ -3,19 +3,28 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import MutableMapping, Sequence, TypeVar, Union
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
from libcst._add_slots import add_slots
from libcst._base_visitor import CSTVisitor
from libcst._removal_sentinel import RemovalSentinel
from libcst.nodes._base import CSTNode
from libcst.nodes._internal import CodegenState, CodePosition, visit_sequence
from libcst.nodes._internal import CodegenState, SyntacticCodegenState, visit_sequence
from libcst.nodes._statement import BaseCompoundStatement, SimpleStatementLine
from libcst.nodes._whitespace import EmptyLine
if TYPE_CHECKING:
# These are circular dependencies only used for typing purposes
from libcst.metadata.position_provider import (
BasicPositionProvider,
SyntacticPositionProvider,
)
_ModuleSelfT = TypeVar("_ModuleSelfT", bound="Module")
_ProviderT = Union[Type["BasicPositionProvider"], Type["SyntacticPositionProvider"]]
# type alias needed for scope overlap in type definition
builtin_bytes = bytes
@ -41,14 +50,6 @@ class Module(CSTNode):
default_newline: str = "\n"
has_trailing_newline: bool = True
# TODO: remove these fields when metadata API is in place
_positions: MutableMapping["CSTNode", CodePosition] = field(
default_factory=dict, init=False, compare=False, repr=False
)
_semantic_positions: MutableMapping["CSTNode", CodePosition] = field(
default_factory=dict, init=False, compare=False, repr=False
)
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Module":
return Module(
header=visit_sequence("header", self.header, visitor),
@ -93,19 +94,30 @@ class Module(CSTNode):
def bytes(self) -> builtin_bytes:
return self.code.encode(self.encoding)
def code_for_node(self, node: CSTNode) -> str:
def code_for_node(
self, node: CSTNode, provider: Optional[_ProviderT] = None
) -> str:
"""
Generates the code for the given node in the context of this module. This is a
method of Module, not CSTNode, because we need to know the module's default
indentation and newline formats.
By default, this also generates syntactic line and column metadata for each
node. Passing BasicPositionProvider will generate basic line and column
metadata instead.
"""
state = CodegenState(
default_indent=self.default_indent, default_newline=self.default_newline
)
from libcst.metadata.position_provider import SyntacticPositionProvider
if provider is None or provider is SyntacticPositionProvider:
state = SyntacticCodegenState(
default_indent=self.default_indent, default_newline=self.default_newline
)
else:
state = CodegenState(
default_indent=self.default_indent, default_newline=self.default_newline
)
node._codegen(state)
# TODO remove when metadata API is in place
object.__setattr__(self, "_positions", state.positions)
object.__setattr__(self, "_semantic_positions", state.semantic_positions)
return "".join(state.tokens)

View file

@ -12,6 +12,7 @@ from unittest.mock import patch
import libcst.nodes as cst
from libcst._base_visitor import CSTVisitor
from libcst.metadata.position_provider import SyntacticPositionProvider
from libcst.nodes._internal import CodegenState, CodePosition, visit_required
from libcst.testing.utils import UnitTest
@ -94,8 +95,9 @@ class CSTNodeTest(UnitTest):
module = cst.Module([])
self.assertEqual(module.code_for_node(node), expected)
if expected_position is not None:
# TODO: replace this when metadata framework is in place
self.assertEqual(module._semantic_positions[node], expected_position)
self.assertEqual(
node.__metadata__[SyntacticPositionProvider], expected_position
)
def __assert_children_match_codegen(self, node: cst.CSTNode) -> None:
children = node.children

View file

@ -7,7 +7,11 @@
from typing import Tuple
import libcst.nodes as cst
from libcst.nodes._internal import CodegenState, CodePosition
from libcst.metadata.position_provider import (
BasicPositionProvider,
SyntacticPositionProvider,
)
from libcst.nodes._internal import CodegenState, CodePosition, SyntacticCodegenState
from libcst.testing.utils import UnitTest
@ -50,7 +54,7 @@ class InternalTest(UnitTest):
state.add_indent_tokens()
self.assertEqual(position(state), (1, 8))
def test_context_manager(self) -> None:
def test_position(self) -> None:
# create a dummy node
node = cst.Pass()
@ -59,13 +63,33 @@ class InternalTest(UnitTest):
state = CodegenState(" " * 4, "\n")
start = (state.line, state.column)
state.add_token(" ")
with state.record_semantic_position(node):
with state.record_syntactic_position(node):
state.add_token("pass")
state.add_token(" ")
end = (state.line, state.column)
state.update_position(node, CodePosition(start, end))
state.record_position(node, CodePosition(start, end))
# check syntactic whitespace is correctly recorded
self.assertEqual(
node.__metadata__[BasicPositionProvider], CodePosition((1, 0), (1, 6))
)
def test_semantic_position(self) -> None:
# create a dummy node
node = cst.Pass()
# simulate codegen behavior for the dummy node
# generates the code " pass "
state = SyntacticCodegenState(" " * 4, "\n")
start = (state.line, state.column)
state.add_token(" ")
with state.record_syntactic_position(node):
state.add_token("pass")
state.add_token(" ")
end = (state.line, state.column)
state.record_position(node, CodePosition(start, end))
# check syntactic whitespace is correctly recorded (includes whitespace)
self.assertEqual(state.positions[node], CodePosition((1, 0), (1, 6)))
# check semantic whitespace is correctly recorded (ignoring whitespace)
self.assertEqual(state.semantic_positions[node], CodePosition((1, 1), (1, 5)))
self.assertEqual(
node.__metadata__[SyntacticPositionProvider], CodePosition((1, 1), (1, 5))
)

View file

@ -7,6 +7,7 @@
from typing import Tuple, cast
import libcst.nodes as cst
from libcst.metadata.position_provider import SyntacticPositionProvider
from libcst.nodes._internal import CodePosition
from libcst.nodes.tests.base import CSTNodeTest
from libcst.parser import parse_module
@ -138,16 +139,14 @@ class ModuleTest(CSTNodeTest):
module = parse_module(code)
module.code
self.assertEqual(module._positions[module], expected)
self.assertEqual(module.__metadata__[SyntacticPositionProvider], expected)
def cmp_position(
self,
module: cst.Module,
node: cst.CSTNode,
start: Tuple[int, int],
end: Tuple[int, int],
self, node: cst.CSTNode, start: Tuple[int, int], end: Tuple[int, int]
) -> None:
self.assertEqual(module._positions[node], CodePosition(start, end))
self.assertEqual(
node.__metadata__[SyntacticPositionProvider], CodePosition(start, end)
)
def test_function_position(self) -> None:
module = parse_module("def foo():\n pass")
@ -156,8 +155,8 @@ class ModuleTest(CSTNodeTest):
fn = cast(cst.FunctionDef, module.body[0])
stmt = cast(cst.SimpleStatementLine, fn.body.body[0])
pass_stmt = cast(cst.Pass, stmt.body[0])
self.cmp_position(module, stmt, (2, 0), (3, 0))
self.cmp_position(module, pass_stmt, (2, 4), (2, 8))
self.cmp_position(stmt, (2, 0), (3, 0))
self.cmp_position(pass_stmt, (2, 4), (2, 8))
def test_nested_indent_position(self) -> None:
module = parse_module(
@ -172,11 +171,11 @@ class ModuleTest(CSTNodeTest):
outer_else = cast(cst.Else, outer_if.orelse)
return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0]
self.cmp_position(module, outer_if, (1, 0), (6, 0))
self.cmp_position(module, inner_if, (2, 0), (4, 0))
self.cmp_position(module, assign, (3, 8), (3, 13))
self.cmp_position(module, outer_else, (4, 0), (6, 0))
self.cmp_position(module, return_stmt, (5, 4), (5, 10))
self.cmp_position(outer_if, (1, 0), (6, 0))
self.cmp_position(inner_if, (2, 0), (4, 0))
self.cmp_position(assign, (3, 8), (3, 13))
self.cmp_position(outer_else, (4, 0), (6, 0))
self.cmp_position(return_stmt, (5, 4), (5, 10))
def test_multiline_string_position(self) -> None:
module = parse_module('"abc"\\\n"def"')
@ -186,6 +185,6 @@ class ModuleTest(CSTNodeTest):
expr = cast(cst.Expr, stmt.body[0])
string = expr.value
self.cmp_position(module, stmt, (1, 0), (3, 0))
self.cmp_position(module, expr, (1, 0), (2, 5))
self.cmp_position(module, string, (1, 0), (2, 5))
self.cmp_position(stmt, (1, 0), (3, 0))
self.cmp_position(expr, (1, 0), (2, 5))
self.cmp_position(string, (1, 0), (2, 5))