diff --git a/docs/source/matchers.rst b/docs/source/matchers.rst index 6bdf214e..37398f40 100644 --- a/docs/source/matchers.rst +++ b/docs/source/matchers.rst @@ -33,6 +33,7 @@ selectively control when LibCST calls visitor functions. .. autofunction:: libcst.matchers.findall .. autofunction:: libcst.matchers.extract .. autofunction:: libcst.matchers.extractall +.. autofunction:: libcst.matchers.replace .. _libcst-matcher-decorators: diff --git a/libcst/codegen/gen_matcher_classes.py b/libcst/codegen/gen_matcher_classes.py index adb3aaad..c8bd204d 100644 --- a/libcst/codegen/gen_matcher_classes.py +++ b/libcst/codegen/gen_matcher_classes.py @@ -482,7 +482,7 @@ generated_code.append("from typing_extensions import Literal") generated_code.append("import libcst as cst") generated_code.append("") generated_code.append( - "from libcst.matchers._matcher_base import BaseMatcherNode, DoNotCareSentinel, DoNotCare, OneOf, AllOf, DoesNotMatch, MatchIfTrue, MatchRegex, MatchMetadata, MatchMetadataIfTrue, ZeroOrMore, AtLeastN, ZeroOrOne, AtMostN, SaveMatchedNode, extract, extractall, findall, matches" + "from libcst.matchers._matcher_base import BaseMatcherNode, DoNotCareSentinel, DoNotCare, OneOf, AllOf, DoesNotMatch, MatchIfTrue, MatchRegex, MatchMetadata, MatchMetadataIfTrue, ZeroOrMore, AtLeastN, ZeroOrOne, AtMostN, SaveMatchedNode, extract, extractall, findall, matches, replace" ) all_exports.update( [ @@ -505,6 +505,7 @@ all_exports.update( "extractall", "findall", "matches", + "replace", ] ) generated_code.append( diff --git a/libcst/matchers/__init__.py b/libcst/matchers/__init__.py index aa4ef486..ae0b38d1 100644 --- a/libcst/matchers/__init__.py +++ b/libcst/matchers/__init__.py @@ -34,6 +34,7 @@ from libcst.matchers._matcher_base import ( extractall, findall, matches, + replace, ) from libcst.matchers._visitors import ( MatchDecoratorMismatch, @@ -13471,5 +13472,6 @@ __all__ = [ "findall", "leave", "matches", + "replace", "visit", ] diff --git a/libcst/matchers/_matcher_base.py b/libcst/matchers/_matcher_base.py index dedf9ca7..37b5bdbd 100644 --- a/libcst/matchers/_matcher_base.py +++ b/libcst/matchers/_matcher_base.py @@ -5,6 +5,8 @@ # pyre-strict import collections.abc +import copy +import inspect import re from dataclasses import fields from enum import Enum, auto @@ -1611,3 +1613,198 @@ def extractall( tree, matcher, metadata_resolver=metadata_resolver ) return extractions + + +class _ReplaceTransformer(libcst.CSTTransformer): + def __init__( + self, + matcher: Union[ + BaseMatcherNode, + MatchIfTrue[Callable[[object], bool]], + _BaseMetadataMatcher, + _InverseOf[ + Union[ + BaseMatcherNode, + MatchIfTrue[Callable[[object], bool]], + _BaseMetadataMatcher, + ] + ], + ], + metadata_lookup: Callable[[meta.ProviderT, libcst.CSTNode], object], + replacement: Union[ + MaybeSentinel, + RemovalSentinel, + libcst.CSTNode, + Callable[ + [ + libcst.CSTNode, + Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]], + ], + Union[MaybeSentinel, RemovalSentinel, libcst.CSTNode], + ], + ], + ) -> None: + self.matcher = matcher + self.metadata_lookup = metadata_lookup + if inspect.isfunction(replacement): + # pyre-ignore Pyre knows replacement is a function, but somehow drops + # the type hint from the init signature. + self.replacement: Callable[ + [ + libcst.CSTNode, + Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]], + ] + ] = replacement + elif isinstance(replacement, (MaybeSentinel, RemovalSentinel)): + self.replacement: Callable[ + [ + libcst.CSTNode, + Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]], + ] + ] = lambda node, matches: copy.deepcopy(replacement) + else: + self.replacement: Callable[ + [ + libcst.CSTNode, + Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]], + ] + # pyre-ignore We know this is a CSTNode. + ] = lambda node, matches: replacement.deep_clone() + # We run into a really weird problem here, where we need to run the match + # and extract step on the original node in order for metadata to work. + # However, if we do that, then using things like `deep_replace` will fail + # since any extracted nodes are the originals, not the updates and LibCST + # does replacement by identity for safety reasons. If we try to run the + # match and extract step on the updated node (or twice, once for the match + # and once for the extract), it will fail to extract if any metadata-based + # matchers are used. So, we try to compromise with the best of both worlds. + # We track all node updates, and when we send the extracted nodes to the + # replacement callable, we look up the original nodes and replace them with + # updated nodes. In the case that an update made the node no-longer exist, + # we act as if there was not a match (because in reality, there would not + # have been if we had run the matcher on the update). + self.node_lut: Dict[libcst.CSTNode, libcst.CSTNode] = {} + + def _node_translate( + self, node_or_sequence: Union[libcst.CSTNode, Sequence[libcst.CSTNode]], + ) -> Union[libcst.CSTNode, Sequence[libcst.CSTNode]]: + if isinstance(node_or_sequence, Sequence): + return tuple(self.node_lut[node] for node in node_or_sequence) + else: + return self.node_lut[node_or_sequence] + + def _extraction_translate( + self, extracted: Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]] + ) -> Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]]: + return {key: self._node_translate(val) for key, val in extracted.items()} + + def on_leave( + self, original_node: libcst.CSTNode, updated_node: libcst.CSTNode + ) -> Union[libcst.CSTNode, MaybeSentinel, RemovalSentinel]: + # Track original to updated node mapping for this node. + self.node_lut[original_node] = updated_node + + # This gets complicated. We need to do the match on the original node, + # but we want to do the extraction on the updated node. This is so + # metadata works properly in matchers. So, if we get a match, we fix + # up the nodes in the match and return that to the replacement lambda. + extracted = _matches(original_node, self.matcher, self.metadata_lookup) + if extracted is not None: + try: + # Attempt to do a translation from original to updated node. + extracted = self._extraction_translate(extracted) + except KeyError: + # One of the nodes we looked up doesn't exist anymore, this + # is no longer a match. This can happen if a child node was + # modified, making this original match not applicable anymore. + extracted = None + if extracted is not None: + # We're replacing this node entirely, so don't save the original + # updated node. We don't want this to be part of a parent match + # since we can't guarantee that the update matches anymore. + del self.node_lut[original_node] + return self.replacement(updated_node, extracted) + return updated_node + + +def replace( + tree: Union[MaybeSentinel, RemovalSentinel, libcst.CSTNode, meta.MetadataWrapper], + matcher: Union[ + BaseMatcherNode, MatchIfTrue[Callable[[object], bool]], _BaseMetadataMatcher, + ], + replacement: Union[ + MaybeSentinel, + RemovalSentinel, + libcst.CSTNode, + Callable[ + [ + libcst.CSTNode, + Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]], + ], + Union[MaybeSentinel, RemovalSentinel, libcst.CSTNode], + ], + ], + *, + metadata_resolver: Optional[ + Union[libcst.MetadataDependent, libcst.MetadataWrapper] + ] = None, +) -> Union[MaybeSentinel, RemovalSentinel, libcst.CSTNode]: + """ + Given an arbitrary node from a LibCST tree and an arbitrary matcher, iterates + over that node and all children and replaces each node that matches the supplied + matcher with a supplied replacement. Note that the replacement can either be a + valid node type, or a callable which takes the matched node and a dictionary of + any extracted child values and returns a valid node type. If you provide a + valid LibCST node type, :func:`replace` will replace every node that matches + the supplied matcher with the replacement node. If you provide a callable, + :func:`replace` will run :func:`extract` over all matched nodes and call the + callable with both the node that should be replaced and the dictionary returned + by :func:`extract`. Under all circumstances a new tree is returned. + :func:`extract` should be viewed as a short-cut to writing a transform which + also returns a new tree even when no changes are applied. + + Note that the tree can also be a :class:`~libcst.RemovalSentinel` or a + :class:`~libcst.MaybeSentinel` in order to use replace directly on transform + results and node attributes. In these cases, :func:`replace` will return the + same :class:`~libcst.RemovalSentinel` or :class:`~libcst.MaybeSentinel`. + Note also that instead of a LibCST tree, you can instead pass in a + :class:`~libcst.metadata.MetadataWrapper`. This mirrors the fact that you can + call ``visit`` on a :class:`~libcst.metadata.MetadataWrapper` in order to + iterate over it with a transform. If you provide a wrapper for the tree and + do not set the ``metadata_resolver`` parameter specifically, it will + automatically be set to the wrapper for you. + + The matcher can be any concrete matcher that subclasses from :class:`BaseMatcherNode`, + or a :class:`OneOf`/:class:`AllOf` special matcher. Unlike :func:`matches`, it can + also be a :class:`MatchIfTrue` or :func:`DoesNotMatch` matcher, since we are + traversing the tree looking for matches. It cannot be a :class:`AtLeastN` or + :class:`AtMostN` matcher because these types are wildcards which can only be usedi + inside sequences. + """ + if isinstance(tree, (RemovalSentinel, MaybeSentinel)): + # We can't do any replacements on this, so return the tree exactly. + return copy.deepcopy(tree) + if isinstance(matcher, (AtLeastN, AtMostN)): + # We can't match this, since these matchers are forbidden at top level. + # These are not subclasses of BaseMatcherNode, but in the case that the + # user is not using type checking, this should still behave correctly. + if isinstance(tree, libcst.CSTNode): + return tree.deep_clone() + elif isinstance(tree, meta.MetadataWrapper): + return tree.module.deep_clone() + else: + raise Exception("Logic error!") + + if isinstance(tree, meta.MetadataWrapper) and metadata_resolver is None: + # Provide a convenience for calling replace directly on a MetadataWrapper. + metadata_resolver = tree + + if metadata_resolver is None: + fetcher = _construct_metadata_fetcher_null() + elif isinstance(metadata_resolver, libcst.MetadataWrapper): + fetcher = _construct_metadata_fetcher_wrapper(metadata_resolver) + else: + fetcher = _construct_metadata_fetcher_dependent(metadata_resolver) + + replacer = _ReplaceTransformer(matcher, fetcher, replacement) + return tree.visit(replacer) diff --git a/libcst/matchers/_visitors.py b/libcst/matchers/_visitors.py index 92a0d956..924597bb 100644 --- a/libcst/matchers/_visitors.py +++ b/libcst/matchers/_visitors.py @@ -42,6 +42,7 @@ from libcst.matchers._matcher_base import ( extractall, findall, matches, + replace, ) from libcst.matchers._return_types import TYPED_FUNCTION_RETURN_MAPPING @@ -611,6 +612,34 @@ class MatcherDecoratableTransformer(CSTTransformer): """ return extractall(tree, matcher, metadata_resolver=self) + def replace( + self, + tree: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode], + matcher: Union[ + BaseMatcherNode, + MatchIfTrue[Callable[..., bool]], + MatchMetadata, + MatchMetadataIfTrue, + ], + replacement: Union[ + cst.MaybeSentinel, + cst.RemovalSentinel, + cst.CSTNode, + Callable[ + [cst.CSTNode, Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]],], + Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode], + ], + ], + ) -> Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode]: + """ + A convenience method to call :func:`~libcst.matchers.replace` without requiring + an explicit parameter for metadata. Since our instance is an instance of + :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see + documentation for :func:`~libcst.matchers.replace` as it is identical to this + function. + """ + return replace(tree, matcher, replacement, metadata_resolver=self) + def _transform_module_impl(self, tree: cst.Module) -> cst.Module: return tree.visit(self) @@ -780,3 +809,31 @@ class MatcherDecoratableVisitor(CSTVisitor): function. """ return extractall(tree, matcher, metadata_resolver=self) + + def replace( + self, + tree: Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode], + matcher: Union[ + BaseMatcherNode, + MatchIfTrue[Callable[..., bool]], + MatchMetadata, + MatchMetadataIfTrue, + ], + replacement: Union[ + cst.MaybeSentinel, + cst.RemovalSentinel, + cst.CSTNode, + Callable[ + [cst.CSTNode, Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]],], + Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode], + ], + ], + ) -> Union[cst.MaybeSentinel, cst.RemovalSentinel, cst.CSTNode]: + """ + A convenience method to call :func:`~libcst.matchers.replace` without requiring + an explicit parameter for metadata. Since our instance is an instance of + :class:`libcst.MetadataDependent`, we work as a metadata resolver. Please see + documentation for :func:`~libcst.matchers.replace` as it is identical to this + function. + """ + return replace(tree, matcher, replacement, metadata_resolver=self) diff --git a/libcst/matchers/tests/test_replace.py b/libcst/matchers/tests/test_replace.py new file mode 100644 index 00000000..8840db13 --- /dev/null +++ b/libcst/matchers/tests/test_replace.py @@ -0,0 +1,283 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# pyre-strict +from typing import Dict, Sequence, Union + +import libcst as cst +import libcst.matchers as m +import libcst.metadata as meta +from libcst.testing.utils import UnitTest + + +class MatchersReplaceTest(UnitTest): + def test_replace_sentinel(self) -> None: + def _swap_bools( + node: cst.CSTNode, + extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], + ) -> cst.CSTNode: + return cst.Name( + "True" if cst.ensure_type(node, cst.Name).value == "False" else "False" + ) + + # Verify behavior when provided a sentinel + replaced = m.replace( + cst.RemovalSentinel.REMOVE, m.Name("True") | m.Name("False"), _swap_bools, + ) + self.assertEqual(replaced, cst.RemovalSentinel.REMOVE) + replaced = m.replace( + cst.MaybeSentinel.DEFAULT, m.Name("True") | m.Name("False"), _swap_bools, + ) + self.assertEqual(replaced, cst.MaybeSentinel.DEFAULT) + + def test_replace_noop(self) -> None: + def _swap_bools( + node: cst.CSTNode, + extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], + ) -> cst.CSTNode: + return cst.Name( + "True" if cst.ensure_type(node, cst.Name).value == "False" else "False" + ) + + # Verify behavior when there's nothing to replace. + original = cst.parse_module("foo: int = 5\ndef bar() -> str:\n return 's'\n") + replaced = cst.ensure_type( + m.replace(original, m.Name("True") | m.Name("False"), _swap_bools,), + cst.Module, + ) + # Should be identical tree contents + self.assertTrue(original.deep_equals(replaced)) + # However, should be a new tree by identity + self.assertNotEqual(id(original), id(replaced)) + + def test_replace_simple(self) -> None: + # Verify behavior when there's a static node as a replacement + original = cst.parse_module( + "foo: bool = True\ndef bar() -> bool:\n return False\n" + ) + replaced = cst.ensure_type( + m.replace(original, m.Name("True") | m.Name("False"), cst.Name("boolean")), + cst.Module, + ).code + self.assertEqual( + replaced, "foo: bool = boolean\ndef bar() -> bool:\n return boolean\n" + ) + + def test_replace_simple_sentinel(self) -> None: + # Verify behavior when there's a sentinel as a replacement + original = cst.parse_module( + "def bar(x: int, y: int) -> bool:\n return False\n" + ) + replaced = cst.ensure_type( + m.replace(original, m.Param(), cst.RemoveFromParent()), cst.Module, + ).code + self.assertEqual(replaced, "def bar() -> bool:\n return False\n") + + def test_replace_actual(self) -> None: + def _swap_bools( + node: cst.CSTNode, + extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], + ) -> cst.CSTNode: + return cst.Name( + "True" if cst.ensure_type(node, cst.Name).value == "False" else "False" + ) + + # Verify behavior when there's lots to replace. + original = cst.parse_module( + "foo: bool = True\ndef bar() -> bool:\n return False\n" + ) + replaced = cst.ensure_type( + m.replace(original, m.Name("True") | m.Name("False"), _swap_bools), + cst.Module, + ).code + self.assertEqual( + replaced, "foo: bool = False\ndef bar() -> bool:\n return True\n" + ) + + def test_replace_add_one(self) -> None: + def _add_one( + node: cst.CSTNode, + extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], + ) -> cst.CSTNode: + return cst.Integer(str(int(cst.ensure_type(node, cst.Integer).value) + 1)) + + # Verify slightly more complex transform behavior. + original = cst.parse_module("foo: int = 36\ndef bar() -> int:\n return 41\n") + replaced = cst.ensure_type( + m.replace(original, m.Integer(), _add_one), cst.Module, + ).code + self.assertEqual(replaced, "foo: int = 37\ndef bar() -> int:\n return 42\n") + + def test_replace_add_one_to_foo_args(self) -> None: + def _add_one_to_arg( + node: cst.CSTNode, + extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], + ) -> cst.CSTNode: + return node.deep_replace( + # This can be either a node or a sequence, pyre doesn't know. + cst.ensure_type(extraction["arg"], cst.CSTNode), + # Grab the arg and add one to its value. + cst.Integer( + str(int(cst.ensure_type(extraction["arg"], cst.Integer).value) + 1) + ), + ) + + # Verify way more complex transform behavior. + original = cst.parse_module( + "foo: int = 37\ndef bar(baz: int) -> int:\n return baz\n\nbiz: int = bar(41)\n" + ) + replaced = cst.ensure_type( + m.replace( + original, + m.Call( + func=m.Name("bar"), + args=[m.Arg(m.SaveMatchedNode(m.Integer(), "arg"))], + ), + _add_one_to_arg, + ), + cst.Module, + ).code + self.assertEqual( + replaced, + "foo: int = 37\ndef bar(baz: int) -> int:\n return baz\n\nbiz: int = bar(42)\n", + ) + + def test_replace_sequence_extract(self) -> None: + def _reverse_params( + node: cst.CSTNode, + extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], + ) -> cst.CSTNode: + return cst.ensure_type(node, cst.FunctionDef).with_changes( + # pyre-ignore We know "params" is a Sequence[Parameters] but asserting that + # to pyre is difficult. + params=cst.Parameters(params=list(reversed(extraction["params"]))), + ) + + # Verify that we can still extract sequences with replace. + original = cst.parse_module( + "def bar(baz: int, foo: int, ) -> int:\n return baz + foo\n" + ) + replaced = cst.ensure_type( + m.replace( + original, + m.FunctionDef( + params=m.Parameters( + params=m.SaveMatchedNode([m.ZeroOrMore(m.Param())], "params"), + ) + ), + _reverse_params, + ), + cst.Module, + ).code + self.assertEqual( + replaced, "def bar(foo: int, baz: int, ) -> int:\n return baz + foo\n" + ) + + def test_replace_metadata(self) -> None: + def _rename_foo( + node: cst.CSTNode, + extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], + ) -> cst.CSTNode: + return cst.ensure_type(node, cst.Name).with_changes(value="replaced") + + original = cst.parse_module( + "foo: int = 37\ndef bar(foo: int) -> int:\n return foo\n\nbiz: int = bar(42)\n" + ) + wrapper = cst.MetadataWrapper(original) + replaced = cst.ensure_type( + m.replace( + wrapper, + m.Name( + metadata=m.MatchMetadataIfTrue( + meta.QualifiedNameProvider, + lambda qualnames: any(n.name == "foo" for n in qualnames), + ), + ), + _rename_foo, + ), + cst.Module, + ).code + self.assertEqual( + replaced, + "replaced: int = 37\ndef bar(foo: int) -> int:\n return foo\n\nbiz: int = bar(42)\n", + ) + + def test_replace_metadata_on_transform(self) -> None: + def _rename_foo( + node: cst.CSTNode, + extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], + ) -> cst.CSTNode: + return cst.ensure_type(node, cst.Name).with_changes(value="replaced") + + original = cst.parse_module( + "foo: int = 37\ndef bar(foo: int) -> int:\n return foo\n\nbiz: int = bar(42)\n" + ) + wrapper = cst.MetadataWrapper(original) + + class TestTransformer(m.MatcherDecoratableTransformer): + METADATA_DEPENDENCIES: Sequence[meta.ProviderT] = ( + meta.QualifiedNameProvider, + ) + + def leave_Module( + self, original_node: cst.Module, updated_node: cst.Module + ) -> cst.Module: + # Somewhat contrived scenario to test codepaths. + return cst.ensure_type( + self.replace( + original_node, + m.Name( + metadata=m.MatchMetadataIfTrue( + meta.QualifiedNameProvider, + lambda qualnames: any( + n.name == "foo" for n in qualnames + ), + ), + ), + _rename_foo, + ), + cst.Module, + ) + + replaced = cst.ensure_type(wrapper.visit(TestTransformer()), cst.Module).code + self.assertEqual( + replaced, + "replaced: int = 37\ndef bar(foo: int) -> int:\n return foo\n\nbiz: int = bar(42)\n", + ) + + def test_replace_updated_node_changes(self) -> None: + def _replace_nested( + node: cst.CSTNode, + extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], + ) -> cst.CSTNode: + return cst.ensure_type(node, cst.Call).with_changes( + args=[ + cst.Arg( + cst.Name( + value=cst.ensure_type( + cst.ensure_type(extraction["inner"], cst.Call).func, + cst.Name, + ).value + + "_immediate" + ) + ) + ] + ) + + original = cst.parse_module( + "def foo(val: int) -> int:\n return val\nbar = foo\nbaz = foo\nbiz = foo\nfoo(bar(baz(biz(5))))\n" + ) + replaced = cst.ensure_type( + m.replace( + original, + m.Call(args=[m.Arg(m.SaveMatchedNode(m.Call(), "inner"))]), + _replace_nested, + ), + cst.Module, + ).code + self.assertEqual( + replaced, + "def foo(val: int) -> int:\n return val\nbar = foo\nbaz = foo\nbiz = foo\nfoo(bar_immediate)\n", + )