From dbbe3d192741b3424002ef02c910cc70fa1cdb71 Mon Sep 17 00:00:00 2001 From: Ray Zeng Date: Wed, 3 Jul 2019 17:40:51 -0700 Subject: [PATCH] Add metadata dependencies to CSTVisitor and add metadata runner Add metadata access and dependency logic in `CSTVisitor` and `Module.visit` to generate all metadata dependencies before performing a visit pass over a tree and validate access to metadata. --- libcst/_base_visitor.py | 26 ++++- libcst/exceptions.py | 4 + libcst/metadata/base_provider.py | 25 ++--- libcst/metadata/position_provider.py | 2 +- libcst/metadata/runner.py | 32 ++++++ libcst/metadata/tests/test_base_provider.py | 15 +-- libcst/metadata/tests/test_runner.py | 113 ++++++++++++++++++++ libcst/nodes/_base.py | 4 +- libcst/nodes/_internal.py | 6 +- libcst/nodes/_module.py | 14 ++- libcst/nodes/tests/base.py | 2 +- libcst/nodes/tests/test_internal.py | 5 +- libcst/nodes/tests/test_module.py | 4 +- 13 files changed, 205 insertions(+), 47 deletions(-) create mode 100644 libcst/metadata/runner.py create mode 100644 libcst/metadata/tests/test_runner.py diff --git a/libcst/_base_visitor.py b/libcst/_base_visitor.py index 48a94a90..f691b76b 100644 --- a/libcst/_base_visitor.py +++ b/libcst/_base_visitor.py @@ -5,7 +5,7 @@ # pyre-strict from abc import ABC -from typing import TYPE_CHECKING, Type, TypeVar, Union +from typing import TYPE_CHECKING, ClassVar, Sequence, Type, TypeVar, Union from libcst._removal_sentinel import RemovalSentinel @@ -23,6 +23,8 @@ __all__ = ["CSTVisitor", "RemovalSentinel"] CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode") _T = TypeVar("_T") +_UNDEFINED_DEFAULT = object() + class CSTVisitor(ABC): """ @@ -32,6 +34,8 @@ class CSTVisitor(ABC): subclass. """ + METADATA_DEPENDENCIES: ClassVar[Sequence[Type["BaseMetadataProvider[object]"]]] = () + def on_visit(self, node: "CSTNode") -> bool: """ Called every time a node is visited, before we've visited its children. @@ -52,12 +56,26 @@ class CSTVisitor(ABC): return updated_node @classmethod - def get_metadata(cls, key: Type["BaseMetadataProvider[_T]"], node: CSTNodeT) -> _T: + def get_metadata( + cls, + key: Type["BaseMetadataProvider[_T]"], + node: CSTNodeT, + default: _T = _UNDEFINED_DEFAULT, + ) -> _T: """ Gets metadata provided by the [key] provider if it is accessible from this vistor. Metadata is accessible if [key] is the same as [cls] or if [key] is in METADATA_DEPENDENCIES. """ - # TODO: runtime checks that metadata is available in this visitor + if key not in cls.METADATA_DEPENDENCIES and key is not cls: + raise KeyError( + f"{key.__name__} is not declared as a dependency from {cls.__name__}" + ) - return node.__metadata__[key] + try: + return node._metadata[key] + except KeyError as err: + if default is not _UNDEFINED_DEFAULT: + return default + else: + raise err diff --git a/libcst/exceptions.py b/libcst/exceptions.py index 9918fa38..0bdbad6f 100644 --- a/libcst/exceptions.py +++ b/libcst/exceptions.py @@ -78,3 +78,7 @@ class ParserSyntaxError(Exception): # Text editors use a one-indexed column, so we need to add one to our # zero-indexed column to get a human-readable result. return tab_adjusted_column + 1 + + +class MetadataException(Exception): + pass diff --git a/libcst/metadata/base_provider.py b/libcst/metadata/base_provider.py index d06fd2a9..266ef9e1 100644 --- a/libcst/metadata/base_provider.py +++ b/libcst/metadata/base_provider.py @@ -10,31 +10,22 @@ import libcst.nodes as cst from libcst._base_visitor import CSTVisitor -_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) -class BaseMetadataProvider(CSTVisitor, Generic[_T]): +class BaseMetadataProvider(CSTVisitor, Generic[_T_co]): """ - Base class for metadata providers to subclass off of. + Base class for visitor-based metadata providers. """ - def leave_Module( - self, orig_node: cst.Module, updated_node: cst.Module - ) -> cst.Module: - # TODO: We may need to change the behavior of CSTVisitor or create a - # subclass to make sure MetaDataProviders don't mutate the tree. - - # A metadata provider can't modify the tree. - # This ensures no copying is done (which would erase metadata). - assert orig_node.deep_equals(updated_node) - return orig_node - - def generate(self, module: cst.Module) -> None: + def run(self, module: cst.Module) -> None: """ Convenience method to run metadata provider over a module. """ module.visit(self) @classmethod - def set_metadata(cls, node: cst.CSTNode, value: _T) -> None: - node.__metadata__[cls] = value + # pyre-ignore[35]: Parameter type cannot be covariant. Pyre can't + # detect that this method is not mutating the Provider class. + def set_metadata(cls, node: cst.CSTNode, value: _T_co) -> None: + node._metadata[cls] = value diff --git a/libcst/metadata/position_provider.py b/libcst/metadata/position_provider.py index dfefc06f..96b74b96 100644 --- a/libcst/metadata/position_provider.py +++ b/libcst/metadata/position_provider.py @@ -16,7 +16,7 @@ class BasicPositionProvider(BaseMetadataProvider[CodeRange]): owned by that node. """ - def generate(self, module: cst.Module) -> None: + def run(self, module: cst.Module) -> None: """ Override default generate behavior as position information is calculated through codegen instead of a standard visitor. diff --git a/libcst/metadata/runner.py b/libcst/metadata/runner.py new file mode 100644 index 00000000..3907190d --- /dev/null +++ b/libcst/metadata/runner.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import TYPE_CHECKING + +from libcst.exceptions import MetadataException + + +if TYPE_CHECKING: + from libcst._base_visitor import CSTVisitor + from libcst.nodes import Module + + +def run(module: "Module", visitor: "CSTVisitor") -> None: + """ + Called by Module.visit to generate metadata dependencies before performing + a visitor pass. + """ + for Provider in visitor.METADATA_DEPENDENCIES: + if Provider in module._remaining_dependencies: + raise MetadataException( + f"Detected circular dependency between {type(visitor).__name__} and {Provider.__name__}" + ) + + if Provider not in module._satisfied_dependencies: + module._remaining_dependencies.add(Provider) + Provider().run(module) + module._satisfied_dependencies.add(Provider) + module._remaining_dependencies.remove(Provider) diff --git a/libcst/metadata/tests/test_base_provider.py b/libcst/metadata/tests/test_base_provider.py index edccd986..3c1d77e3 100644 --- a/libcst/metadata/tests/test_base_provider.py +++ b/libcst/metadata/tests/test_base_provider.py @@ -4,32 +4,23 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import TYPE_CHECKING, TypeVar - from libcst.metadata.base_provider import BaseMetadataProvider +from libcst.nodes._base import CSTNode from libcst.parser import parse_module from libcst.testing.utils import UnitTest -if TYPE_CHECKING: - # Circular dependency for typing reasons only - from libcst.nodes._base import CSTNode - - -CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode") - - class BaseMetadataProviderTest(UnitTest): def test_simple_provider(self) -> None: class SimpleProvider(BaseMetadataProvider[int]): - def on_visit(self, node: "CSTNode") -> bool: + def on_visit(self, node: CSTNode) -> bool: self.set_metadata(node, 1) return True module = parse_module("pass") pass_node = module.body[0] provider = SimpleProvider() - provider.generate(module) + provider.run(module) self.assertEqual(provider.get_metadata(SimpleProvider, module), 1) self.assertEqual(provider.get_metadata(SimpleProvider, pass_node), 1) diff --git a/libcst/metadata/tests/test_runner.py b/libcst/metadata/tests/test_runner.py new file mode 100644 index 00000000..9f35def1 --- /dev/null +++ b/libcst/metadata/tests/test_runner.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Union + +from libcst._base_visitor import CSTVisitor +from libcst._removal_sentinel import RemovalSentinel +from libcst.exceptions import MetadataException +from libcst.metadata.base_provider import BaseMetadataProvider +from libcst.nodes import CSTNode, Module +from libcst.parser import parse_module +from libcst.testing.utils import UnitTest + + +class MetadataRunnerTest(UnitTest): + def test_visitor_with_dependencies(self) -> None: + class SimpleProvider(BaseMetadataProvider[int]): + def on_visit(self, node: CSTNode) -> bool: + self.set_metadata(node, 1) + return True + + class DependentProvider(BaseMetadataProvider[int]): + METADATA_DEPENDENCIES = (SimpleProvider,) + + def on_visit(self, node: CSTNode) -> bool: + self.set_metadata(node, self.get_metadata(SimpleProvider, node) + 1) + return True + + class DependentVisitor(CSTVisitor): + METADATA_DEPENDENCIES = (DependentProvider, SimpleProvider) + + module = parse_module("pass") + pass_node = module.body[0] + visitor = DependentVisitor() + module.visit(visitor) + + self.assertEqual(visitor.get_metadata(SimpleProvider, module), 1) + self.assertEqual(visitor.get_metadata(DependentProvider, module), 2) + self.assertEqual(visitor.get_metadata(SimpleProvider, pass_node), 1) + self.assertEqual(visitor.get_metadata(DependentProvider, pass_node), 2) + + def test_provider_with_circular_dependency(self) -> None: + class ProviderA(BaseMetadataProvider[str]): + pass + + ProviderA.METADATA_DEPENDENCIES = (ProviderA,) + + class BadVisitor(CSTVisitor): + METADATA_DEPENDENCIES = (ProviderA,) + + with self.assertRaisesRegex( + MetadataException, + "Detected circular dependency between ProviderA and ProviderA", + ): + Module([]).visit(BadVisitor()) + + def test_self_access_metadata(self) -> None: + test_runner = self + + class ProviderA(BaseMetadataProvider[bool]): + def on_visit(self, node: CSTNode) -> bool: + self.set_metadata(node, True) + return True + + def on_leave( + self, original_node: CSTNode, updated_node: CSTNode + ) -> Union[CSTNode, RemovalSentinel]: + test_runner.assertEqual( + self.get_metadata(type(self), original_node), True + ) + return original_node + + class AVisitor(CSTVisitor): + METADATA_DEPENDENCIES = (ProviderA,) + + Module([]).visit(AVisitor()) + + def test_access_unset_metadata(self) -> None: + class ProviderA(BaseMetadataProvider[bool]): + pass + + class AVisitor(CSTVisitor): + METADATA_DEPENDENCIES = (ProviderA,) + + def on_visit(self, node: CSTNode) -> bool: + self.get_metadata(ProviderA, node) + return True + + with self.assertRaises(KeyError): + Module([]).visit(AVisitor()) + + def test_access_invalid_metadata(self) -> None: + class ProviderA(BaseMetadataProvider[bool]): + pass + + class ProviderB(BaseMetadataProvider[bool]): + pass + + class AVisitor(CSTVisitor): + METADATA_DEPENDENCIES = (ProviderA,) + + def on_visit(self, node: CSTNode) -> bool: + self.get_metadata(ProviderA, node, True) + self.get_metadata(ProviderB, node) + return True + + with self.assertRaisesRegex( + KeyError, "ProviderB is not declared as a dependency from AVisitor" + ): + Module([]).visit(AVisitor()) diff --git a/libcst/nodes/_base.py b/libcst/nodes/_base.py index 537c29b0..938b2973 100644 --- a/libcst/nodes/_base.py +++ b/libcst/nodes/_base.py @@ -72,7 +72,7 @@ def _indent(value: str) -> str: @dataclass(frozen=True) class CSTNode(ABC): - __metadata__: MutableMapping[Type["BaseMetaDataProvider[_T]"], _T] = field( + _metadata: MutableMapping[Type["BaseMetaDataProvider[_T]"], _T] = field( default_factory=dict, init=False, repr=False, compare=False ) @@ -130,7 +130,7 @@ class CSTNode(ABC): does. """ for f in fields(self): - if f.name == "__metadata__": # skip typechecking metadata field + if f.name == "_metadata": # skip typechecking metadata field continue value = getattr(self, f.name) diff --git a/libcst/nodes/_internal.py b/libcst/nodes/_internal.py index fb9e70fb..55f6cca5 100644 --- a/libcst/nodes/_internal.py +++ b/libcst/nodes/_internal.py @@ -114,8 +114,8 @@ class CodegenState: def record_position(self, node: _CSTNodeT, position: CodeRange) -> None: # Don't overwrite existing position information # (i.e. semantic position has already been recorded) - if self.provider not in node.__metadata__: - node.__metadata__[self.provider] = position + if self.provider not in node._metadata: + node._metadata[self.provider] = position @contextmanager def record_syntactic_position(self, node: _CSTNodeT) -> Iterator[None]: @@ -139,7 +139,7 @@ class SyntacticCodegenState(CodegenState): yield finally: end = CodePosition(self.line, self.column) - node.__metadata__[self.provider] = CodeRange(start, end) + node._metadata[self.provider] = CodeRange(start, end) def visit_required(fieldname: str, node: _CSTNodeT, visitor: "CSTVisitor") -> _CSTNodeT: diff --git a/libcst/nodes/_module.py b/libcst/nodes/_module.py index d3e4c75b..ea58c85e 100644 --- a/libcst/nodes/_module.py +++ b/libcst/nodes/_module.py @@ -3,12 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, MutableSet, Optional, Sequence, Type, TypeVar, Union from libcst._add_slots import add_slots from libcst._base_visitor import CSTVisitor from libcst._removal_sentinel import RemovalSentinel +from libcst.metadata.runner import run as compute_metadata from libcst.nodes._base import CSTNode from libcst.nodes._internal import CodegenState, SyntacticCodegenState, visit_sequence from libcst.nodes._statement import BaseCompoundStatement, SimpleStatementLine @@ -50,6 +51,13 @@ class Module(CSTNode): default_newline: str = "\n" has_trailing_newline: bool = True + _satisfied_dependencies: MutableSet["BaseMetadataProvider[Any]"] = field( + default_factory=set, init=False, repr=False, compare=False + ) + _remaining_dependencies: MutableSet["BaseMetadataProvider[Any]"] = field( + default_factory=set, init=False, repr=False, compare=False + ) + def _visit_and_replace_children(self, visitor: CSTVisitor) -> "Module": return Module( header=visit_sequence("header", self.header, visitor), @@ -62,6 +70,8 @@ class Module(CSTNode): ) def visit(self: _ModuleSelfT, visitor: CSTVisitor) -> _ModuleSelfT: + compute_metadata(self, visitor) + result = CSTNode.visit(self, visitor) if isinstance(result, RemovalSentinel): return self.with_changes(body=(), header=(), footer=()) diff --git a/libcst/nodes/tests/base.py b/libcst/nodes/tests/base.py index 71f6cb49..5aa04339 100644 --- a/libcst/nodes/tests/base.py +++ b/libcst/nodes/tests/base.py @@ -96,7 +96,7 @@ class CSTNodeTest(UnitTest): self.assertEqual(module.code_for_node(node), expected) if expected_position is not None: self.assertEqual( - node.__metadata__[SyntacticPositionProvider], expected_position + node._metadata[SyntacticPositionProvider], expected_position ) def __assert_children_match_codegen(self, node: cst.CSTNode) -> None: diff --git a/libcst/nodes/tests/test_internal.py b/libcst/nodes/tests/test_internal.py index 7dc6e8c3..c25a4bd2 100644 --- a/libcst/nodes/tests/test_internal.py +++ b/libcst/nodes/tests/test_internal.py @@ -76,7 +76,7 @@ class InternalTest(UnitTest): # check syntactic whitespace is correctly recorded self.assertEqual( - node.__metadata__[BasicPositionProvider], CodeRange.create((1, 0), (1, 6)) + node._metadata[BasicPositionProvider], CodeRange.create((1, 0), (1, 6)) ) def test_semantic_position(self) -> None: @@ -96,6 +96,5 @@ class InternalTest(UnitTest): # check semantic whitespace is correctly recorded (ignoring whitespace) self.assertEqual( - node.__metadata__[SyntacticPositionProvider], - CodeRange.create((1, 1), (1, 5)), + node._metadata[SyntacticPositionProvider], CodeRange.create((1, 1), (1, 5)) ) diff --git a/libcst/nodes/tests/test_module.py b/libcst/nodes/tests/test_module.py index 430d9045..e0fbfac5 100644 --- a/libcst/nodes/tests/test_module.py +++ b/libcst/nodes/tests/test_module.py @@ -142,13 +142,13 @@ class ModuleTest(CSTNodeTest): module = parse_module(code) module.code - self.assertEqual(module.__metadata__[SyntacticPositionProvider], expected) + self.assertEqual(module._metadata[SyntacticPositionProvider], expected) def cmp_position( self, node: cst.CSTNode, start: Tuple[int, int], end: Tuple[int, int] ) -> None: self.assertEqual( - node.__metadata__[SyntacticPositionProvider], CodeRange.create(start, end) + node._metadata[SyntacticPositionProvider], CodeRange.create(start, end) ) def test_function_position(self) -> None: