Add metadata dependencies to CSTVisitor and add metadata runner

Add metadata access and dependency logic in `CSTVisitor` and
`Module.visit` to generate all metadata dependencies before performing a
visit pass over a tree and validate access to metadata.
This commit is contained in:
Ray Zeng 2019-07-03 17:40:51 -07:00 committed by Benjamin Woodruff
parent a293787f8c
commit dbbe3d1927
13 changed files with 205 additions and 47 deletions

View file

@ -5,7 +5,7 @@
# pyre-strict
from abc import ABC
from typing import TYPE_CHECKING, Type, TypeVar, Union
from typing import TYPE_CHECKING, ClassVar, Sequence, Type, TypeVar, Union
from libcst._removal_sentinel import RemovalSentinel
@ -23,6 +23,8 @@ __all__ = ["CSTVisitor", "RemovalSentinel"]
CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode")
_T = TypeVar("_T")
_UNDEFINED_DEFAULT = object()
class CSTVisitor(ABC):
"""
@ -32,6 +34,8 @@ class CSTVisitor(ABC):
subclass.
"""
METADATA_DEPENDENCIES: ClassVar[Sequence[Type["BaseMetadataProvider[object]"]]] = ()
def on_visit(self, node: "CSTNode") -> bool:
"""
Called every time a node is visited, before we've visited its children.
@ -52,12 +56,26 @@ class CSTVisitor(ABC):
return updated_node
@classmethod
def get_metadata(cls, key: Type["BaseMetadataProvider[_T]"], node: CSTNodeT) -> _T:
def get_metadata(
cls,
key: Type["BaseMetadataProvider[_T]"],
node: CSTNodeT,
default: _T = _UNDEFINED_DEFAULT,
) -> _T:
"""
Gets metadata provided by the [key] provider if it is accessible from
this vistor. Metadata is accessible if [key] is the same as [cls] or
if [key] is in METADATA_DEPENDENCIES.
"""
# TODO: runtime checks that metadata is available in this visitor
if key not in cls.METADATA_DEPENDENCIES and key is not cls:
raise KeyError(
f"{key.__name__} is not declared as a dependency from {cls.__name__}"
)
return node.__metadata__[key]
try:
return node._metadata[key]
except KeyError as err:
if default is not _UNDEFINED_DEFAULT:
return default
else:
raise err

View file

@ -78,3 +78,7 @@ class ParserSyntaxError(Exception):
# Text editors use a one-indexed column, so we need to add one to our
# zero-indexed column to get a human-readable result.
return tab_adjusted_column + 1
class MetadataException(Exception):
pass

View file

@ -10,31 +10,22 @@ import libcst.nodes as cst
from libcst._base_visitor import CSTVisitor
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
class BaseMetadataProvider(CSTVisitor, Generic[_T]):
class BaseMetadataProvider(CSTVisitor, Generic[_T_co]):
"""
Base class for metadata providers to subclass off of.
Base class for visitor-based metadata providers.
"""
def leave_Module(
self, orig_node: cst.Module, updated_node: cst.Module
) -> cst.Module:
# TODO: We may need to change the behavior of CSTVisitor or create a
# subclass to make sure MetaDataProviders don't mutate the tree.
# A metadata provider can't modify the tree.
# This ensures no copying is done (which would erase metadata).
assert orig_node.deep_equals(updated_node)
return orig_node
def generate(self, module: cst.Module) -> None:
def run(self, module: cst.Module) -> None:
"""
Convenience method to run metadata provider over a module.
"""
module.visit(self)
@classmethod
def set_metadata(cls, node: cst.CSTNode, value: _T) -> None:
node.__metadata__[cls] = value
# pyre-ignore[35]: Parameter type cannot be covariant. Pyre can't
# detect that this method is not mutating the Provider class.
def set_metadata(cls, node: cst.CSTNode, value: _T_co) -> None:
node._metadata[cls] = value

View file

@ -16,7 +16,7 @@ class BasicPositionProvider(BaseMetadataProvider[CodeRange]):
owned by that node.
"""
def generate(self, module: cst.Module) -> None:
def run(self, module: cst.Module) -> None:
"""
Override default generate behavior as position information is
calculated through codegen instead of a standard visitor.

32
libcst/metadata/runner.py Normal file
View file

@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import TYPE_CHECKING
from libcst.exceptions import MetadataException
if TYPE_CHECKING:
from libcst._base_visitor import CSTVisitor
from libcst.nodes import Module
def run(module: "Module", visitor: "CSTVisitor") -> None:
"""
Called by Module.visit to generate metadata dependencies before performing
a visitor pass.
"""
for Provider in visitor.METADATA_DEPENDENCIES:
if Provider in module._remaining_dependencies:
raise MetadataException(
f"Detected circular dependency between {type(visitor).__name__} and {Provider.__name__}"
)
if Provider not in module._satisfied_dependencies:
module._remaining_dependencies.add(Provider)
Provider().run(module)
module._satisfied_dependencies.add(Provider)
module._remaining_dependencies.remove(Provider)

View file

@ -4,32 +4,23 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import TYPE_CHECKING, TypeVar
from libcst.metadata.base_provider import BaseMetadataProvider
from libcst.nodes._base import CSTNode
from libcst.parser import parse_module
from libcst.testing.utils import UnitTest
if TYPE_CHECKING:
# Circular dependency for typing reasons only
from libcst.nodes._base import CSTNode
CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode")
class BaseMetadataProviderTest(UnitTest):
def test_simple_provider(self) -> None:
class SimpleProvider(BaseMetadataProvider[int]):
def on_visit(self, node: "CSTNode") -> bool:
def on_visit(self, node: CSTNode) -> bool:
self.set_metadata(node, 1)
return True
module = parse_module("pass")
pass_node = module.body[0]
provider = SimpleProvider()
provider.generate(module)
provider.run(module)
self.assertEqual(provider.get_metadata(SimpleProvider, module), 1)
self.assertEqual(provider.get_metadata(SimpleProvider, pass_node), 1)

View file

@ -0,0 +1,113 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Union
from libcst._base_visitor import CSTVisitor
from libcst._removal_sentinel import RemovalSentinel
from libcst.exceptions import MetadataException
from libcst.metadata.base_provider import BaseMetadataProvider
from libcst.nodes import CSTNode, Module
from libcst.parser import parse_module
from libcst.testing.utils import UnitTest
class MetadataRunnerTest(UnitTest):
def test_visitor_with_dependencies(self) -> None:
class SimpleProvider(BaseMetadataProvider[int]):
def on_visit(self, node: CSTNode) -> bool:
self.set_metadata(node, 1)
return True
class DependentProvider(BaseMetadataProvider[int]):
METADATA_DEPENDENCIES = (SimpleProvider,)
def on_visit(self, node: CSTNode) -> bool:
self.set_metadata(node, self.get_metadata(SimpleProvider, node) + 1)
return True
class DependentVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (DependentProvider, SimpleProvider)
module = parse_module("pass")
pass_node = module.body[0]
visitor = DependentVisitor()
module.visit(visitor)
self.assertEqual(visitor.get_metadata(SimpleProvider, module), 1)
self.assertEqual(visitor.get_metadata(DependentProvider, module), 2)
self.assertEqual(visitor.get_metadata(SimpleProvider, pass_node), 1)
self.assertEqual(visitor.get_metadata(DependentProvider, pass_node), 2)
def test_provider_with_circular_dependency(self) -> None:
class ProviderA(BaseMetadataProvider[str]):
pass
ProviderA.METADATA_DEPENDENCIES = (ProviderA,)
class BadVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (ProviderA,)
with self.assertRaisesRegex(
MetadataException,
"Detected circular dependency between ProviderA and ProviderA",
):
Module([]).visit(BadVisitor())
def test_self_access_metadata(self) -> None:
test_runner = self
class ProviderA(BaseMetadataProvider[bool]):
def on_visit(self, node: CSTNode) -> bool:
self.set_metadata(node, True)
return True
def on_leave(
self, original_node: CSTNode, updated_node: CSTNode
) -> Union[CSTNode, RemovalSentinel]:
test_runner.assertEqual(
self.get_metadata(type(self), original_node), True
)
return original_node
class AVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (ProviderA,)
Module([]).visit(AVisitor())
def test_access_unset_metadata(self) -> None:
class ProviderA(BaseMetadataProvider[bool]):
pass
class AVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (ProviderA,)
def on_visit(self, node: CSTNode) -> bool:
self.get_metadata(ProviderA, node)
return True
with self.assertRaises(KeyError):
Module([]).visit(AVisitor())
def test_access_invalid_metadata(self) -> None:
class ProviderA(BaseMetadataProvider[bool]):
pass
class ProviderB(BaseMetadataProvider[bool]):
pass
class AVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (ProviderA,)
def on_visit(self, node: CSTNode) -> bool:
self.get_metadata(ProviderA, node, True)
self.get_metadata(ProviderB, node)
return True
with self.assertRaisesRegex(
KeyError, "ProviderB is not declared as a dependency from AVisitor"
):
Module([]).visit(AVisitor())

View file

@ -72,7 +72,7 @@ def _indent(value: str) -> str:
@dataclass(frozen=True)
class CSTNode(ABC):
__metadata__: MutableMapping[Type["BaseMetaDataProvider[_T]"], _T] = field(
_metadata: MutableMapping[Type["BaseMetaDataProvider[_T]"], _T] = field(
default_factory=dict, init=False, repr=False, compare=False
)
@ -130,7 +130,7 @@ class CSTNode(ABC):
does.
"""
for f in fields(self):
if f.name == "__metadata__": # skip typechecking metadata field
if f.name == "_metadata": # skip typechecking metadata field
continue
value = getattr(self, f.name)

