add QualifiedName and QualifiedNameSource enum

This commit is contained in:
Jimmy Lai 2019-09-27 14:04:44 -07:00 committed by jimmylai
parent 881348df27
commit 9fcf84cf2c
2 changed files with 50 additions and 22 deletions

View file

@ -11,6 +11,7 @@ from ast import literal_eval
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
Collection,
Dict,
@ -98,7 +99,15 @@ class BuiltinAssignment(BaseAssignment):
pass
CURRENT_MODULE_PREFIX = "__current_module__"
class QualifiedNameSource(Enum):
IMPORT = auto()
LOCAL = auto()
@dataclass(frozen=True)
class QualifiedName:
name: str
source: QualifiedNameSource
class _QualifiedNameUtil:
@ -123,7 +132,7 @@ class _QualifiedNameUtil:
def find_qualified_name_for_import_alike(
assignment_node: Union[cst.Import, cst.ImportFrom],
name_parts: List[str],
results: Set[str],
results: Set[QualifiedName],
) -> None:
module = ""
if isinstance(assignment_node, cst.ImportFrom):
@ -139,11 +148,16 @@ class _QualifiedNameUtil:
if module:
real_name = f"{module}.{real_name}"
if real_name:
results.add(".".join([real_name, *name_parts[1:]]))
results.add(
QualifiedName(
".".join([real_name, *name_parts[1:]]),
QualifiedNameSource.IMPORT,
)
)
@staticmethod
def find_qualified_name_for_non_import(
assignment: Assignment, name_parts: List[str], results: Set[str]
assignment: Assignment, name_parts: List[str], results: Set[QualifiedName]
) -> None:
scope = assignment.scope
name_prefixes = []
@ -157,7 +171,9 @@ class _QualifiedNameUtil:
scope = scope.parent
results.add(
".".join([CURRENT_MODULE_PREFIX, *name_prefixes[::-1], *name_parts])
QualifiedName(
".".join([*name_prefixes[::-1], *name_parts]), QualifiedNameSource.LOCAL
)
)
@ -211,7 +227,7 @@ class Scope(abc.ABC):
def record_nonlocal_overwrite(self, name: str) -> None:
...
def get_fully_qualified_names_for(self, node: cst.CSTNode) -> Collection[str]:
def get_qualified_names_for(self, node: cst.CSTNode) -> Collection[QualifiedName]:
results = set()
full_name = _QualifiedNameUtil.get_full_name_for(node)
if full_name is None:
@ -232,8 +248,11 @@ class Scope(abc.ABC):
assignment, parts, results
)
elif isinstance(assignment, BuiltinAssignment):
results.add(f"builtins.{assignment.name}")
# TODO: add support to other type of assignment
results.add(
QualifiedName(
f"builtins.{assignment.name}", QualifiedNameSource.IMPORT
)
)
return results

View file

@ -11,12 +11,13 @@ from typing import Mapping, Tuple, cast
import libcst as cst
from libcst import ensure_type
from libcst.metadata.scope_provider import (
CURRENT_MODULE_PREFIX,
Assignment,
ClassScope,
ComprehensionScope,
FunctionScope,
GlobalScope,
QualifiedName,
QualifiedNameSource,
Scope,
ScopeProvider,
)
@ -612,8 +613,10 @@ class ScopeProviderTest(UnitTest):
f = ensure_type(m.body[1], cst.FunctionDef)
scope_of_module = scopes[m]
self.assertEqual(
scope_of_module.get_fully_qualified_names_for(ensure_type(f.returns, cst.Annotation).annotation),
{"a.b.c"},
scope_of_module.get_qualified_names_for(
ensure_type(f.returns, cst.Annotation).annotation
),
{QualifiedName("a.b.c", QualifiedNameSource.IMPORT)},
)
c_call = ensure_type(
@ -621,16 +624,22 @@ class ScopeProviderTest(UnitTest):
).value
scope_of_f = scopes[c_call]
self.assertIsInstance(scope_of_f, FunctionScope)
self.assertEqual(scope_of_f.get_fully_qualified_names_for(c_call), {"a.b.c"})
self.assertEqual(scope_of_f.get_fully_qualified_names_for(c_call), {"a.b.c"})
self.assertEqual(
scope_of_f.get_qualified_names_for(c_call),
{QualifiedName("a.b.c", QualifiedNameSource.IMPORT)},
)
self.assertEqual(
scope_of_f.get_qualified_names_for(c_call),
{QualifiedName("a.b.c", QualifiedNameSource.IMPORT)},
)
f_call = ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertEqual(
scope_of_module.get_fully_qualified_names_for(f_call),
{f"{CURRENT_MODULE_PREFIX}.f"},
scope_of_module.get_qualified_names_for(f_call),
{QualifiedName("f", QualifiedNameSource.LOCAL)},
)
d_name = (
ensure_type(
@ -640,8 +649,8 @@ class ScopeProviderTest(UnitTest):
.target
)
self.assertEqual(
scope_of_f.get_fully_qualified_names_for(d_name),
{f"{CURRENT_MODULE_PREFIX}.f.d"},
scope_of_f.get_qualified_names_for(d_name),
{QualifiedName("f.d", QualifiedNameSource.LOCAL)},
)
d_subscript = (
ensure_type(
@ -651,19 +660,19 @@ class ScopeProviderTest(UnitTest):
.target
)
self.assertEqual(
scope_of_f.get_fully_qualified_names_for(d_subscript),
{f"{CURRENT_MODULE_PREFIX}.f.d"},
scope_of_f.get_qualified_names_for(d_subscript),
{QualifiedName("f.d", QualifiedNameSource.LOCAL)},
)
for builtin in ["map", "int", "dict"]:
self.assertEqual(
scope_of_f.get_fully_qualified_names_for(cst.Name(value=builtin)),
{f"builtins.{builtin}"},
scope_of_f.get_qualified_names_for(cst.Name(value=builtin)),
{QualifiedName(f"builtins.{builtin}", QualifiedNameSource.IMPORT)},
f"Test builtin: {builtin}.",
)
self.assertEqual(
scope_of_module.get_fully_qualified_names_for(cst.Name(value="d")),
scope_of_module.get_qualified_names_for(cst.Name(value="d")),
set(),
"Test variable d in global scope.",
)