diff --git a/libcst/codemod/visitors/_remove_imports.py b/libcst/codemod/visitors/_remove_imports.py index d1cc8b8c..0023bd22 100644 --- a/libcst/codemod/visitors/_remove_imports.py +++ b/libcst/codemod/visitors/_remove_imports.py @@ -306,20 +306,14 @@ class RemoveImportsVisitor(ContextAwareTransformer): if name_or_alias in self.exported_objects: return True - # number of references to the name - references_count = 0 - # number of imports to the same name - assignments_count = 0 for assignment in scope[name_or_alias]: - if isinstance(assignment, Assignment) and isinstance( - assignment.node, (cst.ImportFrom, cst.Import) + if ( + isinstance(assignment, Assignment) + and isinstance(assignment.node, (cst.ImportFrom, cst.Import)) + and len(assignment.references) > 0 ): - assignments_count += 1 - references_count += len(assignment.references) - - # Remove the import if it's a candidate to remove with no references or - # multiple assignments. - return not (references_count == 0 or assignments_count > 1) + return True + return False def leave_Import( self, original_node: cst.Import, updated_node: cst.Import diff --git a/libcst/codemod/visitors/tests/test_remove_imports.py b/libcst/codemod/visitors/tests/test_remove_imports.py index 4e2801c2..ec8e460c 100644 --- a/libcst/codemod/visitors/tests/test_remove_imports.py +++ b/libcst/codemod/visitors/tests/test_remove_imports.py @@ -457,7 +457,7 @@ class TestRemoveImportsCodemod(CodemodTest): def test_remove_import_multiple_assignments(self) -> None: """ - Should remove import with multiple assignments + Should not remove import with multiple assignments """ before = """ @@ -468,6 +468,7 @@ class TestRemoveImportsCodemod(CodemodTest): bar() """ after = """ + from foo import bar from qux import bar def foo() -> None: @@ -476,6 +477,42 @@ class TestRemoveImportsCodemod(CodemodTest): self.assertCodemod(before, after, [("foo", "bar", None)]) + def test_remove_multiple_imports(self) -> None: + """ + Multiple imports + """ + before = """ + try: + import a + except Exception: + import a + + a.hello() + """ + after = """ + try: + import a + except Exception: + import a + + a.hello() + """ + self.assertCodemod(before, after, [("a", None, None)]) + + before = """ + try: + import a + except Exception: + import a + """ + after = """ + try: + pass + except Exception: + pass + """ + self.assertCodemod(before, after, [("a", None, None)]) + @data_provider( ( # Simple removal, no other uses.