mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
rename Assignment.accesses as Assignment.references and add Accesses/Assignments
This commit is contained in:
parent
6520d25816
commit
120fd04ce2
2 changed files with 88 additions and 28 deletions
|
|
@ -22,6 +22,8 @@ from typing import (
|
|||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
Mapping,
|
||||
Generator,
|
||||
)
|
||||
|
||||
import libcst as cst
|
||||
|
|
@ -56,21 +58,21 @@ class BaseAssignment(abc.ABC):
|
|||
|
||||
#: The scope associates to assignment.
|
||||
scope: "Scope"
|
||||
__accesses: List[Access]
|
||||
__accesses: Set[Access]
|
||||
|
||||
def __init__(self, name: str, scope: "Scope") -> None:
|
||||
self.name = name
|
||||
self.scope = scope
|
||||
self.__accesses = []
|
||||
self.__accesses = set()
|
||||
|
||||
def record_access(self, access: Access) -> None:
|
||||
self.__accesses.append(access)
|
||||
self.__accesses.add(access)
|
||||
|
||||
@property
|
||||
def accesses(self) -> Tuple[Access, ...]:
|
||||
def references(self) -> Collection[Access]:
|
||||
"""Return all accesses of the assignment."""
|
||||
# we don't want to publicly expose the mutable version of this
|
||||
return tuple(self.__accesses)
|
||||
return set(self.__accesses)
|
||||
|
||||
|
||||
class Assignment(BaseAssignment):
|
||||
|
|
@ -97,6 +99,50 @@ class BuiltinAssignment(BaseAssignment):
|
|||
pass
|
||||
|
||||
|
||||
class Assignments:
|
||||
"""A container to provide all assignments in a scope."""
|
||||
|
||||
def __init__(self, assignments: Mapping[str, Collection[BaseAssignment]]) -> None:
|
||||
self._assignments = assignments
|
||||
|
||||
def __iter__(self) -> Generator[BaseAssignment, None, None]:
|
||||
"""Iterate through all assignments by ``for i in scope.assignments``."""
|
||||
for assignments in self._assignments.values():
|
||||
for assignment in assignments:
|
||||
yield assignment
|
||||
|
||||
def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[BaseAssignment]:
|
||||
"""Get assignments given a name str or :class:`~libcst.CSTNode` by ``scope.assignments[node]``"""
|
||||
name = _NameUtil.get_name_for(node)
|
||||
return set(self._assignments[name]) if name in self._assignments else set()
|
||||
|
||||
def __contains__(self, node: Union[str, cst.CSTNode]) -> bool:
|
||||
"""Check if a name str or :class:`~libcst.CSTNode` has any assignment by ``node in scope.assignments``"""
|
||||
return len(self[node]) > 0
|
||||
|
||||
|
||||
class Accesses:
|
||||
"""A container to provide all accesses in a scope."""
|
||||
|
||||
def __init__(self, accesses: Mapping[str, Collection[Access]]) -> None:
|
||||
self._accesses = accesses
|
||||
|
||||
def __iter__(self) -> Generator[Access, None, None]:
|
||||
"""Iterate through all accesses by ``for i in scope.accesses``."""
|
||||
for accesses in self._accesses.values():
|
||||
for access in accesses:
|
||||
yield access
|
||||
|
||||
def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[Access]:
|
||||
"""Get accesses given a name str or :class:`~libcst.CSTNode` by ``scope.accesses[node]``"""
|
||||
name = _NameUtil.get_name_for(node)
|
||||
return self._accesses[name] if name in self._accesses else set()
|
||||
|
||||
def __contains__(self, node: Union[str, cst.CSTNode]) -> bool:
|
||||
"""Check if a name str or :class:`~libcst.CSTNode` has any access by ``node in scope.accesses``"""
|
||||
return len(self[node]) > 0
|
||||
|
||||
|
||||
class QualifiedNameSource(Enum):
|
||||
IMPORT = auto()
|
||||
BUILTIN = auto()
|
||||
|
|
@ -109,21 +155,34 @@ class QualifiedName:
|
|||
source: QualifiedNameSource
|
||||
|
||||
|
||||
class _QualifiedNameUtil:
|
||||
class _NameUtil:
|
||||
@staticmethod
|
||||
def get_full_name_for(node: cst.CSTNode) -> Optional[str]:
|
||||
if isinstance(node, cst.Name):
|
||||
return node.value
|
||||
elif isinstance(node, cst.Attribute):
|
||||
return (
|
||||
f"{_QualifiedNameUtil.get_full_name_for(node.value)}.{node.attr.value}"
|
||||
)
|
||||
return f"{_NameUtil.get_full_name_for(node.value)}.{node.attr.value}"
|
||||
elif isinstance(node, cst.Call):
|
||||
return _QualifiedNameUtil.get_full_name_for(node.func)
|
||||
return _NameUtil.get_full_name_for(node.func)
|
||||
elif isinstance(node, cst.Subscript):
|
||||
return _QualifiedNameUtil.get_full_name_for(node.value)
|
||||
return _NameUtil.get_full_name_for(node.value)
|
||||
elif isinstance(node, (cst.FunctionDef, cst.ClassDef)):
|
||||
return _QualifiedNameUtil.get_full_name_for(node.name)
|
||||
return _NameUtil.get_full_name_for(node.name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_name_for(node: Union[str, cst.CSTNode]) -> Optional[str]:
|
||||
"""A helper function to retrieve simple name str from a CSTNode or str"""
|
||||
if isinstance(node, cst.Name):
|
||||
return node.value
|
||||
elif isinstance(node, str):
|
||||
return node
|
||||
elif isinstance(node, cst.Call):
|
||||
return _NameUtil.get_name_for(node.func)
|
||||
elif isinstance(node, cst.Subscript):
|
||||
return _NameUtil.get_name_for(node.value)
|
||||
elif isinstance(node, (cst.FunctionDef, cst.ClassDef)):
|
||||
return _NameUtil.get_name_for(node.name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -136,11 +195,11 @@ class _QualifiedNameUtil:
|
|||
module_attr = assignment_node.module
|
||||
if module_attr:
|
||||
# TODO: for relative import, keep the relative Dot in the qualified name
|
||||
module = _QualifiedNameUtil.get_full_name_for(module_attr)
|
||||
module = _NameUtil.get_full_name_for(module_attr)
|
||||
import_names = assignment_node.names
|
||||
if not isinstance(import_names, cst.ImportStar):
|
||||
for name in import_names:
|
||||
real_name = _QualifiedNameUtil.get_full_name_for(name.name)
|
||||
real_name = _NameUtil.get_full_name_for(name.name)
|
||||
as_name = real_name
|
||||
if name and name.asname:
|
||||
name_asname = name.asname
|
||||
|
|
@ -260,7 +319,7 @@ class Scope(abc.ABC):
|
|||
``List[Union[int, str]]``.
|
||||
"""
|
||||
results = set()
|
||||
full_name = _QualifiedNameUtil.get_full_name_for(node)
|
||||
full_name = _NameUtil.get_full_name_for(node)
|
||||
if full_name is None:
|
||||
return results
|
||||
parts = full_name.split(".")
|
||||
|
|
@ -268,11 +327,11 @@ class Scope(abc.ABC):
|
|||
if isinstance(assignment, Assignment):
|
||||
assignment_node = assignment.node
|
||||
if isinstance(assignment_node, (cst.Import, cst.ImportFrom)):
|
||||
results |= _QualifiedNameUtil.find_qualified_name_for_import_alike(
|
||||
results |= _NameUtil.find_qualified_name_for_import_alike(
|
||||
assignment_node, parts
|
||||
)
|
||||
else:
|
||||
results |= _QualifiedNameUtil.find_qualified_name_for_non_import(
|
||||
results |= _NameUtil.find_qualified_name_for_non_import(
|
||||
assignment, parts
|
||||
)
|
||||
elif isinstance(assignment, BuiltinAssignment):
|
||||
|
|
@ -473,9 +532,7 @@ class ScopeVisitor(cst.CSTVisitor):
|
|||
self.scope.record_assignment(node.name.value, node)
|
||||
self.provider.set_metadata(node.name, self.scope)
|
||||
|
||||
with self._new_scope(
|
||||
FunctionScope, _QualifiedNameUtil.get_full_name_for(node.name)
|
||||
):
|
||||
with self._new_scope(FunctionScope, _NameUtil.get_full_name_for(node.name)):
|
||||
node.params.visit(self)
|
||||
node.body.visit(self)
|
||||
|
||||
|
|
@ -550,9 +607,7 @@ class ScopeVisitor(cst.CSTVisitor):
|
|||
for keyword in node.keywords:
|
||||
keyword.visit(self)
|
||||
|
||||
with self._new_scope(
|
||||
ClassScope, _QualifiedNameUtil.get_full_name_for(node.name)
|
||||
):
|
||||
with self._new_scope(ClassScope, _NameUtil.get_full_name_for(node.name)):
|
||||
for statement in node.body.body:
|
||||
statement.visit(self)
|
||||
|
||||
|
|
|
|||
|
|
@ -67,21 +67,24 @@ class ScopeProviderTest(UnitTest):
|
|||
global_foo_assignments = scope_of_module["foo"]
|
||||
self.assertEqual(len(global_foo_assignments), 1)
|
||||
foo_assignment = global_foo_assignments[0]
|
||||
self.assertEqual(len(foo_assignment.accesses), 2)
|
||||
self.assertEqual(len(foo_assignment.references), 2)
|
||||
fn1_call_arg = ensure_type(
|
||||
ensure_type(
|
||||
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr
|
||||
).value,
|
||||
cst.Call,
|
||||
).args[0]
|
||||
self.assertEqual(foo_assignment.accesses[0].node, fn1_call_arg.value)
|
||||
|
||||
fn2_call_arg = ensure_type(
|
||||
ensure_type(
|
||||
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
|
||||
).value,
|
||||
cst.Call,
|
||||
).args[0]
|
||||
self.assertEqual(foo_assignment.accesses[1].node, fn2_call_arg.value)
|
||||
self.assertEqual(
|
||||
{access.node for access in foo_assignment.references},
|
||||
{fn1_call_arg.value, fn2_call_arg.value},
|
||||
)
|
||||
func_body = ensure_type(m.body[3], cst.FunctionDef).body
|
||||
func_foo_statement = func_body.body[0]
|
||||
scope_of_func_statement = scopes[func_foo_statement]
|
||||
|
|
@ -89,7 +92,7 @@ class ScopeProviderTest(UnitTest):
|
|||
func_foo_assignments = scope_of_func_statement["foo"]
|
||||
self.assertEqual(len(func_foo_assignments), 1)
|
||||
foo_assignment = func_foo_assignments[0]
|
||||
self.assertEqual(len(foo_assignment.accesses), 1)
|
||||
self.assertEqual(len(foo_assignment.references), 1)
|
||||
fn3_call_arg = ensure_type(
|
||||
ensure_type(
|
||||
ensure_type(func_body.body[1], cst.SimpleStatementLine).body[0],
|
||||
|
|
@ -97,7 +100,9 @@ class ScopeProviderTest(UnitTest):
|
|||
).value,
|
||||
cst.Call,
|
||||
).args[0]
|
||||
self.assertEqual(foo_assignment.accesses[0].node, fn3_call_arg.value)
|
||||
self.assertEqual(
|
||||
{access.node for access in foo_assignment.references}, {fn3_call_arg.value}
|
||||
)
|
||||
|
||||
wrapper = MetadataWrapper(cst.parse_module("from a import b\n"))
|
||||
wrapper.visit(DependentVisitor())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue