mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
Refactor metadata runner
Refactors various metadata related classes to better reflect intended usage of the API. The metadata runner now deep clones the module and returns a copy containing the metadata information. Metadata providers also return the tree to enforce the idea that the tree is immutable (even though no copying is actually done and providers write directly to the original tree).
This commit is contained in:
parent
e452bd3a55
commit
36cfb512ea
7 changed files with 147 additions and 76 deletions
|
|
@ -4,12 +4,17 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-strict
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Generic, Iterable, TypeVar, cast
|
||||
|
||||
import libcst.nodes as cst
|
||||
from libcst.batched_visitor import BatchableCSTVisitor
|
||||
from libcst.batched_visitor import (
|
||||
BatchableCSTVisitor,
|
||||
_BatchedCSTVisitor,
|
||||
_get_visitor_methods,
|
||||
)
|
||||
from libcst.exceptions import MetadataException
|
||||
from libcst.metadata._interface import _MetadataInterface
|
||||
from libcst.nodes._module import _ModuleSelfT as _ModuleT
|
||||
from libcst.visitors import CSTVisitor
|
||||
|
||||
|
||||
|
|
@ -22,9 +27,14 @@ class BaseMetadataProvider(_MetadataInterface, Generic[_T_co]):
|
|||
Abstract base class for all metadata providers.
|
||||
"""
|
||||
|
||||
def _run(self, module: cst.Module) -> None:
|
||||
def _run(self, module: _ModuleT) -> _ModuleT:
|
||||
"""
|
||||
Entry point for metadata runner.
|
||||
Returns the given module with metadata from this provider.
|
||||
|
||||
This is a hook for metadata runner and should not be called directly.
|
||||
Any implementation of this method should not handle any dependencies
|
||||
declared by this provider and should not have any side effects besides
|
||||
setting metadata computed by this provider.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
@ -32,31 +42,45 @@ class BaseMetadataProvider(_MetadataInterface, Generic[_T_co]):
|
|||
# pyre-ignore[35]: Parameter type cannot be covariant. Pyre can't
|
||||
# detect that this method is not mutating the Provider class.
|
||||
def set_metadata(cls, node: cst.CSTNode, value: _T_co) -> None:
|
||||
"""
|
||||
Stores given metadata from this provider on the given node.
|
||||
"""
|
||||
node._metadata[cls] = value
|
||||
|
||||
|
||||
class BatchableMetadataProvider(BatchableCSTVisitor, BaseMetadataProvider[_T_co]):
|
||||
"""
|
||||
Base class for batchable visitor metadata providers.
|
||||
"""
|
||||
|
||||
def _run(self, module: cst.Module) -> None:
|
||||
"""
|
||||
Batchable providers are batched by the runner and should not be
|
||||
called directly.
|
||||
"""
|
||||
raise MetadataException(
|
||||
"BatchableMetadataProvider should not be called directly."
|
||||
)
|
||||
|
||||
|
||||
class VisitorMetadataProvider(CSTVisitor, BaseMetadataProvider[_T_co]):
|
||||
"""
|
||||
Base class for non-batchable visitor metadata providers.
|
||||
Extend this to compute metadata with a non-batchable visitor.
|
||||
"""
|
||||
|
||||
def _run(self, module: cst.Module) -> None:
|
||||
def _run(self, module: _ModuleT) -> _ModuleT:
|
||||
"""
|
||||
Does not compute dependencies declared by this provider.
|
||||
Returns the given module with metadata from this provider.
|
||||
"""
|
||||
module._visit_impl(self)
|
||||
# Cast is safe as metadata providers should never mutate the tree
|
||||
return cast(_ModuleT, module._visit_impl(self))
|
||||
|
||||
|
||||
class BatchableMetadataProvider(BatchableCSTVisitor, BaseMetadataProvider[_T_co]):
|
||||
"""
|
||||
Extend this to compute metadata with a batchable visitor.
|
||||
"""
|
||||
|
||||
def _run(self, module: _ModuleT) -> _ModuleT:
|
||||
"""
|
||||
Batchable providers are resolved using [_run_batchable].
|
||||
"""
|
||||
raise MetadataException("BatchableMetadataProvider cannot be called directly.")
|
||||
|
||||
|
||||
def _run_batchable(
|
||||
module: _ModuleT, providers: Iterable[BatchableMetadataProvider[object]]
|
||||
) -> _ModuleT:
|
||||
"""
|
||||
Returns the given module with metadata from the given batchable providers.
|
||||
"""
|
||||
|
||||
visitor_methods = _get_visitor_methods(providers)
|
||||
batched_visitor = _BatchedCSTVisitor(visitor_methods)
|
||||
# Cast is safe as metadata providers should never mutate the tree
|
||||
return cast(_ModuleT, module._visit_impl(batched_visitor))
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-strict
|
||||
import libcst.nodes as cst
|
||||
from libcst.metadata.base_provider import BaseMetadataProvider
|
||||
from libcst.metadata.base_provider import BaseMetadataProvider, _ModuleT
|
||||
from libcst.nodes._internal import CodeRange
|
||||
|
||||
|
||||
|
|
@ -16,12 +15,13 @@ class BasicPositionProvider(BaseMetadataProvider[CodeRange]):
|
|||
owned by that node.
|
||||
"""
|
||||
|
||||
def run(self, module: cst.Module) -> None:
|
||||
def _run(self, module: _ModuleT) -> _ModuleT:
|
||||
"""
|
||||
Override default generate behavior as position information is
|
||||
calculated through codegen instead of a standard visitor.
|
||||
"""
|
||||
module.code_for_node(module, provider=self.__class__)
|
||||
return module
|
||||
|
||||
|
||||
class SyntacticPositionProvider(BasicPositionProvider):
|
||||
|
|
|
|||
|
|
@ -6,24 +6,20 @@
|
|||
# pyre-strict
|
||||
from typing import MutableSet, Type
|
||||
|
||||
import libcst.nodes as cst
|
||||
from libcst.batched_visitor import visit as batched_visit
|
||||
from libcst.exceptions import MetadataException
|
||||
from libcst.metadata._interface import _MetadataInterface
|
||||
from libcst.metadata.base_provider import (
|
||||
BaseMetadataProvider,
|
||||
BatchableMetadataProvider,
|
||||
_run_batchable,
|
||||
)
|
||||
from libcst.nodes._module import _ModuleSelfT as _ModuleT
|
||||
from libcst.visitors import CSTVisitorT
|
||||
|
||||
|
||||
ProviderT = Type[BaseMetadataProvider[object]]
|
||||
|
||||
|
||||
class _MetadataRunner:
|
||||
"""
|
||||
Helper class for resolving metadata dependencies.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.providers: MutableSet[ProviderT] = set()
|
||||
self.satisfied: MutableSet[ProviderT] = set()
|
||||
|
|
@ -34,34 +30,43 @@ class _MetadataRunner:
|
|||
for dep in root.METADATA_DEPENDENCIES:
|
||||
self.gather_providers(dep)
|
||||
|
||||
@staticmethod
|
||||
def resolve(module: _ModuleT, visitor: CSTVisitorT) -> _ModuleT:
|
||||
"""
|
||||
Returns a copy of the module that contains all metadata dependencies
|
||||
declared by the visitor.
|
||||
"""
|
||||
|
||||
def run(module: cst.Module, root: _MetadataInterface) -> None:
|
||||
"""
|
||||
Called by Module.visit to resolve metadata dependencies before performing
|
||||
a visitor pass.
|
||||
"""
|
||||
if len(visitor.METADATA_DEPENDENCIES) == 0:
|
||||
return module
|
||||
|
||||
runner = _MetadataRunner()
|
||||
for dep in root.METADATA_DEPENDENCIES:
|
||||
runner.gather_providers(dep)
|
||||
# We need to deep clone to ensure that there are no duplicate nodes
|
||||
module = module.deep_clone()
|
||||
|
||||
while len(runner.providers) > 0:
|
||||
completed = set()
|
||||
batchable = set()
|
||||
runner = _MetadataRunner()
|
||||
for dep in visitor.METADATA_DEPENDENCIES:
|
||||
runner.gather_providers(dep)
|
||||
|
||||
for P in runner.providers:
|
||||
if set(P.METADATA_DEPENDENCIES).issubset(runner.satisfied):
|
||||
if issubclass(P, BatchableMetadataProvider):
|
||||
batchable.add(P)
|
||||
else:
|
||||
P()._run(module)
|
||||
completed.add(P)
|
||||
while len(runner.providers) > 0:
|
||||
completed = set()
|
||||
batchable = set()
|
||||
|
||||
batched_visit(module, [p() for p in batchable])
|
||||
runner.providers -= completed | batchable
|
||||
runner.satisfied |= completed | batchable
|
||||
for P in runner.providers:
|
||||
if set(P.METADATA_DEPENDENCIES).issubset(runner.satisfied):
|
||||
if issubclass(P, BatchableMetadataProvider):
|
||||
batchable.add(P)
|
||||
else:
|
||||
module = P()._run(module)
|
||||
completed.add(P)
|
||||
|
||||
if len(completed) == 0 and len(batchable) == 0:
|
||||
# runner.providers must be non-empty at this point
|
||||
names = ", ".join([P.__name__ for P in runner.providers])
|
||||
raise MetadataException(f"Detected circular dependencies in {names}")
|
||||
module = _run_batchable(module, [p() for p in batchable])
|
||||
|
||||
runner.providers -= completed | batchable
|
||||
runner.satisfied |= completed | batchable
|
||||
|
||||
if len(completed) == 0 and len(batchable) == 0:
|
||||
# runner.providers must be non-empty at this point
|
||||
names = ", ".join([P.__name__ for P in runner.providers])
|
||||
raise MetadataException(f"Detected circular dependencies in {names}")
|
||||
|
||||
return module
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@
|
|||
from typing import cast
|
||||
|
||||
import libcst.nodes as cst
|
||||
from libcst.batched_visitor import visit
|
||||
from libcst.metadata.base_provider import (
|
||||
BatchableMetadataProvider,
|
||||
VisitorMetadataProvider,
|
||||
_run_batchable,
|
||||
)
|
||||
from libcst.parser import parse_module
|
||||
from libcst.testing.utils import UnitTest
|
||||
|
|
@ -47,17 +47,17 @@ class BaseMetadataProviderTest(UnitTest):
|
|||
def visit_Pass(self, node: cst.Pass) -> None:
|
||||
self.set_metadata(node, 1)
|
||||
|
||||
def leave_Return(self, node: cst.Return) -> None:
|
||||
def visit_Return(self, node: cst.Return) -> None:
|
||||
self.set_metadata(node, 2)
|
||||
|
||||
module = parse_module("pass; return; pass")
|
||||
provider = SimpleProvider()
|
||||
|
||||
module = _run_batchable(module, [provider])
|
||||
pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
|
||||
return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]
|
||||
pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2]
|
||||
|
||||
provider = SimpleProvider()
|
||||
visit(module, [provider])
|
||||
|
||||
self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1)
|
||||
self.assertEqual(provider.get_metadata(SimpleProvider, return_), 2)
|
||||
self.assertEqual(provider.get_metadata(SimpleProvider, pass_2), 1)
|
||||
|
|
|
|||
33
libcst/metadata/tests/test_position_provider.py
Normal file
33
libcst/metadata/tests/test_position_provider.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-strict
|
||||
import libcst.nodes as cst
|
||||
from libcst.metadata.position_provider import SyntacticPositionProvider
|
||||
from libcst.nodes._internal import CodeRange
|
||||
from libcst.parser import parse_module
|
||||
from libcst.testing.utils import UnitTest
|
||||
from libcst.visitors import CSTTransformer
|
||||
|
||||
|
||||
class PositionProviderTest(UnitTest):
|
||||
def test_visitor_provider(self) -> None:
|
||||
"""
|
||||
Sets 2 metadata entries for every node:
|
||||
SimpleProvider -> 1
|
||||
DependentProvider - > 2
|
||||
"""
|
||||
|
||||
test = self
|
||||
|
||||
class DependentVisitor(CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (SyntacticPositionProvider,)
|
||||
|
||||
def visit_Pass(self, node: cst.Pass) -> None:
|
||||
range = self.get_metadata(SyntacticPositionProvider, node)
|
||||
test.assertEqual(range, CodeRange.create((1, 0), (1, 4)))
|
||||
|
||||
module = parse_module("pass")
|
||||
module.visit(DependentVisitor())
|
||||
|
|
@ -22,7 +22,7 @@ from libcst._removal_sentinel import RemovalSentinel
|
|||
from libcst._type_enforce import is_value_of_type
|
||||
from libcst.exceptions import MetadataException
|
||||
from libcst.nodes._internal import CodegenState, CodePosition, CodeRange
|
||||
from libcst.visitors import CSTVisitor, CSTVisitorT
|
||||
from libcst.visitors import CSTTransformer, CSTVisitor, CSTVisitorT
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -53,6 +53,10 @@ class _ChildrenCollectionVisitor(CSTVisitor):
|
|||
return False # Don't include transitive children
|
||||
|
||||
|
||||
class _NOOPVisitor(CSTTransformer):
|
||||
pass
|
||||
|
||||
|
||||
def _pretty_repr(value: object) -> str:
|
||||
if not isinstance(value, str) and isinstance(value, Sequence):
|
||||
return _pretty_repr_sequence(value)
|
||||
|
|
@ -180,15 +184,16 @@ class CSTNode(ABC):
|
|||
self: _CSTNodeSelfT, visitor: CSTVisitorT
|
||||
) -> Union[_CSTNodeSelfT, RemovalSentinel]:
|
||||
"""
|
||||
Main entry point for visitors.
|
||||
|
||||
This wraps [_visit_impl] to validate metadata dependencies prior to
|
||||
performing a visitor pass.
|
||||
Public hook to visit the current node and all transitive children using
|
||||
the given visitor.
|
||||
"""
|
||||
|
||||
# Only modules can be visited a visitor that declare metadata dependencies.
|
||||
# Module overrides this method to resolve metadata dependencies.
|
||||
if len(visitor.METADATA_DEPENDENCIES) > 0:
|
||||
raise MetadataException(
|
||||
f"{type(visitor).__name__} declares metadata dependencies and should only be called from the module level"
|
||||
f"{type(visitor).__name__} declares metadata dependencies "
|
||||
+ "and should only be called from the module level"
|
||||
)
|
||||
|
||||
return self._visit_impl(visitor)
|
||||
|
|
@ -197,8 +202,8 @@ class CSTNode(ABC):
|
|||
self: _CSTNodeSelfT, visitor: CSTVisitorT
|
||||
) -> Union[_CSTNodeSelfT, RemovalSentinel]:
|
||||
"""
|
||||
Visits the current node, its children, and all transitive children using the
|
||||
given CSTVisitor's callbacks.
|
||||
Visits the current node, its children, and all transitive children using
|
||||
the given visitor's callbacks.
|
||||
"""
|
||||
# visit self
|
||||
should_visit_children = visitor.on_visit(self)
|
||||
|
|
@ -311,6 +316,9 @@ class CSTNode(ABC):
|
|||
"""
|
||||
return replace(self, **changes)
|
||||
|
||||
def deep_clone(self: _CSTNodeSelfT) -> _CSTNodeSelfT:
|
||||
return cast(_CSTNodeSelfT, self._visit_impl(_NOOPVisitor()))
|
||||
|
||||
def deep_equals(self: _CSTNodeSelfT, other: _CSTNodeSelfT) -> bool:
|
||||
"""
|
||||
Recursively inspects the entire tree under `self` and `other` to determine if
|
||||
|
|
|
|||
|
|
@ -68,15 +68,16 @@ class Module(CSTNode):
|
|||
|
||||
def visit(self: _ModuleSelfT, visitor: CSTVisitorT) -> _ModuleSelfT:
|
||||
"""
|
||||
Returns the result of running a visitor over this module.
|
||||
|
||||
Module overrides the default visitor entry point to resolve metadata
|
||||
dependencies for the visitor.
|
||||
dependencies declared by [visitor].
|
||||
"""
|
||||
|
||||
from libcst.metadata.runner import run
|
||||
from libcst.metadata.runner import _MetadataRunner
|
||||
|
||||
run(self, visitor)
|
||||
|
||||
result = CSTNode._visit_impl(self, visitor)
|
||||
module = _MetadataRunner.resolve(self, visitor)
|
||||
result = CSTNode._visit_impl(module, visitor)
|
||||
if isinstance(result, RemovalSentinel):
|
||||
return self.with_changes(body=(), header=(), footer=())
|
||||
else: # is a Module
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue