Use helpers in various codemod transforms.

This commit is contained in:
Jennifer Taylor 2020-02-06 16:26:56 -08:00 committed by Jennifer Taylor
parent f2b738deda
commit 6f3f541812
4 changed files with 109 additions and 44 deletions

View file

@ -12,6 +12,7 @@ from libcst import matchers as m, parse_statement
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.helpers import get_absolute_module_for_import
class AddImportsVisitor(ContextAwareTransformer):
@ -191,29 +192,22 @@ class AddImportsVisitor(ContextAwareTransformer):
# There's nothing left, so lets delete this work item
del self.module_mapping[module]
def _get_string_name(self, node: Optional[libcst.CSTNode]) -> str:
if node is None:
return ""
elif isinstance(node, libcst.Name):
return node.value
elif isinstance(node, libcst.Attribute):
return self._get_string_name(node.value) + "." + node.attr.value
else:
raise Exception(f"Invalid node type {type(node)}!")
def leave_ImportFrom(
self, original_node: libcst.ImportFrom, updated_node: libcst.ImportFrom
) -> libcst.ImportFrom:
if len(updated_node.relative) > 0 or updated_node.module is None:
# Don't support relative-only imports at the moment.
return updated_node
if updated_node.names == "*":
if isinstance(updated_node.names, libcst.ImportStar):
# There's nothing to do here!
return updated_node
# Get the module we're importing as a string, see if we have work to do
module = self._get_string_name(updated_node.module)
if module not in self.module_mapping and module not in self.alias_mapping:
# Get the module we're importing as a string, see if we have work to do.
module = get_absolute_module_for_import(
self.context.full_module_name, updated_node
)
if (
module is None
or module not in self.module_mapping
and module not in self.alias_mapping
):
return updated_node
# We have work to do, mark that we won't modify this again.

View file

@ -4,12 +4,12 @@
# LICENSE file in the root directory of this source tree.
#
# pyre-strict
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Dict, List, Sequence, Set, Tuple, Union
import libcst
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareVisitor
from libcst.helpers import get_full_name_for_node
from libcst.helpers import get_absolute_module_for_import
class GatherImportsVisitor(ContextAwareVisitor):
@ -62,38 +62,29 @@ class GatherImportsVisitor(ContextAwareVisitor):
# Track all of the imports found in this transform
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
def _get_string_name(self, node: Optional[libcst.CSTNode]) -> str:
name = "" if node is None else get_full_name_for_node(node)
if name is None:
raise Exception(f"Invalid node type {type(node)}!")
return name
def visit_Import(self, node: libcst.Import) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)
for name in node.names:
asname = name.asname
if asname is not None:
alias = name.evaluated_alias
if alias is not None:
# Track this as an aliased module
self.module_aliases[
self._get_string_name(name.name)
] = libcst.ensure_type(asname.name, libcst.Name).value
self.module_aliases[name.evaluated_name] = alias
else:
# Get the module we're importing as a string.
self.module_imports.add(self._get_string_name(name.name))
self.module_imports.add(name.evaluated_name)
def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)
if len(node.relative) > 0 or node.module is None:
# Don't support relative-only imports at the moment.
return
# Get the module we're importing as a string.
module = self._get_string_name(node.module)
module = get_absolute_module_for_import(self.context.full_module_name, node)
if module is None:
# Can't get the absolute import from relative, so we can't
# support this.
return
nodenames = node.names
if isinstance(nodenames, libcst.ImportStar):
# We cover everything, no need to bother tracking other things
@ -102,20 +93,18 @@ class GatherImportsVisitor(ContextAwareVisitor):
elif isinstance(nodenames, Sequence):
# Get the list of imports we're aliasing in this import
new_aliases = [
# pyre-ignore We check ia.asname below, this is safe
(self._get_string_name(ia.name), ia.asname.name.value)
(ia.evaluated_name, ia.evaluated_alias)
for ia in nodenames
if ia.asname is not None
]
if new_aliases:
if module not in self.alias_mapping:
self.alias_mapping[module] = []
# pyre-ignore We know that aliases are not None here.
self.alias_mapping[module].extend(new_aliases)
# Get the list of imports we're importing in this import
new_objects = {
self._get_string_name(ia.name) for ia in nodenames if ia.asname is None
}
new_objects = {ia.evaluated_name for ia in nodenames if ia.asname is None}
if new_objects:
if module not in self.object_mapping:
self.object_mapping[module] = set()

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
#
# pyre-strict
from libcst.codemod import CodemodTest
from libcst.codemod import CodemodContext, CodemodTest
from libcst.codemod.visitors import AddImportsVisitor
@ -558,3 +558,65 @@ class TestAddImportsCodemod(CodemodTest):
"""
self.assertCodemod(before, after, [("argparse", None, None)])
def test_dont_add_relative_object_simple(self) -> None:
"""
Should not add object as an import since it exists.
"""
before = """
from .c import D
def foo() -> None:
pass
def bar() -> int:
return 5
"""
after = """
from .c import D
def foo() -> None:
pass
def bar() -> int:
return 5
"""
self.assertCodemod(
before,
after,
[("a.b.c", "D", None)],
context_override=CodemodContext(full_module_name="a.b.foobar"),
)
def test_add_object_relative_modify_simple(self) -> None:
"""
Should modify existing import to add new object
"""
before = """
from .c import E, F
def foo() -> None:
pass
def bar() -> int:
return 5
"""
after = """
from .c import D, E, F
def foo() -> None:
pass
def bar() -> int:
return 5
"""
self.assertCodemod(
before,
after,
[("a.b.c", "D", None)],
context_override=CodemodContext(full_module_name="a.b.foobar"),
)

View file

@ -12,7 +12,9 @@ from libcst.testing.utils import UnitTest
class TestGatherImportsVisitor(UnitTest):
def gather_imports(self, code: str) -> GatherImportsVisitor:
transform_instance = GatherImportsVisitor(CodemodContext())
transform_instance = GatherImportsVisitor(
CodemodContext(full_module_name="a.b.foobar")
)
input_tree = parse_module(CodemodTest.make_fixture_data(code))
input_tree.visit(transform_instance)
return transform_instance
@ -154,3 +156,21 @@ class TestGatherImportsVisitor(UnitTest):
self.assertEqual(gatherer.module_aliases, {})
self.assertEqual(gatherer.alias_mapping, {"a.b.c": [("d", "e")]})
self.assertEqual(len(gatherer.all_imports), 1)
def test_gather_relative_object(self) -> None:
code = """
from .c import d as e, f, g
from a.b.c import h, i, j
def foo() -> None:
pass
def bar() -> int:
return 5
"""
gatherer = self.gather_imports(code)
self.assertEqual(gatherer.module_imports, set())
self.assertEqual(gatherer.object_mapping, {"a.b.c": {"f", "g", "h", "i", "j"}})
self.assertEqual(gatherer.module_aliases, {})
self.assertEqual(gatherer.alias_mapping, {"a.b.c": [("d", "e")]})
self.assertEqual(len(gatherer.all_imports), 2)