mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
Use helpers in various codemod transforms.
This commit is contained in:
parent
f2b738deda
commit
6f3f541812
4 changed files with 109 additions and 44 deletions
|
|
@ -12,6 +12,7 @@ from libcst import matchers as m, parse_statement
|
|||
from libcst.codemod._context import CodemodContext
|
||||
from libcst.codemod._visitor import ContextAwareTransformer
|
||||
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
|
||||
from libcst.helpers import get_absolute_module_for_import
|
||||
|
||||
|
||||
class AddImportsVisitor(ContextAwareTransformer):
|
||||
|
|
@ -191,29 +192,22 @@ class AddImportsVisitor(ContextAwareTransformer):
|
|||
# There's nothing left, so lets delete this work item
|
||||
del self.module_mapping[module]
|
||||
|
||||
def _get_string_name(self, node: Optional[libcst.CSTNode]) -> str:
|
||||
if node is None:
|
||||
return ""
|
||||
elif isinstance(node, libcst.Name):
|
||||
return node.value
|
||||
elif isinstance(node, libcst.Attribute):
|
||||
return self._get_string_name(node.value) + "." + node.attr.value
|
||||
else:
|
||||
raise Exception(f"Invalid node type {type(node)}!")
|
||||
|
||||
def leave_ImportFrom(
|
||||
self, original_node: libcst.ImportFrom, updated_node: libcst.ImportFrom
|
||||
) -> libcst.ImportFrom:
|
||||
if len(updated_node.relative) > 0 or updated_node.module is None:
|
||||
# Don't support relative-only imports at the moment.
|
||||
return updated_node
|
||||
if updated_node.names == "*":
|
||||
if isinstance(updated_node.names, libcst.ImportStar):
|
||||
# There's nothing to do here!
|
||||
return updated_node
|
||||
|
||||
# Get the module we're importing as a string, see if we have work to do
|
||||
module = self._get_string_name(updated_node.module)
|
||||
if module not in self.module_mapping and module not in self.alias_mapping:
|
||||
# Get the module we're importing as a string, see if we have work to do.
|
||||
module = get_absolute_module_for_import(
|
||||
self.context.full_module_name, updated_node
|
||||
)
|
||||
if (
|
||||
module is None
|
||||
or module not in self.module_mapping
|
||||
and module not in self.alias_mapping
|
||||
):
|
||||
return updated_node
|
||||
|
||||
# We have work to do, mark that we won't modify this again.
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# pyre-strict
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Dict, List, Sequence, Set, Tuple, Union
|
||||
|
||||
import libcst
|
||||
from libcst.codemod._context import CodemodContext
|
||||
from libcst.codemod._visitor import ContextAwareVisitor
|
||||
from libcst.helpers import get_full_name_for_node
|
||||
from libcst.helpers import get_absolute_module_for_import
|
||||
|
||||
|
||||
class GatherImportsVisitor(ContextAwareVisitor):
|
||||
|
|
@ -62,38 +62,29 @@ class GatherImportsVisitor(ContextAwareVisitor):
|
|||
# Track all of the imports found in this transform
|
||||
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
|
||||
|
||||
def _get_string_name(self, node: Optional[libcst.CSTNode]) -> str:
|
||||
name = "" if node is None else get_full_name_for_node(node)
|
||||
if name is None:
|
||||
raise Exception(f"Invalid node type {type(node)}!")
|
||||
|
||||
return name
|
||||
|
||||
def visit_Import(self, node: libcst.Import) -> None:
|
||||
# Track this import statement for later analysis.
|
||||
self.all_imports.append(node)
|
||||
|
||||
for name in node.names:
|
||||
asname = name.asname
|
||||
if asname is not None:
|
||||
alias = name.evaluated_alias
|
||||
if alias is not None:
|
||||
# Track this as an aliased module
|
||||
self.module_aliases[
|
||||
self._get_string_name(name.name)
|
||||
] = libcst.ensure_type(asname.name, libcst.Name).value
|
||||
self.module_aliases[name.evaluated_name] = alias
|
||||
else:
|
||||
# Get the module we're importing as a string.
|
||||
self.module_imports.add(self._get_string_name(name.name))
|
||||
self.module_imports.add(name.evaluated_name)
|
||||
|
||||
def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
|
||||
# Track this import statement for later analysis.
|
||||
self.all_imports.append(node)
|
||||
|
||||
if len(node.relative) > 0 or node.module is None:
|
||||
# Don't support relative-only imports at the moment.
|
||||
return
|
||||
|
||||
# Get the module we're importing as a string.
|
||||
module = self._get_string_name(node.module)
|
||||
module = get_absolute_module_for_import(self.context.full_module_name, node)
|
||||
if module is None:
|
||||
# Can't get the absolute import from relative, so we can't
|
||||
# support this.
|
||||
return
|
||||
nodenames = node.names
|
||||
if isinstance(nodenames, libcst.ImportStar):
|
||||
# We cover everything, no need to bother tracking other things
|
||||
|
|
@ -102,20 +93,18 @@ class GatherImportsVisitor(ContextAwareVisitor):
|
|||
elif isinstance(nodenames, Sequence):
|
||||
# Get the list of imports we're aliasing in this import
|
||||
new_aliases = [
|
||||
# pyre-ignore We check ia.asname below, this is safe
|
||||
(self._get_string_name(ia.name), ia.asname.name.value)
|
||||
(ia.evaluated_name, ia.evaluated_alias)
|
||||
for ia in nodenames
|
||||
if ia.asname is not None
|
||||
]
|
||||
if new_aliases:
|
||||
if module not in self.alias_mapping:
|
||||
self.alias_mapping[module] = []
|
||||
# pyre-ignore We know that aliases are not None here.
|
||||
self.alias_mapping[module].extend(new_aliases)
|
||||
|
||||
# Get the list of imports we're importing in this import
|
||||
new_objects = {
|
||||
self._get_string_name(ia.name) for ia in nodenames if ia.asname is None
|
||||
}
|
||||
new_objects = {ia.evaluated_name for ia in nodenames if ia.asname is None}
|
||||
if new_objects:
|
||||
if module not in self.object_mapping:
|
||||
self.object_mapping[module] = set()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# pyre-strict
|
||||
from libcst.codemod import CodemodTest
|
||||
from libcst.codemod import CodemodContext, CodemodTest
|
||||
from libcst.codemod.visitors import AddImportsVisitor
|
||||
|
||||
|
||||
|
|
@ -558,3 +558,65 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("argparse", None, None)])
|
||||
|
||||
def test_dont_add_relative_object_simple(self) -> None:
|
||||
"""
|
||||
Should not add object as an import since it exists.
|
||||
"""
|
||||
|
||||
before = """
|
||||
from .c import D
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
after = """
|
||||
from .c import D
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(
|
||||
before,
|
||||
after,
|
||||
[("a.b.c", "D", None)],
|
||||
context_override=CodemodContext(full_module_name="a.b.foobar"),
|
||||
)
|
||||
|
||||
def test_add_object_relative_modify_simple(self) -> None:
|
||||
"""
|
||||
Should modify existing import to add new object
|
||||
"""
|
||||
|
||||
before = """
|
||||
from .c import E, F
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
after = """
|
||||
from .c import D, E, F
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(
|
||||
before,
|
||||
after,
|
||||
[("a.b.c", "D", None)],
|
||||
context_override=CodemodContext(full_module_name="a.b.foobar"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ from libcst.testing.utils import UnitTest
|
|||
|
||||
class TestGatherImportsVisitor(UnitTest):
|
||||
def gather_imports(self, code: str) -> GatherImportsVisitor:
|
||||
transform_instance = GatherImportsVisitor(CodemodContext())
|
||||
transform_instance = GatherImportsVisitor(
|
||||
CodemodContext(full_module_name="a.b.foobar")
|
||||
)
|
||||
input_tree = parse_module(CodemodTest.make_fixture_data(code))
|
||||
input_tree.visit(transform_instance)
|
||||
return transform_instance
|
||||
|
|
@ -154,3 +156,21 @@ class TestGatherImportsVisitor(UnitTest):
|
|||
self.assertEqual(gatherer.module_aliases, {})
|
||||
self.assertEqual(gatherer.alias_mapping, {"a.b.c": [("d", "e")]})
|
||||
self.assertEqual(len(gatherer.all_imports), 1)
|
||||
|
||||
def test_gather_relative_object(self) -> None:
|
||||
code = """
|
||||
from .c import d as e, f, g
|
||||
from a.b.c import h, i, j
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
gatherer = self.gather_imports(code)
|
||||
self.assertEqual(gatherer.module_imports, set())
|
||||
self.assertEqual(gatherer.object_mapping, {"a.b.c": {"f", "g", "h", "i", "j"}})
|
||||
self.assertEqual(gatherer.module_aliases, {})
|
||||
self.assertEqual(gatherer.alias_mapping, {"a.b.c": [("d", "e")]})
|
||||
self.assertEqual(len(gatherer.all_imports), 2)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue