From e6b84fc11a73a3f1579f3946006ca8358dd0fdf3 Mon Sep 17 00:00:00 2001 From: Jimmy Lai Date: Wed, 30 Oct 2019 22:30:46 -0700 Subject: [PATCH] [scope] add qualified name support for imported attribute --- libcst/metadata/scope_provider.py | 25 ++++++++------- libcst/metadata/tests/test_scope_provider.py | 33 ++++++++++++++++++++ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 2c0023fa..bb3e830f 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -238,7 +238,7 @@ class _NameUtil: @staticmethod def find_qualified_name_for_import_alike( - assignment_node: Union[cst.Import, cst.ImportFrom], name_parts: List[str] + assignment_node: Union[cst.Import, cst.ImportFrom], full_name: str ) -> Set[QualifiedName]: module = "" results = set() @@ -256,13 +256,16 @@ class _NameUtil: name_asname = name.asname if name_asname: as_name = cst.ensure_type(name_asname.name, cst.Name).value - if as_name == name_parts[0]: + if as_name and full_name.startswith(as_name): if module: real_name = f"{module}.{real_name}" if real_name: + remaining_name = full_name.split(as_name)[1].lstrip(".") results.add( QualifiedName( - ".".join([real_name, *name_parts[1:]]), + f"{real_name}.{remaining_name}" + if remaining_name + else real_name, QualifiedNameSource.IMPORT, ) ) @@ -270,7 +273,7 @@ class _NameUtil: @staticmethod def find_qualified_name_for_non_import( - assignment: Assignment, name_parts: List[str] + assignment: Assignment, remaining_name: str ) -> Set[QualifiedName]: scope = assignment.scope name_prefixes = [] @@ -287,12 +290,10 @@ class _NameUtil: raise Exception(f"Unexpected Scope: {scope}") scope = scope.parent - return { - QualifiedName( - ".".join([*reversed(name_prefixes), *name_parts]), - QualifiedNameSource.LOCAL, - ) - } + parts = [*reversed(name_prefixes)] + if remaining_name: + parts.append(remaining_name) + return {QualifiedName(".".join(parts), QualifiedNameSource.LOCAL)} class Scope(abc.ABC): @@ -439,11 +440,11 @@ class Scope(abc.ABC): assignment_node = assignment.node if isinstance(assignment_node, (cst.Import, cst.ImportFrom)): results |= _NameUtil.find_qualified_name_for_import_alike( - assignment_node, parts + assignment_node, full_name ) else: results |= _NameUtil.find_qualified_name_for_non_import( - assignment, parts + assignment, full_name ) elif isinstance(assignment, BuiltinAssignment): results.add( diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 828276c7..0bf1c74d 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -905,3 +905,36 @@ class ScopeProviderTest(UnitTest): c_scope = scopes[target_call_2] self.assertIsInstance(c_scope, ClassScope) self.assertEqual(cast(ClassScope, c_scope).node, c) + + def test_with_statement(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + import unittest.mock + + with unittest.mock.patch("something") as obj: + obj.f1() + + unittest.mock + """ + ) + import_ = ensure_type(m.body[0], cst.SimpleStatementLine).body[0] + assignments = scopes[import_]["unittest"] + self.assertEqual(len(assignments), 1) + self.assertEqual(cast(Assignment, list(assignments)[0]).node, import_) + with_ = ensure_type(m.body[1], cst.With) + fn_call = with_.items[0].item + self.assertEqual( + scopes[fn_call].get_qualified_names_for(fn_call), + { + QualifiedName( + name="unittest.mock.patch", source=QualifiedNameSource.IMPORT + ) + }, + ) + mock = ensure_type( + ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr + ).value + self.assertEqual( + scopes[fn_call].get_qualified_names_for(mock), + {QualifiedName(name="unittest.mock", source=QualifiedNameSource.IMPORT)}, + )