[metadata] add cache field to metadata wrapper

This commit is contained in:
Jimmy Lai 2019-11-22 09:54:26 -08:00
parent e7315d2c28
commit 21e237c78b
3 changed files with 63 additions and 6 deletions

View file

@ -44,9 +44,17 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]):
# explanation.
_computed: MutableMapping["CSTNode", _ProvidedMetadataT]
def __init__(self) -> None:
is_cache_required: bool = False
def __init__(self, cache: object = None) -> None:
super().__init__()
self._computed = {}
if self.is_cache_required and not cache:
# The metadata provider implementation is responsible to store and use cache.
raise Exception(
f"Cache is required for initializing {self.__class__.__name__}."
)
self.cache = cache
def _gen(
self, wrapper: "MetadataWrapper"

View file

@ -5,8 +5,10 @@
# pyre-strict
from typing import Optional
import libcst as cst
from libcst.metadata import MetadataWrapper
from libcst.metadata import BatchableMetadataProvider, MetadataWrapper
from libcst.testing.utils import UnitTest
@ -41,3 +43,34 @@ class MetadataWrapperTest(UnitTest):
self.assertNotEqual(hash(mw1), hash(mw2))
self.assertNotEqual(hash(mw1), hash(mw3))
self.assertNotEqual(hash(mw2), hash(mw3))
def test_metadata_cache(self) -> None:
class DummyMetadataProvider(BatchableMetadataProvider[None]):
is_cache_required = True
m = cst.parse_module("pass")
mw = MetadataWrapper(m)
with self.assertRaisesRegex(
Exception, "Cache is required for initializing DummyMetadataProvider."
):
mw.resolve(DummyMetadataProvider)
class SimpleCacheMetadataProvider(BatchableMetadataProvider[object]):
is_cache_required = True
def __init__(self, cache: object) -> None:
super().__init__(cache)
self.cache = cache
def visit_Pass(self, node: cst.Pass) -> Optional[bool]:
self.set_metadata(node, self.cache)
cached_data = object()
# pyre-fixme[6]: Expected `Mapping[Type[BaseMetadataProvider[object]],
# object]` for 2nd param but got `Dict[Type[SimpleCacheMetadataProvider],
# object]`.
mw = MetadataWrapper(m, cache={SimpleCacheMetadataProvider: cached_data})
pass_node = cst.ensure_type(mw.module.body[0], cst.SimpleStatementLine).body[0]
self.assertEqual(
mw.resolve(SimpleCacheMetadataProvider)[pass_node], cached_data
)

View file

@ -6,6 +6,7 @@
# pyre-strict
import textwrap
from collections import defaultdict
from contextlib import ExitStack
from types import MappingProxyType
from typing import (
@ -87,10 +88,17 @@ def _resolve_impl(
if issubclass(P, BatchableMetadataProvider):
batchable.add(P)
else:
wrapper._metadata[P] = P()._gen(wrapper)
wrapper._metadata[P] = (
P(wrapper._cache[P])._gen(wrapper)
if P.is_cache_required
else P()._gen(wrapper)
)
completed.add(P)
metadata_batch = _gen_batchable(wrapper, [p() for p in batchable])
initialized_batchable = [
p(wrapper._cache[p]) if p.is_cache_required else p() for p in batchable
]
metadata_batch = _gen_batchable(wrapper, initialized_batchable)
wrapper._metadata.update(metadata_batch)
completed |= batchable
@ -116,18 +124,25 @@ class MetadataWrapper:
node's identity.
"""
__slots__ = ["__module", "_metadata"]
__slots__ = ["__module", "_metadata", "_cache"]
__module: "Module"
_metadata: MutableMapping["ProviderT", Mapping["CSTNode", object]]
_cache: Mapping["ProviderT", object]
def __init__(self, module: "Module", unsafe_skip_copy: bool = False) -> None:
def __init__(
self,
module: "Module",
unsafe_skip_copy: bool = False,
cache: Mapping["ProviderT", object] = defaultdict(dict),
) -> None:
"""
:param module: The module to wrap. This is deeply copied by default.
:param unsafe_skip_copy: When true, this skips the deep cloning of the module.
This can provide a small performance benefit, but you should only use this
if you know that there are no duplicate nodes in your tree (e.g. this
module came from the parser).
:param cache: Pass the needed cache to wrapper to be used when resolving metadata.
"""
# Ensure that module is safe to use by copying the module to remove
# any duplicate nodes.
@ -135,6 +150,7 @@ class MetadataWrapper:
module = module.deep_clone()
self.__module = module
self._metadata = {}
self._cache = cache
def __repr__(self) -> str:
return f"MetadataWrapper(\n{textwrap.indent(repr(self.module), ' ' * 4)},\n)"