From 21e237c78b57e69bcdde5d6aa7fae980cf64fbb9 Mon Sep 17 00:00:00 2001 From: Jimmy Lai Date: Fri, 22 Nov 2019 09:54:26 -0800 Subject: [PATCH] [metadata] add cache field to metadata wrapper --- libcst/metadata/base_provider.py | 10 +++++- .../metadata/tests/test_metadata_wrapper.py | 35 ++++++++++++++++++- libcst/metadata/wrapper.py | 24 ++++++++++--- 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/libcst/metadata/base_provider.py b/libcst/metadata/base_provider.py index 3cff524e..4019a417 100644 --- a/libcst/metadata/base_provider.py +++ b/libcst/metadata/base_provider.py @@ -44,9 +44,17 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]): # explanation. _computed: MutableMapping["CSTNode", _ProvidedMetadataT] - def __init__(self) -> None: + is_cache_required: bool = False + + def __init__(self, cache: object = None) -> None: super().__init__() self._computed = {} + if self.is_cache_required and not cache: + # The metadata provider implementation is responsible to store and use cache. + raise Exception( + f"Cache is required for initializing {self.__class__.__name__}." + ) + self.cache = cache def _gen( self, wrapper: "MetadataWrapper" diff --git a/libcst/metadata/tests/test_metadata_wrapper.py b/libcst/metadata/tests/test_metadata_wrapper.py index fd0df591..25079416 100644 --- a/libcst/metadata/tests/test_metadata_wrapper.py +++ b/libcst/metadata/tests/test_metadata_wrapper.py @@ -5,8 +5,10 @@ # pyre-strict +from typing import Optional + import libcst as cst -from libcst.metadata import MetadataWrapper +from libcst.metadata import BatchableMetadataProvider, MetadataWrapper from libcst.testing.utils import UnitTest @@ -41,3 +43,34 @@ class MetadataWrapperTest(UnitTest): self.assertNotEqual(hash(mw1), hash(mw2)) self.assertNotEqual(hash(mw1), hash(mw3)) self.assertNotEqual(hash(mw2), hash(mw3)) + + def test_metadata_cache(self) -> None: + class DummyMetadataProvider(BatchableMetadataProvider[None]): + is_cache_required = True + + m = cst.parse_module("pass") + mw = MetadataWrapper(m) + with self.assertRaisesRegex( + Exception, "Cache is required for initializing DummyMetadataProvider." + ): + mw.resolve(DummyMetadataProvider) + + class SimpleCacheMetadataProvider(BatchableMetadataProvider[object]): + is_cache_required = True + + def __init__(self, cache: object) -> None: + super().__init__(cache) + self.cache = cache + + def visit_Pass(self, node: cst.Pass) -> Optional[bool]: + self.set_metadata(node, self.cache) + + cached_data = object() + # pyre-fixme[6]: Expected `Mapping[Type[BaseMetadataProvider[object]], + # object]` for 2nd param but got `Dict[Type[SimpleCacheMetadataProvider], + # object]`. + mw = MetadataWrapper(m, cache={SimpleCacheMetadataProvider: cached_data}) + pass_node = cst.ensure_type(mw.module.body[0], cst.SimpleStatementLine).body[0] + self.assertEqual( + mw.resolve(SimpleCacheMetadataProvider)[pass_node], cached_data + ) diff --git a/libcst/metadata/wrapper.py b/libcst/metadata/wrapper.py index bfaa2958..d0e52026 100644 --- a/libcst/metadata/wrapper.py +++ b/libcst/metadata/wrapper.py @@ -6,6 +6,7 @@ # pyre-strict import textwrap +from collections import defaultdict from contextlib import ExitStack from types import MappingProxyType from typing import ( @@ -87,10 +88,17 @@ def _resolve_impl( if issubclass(P, BatchableMetadataProvider): batchable.add(P) else: - wrapper._metadata[P] = P()._gen(wrapper) + wrapper._metadata[P] = ( + P(wrapper._cache[P])._gen(wrapper) + if P.is_cache_required + else P()._gen(wrapper) + ) completed.add(P) - metadata_batch = _gen_batchable(wrapper, [p() for p in batchable]) + initialized_batchable = [ + p(wrapper._cache[p]) if p.is_cache_required else p() for p in batchable + ] + metadata_batch = _gen_batchable(wrapper, initialized_batchable) wrapper._metadata.update(metadata_batch) completed |= batchable @@ -116,18 +124,25 @@ class MetadataWrapper: node's identity. """ - __slots__ = ["__module", "_metadata"] + __slots__ = ["__module", "_metadata", "_cache"] __module: "Module" _metadata: MutableMapping["ProviderT", Mapping["CSTNode", object]] + _cache: Mapping["ProviderT", object] - def __init__(self, module: "Module", unsafe_skip_copy: bool = False) -> None: + def __init__( + self, + module: "Module", + unsafe_skip_copy: bool = False, + cache: Mapping["ProviderT", object] = defaultdict(dict), + ) -> None: """ :param module: The module to wrap. This is deeply copied by default. :param unsafe_skip_copy: When true, this skips the deep cloning of the module. This can provide a small performance benefit, but you should only use this if you know that there are no duplicate nodes in your tree (e.g. this module came from the parser). + :param cache: Pass the needed cache to wrapper to be used when resolving metadata. """ # Ensure that module is safe to use by copying the module to remove # any duplicate nodes. @@ -135,6 +150,7 @@ class MetadataWrapper: module = module.deep_clone() self.__module = module self._metadata = {} + self._cache = cache def __repr__(self) -> str: return f"MetadataWrapper(\n{textwrap.indent(repr(self.module), ' ' * 4)},\n)"