diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index e5bf3389..f5e58058 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -22,6 +22,8 @@ from typing import ( Tuple, Type, Union, + Mapping, + Generator, ) import libcst as cst @@ -56,21 +58,21 @@ class BaseAssignment(abc.ABC): #: The scope associates to assignment. scope: "Scope" - __accesses: List[Access] + __accesses: Set[Access] def __init__(self, name: str, scope: "Scope") -> None: self.name = name self.scope = scope - self.__accesses = [] + self.__accesses = set() def record_access(self, access: Access) -> None: - self.__accesses.append(access) + self.__accesses.add(access) @property - def accesses(self) -> Tuple[Access, ...]: + def references(self) -> Collection[Access]: """Return all accesses of the assignment.""" # we don't want to publicly expose the mutable version of this - return tuple(self.__accesses) + return set(self.__accesses) class Assignment(BaseAssignment): @@ -97,6 +99,50 @@ class BuiltinAssignment(BaseAssignment): pass +class Assignments: + """A container to provide all assignments in a scope.""" + + def __init__(self, assignments: Mapping[str, Collection[BaseAssignment]]) -> None: + self._assignments = assignments + + def __iter__(self) -> Generator[BaseAssignment, None, None]: + """Iterate through all assignments by ``for i in scope.assignments``.""" + for assignments in self._assignments.values(): + for assignment in assignments: + yield assignment + + def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[BaseAssignment]: + """Get assignments given a name str or :class:`~libcst.CSTNode` by ``scope.assignments[node]``""" + name = _NameUtil.get_name_for(node) + return set(self._assignments[name]) if name in self._assignments else set() + + def __contains__(self, node: Union[str, cst.CSTNode]) -> bool: + """Check if a name str or :class:`~libcst.CSTNode` has any assignment by ``node in scope.assignments``""" + return len(self[node]) > 0 + + +class Accesses: + """A container to provide all accesses in a scope.""" + + def __init__(self, accesses: Mapping[str, Collection[Access]]) -> None: + self._accesses = accesses + + def __iter__(self) -> Generator[Access, None, None]: + """Iterate through all accesses by ``for i in scope.accesses``.""" + for accesses in self._accesses.values(): + for access in accesses: + yield access + + def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[Access]: + """Get accesses given a name str or :class:`~libcst.CSTNode` by ``scope.accesses[node]``""" + name = _NameUtil.get_name_for(node) + return self._accesses[name] if name in self._accesses else set() + + def __contains__(self, node: Union[str, cst.CSTNode]) -> bool: + """Check if a name str or :class:`~libcst.CSTNode` has any access by ``node in scope.accesses``""" + return len(self[node]) > 0 + + class QualifiedNameSource(Enum): IMPORT = auto() BUILTIN = auto() @@ -109,21 +155,34 @@ class QualifiedName: source: QualifiedNameSource -class _QualifiedNameUtil: +class _NameUtil: @staticmethod def get_full_name_for(node: cst.CSTNode) -> Optional[str]: if isinstance(node, cst.Name): return node.value elif isinstance(node, cst.Attribute): - return ( - f"{_QualifiedNameUtil.get_full_name_for(node.value)}.{node.attr.value}" - ) + return f"{_NameUtil.get_full_name_for(node.value)}.{node.attr.value}" elif isinstance(node, cst.Call): - return _QualifiedNameUtil.get_full_name_for(node.func) + return _NameUtil.get_full_name_for(node.func) elif isinstance(node, cst.Subscript): - return _QualifiedNameUtil.get_full_name_for(node.value) + return _NameUtil.get_full_name_for(node.value) elif isinstance(node, (cst.FunctionDef, cst.ClassDef)): - return _QualifiedNameUtil.get_full_name_for(node.name) + return _NameUtil.get_full_name_for(node.name) + return None + + @staticmethod + def get_name_for(node: Union[str, cst.CSTNode]) -> Optional[str]: + """A helper function to retrieve simple name str from a CSTNode or str""" + if isinstance(node, cst.Name): + return node.value + elif isinstance(node, str): + return node + elif isinstance(node, cst.Call): + return _NameUtil.get_name_for(node.func) + elif isinstance(node, cst.Subscript): + return _NameUtil.get_name_for(node.value) + elif isinstance(node, (cst.FunctionDef, cst.ClassDef)): + return _NameUtil.get_name_for(node.name) return None @staticmethod @@ -136,11 +195,11 @@ class _QualifiedNameUtil: module_attr = assignment_node.module if module_attr: # TODO: for relative import, keep the relative Dot in the qualified name - module = _QualifiedNameUtil.get_full_name_for(module_attr) + module = _NameUtil.get_full_name_for(module_attr) import_names = assignment_node.names if not isinstance(import_names, cst.ImportStar): for name in import_names: - real_name = _QualifiedNameUtil.get_full_name_for(name.name) + real_name = _NameUtil.get_full_name_for(name.name) as_name = real_name if name and name.asname: name_asname = name.asname @@ -260,7 +319,7 @@ class Scope(abc.ABC): ``List[Union[int, str]]``. """ results = set() - full_name = _QualifiedNameUtil.get_full_name_for(node) + full_name = _NameUtil.get_full_name_for(node) if full_name is None: return results parts = full_name.split(".") @@ -268,11 +327,11 @@ class Scope(abc.ABC): if isinstance(assignment, Assignment): assignment_node = assignment.node if isinstance(assignment_node, (cst.Import, cst.ImportFrom)): - results |= _QualifiedNameUtil.find_qualified_name_for_import_alike( + results |= _NameUtil.find_qualified_name_for_import_alike( assignment_node, parts ) else: - results |= _QualifiedNameUtil.find_qualified_name_for_non_import( + results |= _NameUtil.find_qualified_name_for_non_import( assignment, parts ) elif isinstance(assignment, BuiltinAssignment): @@ -473,9 +532,7 @@ class ScopeVisitor(cst.CSTVisitor): self.scope.record_assignment(node.name.value, node) self.provider.set_metadata(node.name, self.scope) - with self._new_scope( - FunctionScope, _QualifiedNameUtil.get_full_name_for(node.name) - ): + with self._new_scope(FunctionScope, _NameUtil.get_full_name_for(node.name)): node.params.visit(self) node.body.visit(self) @@ -550,9 +607,7 @@ class ScopeVisitor(cst.CSTVisitor): for keyword in node.keywords: keyword.visit(self) - with self._new_scope( - ClassScope, _QualifiedNameUtil.get_full_name_for(node.name) - ): + with self._new_scope(ClassScope, _NameUtil.get_full_name_for(node.name)): for statement in node.body.body: statement.visit(self) diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 26936a35..66b8a4f4 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -67,21 +67,24 @@ class ScopeProviderTest(UnitTest): global_foo_assignments = scope_of_module["foo"] self.assertEqual(len(global_foo_assignments), 1) foo_assignment = global_foo_assignments[0] - self.assertEqual(len(foo_assignment.accesses), 2) + self.assertEqual(len(foo_assignment.references), 2) fn1_call_arg = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr ).value, cst.Call, ).args[0] - self.assertEqual(foo_assignment.accesses[0].node, fn1_call_arg.value) + fn2_call_arg = ensure_type( ensure_type( ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr ).value, cst.Call, ).args[0] - self.assertEqual(foo_assignment.accesses[1].node, fn2_call_arg.value) + self.assertEqual( + {access.node for access in foo_assignment.references}, + {fn1_call_arg.value, fn2_call_arg.value}, + ) func_body = ensure_type(m.body[3], cst.FunctionDef).body func_foo_statement = func_body.body[0] scope_of_func_statement = scopes[func_foo_statement] @@ -89,7 +92,7 @@ class ScopeProviderTest(UnitTest): func_foo_assignments = scope_of_func_statement["foo"] self.assertEqual(len(func_foo_assignments), 1) foo_assignment = func_foo_assignments[0] - self.assertEqual(len(foo_assignment.accesses), 1) + self.assertEqual(len(foo_assignment.references), 1) fn3_call_arg = ensure_type( ensure_type( ensure_type(func_body.body[1], cst.SimpleStatementLine).body[0], @@ -97,7 +100,9 @@ class ScopeProviderTest(UnitTest): ).value, cst.Call, ).args[0] - self.assertEqual(foo_assignment.accesses[0].node, fn3_call_arg.value) + self.assertEqual( + {access.node for access in foo_assignment.references}, {fn3_call_arg.value} + ) wrapper = MetadataWrapper(cst.parse_module("from a import b\n")) wrapper.visit(DependentVisitor())