Port add import updates to open-source.

This commit is contained in:
Jennifer Taylor 2020-01-06 13:46:07 -08:00 committed by Jennifer Taylor
parent 109a0bbc16
commit 0932049c75
2 changed files with 296 additions and 50 deletions

View file

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
#
# pyre-strict
from collections import defaultdict
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
import libcst
@ -60,7 +61,7 @@ class AddImportsVisitor(ContextAwareTransformer):
@staticmethod
def _get_imports_from_context(
context: CodemodContext,
) -> List[Tuple[str, Optional[str]]]:
) -> List[Tuple[str, Optional[str], Optional[str]]]:
imports = context.scratch.get(AddImportsVisitor.CONTEXT_KEY, [])
if not isinstance(imports, list):
raise Exception("Logic error!")
@ -68,7 +69,10 @@ class AddImportsVisitor(ContextAwareTransformer):
@staticmethod
def add_needed_import(
context: CodemodContext, module: str, obj: Optional[str] = None
context: CodemodContext,
module: str,
obj: Optional[str] = None,
asname: Optional[str] = None,
) -> None:
"""
Schedule an import to be added in a future invocation of this class by
@ -84,38 +88,72 @@ class AddImportsVisitor(ContextAwareTransformer):
if module == "__future__" and obj is None:
raise Exception("Cannot import __future__ directly!")
imports = AddImportsVisitor._get_imports_from_context(context)
imports.append((module, obj))
imports.append((module, obj, asname))
context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports
def __init__(
self, context: CodemodContext, imports: Sequence[Tuple[str, Optional[str]]] = ()
self,
context: CodemodContext,
imports: Sequence[Tuple[str, Optional[str], Optional[str]]] = (),
) -> None:
# Allow for instantiation from either a context (used when multiple transforms
# get chained) or from a direct instantiation.
super().__init__(context)
imports: List[Tuple[str, Optional[str]]] = [
imports: List[Tuple[str, Optional[str], Optional[str]]] = [
*AddImportsVisitor._get_imports_from_context(context),
*imports,
]
# Verify that the imports are valid
for module, obj in imports:
for module, obj, alias in imports:
if module == "__future__" and obj is None:
raise Exception("Cannot import __future__ directly!")
if module == "__future__" and alias is not None:
raise Exception("Cannot import __future__ objects with aliases!")
# List of modules we need to ensure are imported
self.module_imports: Set[str] = {
module for (module, obj) in imports if obj is None
module for (module, obj, alias) in imports if obj is None and alias is None
}
# List of modules we need to check for object imports on
from_imports: Set[str] = {
module for (module, obj) in imports if obj is not None
module
for (module, obj, alias) in imports
if obj is not None and alias is None
}
# Mapping of modules we're adding to the object they should import
self.module_mapping: Dict[str, Set[str]] = {
module: {o for (m, o) in imports if m == module and o is not None}
module: {
o
for (m, o, n) in imports
if m == module and o is not None and n is None
}
for module in sorted(from_imports)
}
# List of aliased modules we need to ensure are imported
self.module_aliases: Dict[str, str] = {
module: alias
for (module, obj, alias) in imports
if obj is None and alias is not None
}
# List of modules we need to check for object imports on
from_imports_aliases: Set[str] = {
module
for (module, obj, alias) in imports
if obj is not None and alias is not None
}
# Mapping of modules we're adding to the object with alias they should import
self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {
module: [
(o, n)
for (m, o, n) in imports
if m == module and o is not None and n is not None
]
for module in sorted(from_imports_aliases)
}
# Track the list of imports found in the file
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
@ -124,7 +162,21 @@ class AddImportsVisitor(ContextAwareTransformer):
gatherer = GatherImportsVisitor(self.context)
node.visit(gatherer)
self.all_imports = gatherer.all_imports
self.module_imports = self.module_imports - gatherer.module_imports
for module, alias in gatherer.module_aliases.items():
if module in self.module_aliases and self.module_aliases[module] == alias:
del self.module_aliases[module]
for module, aliases in gatherer.alias_mapping.items():
for (obj, alias) in aliases:
if (
module in self.alias_mapping
and (obj, alias) in self.alias_mapping[module]
):
self.alias_mapping[module].remove((obj, alias))
if len(self.alias_mapping[module]) == 0:
del self.alias_mapping[module]
for module, imports in gatherer.object_mapping.items():
if module not in self.module_mapping:
# We don't care about this import at all
@ -161,17 +213,28 @@ class AddImportsVisitor(ContextAwareTransformer):
# Get the module we're importing as a string, see if we have work to do
module = self._get_string_name(updated_node.module)
if module not in self.module_mapping:
if module not in self.module_mapping and module not in self.alias_mapping:
return updated_node
# We have work to do, mark that we won't modify this again.
imports_to_add = self.module_mapping[module]
del self.module_mapping[module]
imports_to_add = self.module_mapping.get(module, [])
if module in self.module_mapping:
del self.module_mapping[module]
aliases_to_add = self.alias_mapping.get(module, [])
if module in self.alias_mapping:
del self.alias_mapping[module]
# Now, do the actual update.
return updated_node.with_changes(
names=(
*[libcst.ImportAlias(name=libcst.Name(imp)) for imp in imports_to_add],
*[
libcst.ImportAlias(
name=libcst.Name(imp),
asname=libcst.AsName(name=libcst.Name(alias)),
)
for (imp, alias) in aliases_to_add
],
*updated_node.names,
)
)
@ -249,7 +312,12 @@ class AddImportsVisitor(ContextAwareTransformer):
self, original_node: libcst.Module, updated_node: libcst.Module
) -> libcst.Module:
# Don't try to modify if we have nothing to do
if not self.module_imports and not self.module_mapping:
if (
not self.module_imports
and not self.module_mapping
and not self.module_aliases
and not self.alias_mapping
):
return updated_node
# First, find the insertion point for imports
@ -260,15 +328,34 @@ class AddImportsVisitor(ContextAwareTransformer):
# Make sure there's at least one empty line before the first non-import
statements_after_imports = self._insert_empty_line(statements_after_imports)
# Mapping of modules we're adding to the object with and without alias they should import
module_and_alias_mapping = defaultdict(list)
for module, aliases in self.alias_mapping.items():
module_and_alias_mapping[module].extend(aliases)
for module, imports in self.module_mapping.items():
module_and_alias_mapping[module].extend(
[(object, None) for object in imports]
)
module_and_alias_mapping = {
module: sorted(aliases)
for module, aliases in module_and_alias_mapping.items()
}
# import ptvsd; ptvsd.set_trace()
# Now, add all of the imports we need!
return updated_node.with_changes(
body=(
*[
parse_statement(
f"from {module} import {', '.join(sorted(imports))}",
f"from {module} import "
+ ", ".join(
[
obj if alias is None else f"{obj} as {alias}"
for (obj, alias) in aliases
]
),
config=updated_node.config_for_parsing,
)
for module, imports in self.module_mapping.items()
for module, aliases in module_and_alias_mapping.items()
if module == "__future__"
],
*statements_before_imports,
@ -280,10 +367,23 @@ class AddImportsVisitor(ContextAwareTransformer):
],
*[
parse_statement(
f"from {module} import {', '.join(sorted(imports))}",
f"import {module} as {asname}",
config=updated_node.config_for_parsing,
)
for module, imports in self.module_mapping.items()
for (module, asname) in self.module_aliases.items()
],
*[
parse_statement(
f"from {module} import "
+ ", ".join(
[
obj if alias is None else f"{obj} as {alias}"
for (obj, alias) in aliases
]
),
config=updated_node.config_for_parsing,
)
for module, aliases in module_and_alias_mapping.items()
if module != "__future__"
],
*statements_after_imports,

View file

@ -56,7 +56,7 @@ class TestAddImportsCodemod(CodemodTest):
return 5
"""
self.assertCodemod(before, after, [("a.b.c", None)])
self.assertCodemod(before, after, [("a.b.c", None, None)])
def test_dont_add_module_simple(self) -> None:
"""
@ -82,7 +82,57 @@ class TestAddImportsCodemod(CodemodTest):
return 5
"""
self.assertCodemod(before, after, [("a.b.c", None)])
self.assertCodemod(before, after, [("a.b.c", None, None)])
def test_add_module_alias_simple(self) -> None:
"""
Should add module with alias as an import.
"""
before = """
def foo() -> None:
pass
def bar() -> int:
return 5
"""
after = """
import a.b.c as d
def foo() -> None:
pass
def bar() -> int:
return 5
"""
self.assertCodemod(before, after, [("a.b.c", None, "d")])
def test_dont_add_module_alias_simple(self) -> None:
"""
Should not add module with alias as an import since it exists
"""
before = """
import a.b.c as d
def foo() -> None:
pass
def bar() -> int:
return 5
"""
after = """
import a.b.c as d
def foo() -> None:
pass
def bar() -> int:
return 5
"""
self.assertCodemod(before, after, [("a.b.c", None, "d")])
def test_add_module_complex(self) -> None:
"""
@ -104,6 +154,8 @@ class TestAddImportsCodemod(CodemodTest):
import sys
import a.b.c
import defg.hi
import jkl as h
import i.j as k
def foo() -> None:
pass
@ -113,7 +165,15 @@ class TestAddImportsCodemod(CodemodTest):
"""
self.assertCodemod(
before, after, [("a.b.c", None), ("defg.hi", None), ("argparse", None)]
before,
after,
[
("a.b.c", None, None),
("defg.hi", None, None),
("argparse", None, None),
("jkl", None, "h"),
("i.j", None, "k"),
],
)
def test_add_object_simple(self) -> None:
@ -138,7 +198,31 @@ class TestAddImportsCodemod(CodemodTest):
return 5
"""
self.assertCodemod(before, after, [("a.b.c", "D")])
self.assertCodemod(before, after, [("a.b.c", "D", None)])
def test_add_object_alias_simple(self) -> None:
"""
Should add object with alias as an import.
"""
before = """
def foo() -> None:
pass
def bar() -> int:
return 5
"""
after = """
from a.b.c import D as E
def foo() -> None:
pass
def bar() -> int:
return 5
"""
self.assertCodemod(before, after, [("a.b.c", "D", "E")])
def test_add_future(self) -> None:
"""
@ -167,7 +251,7 @@ class TestAddImportsCodemod(CodemodTest):
return 5
"""
self.assertCodemod(before, after, [("__future__", "dummy_feature")])
self.assertCodemod(before, after, [("__future__", "dummy_feature", None)])
def test_dont_add_object_simple(self) -> None:
"""
@ -193,7 +277,33 @@ class TestAddImportsCodemod(CodemodTest):
return 5
"""
self.assertCodemod(before, after, [("a.b.c", "D")])
self.assertCodemod(before, after, [("a.b.c", "D", None)])
def test_dont_add_object_alias_simple(self) -> None:
"""
Should not add object as an import since it exists.
"""
before = """
from a.b.c import D as E
def foo() -> None:
pass
def bar() -> int:
return 5
"""
after = """
from a.b.c import D as E
def foo() -> None:
pass
def bar() -> int:
return 5
"""
self.assertCodemod(before, after, [("a.b.c", "D", "E")])
def test_add_object_modify_simple(self) -> None:
"""
@ -219,7 +329,33 @@ class TestAddImportsCodemod(CodemodTest):
return 5
"""
self.assertCodemod(before, after, [("a.b.c", "D")])
self.assertCodemod(before, after, [("a.b.c", "D", None)])
def test_add_object_alias_modify_simple(self) -> None:
"""
Should modify existing import with alias to add new object
"""
before = """
from a.b.c import E, F
def foo() -> None:
pass
def bar() -> int:
return 5
"""
after = """
from a.b.c import D as _, E, F
def foo() -> None:
pass
def bar() -> int:
return 5
"""
self.assertCodemod(before, after, [("a.b.c", "D", "_")])
def test_add_object_modify_complex(self) -> None:
"""
@ -227,7 +363,7 @@ class TestAddImportsCodemod(CodemodTest):
"""
before = """
from a.b.c import E, F
from a.b.c import E, F, G as H
from d.e.f import Foo, Bar
def foo() -> None:
@ -237,9 +373,9 @@ class TestAddImportsCodemod(CodemodTest):
return 5
"""
after = """
from a.b.c import D, E, F
from d.e.f import Foo, Bar
from g.h.i import X, Y, Z
from a.b.c import D, E, F, G as H
from d.e.f import Baz as Qux, Foo, Bar
from g.h.i import V as W, X, Y, Z
def foo() -> None:
pass
@ -252,14 +388,17 @@ class TestAddImportsCodemod(CodemodTest):
before,
after,
[
("a.b.c", "D"),
("a.b.c", "F"),
("d.e.f", "Foo"),
("g.h.i", "Z"),
("g.h.i", "X"),
("d.e.f", "Bar"),
("g.h.i", "Y"),
("a.b.c", "F"),
("a.b.c", "D", None),
("a.b.c", "F", None),
("a.b.c", "G", "H"),
("d.e.f", "Foo", None),
("g.h.i", "Z", None),
("g.h.i", "X", None),
("d.e.f", "Bar", None),
("d.e.f", "Baz", "Qux"),
("g.h.i", "Y", None),
("g.h.i", "V", "W"),
("a.b.c", "F", None),
],
)
@ -273,6 +412,7 @@ class TestAddImportsCodemod(CodemodTest):
import sys
from a.b.c import E, F
from d.e.f import Foo, Bar
import bar as baz
def foo() -> None:
pass
@ -285,7 +425,9 @@ class TestAddImportsCodemod(CodemodTest):
import sys
from a.b.c import D, E, F
from d.e.f import Foo, Bar
import bar as baz
import foo
import qux as quux
from g.h.i import X, Y, Z
def foo() -> None:
@ -299,16 +441,18 @@ class TestAddImportsCodemod(CodemodTest):
before,
after,
[
("a.b.c", "D"),
("a.b.c", "F"),
("d.e.f", "Foo"),
("sys", None),
("g.h.i", "Z"),
("g.h.i", "X"),
("d.e.f", "Bar"),
("g.h.i", "Y"),
("foo", None),
("a.b.c", "F"),
("a.b.c", "D", None),
("a.b.c", "F", None),
("d.e.f", "Foo", None),
("sys", None, None),
("g.h.i", "Z", None),
("g.h.i", "X", None),
("d.e.f", "Bar", None),
("g.h.i", "Y", None),
("foo", None, None),
("a.b.c", "F", None),
("bar", None, "baz"),
("qux", None, "quux"),
],
)
@ -338,7 +482,7 @@ class TestAddImportsCodemod(CodemodTest):
return 5
"""
self.assertCodemod(before, after, [("a.b.c", "D")])
self.assertCodemod(before, after, [("a.b.c", "D", None)])
def test_add_import_preserve_doctring_multiples(self) -> None:
"""
@ -367,7 +511,9 @@ class TestAddImportsCodemod(CodemodTest):
return 5
"""
self.assertCodemod(before, after, [("a.b.c", "D"), ("argparse", None)])
self.assertCodemod(
before, after, [("a.b.c", "D", None), ("argparse", None, None)]
)
def test_strict_module_no_imports(self) -> None:
"""
@ -387,7 +533,7 @@ class TestAddImportsCodemod(CodemodTest):
pass
"""
self.assertCodemod(before, after, [("argparse", None)])
self.assertCodemod(before, after, [("argparse", None, None)])
def test_strict_module_with_imports(self) -> None:
"""
@ -411,4 +557,4 @@ class TestAddImportsCodemod(CodemodTest):
pass
"""
self.assertCodemod(before, after, [("argparse", None)])
self.assertCodemod(before, after, [("argparse", None, None)])