diff --git a/libcst/_types.py b/libcst/_types.py index 8df90ee5..24055a5c 100644 --- a/libcst/_types.py +++ b/libcst/_types.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. -from typing import TYPE_CHECKING, TypeVar +from pathlib import PurePath +from typing import TYPE_CHECKING, TypeVar, Union if TYPE_CHECKING: from libcst._nodes.base import CSTNode # noqa: F401 @@ -12,3 +13,4 @@ if TYPE_CHECKING: CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode") CSTNodeT_co = TypeVar("CSTNodeT_co", bound="CSTNode", covariant=True) +StrPath = Union[str, PurePath] diff --git a/libcst/helpers/module.py b/libcst/helpers/module.py index f9961807..2ff5ef00 100644 --- a/libcst/helpers/module.py +++ b/libcst/helpers/module.py @@ -9,6 +9,7 @@ from pathlib import PurePath from typing import List, Optional from libcst import Comment, EmptyLine, ImportFrom, Module +from libcst._types import StrPath from libcst.helpers.expression import get_full_name_for_node @@ -130,7 +131,9 @@ class ModuleNameAndPackage: package: str -def calculate_module_and_package(repo_root: str, filename: str) -> ModuleNameAndPackage: +def calculate_module_and_package( + repo_root: StrPath, filename: StrPath +) -> ModuleNameAndPackage: # Given an absolute repo_root and an absolute filename, calculate the # python module name for the file. relative_filename = PurePath(filename).relative_to(repo_root) diff --git a/libcst/metadata/full_repo_manager.py b/libcst/metadata/full_repo_manager.py index 6a7c1e9a..83bb6e83 100644 --- a/libcst/metadata/full_repo_manager.py +++ b/libcst/metadata/full_repo_manager.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Collection, Dict, List, Mapping, TYPE_CHECKING import libcst as cst +from libcst._types import StrPath from libcst.metadata.wrapper import MetadataWrapper if TYPE_CHECKING: @@ -17,7 +18,7 @@ if TYPE_CHECKING: class FullRepoManager: def __init__( self, - repo_root_dir: str, + repo_root_dir: StrPath, paths: Collection[str], providers: Collection["ProviderT"], timeout: int = 5, diff --git a/libcst/metadata/name_provider.py b/libcst/metadata/name_provider.py index 60d8763e..1868fa66 100644 --- a/libcst/metadata/name_provider.py +++ b/libcst/metadata/name_provider.py @@ -114,7 +114,7 @@ class FullyQualifiedNameProvider(BatchableMetadataProvider[Collection[QualifiedN def gen_cache( cls, root_path: Path, paths: List[str], timeout: Optional[int] = None ) -> Mapping[str, ModuleNameAndPackage]: - cache = {path: calculate_module_and_package(".", path) for path in paths} + cache = {path: calculate_module_and_package(root_path, path) for path in paths} return cache def __init__(self, cache: ModuleNameAndPackage) -> None: diff --git a/libcst/metadata/tests/test_name_provider.py b/libcst/metadata/tests/test_name_provider.py index 9f381368..c1ba59bc 100644 --- a/libcst/metadata/tests/test_name_provider.py +++ b/libcst/metadata/tests/test_name_provider.py @@ -560,11 +560,14 @@ class FullyQualifiedNameProviderTest(UnitTest): class FullyQualifiedNameIntegrationTest(UnitTest): def test_with_full_repo_manager(self) -> None: with TemporaryDirectory() as dir: - fname = "pkg/mod.py" - (Path(dir) / "pkg").mkdir() - (Path(dir) / fname).touch() - mgr = FullRepoManager(dir, [fname], [FullyQualifiedNameProvider]) - wrapper = mgr.get_metadata_wrapper_for_path(fname) + root = Path(dir) + file_path = root / "pkg/mod.py" + file_path.parent.mkdir() + file_path.touch() + + file_path_str = file_path.as_posix() + mgr = FullRepoManager(root, [file_path_str], [FullyQualifiedNameProvider]) + wrapper = mgr.get_metadata_wrapper_for_path(file_path_str) fqnames = wrapper.resolve(FullyQualifiedNameProvider) (mod, names) = next(iter(fqnames.items())) self.assertIsInstance(mod, cst.Module)