From 9fcf84cf2cb1b6b7e7ba2894bdfb744de5c263dc Mon Sep 17 00:00:00 2001 From: Jimmy Lai Date: Fri, 27 Sep 2019 14:04:44 -0700 Subject: [PATCH] add QualifiedName and QualifiedNameSource enum --- libcst/metadata/scope_provider.py | 35 +++++++++++++----- libcst/metadata/tests/test_scope_provider.py | 37 ++++++++++++-------- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index bc809d09..7f55c787 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -11,6 +11,7 @@ from ast import literal_eval from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass +from enum import Enum, auto from typing import ( Collection, Dict, @@ -98,7 +99,15 @@ class BuiltinAssignment(BaseAssignment): pass -CURRENT_MODULE_PREFIX = "__current_module__" +class QualifiedNameSource(Enum): + IMPORT = auto() + LOCAL = auto() + + +@dataclass(frozen=True) +class QualifiedName: + name: str + source: QualifiedNameSource class _QualifiedNameUtil: @@ -123,7 +132,7 @@ class _QualifiedNameUtil: def find_qualified_name_for_import_alike( assignment_node: Union[cst.Import, cst.ImportFrom], name_parts: List[str], - results: Set[str], + results: Set[QualifiedName], ) -> None: module = "" if isinstance(assignment_node, cst.ImportFrom): @@ -139,11 +148,16 @@ class _QualifiedNameUtil: if module: real_name = f"{module}.{real_name}" if real_name: - results.add(".".join([real_name, *name_parts[1:]])) + results.add( + QualifiedName( + ".".join([real_name, *name_parts[1:]]), + QualifiedNameSource.IMPORT, + ) + ) @staticmethod def find_qualified_name_for_non_import( - assignment: Assignment, name_parts: List[str], results: Set[str] + assignment: Assignment, name_parts: List[str], results: Set[QualifiedName] ) -> None: scope = assignment.scope name_prefixes = [] @@ -157,7 +171,9 @@ class _QualifiedNameUtil: scope = scope.parent results.add( - ".".join([CURRENT_MODULE_PREFIX, *name_prefixes[::-1], *name_parts]) + QualifiedName( + ".".join([*name_prefixes[::-1], *name_parts]), QualifiedNameSource.LOCAL + ) ) @@ -211,7 +227,7 @@ class Scope(abc.ABC): def record_nonlocal_overwrite(self, name: str) -> None: ... - def get_fully_qualified_names_for(self, node: cst.CSTNode) -> Collection[str]: + def get_qualified_names_for(self, node: cst.CSTNode) -> Collection[QualifiedName]: results = set() full_name = _QualifiedNameUtil.get_full_name_for(node) if full_name is None: @@ -232,8 +248,11 @@ class Scope(abc.ABC): assignment, parts, results ) elif isinstance(assignment, BuiltinAssignment): - results.add(f"builtins.{assignment.name}") - # TODO: add support to other type of assignment + results.add( + QualifiedName( + f"builtins.{assignment.name}", QualifiedNameSource.IMPORT + ) + ) return results diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 2bfbaaef..11f8ab8b 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -11,12 +11,13 @@ from typing import Mapping, Tuple, cast import libcst as cst from libcst import ensure_type from libcst.metadata.scope_provider import ( - CURRENT_MODULE_PREFIX, Assignment, ClassScope, ComprehensionScope, FunctionScope, GlobalScope, + QualifiedName, + QualifiedNameSource, Scope, ScopeProvider, ) @@ -612,8 +613,10 @@ class ScopeProviderTest(UnitTest): f = ensure_type(m.body[1], cst.FunctionDef) scope_of_module = scopes[m] self.assertEqual( - scope_of_module.get_fully_qualified_names_for(ensure_type(f.returns, cst.Annotation).annotation), - {"a.b.c"}, + scope_of_module.get_qualified_names_for( + ensure_type(f.returns, cst.Annotation).annotation + ), + {QualifiedName("a.b.c", QualifiedNameSource.IMPORT)}, ) c_call = ensure_type( @@ -621,16 +624,22 @@ class ScopeProviderTest(UnitTest): ).value scope_of_f = scopes[c_call] self.assertIsInstance(scope_of_f, FunctionScope) - self.assertEqual(scope_of_f.get_fully_qualified_names_for(c_call), {"a.b.c"}) - self.assertEqual(scope_of_f.get_fully_qualified_names_for(c_call), {"a.b.c"}) + self.assertEqual( + scope_of_f.get_qualified_names_for(c_call), + {QualifiedName("a.b.c", QualifiedNameSource.IMPORT)}, + ) + self.assertEqual( + scope_of_f.get_qualified_names_for(c_call), + {QualifiedName("a.b.c", QualifiedNameSource.IMPORT)}, + ) f_call = ensure_type( ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr ).value self.assertIsInstance(scope_of_module, GlobalScope) self.assertEqual( - scope_of_module.get_fully_qualified_names_for(f_call), - {f"{CURRENT_MODULE_PREFIX}.f"}, + scope_of_module.get_qualified_names_for(f_call), + {QualifiedName("f", QualifiedNameSource.LOCAL)}, ) d_name = ( ensure_type( @@ -640,8 +649,8 @@ class ScopeProviderTest(UnitTest): .target ) self.assertEqual( - scope_of_f.get_fully_qualified_names_for(d_name), - {f"{CURRENT_MODULE_PREFIX}.f.d"}, + scope_of_f.get_qualified_names_for(d_name), + {QualifiedName("f.d", QualifiedNameSource.LOCAL)}, ) d_subscript = ( ensure_type( @@ -651,19 +660,19 @@ class ScopeProviderTest(UnitTest): .target ) self.assertEqual( - scope_of_f.get_fully_qualified_names_for(d_subscript), - {f"{CURRENT_MODULE_PREFIX}.f.d"}, + scope_of_f.get_qualified_names_for(d_subscript), + {QualifiedName("f.d", QualifiedNameSource.LOCAL)}, ) for builtin in ["map", "int", "dict"]: self.assertEqual( - scope_of_f.get_fully_qualified_names_for(cst.Name(value=builtin)), - {f"builtins.{builtin}"}, + scope_of_f.get_qualified_names_for(cst.Name(value=builtin)), + {QualifiedName(f"builtins.{builtin}", QualifiedNameSource.IMPORT)}, f"Test builtin: {builtin}.", ) self.assertEqual( - scope_of_module.get_fully_qualified_names_for(cst.Name(value="d")), + scope_of_module.get_qualified_names_for(cst.Name(value="d")), set(), "Test variable d in global scope.", )