Only remove trailing comma if the last alias is removed (#334)

This commit is contained in:
Zsolt Dollenstein 2020-07-09 16:37:56 +01:00 committed by GitHub
parent 8523852d05
commit 9d3bb11eb8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 4 deletions

View file

@ -83,3 +83,10 @@ class RemoveUnusedImportsCommandTest(CodemodTest):
b(0)[x] = False
"""
self.assertCodemod(before, before)
def test_no_formatting_if_no_unused_imports(self) -> None:
before = """
from m import (a, b,)
a(b, 'look at these ugly quotes')
"""
self.assertCodemod(before, before)

View file

@ -343,17 +343,22 @@ class RemoveImportsVisitor(ContextAwareTransformer):
names_to_keep.append(import_alias)
continue
# no changes
if names_to_keep == original_node.names:
return updated_node
# Now, either remove this statement or remove the imports we are
# deleting from this statement.
if len(names_to_keep) == 0:
return cst.RemoveFromParent()
else:
if names_to_keep[-1] != original_node.names[-1]:
# Remove trailing comma in order to not mess up import statements.
names_to_keep = [
*names_to_keep[:-1],
names_to_keep[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT),
]
return updated_node.with_changes(names=names_to_keep)
return updated_node.with_changes(names=names_to_keep)
def leave_ImportFrom(
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
@ -399,14 +404,19 @@ class RemoveImportsVisitor(ContextAwareTransformer):
names_to_keep.append(import_alias)
continue
# no changes
if names_to_keep == names:
return updated_node
# Now, either remove this statement or remove the imports we are
# deleting from this statement.
if len(names_to_keep) == 0:
return cst.RemoveFromParent()
else:
if names_to_keep[-1] != names[-1]:
# Remove trailing comma in order to not mess up import statements.
names_to_keep = [
*names_to_keep[:-1],
names_to_keep[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT),
]
return updated_node.with_changes(names=names_to_keep)
return updated_node.with_changes(names=names_to_keep)

View file

@ -797,3 +797,17 @@ class TestRemoveImportsCodemod(CodemodTest):
after,
RemoveImportTransformer(CodemodContext()).transform_module(module).code,
)
def test_remove_comma(self) -> None:
"""
Trailing commas should be removed if and only if the last alias is removed.
"""
before = """
from m import (a, b,)
import x, y
"""
after = """
from m import (b,)
import x
"""
self.assertCodemod(before, after, [("m", "a", None), ("y", None, None)])