rename Assignment.accesses as Assignment.references and add Accesses/Assignments

This commit is contained in:
Jimmy Lai 2019-10-04 15:44:37 -07:00 committed by jimmylai
parent 6520d25816
commit 120fd04ce2
2 changed files with 88 additions and 28 deletions

View file

@ -22,6 +22,8 @@ from typing import (
Tuple,
Type,
Union,
Mapping,
Generator,
)
import libcst as cst
@ -56,21 +58,21 @@ class BaseAssignment(abc.ABC):
#: The scope associates to assignment.
scope: "Scope"
__accesses: List[Access]
__accesses: Set[Access]
def __init__(self, name: str, scope: "Scope") -> None:
self.name = name
self.scope = scope
self.__accesses = []
self.__accesses = set()
def record_access(self, access: Access) -> None:
self.__accesses.append(access)
self.__accesses.add(access)
@property
def accesses(self) -> Tuple[Access, ...]:
def references(self) -> Collection[Access]:
"""Return all accesses of the assignment."""
# we don't want to publicly expose the mutable version of this
return tuple(self.__accesses)
return set(self.__accesses)
class Assignment(BaseAssignment):
@ -97,6 +99,50 @@ class BuiltinAssignment(BaseAssignment):
pass
class Assignments:
"""A container to provide all assignments in a scope."""
def __init__(self, assignments: Mapping[str, Collection[BaseAssignment]]) -> None:
self._assignments = assignments
def __iter__(self) -> Generator[BaseAssignment, None, None]:
"""Iterate through all assignments by ``for i in scope.assignments``."""
for assignments in self._assignments.values():
for assignment in assignments:
yield assignment
def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[BaseAssignment]:
"""Get assignments given a name str or :class:`~libcst.CSTNode` by ``scope.assignments[node]``"""
name = _NameUtil.get_name_for(node)
return set(self._assignments[name]) if name in self._assignments else set()
def __contains__(self, node: Union[str, cst.CSTNode]) -> bool:
"""Check if a name str or :class:`~libcst.CSTNode` has any assignment by ``node in scope.assignments``"""
return len(self[node]) > 0
class Accesses:
"""A container to provide all accesses in a scope."""
def __init__(self, accesses: Mapping[str, Collection[Access]]) -> None:
self._accesses = accesses
def __iter__(self) -> Generator[Access, None, None]:
"""Iterate through all accesses by ``for i in scope.accesses``."""
for accesses in self._accesses.values():
for access in accesses:
yield access
def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[Access]:
"""Get accesses given a name str or :class:`~libcst.CSTNode` by ``scope.accesses[node]``"""
name = _NameUtil.get_name_for(node)
return self._accesses[name] if name in self._accesses else set()
def __contains__(self, node: Union[str, cst.CSTNode]) -> bool:
"""Check if a name str or :class:`~libcst.CSTNode` has any access by ``node in scope.accesses``"""
return len(self[node]) > 0
class QualifiedNameSource(Enum):
IMPORT = auto()
BUILTIN = auto()
@ -109,21 +155,34 @@ class QualifiedName:
source: QualifiedNameSource
class _QualifiedNameUtil:
class _NameUtil:
@staticmethod
def get_full_name_for(node: cst.CSTNode) -> Optional[str]:
if isinstance(node, cst.Name):
return node.value
elif isinstance(node, cst.Attribute):
return (
f"{_QualifiedNameUtil.get_full_name_for(node.value)}.{node.attr.value}"
)
return f"{_NameUtil.get_full_name_for(node.value)}.{node.attr.value}"
elif isinstance(node, cst.Call):
return _QualifiedNameUtil.get_full_name_for(node.func)
return _NameUtil.get_full_name_for(node.func)
elif isinstance(node, cst.Subscript):
return _QualifiedNameUtil.get_full_name_for(node.value)
return _NameUtil.get_full_name_for(node.value)
elif isinstance(node, (cst.FunctionDef, cst.ClassDef)):
return _QualifiedNameUtil.get_full_name_for(node.name)
return _NameUtil.get_full_name_for(node.name)
return None
@staticmethod
def get_name_for(node: Union[str, cst.CSTNode]) -> Optional[str]:
"""A helper function to retrieve simple name str from a CSTNode or str"""
if isinstance(node, cst.Name):
return node.value
elif isinstance(node, str):
return node
elif isinstance(node, cst.Call):
return _NameUtil.get_name_for(node.func)
elif isinstance(node, cst.Subscript):
return _NameUtil.get_name_for(node.value)
elif isinstance(node, (cst.FunctionDef, cst.ClassDef)):
return _NameUtil.get_name_for(node.name)
return None
@staticmethod
@ -136,11 +195,11 @@ class _QualifiedNameUtil:
module_attr = assignment_node.module
if module_attr:
# TODO: for relative import, keep the relative Dot in the qualified name
module = _QualifiedNameUtil.get_full_name_for(module_attr)
module = _NameUtil.get_full_name_for(module_attr)
import_names = assignment_node.names
if not isinstance(import_names, cst.ImportStar):
for name in import_names:
real_name = _QualifiedNameUtil.get_full_name_for(name.name)
real_name = _NameUtil.get_full_name_for(name.name)
as_name = real_name
if name and name.asname:
name_asname = name.asname
@ -260,7 +319,7 @@ class Scope(abc.ABC):
``List[Union[int, str]]``.
"""
results = set()
full_name = _QualifiedNameUtil.get_full_name_for(node)
full_name = _NameUtil.get_full_name_for(node)
if full_name is None:
return results
parts = full_name.split(".")
@ -268,11 +327,11 @@ class Scope(abc.ABC):
if isinstance(assignment, Assignment):
assignment_node = assignment.node
if isinstance(assignment_node, (cst.Import, cst.ImportFrom)):
results |= _QualifiedNameUtil.find_qualified_name_for_import_alike(
results |= _NameUtil.find_qualified_name_for_import_alike(
assignment_node, parts
)
else:
results |= _QualifiedNameUtil.find_qualified_name_for_non_import(
results |= _NameUtil.find_qualified_name_for_non_import(
assignment, parts
)
elif isinstance(assignment, BuiltinAssignment):
@ -473,9 +532,7 @@ class ScopeVisitor(cst.CSTVisitor):
self.scope.record_assignment(node.name.value, node)
self.provider.set_metadata(node.name, self.scope)
with self._new_scope(
FunctionScope, _QualifiedNameUtil.get_full_name_for(node.name)
):
with self._new_scope(FunctionScope, _NameUtil.get_full_name_for(node.name)):
node.params.visit(self)
node.body.visit(self)
@ -550,9 +607,7 @@ class ScopeVisitor(cst.CSTVisitor):
for keyword in node.keywords:
keyword.visit(self)
with self._new_scope(
ClassScope, _QualifiedNameUtil.get_full_name_for(node.name)
):
with self._new_scope(ClassScope, _NameUtil.get_full_name_for(node.name)):
for statement in node.body.body:
statement.visit(self)

View file

@ -67,21 +67,24 @@ class ScopeProviderTest(UnitTest):
global_foo_assignments = scope_of_module["foo"]
self.assertEqual(len(global_foo_assignments), 1)
foo_assignment = global_foo_assignments[0]
self.assertEqual(len(foo_assignment.accesses), 2)
self.assertEqual(len(foo_assignment.references), 2)
fn1_call_arg = ensure_type(
ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
).args[0]
self.assertEqual(foo_assignment.accesses[0].node, fn1_call_arg.value)
fn2_call_arg = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
).args[0]
self.assertEqual(foo_assignment.accesses[1].node, fn2_call_arg.value)
self.assertEqual(
{access.node for access in foo_assignment.references},
{fn1_call_arg.value, fn2_call_arg.value},
)
func_body = ensure_type(m.body[3], cst.FunctionDef).body
func_foo_statement = func_body.body[0]
scope_of_func_statement = scopes[func_foo_statement]
@ -89,7 +92,7 @@ class ScopeProviderTest(UnitTest):
func_foo_assignments = scope_of_func_statement["foo"]
self.assertEqual(len(func_foo_assignments), 1)
foo_assignment = func_foo_assignments[0]
self.assertEqual(len(foo_assignment.accesses), 1)
self.assertEqual(len(foo_assignment.references), 1)
fn3_call_arg = ensure_type(
ensure_type(
ensure_type(func_body.body[1], cst.SimpleStatementLine).body[0],
@ -97,7 +100,9 @@ class ScopeProviderTest(UnitTest):
).value,
cst.Call,
).args[0]
self.assertEqual(foo_assignment.accesses[0].node, fn3_call_arg.value)
self.assertEqual(
{access.node for access in foo_assignment.references}, {fn3_call_arg.value}
)
wrapper = MetadataWrapper(cst.parse_module("from a import b\n"))
wrapper.visit(DependentVisitor())