mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
277 lines
10 KiB
Python
277 lines
10 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
#
|
|
from typing import Optional
|
|
|
|
import libcst as cst
|
|
from libcst.helpers.common import ensure_type
|
|
from libcst.helpers.module import (
|
|
calculate_module_and_package,
|
|
get_absolute_module_for_import,
|
|
get_absolute_module_for_import_or_raise,
|
|
get_absolute_module_from_package_for_import,
|
|
get_absolute_module_from_package_for_import_or_raise,
|
|
insert_header_comments,
|
|
ModuleNameAndPackage,
|
|
)
|
|
from libcst.testing.utils import data_provider, UnitTest
|
|
|
|
|
|
class ModuleTest(UnitTest):
|
|
def test_insert_header_comments(self) -> None:
|
|
inserted_comments = ["# INSERT ME", "# AND ME"]
|
|
comment_lines = ["# First comment", "# Another one", "# comment 3"]
|
|
empty_lines = [" ", ""]
|
|
non_header_line = ["SOME_VARIABLE = 0"]
|
|
original_code = "\n".join(comment_lines + empty_lines + non_header_line)
|
|
expected_code = "\n".join(
|
|
comment_lines + inserted_comments + empty_lines + non_header_line
|
|
)
|
|
node = cst.parse_module(original_code)
|
|
self.assertEqual(
|
|
insert_header_comments(node, inserted_comments).code, expected_code
|
|
)
|
|
|
|
# No comment case
|
|
original_code = "\n".join(empty_lines + non_header_line)
|
|
expected_code = "\n".join(inserted_comments + empty_lines + non_header_line)
|
|
node = cst.parse_module(original_code)
|
|
self.assertEqual(
|
|
insert_header_comments(node, inserted_comments).code, expected_code
|
|
)
|
|
|
|
# No empty lines case
|
|
original_code = "\n".join(comment_lines + non_header_line)
|
|
expected_code = "\n".join(comment_lines + inserted_comments + non_header_line)
|
|
node = cst.parse_module(original_code)
|
|
self.assertEqual(
|
|
insert_header_comments(node, inserted_comments).code, expected_code
|
|
)
|
|
|
|
# Empty line between comments
|
|
comment_lines.insert(1, " ")
|
|
original_code = "\n".join(comment_lines + empty_lines + non_header_line)
|
|
expected_code = "\n".join(
|
|
comment_lines + inserted_comments + empty_lines + non_header_line
|
|
)
|
|
node = cst.parse_module(original_code)
|
|
self.assertEqual(
|
|
insert_header_comments(node, inserted_comments).code, expected_code
|
|
)
|
|
|
|
# No header case
|
|
original_code = "\n".join(non_header_line)
|
|
expected_code = "\n".join(inserted_comments + non_header_line)
|
|
node = cst.parse_module(original_code)
|
|
self.assertEqual(
|
|
insert_header_comments(node, inserted_comments).code, expected_code
|
|
)
|
|
|
|
@data_provider(
|
|
(
|
|
# Simple imports that are already absolute.
|
|
(None, "from a.b import c", "a.b"),
|
|
("x.y.z", "from a.b import c", "a.b"),
|
|
# Relative import that can't be resolved due to missing module.
|
|
(None, "from ..w import c", None),
|
|
# Relative import that goes past the module level.
|
|
("x", "from ...y import z", None),
|
|
("x.y.z", "from .....w import c", None),
|
|
("x.y.z", "from ... import c", None),
|
|
# Correct resolution of absolute from relative modules.
|
|
("x.y.z", "from . import c", "x.y"),
|
|
("x.y.z", "from .. import c", "x"),
|
|
("x.y.z", "from .w import c", "x.y.w"),
|
|
("x.y.z", "from ..w import c", "x.w"),
|
|
("x.y.z", "from ...w import c", "w"),
|
|
)
|
|
)
|
|
def test_get_absolute_module(
|
|
self,
|
|
module: Optional[str],
|
|
importfrom: str,
|
|
output: Optional[str],
|
|
) -> None:
|
|
node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine)
|
|
assert len(node.body) == 1, "Unexpected number of statements!"
|
|
import_node = ensure_type(node.body[0], cst.ImportFrom)
|
|
|
|
self.assertEqual(get_absolute_module_for_import(module, import_node), output)
|
|
if output is None:
|
|
with self.assertRaises(Exception):
|
|
get_absolute_module_for_import_or_raise(module, import_node)
|
|
else:
|
|
self.assertEqual(
|
|
get_absolute_module_for_import_or_raise(module, import_node), output
|
|
)
|
|
|
|
@data_provider(
|
|
(
|
|
# Simple imports that are already absolute.
|
|
(None, "from a.b import c", "a.b"),
|
|
("x/y/z.py", "from a.b import c", "a.b"),
|
|
("x/y/z/__init__.py", "from a.b import c", "a.b"),
|
|
# Relative import that can't be resolved due to missing module.
|
|
(None, "from ..w import c", None),
|
|
# Attempted relative import with no known parent package
|
|
("__init__.py", "from .y import z", None),
|
|
("x.py", "from .y import z", None),
|
|
# Relative import that goes past the module level.
|
|
("x.py", "from ...y import z", None),
|
|
("x/y/z.py", "from ... import c", None),
|
|
("x/y/z.py", "from ...w import c", None),
|
|
("x/y/z/__init__.py", "from .... import c", None),
|
|
("x/y/z/__init__.py", "from ....w import c", None),
|
|
# Correct resolution of absolute from relative modules.
|
|
("x/y/z.py", "from . import c", "x.y"),
|
|
("x/y/z.py", "from .. import c", "x"),
|
|
("x/y/z.py", "from .w import c", "x.y.w"),
|
|
("x/y/z.py", "from ..w import c", "x.w"),
|
|
("x/y/z/__init__.py", "from . import c", "x.y.z"),
|
|
("x/y/z/__init__.py", "from .. import c", "x.y"),
|
|
("x/y/z/__init__.py", "from ... import c", "x"),
|
|
("x/y/z/__init__.py", "from .w import c", "x.y.z.w"),
|
|
("x/y/z/__init__.py", "from ..w import c", "x.y.w"),
|
|
("x/y/z/__init__.py", "from ...w import c", "x.w"),
|
|
)
|
|
)
|
|
def test_get_absolute_module_from_package(
|
|
self,
|
|
filename: Optional[str],
|
|
importfrom: str,
|
|
output: Optional[str],
|
|
) -> None:
|
|
package = None
|
|
if filename is not None:
|
|
info = calculate_module_and_package(".", filename)
|
|
package = info.package
|
|
node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine)
|
|
assert len(node.body) == 1, "Unexpected number of statements!"
|
|
import_node = ensure_type(node.body[0], cst.ImportFrom)
|
|
|
|
self.assertEqual(
|
|
get_absolute_module_from_package_for_import(package, import_node), output
|
|
)
|
|
if output is None:
|
|
with self.assertRaises(Exception):
|
|
get_absolute_module_from_package_for_import_or_raise(
|
|
package, import_node
|
|
)
|
|
else:
|
|
self.assertEqual(
|
|
get_absolute_module_from_package_for_import_or_raise(
|
|
package, import_node
|
|
),
|
|
output,
|
|
)
|
|
|
|
@data_provider(
|
|
(
|
|
# Nodes without an asname
|
|
(cst.ImportAlias(name=cst.Name("foo")), "foo", None),
|
|
(
|
|
cst.ImportAlias(name=cst.Attribute(cst.Name("foo"), cst.Name("bar"))),
|
|
"foo.bar",
|
|
None,
|
|
),
|
|
# Nodes with an asname
|
|
(
|
|
cst.ImportAlias(
|
|
name=cst.Name("foo"), asname=cst.AsName(name=cst.Name("baz"))
|
|
),
|
|
"foo",
|
|
"baz",
|
|
),
|
|
(
|
|
cst.ImportAlias(
|
|
name=cst.Attribute(cst.Name("foo"), cst.Name("bar")),
|
|
asname=cst.AsName(name=cst.Name("baz")),
|
|
),
|
|
"foo.bar",
|
|
"baz",
|
|
),
|
|
)
|
|
)
|
|
def test_importalias_helpers(
|
|
self, alias_node: cst.ImportAlias, full_name: str, alias: Optional[str]
|
|
) -> None:
|
|
self.assertEqual(alias_node.evaluated_name, full_name)
|
|
self.assertEqual(alias_node.evaluated_alias, alias)
|
|
|
|
@data_provider(
|
|
(
|
|
# Various files inside the root should give back valid modules.
|
|
(
|
|
"/home/username/root",
|
|
"/home/username/root/file.py",
|
|
ModuleNameAndPackage("file", ""),
|
|
),
|
|
(
|
|
"/home/username/root/",
|
|
"/home/username/root/file.py",
|
|
ModuleNameAndPackage("file", ""),
|
|
),
|
|
(
|
|
"/home/username/root/",
|
|
"/home/username/root/some/dir/file.py",
|
|
ModuleNameAndPackage("some.dir.file", "some.dir"),
|
|
),
|
|
# Various special files inside the root should give back valid modules.
|
|
(
|
|
"/home/username/root/",
|
|
"/home/username/root/some/dir/__init__.py",
|
|
ModuleNameAndPackage("some.dir", "some.dir"),
|
|
),
|
|
(
|
|
"/home/username/root/",
|
|
"/home/username/root/some/dir/__main__.py",
|
|
ModuleNameAndPackage("some.dir", "some.dir"),
|
|
),
|
|
(
|
|
"c:/Program Files/",
|
|
"c:/Program Files/some/dir/file.py",
|
|
ModuleNameAndPackage("some.dir.file", "some.dir"),
|
|
),
|
|
(
|
|
"c:/Program Files/",
|
|
"c:/Program Files/some/dir/__main__.py",
|
|
ModuleNameAndPackage("some.dir", "some.dir"),
|
|
),
|
|
),
|
|
)
|
|
def test_calculate_module_and_package(
|
|
self,
|
|
repo_root: str,
|
|
filename: str,
|
|
module_and_package: Optional[ModuleNameAndPackage],
|
|
) -> None:
|
|
self.assertEqual(
|
|
calculate_module_and_package(repo_root, filename), module_and_package
|
|
)
|
|
|
|
@data_provider(
|
|
(
|
|
# Providing a file outside the root should raise an exception
|
|
("/home/username/root", "/some/dummy/file.py"),
|
|
("/home/username/root/", "/some/dummy/file.py"),
|
|
("/home/username/root", "/home/username/file.py"),
|
|
# some windows tests
|
|
(
|
|
"c:/Program Files/",
|
|
"d:/Program Files/some/dir/file.py",
|
|
),
|
|
(
|
|
"c:/Program Files/other/",
|
|
"c:/Program Files/some/dir/file.py",
|
|
),
|
|
)
|
|
)
|
|
def test_invalid_module_and_package(
|
|
self,
|
|
repo_root: str,
|
|
filename: str,
|
|
) -> None:
|
|
with self.assertRaises(ValueError):
|
|
calculate_module_and_package(repo_root, filename)
|