View file

@ -114,8 +114,8 @@ class CodegenState:
def record_position(self, node: _CSTNodeT, position: CodeRange) -> None:
# Don't overwrite existing position information
# (i.e. semantic position has already been recorded)
if self.provider not in node.__metadata__:
node.__metadata__[self.provider] = position
if self.provider not in node._metadata:
node._metadata[self.provider] = position
@contextmanager
def record_syntactic_position(self, node: _CSTNodeT) -> Iterator[None]:
@ -139,7 +139,7 @@ class SyntacticCodegenState(CodegenState):
yield
finally:
end = CodePosition(self.line, self.column)
node.__metadata__[self.provider] = CodeRange(start, end)
node._metadata[self.provider] = CodeRange(start, end)
def visit_required(fieldname: str, node: _CSTNodeT, visitor: "CSTVisitor") -> _CSTNodeT:

View file

@ -3,12 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, MutableSet, Optional, Sequence, Type, TypeVar, Union
from libcst._add_slots import add_slots
from libcst._base_visitor import CSTVisitor
from libcst._removal_sentinel import RemovalSentinel
from libcst.metadata.runner import run as compute_metadata
from libcst.nodes._base import CSTNode
from libcst.nodes._internal import CodegenState, SyntacticCodegenState, visit_sequence
from libcst.nodes._statement import BaseCompoundStatement, SimpleStatementLine
@ -50,6 +51,13 @@ class Module(CSTNode):
default_newline: str = "\n"
has_trailing_newline: bool = True
_satisfied_dependencies: MutableSet["BaseMetadataProvider[Any]"] = field(
default_factory=set, init=False, repr=False, compare=False
)
_remaining_dependencies: MutableSet["BaseMetadataProvider[Any]"] = field(
default_factory=set, init=False, repr=False, compare=False
)
def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Module":
return Module(
header=visit_sequence("header", self.header, visitor),
@ -62,6 +70,8 @@ class Module(CSTNode):
)
def visit(self: _ModuleSelfT, visitor: CSTVisitor) -> _ModuleSelfT:
compute_metadata(self, visitor)
result = CSTNode.visit(self, visitor)
if isinstance(result, RemovalSentinel):
return self.with_changes(body=(), header=(), footer=())

View file

@ -96,7 +96,7 @@ class CSTNodeTest(UnitTest):
self.assertEqual(module.code_for_node(node), expected)
if expected_position is not None:
self.assertEqual(
node.__metadata__[SyntacticPositionProvider], expected_position
node._metadata[SyntacticPositionProvider], expected_position
)
def __assert_children_match_codegen(self, node: cst.CSTNode) -> None:

View file

@ -76,7 +76,7 @@ class InternalTest(UnitTest):
# check syntactic whitespace is correctly recorded
self.assertEqual(
node.__metadata__[BasicPositionProvider], CodeRange.create((1, 0), (1, 6))
node._metadata[BasicPositionProvider], CodeRange.create((1, 0), (1, 6))
)
def test_semantic_position(self) -> None:
@ -96,6 +96,5 @@ class InternalTest(UnitTest):
# check semantic whitespace is correctly recorded (ignoring whitespace)
self.assertEqual(
node.__metadata__[SyntacticPositionProvider],
CodeRange.create((1, 1), (1, 5)),
node._metadata[SyntacticPositionProvider], CodeRange.create((1, 1), (1, 5))
)

View file

@ -142,13 +142,13 @@ class ModuleTest(CSTNodeTest):
module = parse_module(code)
module.code
self.assertEqual(module.__metadata__[SyntacticPositionProvider], expected)
self.assertEqual(module._metadata[SyntacticPositionProvider], expected)
def cmp_position(
self, node: cst.CSTNode, start: Tuple[int, int], end: Tuple[int, int]
) -> None:
self.assertEqual(
node.__metadata__[SyntacticPositionProvider], CodeRange.create(start, end)
node._metadata[SyntacticPositionProvider], CodeRange.create(start, end)
)
def test_function_position(self) -> None: