Refactor metadata runner

Refactors various metadata related classes to better reflect intended
usage of the API.

The metadata runner now deep clones the module and returns a copy
containing the metadata information. Metadata providers also return the
tree to enforce the idea that the tree is immutable (even though no
copying is actually done and providers write directly to the original
tree).
This commit is contained in:
Ray Zeng 2019-07-17 14:42:49 -07:00 committed by Benjamin Woodruff
parent e452bd3a55
commit 36cfb512ea
7 changed files with 147 additions and 76 deletions

View file

@ -4,12 +4,17 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Generic, TypeVar
from typing import Generic, Iterable, TypeVar, cast
import libcst.nodes as cst
from libcst.batched_visitor import BatchableCSTVisitor
from libcst.batched_visitor import (
BatchableCSTVisitor,
_BatchedCSTVisitor,
_get_visitor_methods,
)
from libcst.exceptions import MetadataException
from libcst.metadata._interface import _MetadataInterface
from libcst.nodes._module import _ModuleSelfT as _ModuleT
from libcst.visitors import CSTVisitor
@ -22,9 +27,14 @@ class BaseMetadataProvider(_MetadataInterface, Generic[_T_co]):
Abstract base class for all metadata providers.
"""
def _run(self, module: cst.Module) -> None:
def _run(self, module: _ModuleT) -> _ModuleT:
"""
Entry point for metadata runner.
Returns the given module with metadata from this provider.
This is a hook for metadata runner and should not be called directly.
Any implementation of this method should not handle any dependencies
declared by this provider and should not have any side effects besides
setting metadata computed by this provider.
"""
...
@ -32,31 +42,45 @@ class BaseMetadataProvider(_MetadataInterface, Generic[_T_co]):
# pyre-ignore[35]: Parameter type cannot be covariant. Pyre can't
# detect that this method is not mutating the Provider class.
def set_metadata(cls, node: cst.CSTNode, value: _T_co) -> None:
"""
Stores given metadata from this provider on the given node.
"""
node._metadata[cls] = value
class BatchableMetadataProvider(BatchableCSTVisitor, BaseMetadataProvider[_T_co]):
"""
Base class for batchable visitor metadata providers.
"""
def _run(self, module: cst.Module) -> None:
"""
Batchable providers are batched by the runner and should not be
called directly.
"""
raise MetadataException(
"BatchableMetadataProvider should not be called directly."
)
class VisitorMetadataProvider(CSTVisitor, BaseMetadataProvider[_T_co]):
"""
Base class for non-batchable visitor metadata providers.
Extend this to compute metadata with a non-batchable visitor.
"""
def _run(self, module: cst.Module) -> None:
def _run(self, module: _ModuleT) -> _ModuleT:
"""
Does not compute dependencies declared by this provider.
Returns the given module with metadata from this provider.
"""
module._visit_impl(self)
# Cast is safe as metadata providers should never mutate the tree
return cast(_ModuleT, module._visit_impl(self))
class BatchableMetadataProvider(BatchableCSTVisitor, BaseMetadataProvider[_T_co]):
"""
Extend this to compute metadata with a batchable visitor.
"""
def _run(self, module: _ModuleT) -> _ModuleT:
"""
Batchable providers are resolved using [_run_batchable].
"""
raise MetadataException("BatchableMetadataProvider cannot be called directly.")
def _run_batchable(
module: _ModuleT, providers: Iterable[BatchableMetadataProvider[object]]
) -> _ModuleT:
"""
Returns the given module with metadata from the given batchable providers.
"""
visitor_methods = _get_visitor_methods(providers)
batched_visitor = _BatchedCSTVisitor(visitor_methods)
# Cast is safe as metadata providers should never mutate the tree
return cast(_ModuleT, module._visit_impl(batched_visitor))

View file

@ -4,8 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
import libcst.nodes as cst
from libcst.metadata.base_provider import BaseMetadataProvider
from libcst.metadata.base_provider import BaseMetadataProvider, _ModuleT
from libcst.nodes._internal import CodeRange
@ -16,12 +15,13 @@ class BasicPositionProvider(BaseMetadataProvider[CodeRange]):
owned by that node.
"""
def run(self, module: cst.Module) -> None:
def _run(self, module: _ModuleT) -> _ModuleT:
"""
Override default generate behavior as position information is
calculated through codegen instead of a standard visitor.
"""
module.code_for_node(module, provider=self.__class__)
return module
class SyntacticPositionProvider(BasicPositionProvider):

View file

@ -6,24 +6,20 @@
# pyre-strict
from typing import MutableSet, Type
import libcst.nodes as cst
from libcst.batched_visitor import visit as batched_visit
from libcst.exceptions import MetadataException
from libcst.metadata._interface import _MetadataInterface
from libcst.metadata.base_provider import (
BaseMetadataProvider,
BatchableMetadataProvider,
_run_batchable,
)
from libcst.nodes._module import _ModuleSelfT as _ModuleT
from libcst.visitors import CSTVisitorT
ProviderT = Type[BaseMetadataProvider[object]]
class _MetadataRunner:
"""
Helper class for resolving metadata dependencies.
"""
def __init__(self) -> None:
self.providers: MutableSet[ProviderT] = set()
self.satisfied: MutableSet[ProviderT] = set()
@ -34,34 +30,43 @@ class _MetadataRunner:
for dep in root.METADATA_DEPENDENCIES:
self.gather_providers(dep)
@staticmethod
def resolve(module: _ModuleT, visitor: CSTVisitorT) -> _ModuleT:
"""
Returns a copy of the module that contains all metadata dependencies
declared by the visitor.
"""
def run(module: cst.Module, root: _MetadataInterface) -> None:
"""
Called by Module.visit to resolve metadata dependencies before performing
a visitor pass.
"""
if len(visitor.METADATA_DEPENDENCIES) == 0:
return module
runner = _MetadataRunner()
for dep in root.METADATA_DEPENDENCIES:
runner.gather_providers(dep)
# We need to deep clone to ensure that there are no duplicate nodes
module = module.deep_clone()
while len(runner.providers) > 0:
completed = set()
batchable = set()
runner = _MetadataRunner()
for dep in visitor.METADATA_DEPENDENCIES:
runner.gather_providers(dep)
for P in runner.providers:
if set(P.METADATA_DEPENDENCIES).issubset(runner.satisfied):
if issubclass(P, BatchableMetadataProvider):
batchable.add(P)
else:
P()._run(module)
completed.add(P)
while len(runner.providers) > 0:
completed = set()
batchable = set()
batched_visit(module, [p() for p in batchable])
runner.providers -= completed | batchable
runner.satisfied |= completed | batchable
for P in runner.providers:
if set(P.METADATA_DEPENDENCIES).issubset(runner.satisfied):
if issubclass(P, BatchableMetadataProvider):
batchable.add(P)
else:
module = P()._run(module)
completed.add(P)
if len(completed) == 0 and len(batchable) == 0:
# runner.providers must be non-empty at this point
names = ", ".join([P.__name__ for P in runner.providers])
raise MetadataException(f"Detected circular dependencies in {names}")
module = _run_batchable(module, [p() for p in batchable])
runner.providers -= completed | batchable
runner.satisfied |= completed | batchable
if len(completed) == 0 and len(batchable) == 0:
# runner.providers must be non-empty at this point
names = ", ".join([P.__name__ for P in runner.providers])
raise MetadataException(f"Detected circular dependencies in {names}")
return module

View file

@ -7,10 +7,10 @@
from typing import cast
import libcst.nodes as cst
from libcst.batched_visitor import visit
from libcst.metadata.base_provider import (
BatchableMetadataProvider,
VisitorMetadataProvider,
_run_batchable,
)
from libcst.parser import parse_module
from libcst.testing.utils import UnitTest
@ -47,17 +47,17 @@ class BaseMetadataProviderTest(UnitTest):
def visit_Pass(self, node: cst.Pass) -> None:
self.set_metadata(node, 1)
def leave_Return(self, node: cst.Return) -> None:
def visit_Return(self, node: cst.Return) -> None:
self.set_metadata(node, 2)
module = parse_module("pass; return; pass")
provider = SimpleProvider()
module = _run_batchable(module, [provider])
pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]
pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2]
provider = SimpleProvider()
visit(module, [provider])
self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1)
self.assertEqual(provider.get_metadata(SimpleProvider, return_), 2)
self.assertEqual(provider.get_metadata(SimpleProvider, pass_2), 1)

View file

@ -0,0 +1,33 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import libcst.nodes as cst
from libcst.metadata.position_provider import SyntacticPositionProvider
from libcst.nodes._internal import CodeRange
from libcst.parser import parse_module
from libcst.testing.utils import UnitTest
from libcst.visitors import CSTTransformer
class PositionProviderTest(UnitTest):
def test_visitor_provider(self) -> None:
"""
Sets 2 metadata entries for every node:
SimpleProvider -> 1
DependentProvider - > 2
"""
test = self
class DependentVisitor(CSTTransformer):
METADATA_DEPENDENCIES = (SyntacticPositionProvider,)
def visit_Pass(self, node: cst.Pass) -> None:
range = self.get_metadata(SyntacticPositionProvider, node)
test.assertEqual(range, CodeRange.create((1, 0), (1, 4)))
module = parse_module("pass")
module.visit(DependentVisitor())

View file

@ -22,7 +22,7 @@ from libcst._removal_sentinel import RemovalSentinel
from libcst._type_enforce import is_value_of_type
from libcst.exceptions import MetadataException
from libcst.nodes._internal import CodegenState, CodePosition, CodeRange
from libcst.visitors import CSTVisitor, CSTVisitorT
from libcst.visitors import CSTTransformer, CSTVisitor, CSTVisitorT
if TYPE_CHECKING:
@ -53,6 +53,10 @@ class _ChildrenCollectionVisitor(CSTVisitor):
return False # Don't include transitive children
class _NOOPVisitor(CSTTransformer):
pass
def _pretty_repr(value: object) -> str:
if not isinstance(value, str) and isinstance(value, Sequence):
return _pretty_repr_sequence(value)
@ -180,15 +184,16 @@ class CSTNode(ABC):
self: _CSTNodeSelfT, visitor: CSTVisitorT
) -> Union[_CSTNodeSelfT, RemovalSentinel]:
"""
Main entry point for visitors.
This wraps [_visit_impl] to validate metadata dependencies prior to
performing a visitor pass.
Public hook to visit the current node and all transitive children using
the given visitor.
"""
# Only modules can be visited a visitor that declare metadata dependencies.
# Module overrides this method to resolve metadata dependencies.
if len(visitor.METADATA_DEPENDENCIES) > 0:
raise MetadataException(
f"{type(visitor).__name__} declares metadata dependencies and should only be called from the module level"
f"{type(visitor).__name__} declares metadata dependencies "
+ "and should only be called from the module level"
)
return self._visit_impl(visitor)
@ -197,8 +202,8 @@ class CSTNode(ABC):
self: _CSTNodeSelfT, visitor: CSTVisitorT
) -> Union[_CSTNodeSelfT, RemovalSentinel]:
"""
Visits the current node, its children, and all transitive children using the
given CSTVisitor's callbacks.
Visits the current node, its children, and all transitive children using
the given visitor's callbacks.
"""
# visit self
should_visit_children = visitor.on_visit(self)
@ -311,6 +316,9 @@ class CSTNode(ABC):
"""
return replace(self, **changes)
def deep_clone(self: _CSTNodeSelfT) -> _CSTNodeSelfT:
return cast(_CSTNodeSelfT, self._visit_impl(_NOOPVisitor()))
def deep_equals(self: _CSTNodeSelfT, other: _CSTNodeSelfT) -> bool:
"""
Recursively inspects the entire tree under `self` and `other` to determine if

View file

@ -68,15 +68,16 @@ class Module(CSTNode):
def visit(self: _ModuleSelfT, visitor: CSTVisitorT) -> _ModuleSelfT:
"""
Returns the result of running a visitor over this module.
Module overrides the default visitor entry point to resolve metadata
dependencies for the visitor.
dependencies declared by [visitor].
"""
from libcst.metadata.runner import run
from libcst.metadata.runner import _MetadataRunner
run(self, visitor)
result = CSTNode._visit_impl(self, visitor)
module = _MetadataRunner.resolve(self, visitor)
result = CSTNode._visit_impl(module, visitor)
if isinstance(result, RemovalSentinel):
return self.with_changes(body=(), header=(), footer=())
else: # is a Module