mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
Port add import updates to open-source.
This commit is contained in:
parent
109a0bbc16
commit
0932049c75
2 changed files with 296 additions and 50 deletions
|
|
@ -4,6 +4,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# pyre-strict
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
import libcst
|
||||
|
|
@ -60,7 +61,7 @@ class AddImportsVisitor(ContextAwareTransformer):
|
|||
@staticmethod
|
||||
def _get_imports_from_context(
|
||||
context: CodemodContext,
|
||||
) -> List[Tuple[str, Optional[str]]]:
|
||||
) -> List[Tuple[str, Optional[str], Optional[str]]]:
|
||||
imports = context.scratch.get(AddImportsVisitor.CONTEXT_KEY, [])
|
||||
if not isinstance(imports, list):
|
||||
raise Exception("Logic error!")
|
||||
|
|
@ -68,7 +69,10 @@ class AddImportsVisitor(ContextAwareTransformer):
|
|||
|
||||
@staticmethod
|
||||
def add_needed_import(
|
||||
context: CodemodContext, module: str, obj: Optional[str] = None
|
||||
context: CodemodContext,
|
||||
module: str,
|
||||
obj: Optional[str] = None,
|
||||
asname: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Schedule an import to be added in a future invocation of this class by
|
||||
|
|
@ -84,38 +88,72 @@ class AddImportsVisitor(ContextAwareTransformer):
|
|||
if module == "__future__" and obj is None:
|
||||
raise Exception("Cannot import __future__ directly!")
|
||||
imports = AddImportsVisitor._get_imports_from_context(context)
|
||||
imports.append((module, obj))
|
||||
imports.append((module, obj, asname))
|
||||
context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports
|
||||
|
||||
def __init__(
|
||||
self, context: CodemodContext, imports: Sequence[Tuple[str, Optional[str]]] = ()
|
||||
self,
|
||||
context: CodemodContext,
|
||||
imports: Sequence[Tuple[str, Optional[str], Optional[str]]] = (),
|
||||
) -> None:
|
||||
# Allow for instantiation from either a context (used when multiple transforms
|
||||
# get chained) or from a direct instantiation.
|
||||
super().__init__(context)
|
||||
imports: List[Tuple[str, Optional[str]]] = [
|
||||
imports: List[Tuple[str, Optional[str], Optional[str]]] = [
|
||||
*AddImportsVisitor._get_imports_from_context(context),
|
||||
*imports,
|
||||
]
|
||||
|
||||
# Verify that the imports are valid
|
||||
for module, obj in imports:
|
||||
for module, obj, alias in imports:
|
||||
if module == "__future__" and obj is None:
|
||||
raise Exception("Cannot import __future__ directly!")
|
||||
if module == "__future__" and alias is not None:
|
||||
raise Exception("Cannot import __future__ objects with aliases!")
|
||||
|
||||
# List of modules we need to ensure are imported
|
||||
self.module_imports: Set[str] = {
|
||||
module for (module, obj) in imports if obj is None
|
||||
module for (module, obj, alias) in imports if obj is None and alias is None
|
||||
}
|
||||
|
||||
# List of modules we need to check for object imports on
|
||||
from_imports: Set[str] = {
|
||||
module for (module, obj) in imports if obj is not None
|
||||
module
|
||||
for (module, obj, alias) in imports
|
||||
if obj is not None and alias is None
|
||||
}
|
||||
# Mapping of modules we're adding to the object they should import
|
||||
self.module_mapping: Dict[str, Set[str]] = {
|
||||
module: {o for (m, o) in imports if m == module and o is not None}
|
||||
module: {
|
||||
o
|
||||
for (m, o, n) in imports
|
||||
if m == module and o is not None and n is None
|
||||
}
|
||||
for module in sorted(from_imports)
|
||||
}
|
||||
|
||||
# List of aliased modules we need to ensure are imported
|
||||
self.module_aliases: Dict[str, str] = {
|
||||
module: alias
|
||||
for (module, obj, alias) in imports
|
||||
if obj is None and alias is not None
|
||||
}
|
||||
# List of modules we need to check for object imports on
|
||||
from_imports_aliases: Set[str] = {
|
||||
module
|
||||
for (module, obj, alias) in imports
|
||||
if obj is not None and alias is not None
|
||||
}
|
||||
# Mapping of modules we're adding to the object with alias they should import
|
||||
self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {
|
||||
module: [
|
||||
(o, n)
|
||||
for (m, o, n) in imports
|
||||
if m == module and o is not None and n is not None
|
||||
]
|
||||
for module in sorted(from_imports_aliases)
|
||||
}
|
||||
|
||||
# Track the list of imports found in the file
|
||||
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
|
||||
|
||||
|
|
@ -124,7 +162,21 @@ class AddImportsVisitor(ContextAwareTransformer):
|
|||
gatherer = GatherImportsVisitor(self.context)
|
||||
node.visit(gatherer)
|
||||
self.all_imports = gatherer.all_imports
|
||||
|
||||
self.module_imports = self.module_imports - gatherer.module_imports
|
||||
for module, alias in gatherer.module_aliases.items():
|
||||
if module in self.module_aliases and self.module_aliases[module] == alias:
|
||||
del self.module_aliases[module]
|
||||
for module, aliases in gatherer.alias_mapping.items():
|
||||
for (obj, alias) in aliases:
|
||||
if (
|
||||
module in self.alias_mapping
|
||||
and (obj, alias) in self.alias_mapping[module]
|
||||
):
|
||||
self.alias_mapping[module].remove((obj, alias))
|
||||
if len(self.alias_mapping[module]) == 0:
|
||||
del self.alias_mapping[module]
|
||||
|
||||
for module, imports in gatherer.object_mapping.items():
|
||||
if module not in self.module_mapping:
|
||||
# We don't care about this import at all
|
||||
|
|
@ -161,17 +213,28 @@ class AddImportsVisitor(ContextAwareTransformer):
|
|||
|
||||
# Get the module we're importing as a string, see if we have work to do
|
||||
module = self._get_string_name(updated_node.module)
|
||||
if module not in self.module_mapping:
|
||||
if module not in self.module_mapping and module not in self.alias_mapping:
|
||||
return updated_node
|
||||
|
||||
# We have work to do, mark that we won't modify this again.
|
||||
imports_to_add = self.module_mapping[module]
|
||||
del self.module_mapping[module]
|
||||
imports_to_add = self.module_mapping.get(module, [])
|
||||
if module in self.module_mapping:
|
||||
del self.module_mapping[module]
|
||||
aliases_to_add = self.alias_mapping.get(module, [])
|
||||
if module in self.alias_mapping:
|
||||
del self.alias_mapping[module]
|
||||
|
||||
# Now, do the actual update.
|
||||
return updated_node.with_changes(
|
||||
names=(
|
||||
*[libcst.ImportAlias(name=libcst.Name(imp)) for imp in imports_to_add],
|
||||
*[
|
||||
libcst.ImportAlias(
|
||||
name=libcst.Name(imp),
|
||||
asname=libcst.AsName(name=libcst.Name(alias)),
|
||||
)
|
||||
for (imp, alias) in aliases_to_add
|
||||
],
|
||||
*updated_node.names,
|
||||
)
|
||||
)
|
||||
|
|
@ -249,7 +312,12 @@ class AddImportsVisitor(ContextAwareTransformer):
|
|||
self, original_node: libcst.Module, updated_node: libcst.Module
|
||||
) -> libcst.Module:
|
||||
# Don't try to modify if we have nothing to do
|
||||
if not self.module_imports and not self.module_mapping:
|
||||
if (
|
||||
not self.module_imports
|
||||
and not self.module_mapping
|
||||
and not self.module_aliases
|
||||
and not self.alias_mapping
|
||||
):
|
||||
return updated_node
|
||||
|
||||
# First, find the insertion point for imports
|
||||
|
|
@ -260,15 +328,34 @@ class AddImportsVisitor(ContextAwareTransformer):
|
|||
# Make sure there's at least one empty line before the first non-import
|
||||
statements_after_imports = self._insert_empty_line(statements_after_imports)
|
||||
|
||||
# Mapping of modules we're adding to the object with and without alias they should import
|
||||
module_and_alias_mapping = defaultdict(list)
|
||||
for module, aliases in self.alias_mapping.items():
|
||||
module_and_alias_mapping[module].extend(aliases)
|
||||
for module, imports in self.module_mapping.items():
|
||||
module_and_alias_mapping[module].extend(
|
||||
[(object, None) for object in imports]
|
||||
)
|
||||
module_and_alias_mapping = {
|
||||
module: sorted(aliases)
|
||||
for module, aliases in module_and_alias_mapping.items()
|
||||
}
|
||||
# import ptvsd; ptvsd.set_trace()
|
||||
# Now, add all of the imports we need!
|
||||
return updated_node.with_changes(
|
||||
body=(
|
||||
*[
|
||||
parse_statement(
|
||||
f"from {module} import {', '.join(sorted(imports))}",
|
||||
f"from {module} import "
|
||||
+ ", ".join(
|
||||
[
|
||||
obj if alias is None else f"{obj} as {alias}"
|
||||
for (obj, alias) in aliases
|
||||
]
|
||||
),
|
||||
config=updated_node.config_for_parsing,
|
||||
)
|
||||
for module, imports in self.module_mapping.items()
|
||||
for module, aliases in module_and_alias_mapping.items()
|
||||
if module == "__future__"
|
||||
],
|
||||
*statements_before_imports,
|
||||
|
|
@ -280,10 +367,23 @@ class AddImportsVisitor(ContextAwareTransformer):
|
|||
],
|
||||
*[
|
||||
parse_statement(
|
||||
f"from {module} import {', '.join(sorted(imports))}",
|
||||
f"import {module} as {asname}",
|
||||
config=updated_node.config_for_parsing,
|
||||
)
|
||||
for module, imports in self.module_mapping.items()
|
||||
for (module, asname) in self.module_aliases.items()
|
||||
],
|
||||
*[
|
||||
parse_statement(
|
||||
f"from {module} import "
|
||||
+ ", ".join(
|
||||
[
|
||||
obj if alias is None else f"{obj} as {alias}"
|
||||
for (obj, alias) in aliases
|
||||
]
|
||||
),
|
||||
config=updated_node.config_for_parsing,
|
||||
)
|
||||
for module, aliases in module_and_alias_mapping.items()
|
||||
if module != "__future__"
|
||||
],
|
||||
*statements_after_imports,
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", None)])
|
||||
self.assertCodemod(before, after, [("a.b.c", None, None)])
|
||||
|
||||
def test_dont_add_module_simple(self) -> None:
|
||||
"""
|
||||
|
|
@ -82,7 +82,57 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", None)])
|
||||
self.assertCodemod(before, after, [("a.b.c", None, None)])
|
||||
|
||||
def test_add_module_alias_simple(self) -> None:
|
||||
"""
|
||||
Should add module with alias as an import.
|
||||
"""
|
||||
|
||||
before = """
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
after = """
|
||||
import a.b.c as d
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", None, "d")])
|
||||
|
||||
def test_dont_add_module_alias_simple(self) -> None:
|
||||
"""
|
||||
Should not add module with alias as an import since it exists
|
||||
"""
|
||||
|
||||
before = """
|
||||
import a.b.c as d
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
after = """
|
||||
import a.b.c as d
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", None, "d")])
|
||||
|
||||
def test_add_module_complex(self) -> None:
|
||||
"""
|
||||
|
|
@ -104,6 +154,8 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
import sys
|
||||
import a.b.c
|
||||
import defg.hi
|
||||
import jkl as h
|
||||
import i.j as k
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
|
@ -113,7 +165,15 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
"""
|
||||
|
||||
self.assertCodemod(
|
||||
before, after, [("a.b.c", None), ("defg.hi", None), ("argparse", None)]
|
||||
before,
|
||||
after,
|
||||
[
|
||||
("a.b.c", None, None),
|
||||
("defg.hi", None, None),
|
||||
("argparse", None, None),
|
||||
("jkl", None, "h"),
|
||||
("i.j", None, "k"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_add_object_simple(self) -> None:
|
||||
|
|
@ -138,7 +198,31 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", "D")])
|
||||
self.assertCodemod(before, after, [("a.b.c", "D", None)])
|
||||
|
||||
def test_add_object_alias_simple(self) -> None:
|
||||
"""
|
||||
Should add object with alias as an import.
|
||||
"""
|
||||
|
||||
before = """
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
after = """
|
||||
from a.b.c import D as E
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", "D", "E")])
|
||||
|
||||
def test_add_future(self) -> None:
|
||||
"""
|
||||
|
|
@ -167,7 +251,7 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("__future__", "dummy_feature")])
|
||||
self.assertCodemod(before, after, [("__future__", "dummy_feature", None)])
|
||||
|
||||
def test_dont_add_object_simple(self) -> None:
|
||||
"""
|
||||
|
|
@ -193,7 +277,33 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", "D")])
|
||||
self.assertCodemod(before, after, [("a.b.c", "D", None)])
|
||||
|
||||
def test_dont_add_object_alias_simple(self) -> None:
|
||||
"""
|
||||
Should not add object as an import since it exists.
|
||||
"""
|
||||
|
||||
before = """
|
||||
from a.b.c import D as E
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
after = """
|
||||
from a.b.c import D as E
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", "D", "E")])
|
||||
|
||||
def test_add_object_modify_simple(self) -> None:
|
||||
"""
|
||||
|
|
@ -219,7 +329,33 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", "D")])
|
||||
self.assertCodemod(before, after, [("a.b.c", "D", None)])
|
||||
|
||||
def test_add_object_alias_modify_simple(self) -> None:
|
||||
"""
|
||||
Should modify existing import with alias to add new object
|
||||
"""
|
||||
|
||||
before = """
|
||||
from a.b.c import E, F
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
after = """
|
||||
from a.b.c import D as _, E, F
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
def bar() -> int:
|
||||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", "D", "_")])
|
||||
|
||||
def test_add_object_modify_complex(self) -> None:
|
||||
"""
|
||||
|
|
@ -227,7 +363,7 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
"""
|
||||
|
||||
before = """
|
||||
from a.b.c import E, F
|
||||
from a.b.c import E, F, G as H
|
||||
from d.e.f import Foo, Bar
|
||||
|
||||
def foo() -> None:
|
||||
|
|
@ -237,9 +373,9 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
return 5
|
||||
"""
|
||||
after = """
|
||||
from a.b.c import D, E, F
|
||||
from d.e.f import Foo, Bar
|
||||
from g.h.i import X, Y, Z
|
||||
from a.b.c import D, E, F, G as H
|
||||
from d.e.f import Baz as Qux, Foo, Bar
|
||||
from g.h.i import V as W, X, Y, Z
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
|
@ -252,14 +388,17 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
before,
|
||||
after,
|
||||
[
|
||||
("a.b.c", "D"),
|
||||
("a.b.c", "F"),
|
||||
("d.e.f", "Foo"),
|
||||
("g.h.i", "Z"),
|
||||
("g.h.i", "X"),
|
||||
("d.e.f", "Bar"),
|
||||
("g.h.i", "Y"),
|
||||
("a.b.c", "F"),
|
||||
("a.b.c", "D", None),
|
||||
("a.b.c", "F", None),
|
||||
("a.b.c", "G", "H"),
|
||||
("d.e.f", "Foo", None),
|
||||
("g.h.i", "Z", None),
|
||||
("g.h.i", "X", None),
|
||||
("d.e.f", "Bar", None),
|
||||
("d.e.f", "Baz", "Qux"),
|
||||
("g.h.i", "Y", None),
|
||||
("g.h.i", "V", "W"),
|
||||
("a.b.c", "F", None),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -273,6 +412,7 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
import sys
|
||||
from a.b.c import E, F
|
||||
from d.e.f import Foo, Bar
|
||||
import bar as baz
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
|
@ -285,7 +425,9 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
import sys
|
||||
from a.b.c import D, E, F
|
||||
from d.e.f import Foo, Bar
|
||||
import bar as baz
|
||||
import foo
|
||||
import qux as quux
|
||||
from g.h.i import X, Y, Z
|
||||
|
||||
def foo() -> None:
|
||||
|
|
@ -299,16 +441,18 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
before,
|
||||
after,
|
||||
[
|
||||
("a.b.c", "D"),
|
||||
("a.b.c", "F"),
|
||||
("d.e.f", "Foo"),
|
||||
("sys", None),
|
||||
("g.h.i", "Z"),
|
||||
("g.h.i", "X"),
|
||||
("d.e.f", "Bar"),
|
||||
("g.h.i", "Y"),
|
||||
("foo", None),
|
||||
("a.b.c", "F"),
|
||||
("a.b.c", "D", None),
|
||||
("a.b.c", "F", None),
|
||||
("d.e.f", "Foo", None),
|
||||
("sys", None, None),
|
||||
("g.h.i", "Z", None),
|
||||
("g.h.i", "X", None),
|
||||
("d.e.f", "Bar", None),
|
||||
("g.h.i", "Y", None),
|
||||
("foo", None, None),
|
||||
("a.b.c", "F", None),
|
||||
("bar", None, "baz"),
|
||||
("qux", None, "quux"),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -338,7 +482,7 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", "D")])
|
||||
self.assertCodemod(before, after, [("a.b.c", "D", None)])
|
||||
|
||||
def test_add_import_preserve_doctring_multiples(self) -> None:
|
||||
"""
|
||||
|
|
@ -367,7 +511,9 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
return 5
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("a.b.c", "D"), ("argparse", None)])
|
||||
self.assertCodemod(
|
||||
before, after, [("a.b.c", "D", None), ("argparse", None, None)]
|
||||
)
|
||||
|
||||
def test_strict_module_no_imports(self) -> None:
|
||||
"""
|
||||
|
|
@ -387,7 +533,7 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
pass
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("argparse", None)])
|
||||
self.assertCodemod(before, after, [("argparse", None, None)])
|
||||
|
||||
def test_strict_module_with_imports(self) -> None:
|
||||
"""
|
||||
|
|
@ -411,4 +557,4 @@ class TestAddImportsCodemod(CodemodTest):
|
|||
pass
|
||||
"""
|
||||
|
||||
self.assertCodemod(before, after, [("argparse", None)])
|
||||
self.assertCodemod(before, after, [("argparse", None, None)])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue