mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
Add metadata dependencies to CSTVisitor and add metadata runner
Add metadata access and dependency logic in `CSTVisitor` and `Module.visit` to generate all metadata dependencies before performing a visit pass over a tree and validate access to metadata.
This commit is contained in:
parent
a293787f8c
commit
dbbe3d1927
13 changed files with 205 additions and 47 deletions
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
# pyre-strict
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Type, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, ClassVar, Sequence, Type, TypeVar, Union
|
||||
|
||||
from libcst._removal_sentinel import RemovalSentinel
|
||||
|
||||
|
|
@ -23,6 +23,8 @@ __all__ = ["CSTVisitor", "RemovalSentinel"]
|
|||
CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
_UNDEFINED_DEFAULT = object()
|
||||
|
||||
|
||||
class CSTVisitor(ABC):
|
||||
"""
|
||||
|
|
@ -32,6 +34,8 @@ class CSTVisitor(ABC):
|
|||
subclass.
|
||||
"""
|
||||
|
||||
METADATA_DEPENDENCIES: ClassVar[Sequence[Type["BaseMetadataProvider[object]"]]] = ()
|
||||
|
||||
def on_visit(self, node: "CSTNode") -> bool:
|
||||
"""
|
||||
Called every time a node is visited, before we've visited its children.
|
||||
|
|
@ -52,12 +56,26 @@ class CSTVisitor(ABC):
|
|||
return updated_node
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls, key: Type["BaseMetadataProvider[_T]"], node: CSTNodeT) -> _T:
|
||||
def get_metadata(
|
||||
cls,
|
||||
key: Type["BaseMetadataProvider[_T]"],
|
||||
node: CSTNodeT,
|
||||
default: _T = _UNDEFINED_DEFAULT,
|
||||
) -> _T:
|
||||
"""
|
||||
Gets metadata provided by the [key] provider if it is accessible from
|
||||
this vistor. Metadata is accessible if [key] is the same as [cls] or
|
||||
if [key] is in METADATA_DEPENDENCIES.
|
||||
"""
|
||||
# TODO: runtime checks that metadata is available in this visitor
|
||||
if key not in cls.METADATA_DEPENDENCIES and key is not cls:
|
||||
raise KeyError(
|
||||
f"{key.__name__} is not declared as a dependency from {cls.__name__}"
|
||||
)
|
||||
|
||||
return node.__metadata__[key]
|
||||
try:
|
||||
return node._metadata[key]
|
||||
except KeyError as err:
|
||||
if default is not _UNDEFINED_DEFAULT:
|
||||
return default
|
||||
else:
|
||||
raise err
|
||||
|
|
|
|||
|
|
@ -78,3 +78,7 @@ class ParserSyntaxError(Exception):
|
|||
# Text editors use a one-indexed column, so we need to add one to our
|
||||
# zero-indexed column to get a human-readable result.
|
||||
return tab_adjusted_column + 1
|
||||
|
||||
|
||||
class MetadataException(Exception):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -10,31 +10,22 @@ import libcst.nodes as cst
|
|||
from libcst._base_visitor import CSTVisitor
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
|
||||
|
||||
class BaseMetadataProvider(CSTVisitor, Generic[_T]):
|
||||
class BaseMetadataProvider(CSTVisitor, Generic[_T_co]):
|
||||
"""
|
||||
Base class for metadata providers to subclass off of.
|
||||
Base class for visitor-based metadata providers.
|
||||
"""
|
||||
|
||||
def leave_Module(
|
||||
self, orig_node: cst.Module, updated_node: cst.Module
|
||||
) -> cst.Module:
|
||||
# TODO: We may need to change the behavior of CSTVisitor or create a
|
||||
# subclass to make sure MetaDataProviders don't mutate the tree.
|
||||
|
||||
# A metadata provider can't modify the tree.
|
||||
# This ensures no copying is done (which would erase metadata).
|
||||
assert orig_node.deep_equals(updated_node)
|
||||
return orig_node
|
||||
|
||||
def generate(self, module: cst.Module) -> None:
|
||||
def run(self, module: cst.Module) -> None:
|
||||
"""
|
||||
Convenience method to run metadata provider over a module.
|
||||
"""
|
||||
module.visit(self)
|
||||
|
||||
@classmethod
|
||||
def set_metadata(cls, node: cst.CSTNode, value: _T) -> None:
|
||||
node.__metadata__[cls] = value
|
||||
# pyre-ignore[35]: Parameter type cannot be covariant. Pyre can't
|
||||
# detect that this method is not mutating the Provider class.
|
||||
def set_metadata(cls, node: cst.CSTNode, value: _T_co) -> None:
|
||||
node._metadata[cls] = value
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class BasicPositionProvider(BaseMetadataProvider[CodeRange]):
|
|||
owned by that node.
|
||||
"""
|
||||
|
||||
def generate(self, module: cst.Module) -> None:
|
||||
def run(self, module: cst.Module) -> None:
|
||||
"""
|
||||
Override default generate behavior as position information is
|
||||
calculated through codegen instead of a standard visitor.
|
||||
|
|
|
|||
32
libcst/metadata/runner.py
Normal file
32
libcst/metadata/runner.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-strict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from libcst.exceptions import MetadataException
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libcst._base_visitor import CSTVisitor
|
||||
from libcst.nodes import Module
|
||||
|
||||
|
||||
def run(module: "Module", visitor: "CSTVisitor") -> None:
|
||||
"""
|
||||
Called by Module.visit to generate metadata dependencies before performing
|
||||
a visitor pass.
|
||||
"""
|
||||
for Provider in visitor.METADATA_DEPENDENCIES:
|
||||
if Provider in module._remaining_dependencies:
|
||||
raise MetadataException(
|
||||
f"Detected circular dependency between {type(visitor).__name__} and {Provider.__name__}"
|
||||
)
|
||||
|
||||
if Provider not in module._satisfied_dependencies:
|
||||
module._remaining_dependencies.add(Provider)
|
||||
Provider().run(module)
|
||||
module._satisfied_dependencies.add(Provider)
|
||||
module._remaining_dependencies.remove(Provider)
|
||||
|
|
@ -4,32 +4,23 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-strict
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from libcst.metadata.base_provider import BaseMetadataProvider
|
||||
from libcst.nodes._base import CSTNode
|
||||
from libcst.parser import parse_module
|
||||
from libcst.testing.utils import UnitTest
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Circular dependency for typing reasons only
|
||||
from libcst.nodes._base import CSTNode
|
||||
|
||||
|
||||
CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode")
|
||||
|
||||
|
||||
class BaseMetadataProviderTest(UnitTest):
|
||||
def test_simple_provider(self) -> None:
|
||||
class SimpleProvider(BaseMetadataProvider[int]):
|
||||
def on_visit(self, node: "CSTNode") -> bool:
|
||||
def on_visit(self, node: CSTNode) -> bool:
|
||||
self.set_metadata(node, 1)
|
||||
return True
|
||||
|
||||
module = parse_module("pass")
|
||||
pass_node = module.body[0]
|
||||
provider = SimpleProvider()
|
||||
provider.generate(module)
|
||||
provider.run(module)
|
||||
|
||||
self.assertEqual(provider.get_metadata(SimpleProvider, module), 1)
|
||||
self.assertEqual(provider.get_metadata(SimpleProvider, pass_node), 1)
|
||||
|
|
|
|||
113
libcst/metadata/tests/test_runner.py
Normal file
113
libcst/metadata/tests/test_runner.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-strict
|
||||
from typing import Union
|
||||
|
||||
from libcst._base_visitor import CSTVisitor
|
||||
from libcst._removal_sentinel import RemovalSentinel
|
||||
from libcst.exceptions import MetadataException
|
||||
from libcst.metadata.base_provider import BaseMetadataProvider
|
||||
from libcst.nodes import CSTNode, Module
|
||||
from libcst.parser import parse_module
|
||||
from libcst.testing.utils import UnitTest
|
||||
|
||||
|
||||
class MetadataRunnerTest(UnitTest):
|
||||
def test_visitor_with_dependencies(self) -> None:
|
||||
class SimpleProvider(BaseMetadataProvider[int]):
|
||||
def on_visit(self, node: CSTNode) -> bool:
|
||||
self.set_metadata(node, 1)
|
||||
return True
|
||||
|
||||
class DependentProvider(BaseMetadataProvider[int]):
|
||||
METADATA_DEPENDENCIES = (SimpleProvider,)
|
||||
|
||||
def on_visit(self, node: CSTNode) -> bool:
|
||||
self.set_metadata(node, self.get_metadata(SimpleProvider, node) + 1)
|
||||
return True
|
||||
|
||||
class DependentVisitor(CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (DependentProvider, SimpleProvider)
|
||||
|
||||
module = parse_module("pass")
|
||||
pass_node = module.body[0]
|
||||
visitor = DependentVisitor()
|
||||
module.visit(visitor)
|
||||
|
||||
self.assertEqual(visitor.get_metadata(SimpleProvider, module), 1)
|
||||
self.assertEqual(visitor.get_metadata(DependentProvider, module), 2)
|
||||
self.assertEqual(visitor.get_metadata(SimpleProvider, pass_node), 1)
|
||||
self.assertEqual(visitor.get_metadata(DependentProvider, pass_node), 2)
|
||||
|
||||
def test_provider_with_circular_dependency(self) -> None:
|
||||
class ProviderA(BaseMetadataProvider[str]):
|
||||
pass
|
||||
|
||||
ProviderA.METADATA_DEPENDENCIES = (ProviderA,)
|
||||
|
||||
class BadVisitor(CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (ProviderA,)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MetadataException,
|
||||
"Detected circular dependency between ProviderA and ProviderA",
|
||||
):
|
||||
Module([]).visit(BadVisitor())
|
||||
|
||||
def test_self_access_metadata(self) -> None:
|
||||
test_runner = self
|
||||
|
||||
class ProviderA(BaseMetadataProvider[bool]):
|
||||
def on_visit(self, node: CSTNode) -> bool:
|
||||
self.set_metadata(node, True)
|
||||
return True
|
||||
|
||||
def on_leave(
|
||||
self, original_node: CSTNode, updated_node: CSTNode
|
||||
) -> Union[CSTNode, RemovalSentinel]:
|
||||
test_runner.assertEqual(
|
||||
self.get_metadata(type(self), original_node), True
|
||||
)
|
||||
return original_node
|
||||
|
||||
class AVisitor(CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (ProviderA,)
|
||||
|
||||
Module([]).visit(AVisitor())
|
||||
|
||||
def test_access_unset_metadata(self) -> None:
|
||||
class ProviderA(BaseMetadataProvider[bool]):
|
||||
pass
|
||||
|
||||
class AVisitor(CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (ProviderA,)
|
||||
|
||||
def on_visit(self, node: CSTNode) -> bool:
|
||||
self.get_metadata(ProviderA, node)
|
||||
return True
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
Module([]).visit(AVisitor())
|
||||
|
||||
def test_access_invalid_metadata(self) -> None:
|
||||
class ProviderA(BaseMetadataProvider[bool]):
|
||||
pass
|
||||
|
||||
class ProviderB(BaseMetadataProvider[bool]):
|
||||
pass
|
||||
|
||||
class AVisitor(CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (ProviderA,)
|
||||
|
||||
def on_visit(self, node: CSTNode) -> bool:
|
||||
self.get_metadata(ProviderA, node, True)
|
||||
self.get_metadata(ProviderB, node)
|
||||
return True
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
KeyError, "ProviderB is not declared as a dependency from AVisitor"
|
||||
):
|
||||
Module([]).visit(AVisitor())
|
||||
|
|
@ -72,7 +72,7 @@ def _indent(value: str) -> str:
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class CSTNode(ABC):
|
||||
__metadata__: MutableMapping[Type["BaseMetaDataProvider[_T]"], _T] = field(
|
||||
_metadata: MutableMapping[Type["BaseMetaDataProvider[_T]"], _T] = field(
|
||||
default_factory=dict, init=False, repr=False, compare=False
|
||||
)
|
||||
|
||||
|
|
@ -130,7 +130,7 @@ class CSTNode(ABC):
|
|||
does.
|
||||
"""
|
||||
for f in fields(self):
|
||||
if f.name == "__metadata__": # skip typechecking metadata field
|
||||
if f.name == "_metadata": # skip typechecking metadata field
|
||||
continue
|
||||
|
||||
value = getattr(self, f.name)
|
||||
|
|
|
|||
|
|
@ -114,8 +114,8 @@ class CodegenState:
|
|||
def record_position(self, node: _CSTNodeT, position: CodeRange) -> None:
|
||||
# Don't overwrite existing position information
|
||||
# (i.e. semantic position has already been recorded)
|
||||
if self.provider not in node.__metadata__:
|
||||
node.__metadata__[self.provider] = position
|
||||
if self.provider not in node._metadata:
|
||||
node._metadata[self.provider] = position
|
||||
|
||||
@contextmanager
|
||||
def record_syntactic_position(self, node: _CSTNodeT) -> Iterator[None]:
|
||||
|
|
@ -139,7 +139,7 @@ class SyntacticCodegenState(CodegenState):
|
|||
yield
|
||||
finally:
|
||||
end = CodePosition(self.line, self.column)
|
||||
node.__metadata__[self.provider] = CodeRange(start, end)
|
||||
node._metadata[self.provider] = CodeRange(start, end)
|
||||
|
||||
|
||||
def visit_required(fieldname: str, node: _CSTNodeT, visitor: "CSTVisitor") -> _CSTNodeT:
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, MutableSet, Optional, Sequence, Type, TypeVar, Union
|
||||
|
||||
from libcst._add_slots import add_slots
|
||||
from libcst._base_visitor import CSTVisitor
|
||||
from libcst._removal_sentinel import RemovalSentinel
|
||||
from libcst.metadata.runner import run as compute_metadata
|
||||
from libcst.nodes._base import CSTNode
|
||||
from libcst.nodes._internal import CodegenState, SyntacticCodegenState, visit_sequence
|
||||
from libcst.nodes._statement import BaseCompoundStatement, SimpleStatementLine
|
||||
|
|
@ -50,6 +51,13 @@ class Module(CSTNode):
|
|||
default_newline: str = "\n"
|
||||
has_trailing_newline: bool = True
|
||||
|
||||
_satisfied_dependencies: MutableSet["BaseMetadataProvider[Any]"] = field(
|
||||
default_factory=set, init=False, repr=False, compare=False
|
||||
)
|
||||
_remaining_dependencies: MutableSet["BaseMetadataProvider[Any]"] = field(
|
||||
default_factory=set, init=False, repr=False, compare=False
|
||||
)
|
||||
|
||||
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Module":
|
||||
return Module(
|
||||
header=visit_sequence("header", self.header, visitor),
|
||||
|
|
@ -62,6 +70,8 @@ class Module(CSTNode):
|
|||
)
|
||||
|
||||
def visit(self: _ModuleSelfT, visitor: CSTVisitor) -> _ModuleSelfT:
|
||||
compute_metadata(self, visitor)
|
||||
|
||||
result = CSTNode.visit(self, visitor)
|
||||
if isinstance(result, RemovalSentinel):
|
||||
return self.with_changes(body=(), header=(), footer=())
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class CSTNodeTest(UnitTest):
|
|||
self.assertEqual(module.code_for_node(node), expected)
|
||||
if expected_position is not None:
|
||||
self.assertEqual(
|
||||
node.__metadata__[SyntacticPositionProvider], expected_position
|
||||
node._metadata[SyntacticPositionProvider], expected_position
|
||||
)
|
||||
|
||||
def __assert_children_match_codegen(self, node: cst.CSTNode) -> None:
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class InternalTest(UnitTest):
|
|||
|
||||
# check syntactic whitespace is correctly recorded
|
||||
self.assertEqual(
|
||||
node.__metadata__[BasicPositionProvider], CodeRange.create((1, 0), (1, 6))
|
||||
node._metadata[BasicPositionProvider], CodeRange.create((1, 0), (1, 6))
|
||||
)
|
||||
|
||||
def test_semantic_position(self) -> None:
|
||||
|
|
@ -96,6 +96,5 @@ class InternalTest(UnitTest):
|
|||
|
||||
# check semantic whitespace is correctly recorded (ignoring whitespace)
|
||||
self.assertEqual(
|
||||
node.__metadata__[SyntacticPositionProvider],
|
||||
CodeRange.create((1, 1), (1, 5)),
|
||||
node._metadata[SyntacticPositionProvider], CodeRange.create((1, 1), (1, 5))
|
||||
)
|
||||
|
|
|
|||
|
|
@ -142,13 +142,13 @@ class ModuleTest(CSTNodeTest):
|
|||
module = parse_module(code)
|
||||
module.code
|
||||
|
||||
self.assertEqual(module.__metadata__[SyntacticPositionProvider], expected)
|
||||
self.assertEqual(module._metadata[SyntacticPositionProvider], expected)
|
||||
|
||||
def cmp_position(
|
||||
self, node: cst.CSTNode, start: Tuple[int, int], end: Tuple[int, int]
|
||||
) -> None:
|
||||
self.assertEqual(
|
||||
node.__metadata__[SyntacticPositionProvider], CodeRange.create(start, end)
|
||||
node._metadata[SyntacticPositionProvider], CodeRange.create(start, end)
|
||||
)
|
||||
|
||||
def test_function_position(self) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue