mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
[metadata] add cache field to metadata wrapper
This commit is contained in:
parent
e7315d2c28
commit
21e237c78b
3 changed files with 63 additions and 6 deletions
|
|
@ -44,9 +44,17 @@ class BaseMetadataProvider(MetadataDependent, Generic[_ProvidedMetadataT]):
|
|||
# explanation.
|
||||
_computed: MutableMapping["CSTNode", _ProvidedMetadataT]
|
||||
|
||||
def __init__(self) -> None:
|
||||
is_cache_required: bool = False
|
||||
|
||||
def __init__(self, cache: object = None) -> None:
|
||||
super().__init__()
|
||||
self._computed = {}
|
||||
if self.is_cache_required and not cache:
|
||||
# The metadata provider implementation is responsible to store and use cache.
|
||||
raise Exception(
|
||||
f"Cache is required for initializing {self.__class__.__name__}."
|
||||
)
|
||||
self.cache = cache
|
||||
|
||||
def _gen(
|
||||
self, wrapper: "MetadataWrapper"
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@
|
|||
|
||||
# pyre-strict
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import libcst as cst
|
||||
from libcst.metadata import MetadataWrapper
|
||||
from libcst.metadata import BatchableMetadataProvider, MetadataWrapper
|
||||
from libcst.testing.utils import UnitTest
|
||||
|
||||
|
||||
|
|
@ -41,3 +43,34 @@ class MetadataWrapperTest(UnitTest):
|
|||
self.assertNotEqual(hash(mw1), hash(mw2))
|
||||
self.assertNotEqual(hash(mw1), hash(mw3))
|
||||
self.assertNotEqual(hash(mw2), hash(mw3))
|
||||
|
||||
def test_metadata_cache(self) -> None:
|
||||
class DummyMetadataProvider(BatchableMetadataProvider[None]):
|
||||
is_cache_required = True
|
||||
|
||||
m = cst.parse_module("pass")
|
||||
mw = MetadataWrapper(m)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "Cache is required for initializing DummyMetadataProvider."
|
||||
):
|
||||
mw.resolve(DummyMetadataProvider)
|
||||
|
||||
class SimpleCacheMetadataProvider(BatchableMetadataProvider[object]):
|
||||
is_cache_required = True
|
||||
|
||||
def __init__(self, cache: object) -> None:
|
||||
super().__init__(cache)
|
||||
self.cache = cache
|
||||
|
||||
def visit_Pass(self, node: cst.Pass) -> Optional[bool]:
|
||||
self.set_metadata(node, self.cache)
|
||||
|
||||
cached_data = object()
|
||||
# pyre-fixme[6]: Expected `Mapping[Type[BaseMetadataProvider[object]],
|
||||
# object]` for 2nd param but got `Dict[Type[SimpleCacheMetadataProvider],
|
||||
# object]`.
|
||||
mw = MetadataWrapper(m, cache={SimpleCacheMetadataProvider: cached_data})
|
||||
pass_node = cst.ensure_type(mw.module.body[0], cst.SimpleStatementLine).body[0]
|
||||
self.assertEqual(
|
||||
mw.resolve(SimpleCacheMetadataProvider)[pass_node], cached_data
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
# pyre-strict
|
||||
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from contextlib import ExitStack
|
||||
from types import MappingProxyType
|
||||
from typing import (
|
||||
|
|
@ -87,10 +88,17 @@ def _resolve_impl(
|
|||
if issubclass(P, BatchableMetadataProvider):
|
||||
batchable.add(P)
|
||||
else:
|
||||
wrapper._metadata[P] = P()._gen(wrapper)
|
||||
wrapper._metadata[P] = (
|
||||
P(wrapper._cache[P])._gen(wrapper)
|
||||
if P.is_cache_required
|
||||
else P()._gen(wrapper)
|
||||
)
|
||||
completed.add(P)
|
||||
|
||||
metadata_batch = _gen_batchable(wrapper, [p() for p in batchable])
|
||||
initialized_batchable = [
|
||||
p(wrapper._cache[p]) if p.is_cache_required else p() for p in batchable
|
||||
]
|
||||
metadata_batch = _gen_batchable(wrapper, initialized_batchable)
|
||||
wrapper._metadata.update(metadata_batch)
|
||||
completed |= batchable
|
||||
|
||||
|
|
@ -116,18 +124,25 @@ class MetadataWrapper:
|
|||
node's identity.
|
||||
"""
|
||||
|
||||
__slots__ = ["__module", "_metadata"]
|
||||
__slots__ = ["__module", "_metadata", "_cache"]
|
||||
|
||||
__module: "Module"
|
||||
_metadata: MutableMapping["ProviderT", Mapping["CSTNode", object]]
|
||||
_cache: Mapping["ProviderT", object]
|
||||
|
||||
def __init__(self, module: "Module", unsafe_skip_copy: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
module: "Module",
|
||||
unsafe_skip_copy: bool = False,
|
||||
cache: Mapping["ProviderT", object] = defaultdict(dict),
|
||||
) -> None:
|
||||
"""
|
||||
:param module: The module to wrap. This is deeply copied by default.
|
||||
:param unsafe_skip_copy: When true, this skips the deep cloning of the module.
|
||||
This can provide a small performance benefit, but you should only use this
|
||||
if you know that there are no duplicate nodes in your tree (e.g. this
|
||||
module came from the parser).
|
||||
:param cache: Pass the needed cache to wrapper to be used when resolving metadata.
|
||||
"""
|
||||
# Ensure that module is safe to use by copying the module to remove
|
||||
# any duplicate nodes.
|
||||
|
|
@ -135,6 +150,7 @@ class MetadataWrapper:
|
|||
module = module.deep_clone()
|
||||
self.__module = module
|
||||
self._metadata = {}
|
||||
self._cache = cache
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MetadataWrapper(\n{textwrap.indent(repr(self.module), ' ' * 4)},\n)"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue