From 2e8a0c6df7ab8142dd9ae828f2d6642ac2d81db8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20M=C3=A9ndez=20Bravo?= Date: Tue, 4 Aug 2020 17:33:22 -0700 Subject: [PATCH] Fix dotted names (#358) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary When importing things like `import os.path` and using it as `os.path.join("a", "b").lower()`, references ended up being in the `["os"]` assignment instead of `["os.path"]`. This fixes the problem by updating the dotted names generator in the scope provider· ## Test Plan ``` tox -e py37 ``` Co-authored-by: Germán Méndez Bravo --- libcst/metadata/scope_provider.py | 33 ++++++---- libcst/metadata/tests/test_scope_provider.py | 65 ++++++++++++++++++++ 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index bb0adc56..fdda406a 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -610,18 +610,27 @@ def _gen_dotted_names( yield node.value, node else: value = node.value - if not isinstance(value, (cst.Attribute, cst.Name)): - # this is not an import - return - name_values = _gen_dotted_names(value) - try: - next_name, next_node = next(name_values) - except StopIteration: - return - else: - yield f"{next_name}.{node.attr.value}", node - yield next_name, next_node - yield from name_values + if isinstance(value, cst.Call): + value = value.func + if isinstance(value, (cst.Attribute, cst.Name)): + name_values = _gen_dotted_names(value) + try: + next_name, next_node = next(name_values) + except StopIteration: + return + else: + yield next_name, next_node + yield from name_values + elif isinstance(value, (cst.Attribute, cst.Name)): + name_values = _gen_dotted_names(value) + try: + next_name, next_node = next(name_values) + except StopIteration: + return + else: + yield f"{next_name}.{node.attr.value}", node + yield next_name, next_node + yield from name_values class ScopeVisitor(cst.CSTVisitor): diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 31c1ac1b..a4f24591 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -21,6 +21,7 @@ from libcst.metadata.scope_provider import ( QualifiedNameSource, Scope, ScopeProvider, + _gen_dotted_names, ) from libcst.testing.utils import UnitTest, data_provider @@ -203,6 +204,39 @@ class ScopeProviderTest(UnitTest): self.assertEqual(list(scope_of_module["x.y"])[0].references, set()) self.assertEqual(scope_of_module.accesses["x.y"], set()) + def test_dotted_import_with_call_access(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + import os.path + os.path.join("A", "B").lower() + """ + ) + scope_of_module = scopes[m] + first_statement = ensure_type(m.body[1], cst.SimpleStatementLine) + attr = ensure_type( + ensure_type( + ensure_type( + ensure_type( + ensure_type(first_statement.body[0], cst.Expr).value, cst.Call + ).func, + cst.Attribute, + ).value, + cst.Call, + ).func, + cst.Attribute, + ).value + self.assertTrue("os.path" in scope_of_module) + self.assertTrue("os" in scope_of_module) + + os_path_join_assignment = cast(Assignment, list(scope_of_module["os.path"])[0]) + os_path_join_assignment_references = list(os_path_join_assignment.references) + self.assertNotEqual(len(os_path_join_assignment_references), 0) + os_path_join_access = os_path_join_assignment_references[0] + self.assertEqual(scope_of_module.accesses["os"], set()) + self.assertEqual(scope_of_module.accesses["os.path"], {os_path_join_access}) + self.assertEqual(scope_of_module.accesses["os.path.join"], set()) + self.assertEqual(os_path_join_access.node, attr) + def test_import_from(self) -> None: m, scopes = get_scope_metadata_provider( """ @@ -1153,3 +1187,34 @@ class ScopeProviderTest(UnitTest): scope.get_qualified_names_for("doesnt_exist") self.assertEqual(len(scope._assignments), assignments_len_before) self.assertEqual(len(scope._accesses), accesses_len_before) + + def test_gen_dotted_names(self) -> None: + names = {name for name, node in _gen_dotted_names(cst.Name(value="a"))} + self.assertEqual(names, {"a"}) + + names = { + name + for name, node in _gen_dotted_names( + cst.Attribute(value=cst.Name(value="a"), attr=cst.Name(value="b")) + ) + } + self.assertEqual(names, {"a.b", "a"}) + + names = { + name + for name, node in _gen_dotted_names( + cst.Attribute( + value=cst.Call( + func=cst.Attribute( + value=cst.Attribute( + value=cst.Name(value="a"), attr=cst.Name(value="b") + ), + attr=cst.Name(value="c"), + ), + args=[], + ), + attr=cst.Name(value="d"), + ) + ) + } + self.assertEqual(names, {"a.b.c", "a.b", "a"})