diff --git a/libcst/matchers/_visitors.py b/libcst/matchers/_visitors.py index e4454eac..cc5cc25d 100644 --- a/libcst/matchers/_visitors.py +++ b/libcst/matchers/_visitors.py @@ -334,12 +334,16 @@ def _gather_constructed_leave_funcs( def _visit_matchers( - matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]], node: cst.CSTNode + matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]], + node: cst.CSTNode, + metadata_resolver: cst.MetadataDependent, ) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]: new_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {} for matcher, existing_node in matchers.items(): # We don't care about visiting matchers that are already true. - if existing_node is None and matches(node, matcher): + if existing_node is None and matches( + node, matcher, metadata_resolver=metadata_resolver + ): # This node matches! Remember which node it was so we can # cancel it later. new_matchers[matcher] = node @@ -397,9 +401,10 @@ def _visit_constructed_funcs( visit_funcs: Dict[BaseMatcherNode, Sequence[Callable[[cst.CSTNode], None]]], all_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]], node: cst.CSTNode, + metadata_resolver: cst.MetadataDependent, ) -> None: for matcher, visit_funcs in visit_funcs.items(): - if matches(node, matcher): + if matches(node, matcher, metadata_resolver=metadata_resolver): for visit_func in visit_funcs: if _should_allow_visit(all_matchers, visit_func): visit_func(node) @@ -455,10 +460,10 @@ class MatcherDecoratableTransformer(CSTTransformer): def on_visit(self, node: cst.CSTNode) -> bool: # First, evaluate any matchers that we have which we are not inside already. - self._matchers = _visit_matchers(self._matchers, node) + self._matchers = _visit_matchers(self._matchers, node, self) # Now, call any visitors that were hooked using a visit decorator. - _visit_constructed_funcs(self._extra_visit_funcs, self._matchers, node) + _visit_constructed_funcs(self._extra_visit_funcs, self._matchers, node, self) # Now, evaluate whether this current function has any matchers it requires. if not _should_allow_visit( @@ -485,7 +490,7 @@ class MatcherDecoratableTransformer(CSTTransformer): # Now, call any visitors that were hooked using a leave decorator. for matcher, leave_funcs in reversed(list(self._extra_leave_funcs.items())): - if not matches(original_node, matcher): + if not self.matches(original_node, matcher): continue for leave_func in leave_funcs: if _should_allow_visit(self._matchers, leave_func) and isinstance( @@ -522,6 +527,20 @@ class MatcherDecoratableTransformer(CSTTransformer): # matchers. In either case, just call the superclass behavior. CSTVisitor.on_leave_attribute(self, original_node, attribute) + def matches( + self, + node: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode], + matcher: BaseMatcherNode, + ) -> bool: + """ + A convenience method to call :func:`~libcst.matchers.matches` without requiring + an explicit parameter for metadata. Since our instance is an instance of + :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see + documentation for :func:`~libcst.matchers.matches` as it is identical to this + function. + """ + return matches(node, matcher, metadata_resolver=self) + def _transform_module_impl(self, tree: cst.Module) -> cst.Module: return tree.visit(self) @@ -571,10 +590,10 @@ class MatcherDecoratableVisitor(CSTVisitor): def on_visit(self, node: cst.CSTNode) -> bool: # First, evaluate any matchers that we have which we are not inside already. - self._matchers = _visit_matchers(self._matchers, node) + self._matchers = _visit_matchers(self._matchers, node, self) # Now, call any visitors that were hooked using a visit decorator. - _visit_constructed_funcs(self._extra_visit_funcs, self._matchers, node) + _visit_constructed_funcs(self._extra_visit_funcs, self._matchers, node, self) # Now, evaluate whether this current function has a decorator on it. if not _should_allow_visit( @@ -597,7 +616,7 @@ class MatcherDecoratableVisitor(CSTVisitor): # Now, call any visitors that were hooked using a leave decorator. for matcher, leave_funcs in reversed(list(self._extra_leave_funcs.items())): - if not matches(original_node, matcher): + if not self.matches(original_node, matcher): continue for leave_func in leave_funcs: if _should_allow_visit(self._matchers, leave_func): @@ -625,3 +644,17 @@ class MatcherDecoratableVisitor(CSTVisitor): # Either the visit_func doesn't exist, we have no matchers, or we passed all # matchers. In either case, just call the superclass behavior. CSTVisitor.on_leave_attribute(self, original_node, attribute) + + def matches( + self, + node: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode], + matcher: BaseMatcherNode, + ) -> bool: + """ + A convenience method to call :func:`~libcst.matchers.matches` without requiring + an explicit parameter for metadata. Since our instance is an instance of + :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see + documentation for :func:`~libcst.matchers.matches` as it is identical to this + function. + """ + return matches(node, matcher, metadata_resolver=self) diff --git a/libcst/matchers/tests/test_matchers_with_metadata.py b/libcst/matchers/tests/test_matchers_with_metadata.py index 8e4cfaa9..27cf7f02 100644 --- a/libcst/matchers/tests/test_matchers_with_metadata.py +++ b/libcst/matchers/tests/test_matchers_with_metadata.py @@ -5,7 +5,7 @@ # pyre-strict from textwrap import dedent -from typing import Tuple +from typing import Sequence, Set, Tuple import libcst as cst import libcst.matchers as m @@ -199,3 +199,170 @@ class MatchersMetadataTest(UnitTest): ) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) + + +class MatchersVisitorMetadataTest(UnitTest): + def _make_fixture(self, code: str) -> cst.MetadataWrapper: + return cst.MetadataWrapper(cst.parse_module(dedent(code))) + + def test_matches_on_visitors(self) -> None: + # Set up a simple visitor that has a metadata dependency, try to use it in matchers. + class TestVisitor(m.MatcherDecoratableVisitor): + METADATA_DEPENDENCIES: Sequence[meta.ProviderT] = ( + meta.ExpressionContextProvider, + ) + + def __init__(self) -> None: + super().__init__() + self.match_names: Set[str] = set() + + def visit_Name(self, node: cst.Name) -> None: + # Only match name nodes that are being assigned to. + if self.matches( + node, + m.Name( + metadata=m.MatchMetadata( + meta.ExpressionContextProvider, meta.ExpressionContext.STORE + ) + ), + ): + self.match_names.add(node.value) + + module = self._make_fixture( + """ + a = 1 + 2 + b = 3 + 4 + d + e + def foo() -> str: + c = "baz" + return c + def bar() -> int: + return b + del foo + del bar + """ + ) + visitor = TestVisitor() + module.visit(visitor) + + self.assertEqual(visitor.match_names, {"a", "b", "c"}) + + def test_matches_on_transformers(self) -> None: + # Set up a simple visitor that has a metadata dependency, try to use it in matchers. + class TestTransformer(m.MatcherDecoratableTransformer): + METADATA_DEPENDENCIES: Sequence[meta.ProviderT] = ( + meta.ExpressionContextProvider, + ) + + def __init__(self) -> None: + super().__init__() + self.match_names: Set[str] = set() + + def visit_Name(self, node: cst.Name) -> None: + # Only match name nodes that are being assigned to. + if self.matches( + node, + m.Name( + metadata=m.MatchMetadata( + meta.ExpressionContextProvider, meta.ExpressionContext.STORE + ) + ), + ): + self.match_names.add(node.value) + + module = self._make_fixture( + """ + a = 1 + 2 + b = 3 + 4 + d + e + def foo() -> str: + c = "baz" + return c + def bar() -> int: + return b + del foo + del bar + """ + ) + visitor = TestTransformer() + module.visit(visitor) + + self.assertEqual(visitor.match_names, {"a", "b", "c"}) + + def test_matches_decorator_on_visitors(self) -> None: + # Set up a simple visitor that has a metadata dependency, try to use it in matchers. + class TestVisitor(m.MatcherDecoratableVisitor): + METADATA_DEPENDENCIES: Sequence[meta.ProviderT] = ( + meta.ExpressionContextProvider, + ) + + def __init__(self) -> None: + super().__init__() + self.match_names: Set[str] = set() + + @m.visit( + m.Name( + metadata=m.MatchMetadata( + meta.ExpressionContextProvider, meta.ExpressionContext.STORE + ) + ) + ) + def _visit_assignments(self, node: cst.Name) -> None: + # Only match name nodes that are being assigned to. + self.match_names.add(node.value) + + module = self._make_fixture( + """ + a = 1 + 2 + b = 3 + 4 + d + e + def foo() -> str: + c = "baz" + return c + def bar() -> int: + return b + del foo + del bar + """ + ) + visitor = TestVisitor() + module.visit(visitor) + + self.assertEqual(visitor.match_names, {"a", "b", "c"}) + + def test_matches_decorator_on_transformers(self) -> None: + # Set up a simple visitor that has a metadata dependency, try to use it in matchers. + class TestTransformer(m.MatcherDecoratableTransformer): + METADATA_DEPENDENCIES: Sequence[meta.ProviderT] = ( + meta.ExpressionContextProvider, + ) + + def __init__(self) -> None: + super().__init__() + self.match_names: Set[str] = set() + + @m.visit( + m.Name( + metadata=m.MatchMetadata( + meta.ExpressionContextProvider, meta.ExpressionContext.STORE + ) + ) + ) + def _visit_assignments(self, node: cst.Name) -> None: + # Only match name nodes that are being assigned to. + self.match_names.add(node.value) + + module = self._make_fixture( + """ + a = 1 + 2 + b = 3 + 4 + d + e + def foo() -> str: + c = "baz" + return c + def bar() -> int: + return b + del foo + del bar + """ + ) + visitor = TestTransformer() + module.visit(visitor) + + self.assertEqual(visitor.match_names, {"a", "b", "c"})