mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
[scope] keep track of assignment/access ordering (#413)
This commit is contained in:
parent
90df5a6a37
commit
2ef730292b
2 changed files with 177 additions and 9 deletions
|
|
@ -36,6 +36,25 @@ from libcst.metadata.expression_context_provider import (
|
|||
)
|
||||
|
||||
|
||||
_ASSIGNMENT_LIKE_NODES = (
|
||||
cst.AnnAssign,
|
||||
cst.AsName,
|
||||
cst.Assign,
|
||||
cst.AugAssign,
|
||||
cst.ClassDef,
|
||||
cst.CompFor,
|
||||
cst.For,
|
||||
cst.FunctionDef,
|
||||
cst.Global,
|
||||
cst.Import,
|
||||
cst.ImportFrom,
|
||||
cst.NamedExpr,
|
||||
cst.Nonlocal,
|
||||
cst.Parameters,
|
||||
cst.WithItem,
|
||||
)
|
||||
|
||||
|
||||
@add_slots
|
||||
@dataclass(frozen=False)
|
||||
class Access:
|
||||
|
|
@ -68,6 +87,7 @@ class Access:
|
|||
is_type_hint: bool
|
||||
|
||||
__assignments: Set["BaseAssignment"]
|
||||
__index: int
|
||||
|
||||
def __init__(
|
||||
self, node: cst.Name, scope: "Scope", is_annotation: bool, is_type_hint: bool
|
||||
|
|
@ -77,6 +97,7 @@ class Access:
|
|||
self.is_annotation = is_annotation
|
||||
self.is_type_hint = is_type_hint
|
||||
self.__assignments = set()
|
||||
self.__index = scope._assignment_count
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
|
@ -86,11 +107,25 @@ class Access:
|
|||
"""Return all assignments of the access."""
|
||||
return self.__assignments
|
||||
|
||||
def record_assignment(self, assignment: "BaseAssignment") -> None:
|
||||
self.__assignments.add(assignment)
|
||||
@property
|
||||
def _index(self) -> int:
|
||||
return self.__index
|
||||
|
||||
def record_assignments(self, assignments: Set["BaseAssignment"]) -> None:
|
||||
self.__assignments |= assignments
|
||||
def record_assignment(self, assignment: "BaseAssignment") -> None:
|
||||
if assignment.scope != self.scope or assignment._index < self.__index:
|
||||
self.__assignments.add(assignment)
|
||||
|
||||
def record_assignments(self, name: str) -> None:
|
||||
assignments = self.scope[name]
|
||||
# filter out assignments that happened later than this access
|
||||
previous_assignments = {
|
||||
assignment
|
||||
for assignment in assignments
|
||||
if assignment.scope != self.scope or assignment._index < self.__index
|
||||
}
|
||||
if not previous_assignments and assignments:
|
||||
previous_assignments = self.scope.parent[name]
|
||||
self.__assignments |= previous_assignments
|
||||
|
||||
|
||||
class BaseAssignment(abc.ABC):
|
||||
|
|
@ -109,10 +144,22 @@ class BaseAssignment(abc.ABC):
|
|||
self.__accesses = set()
|
||||
|
||||
def record_access(self, access: Access) -> None:
|
||||
self.__accesses.add(access)
|
||||
if access.scope != self.scope or self._index < access._index:
|
||||
self.__accesses.add(access)
|
||||
|
||||
def record_accesses(self, accesses: Set[Access]) -> None:
|
||||
self.__accesses |= accesses
|
||||
later_accesses = {
|
||||
access
|
||||
for access in accesses
|
||||
if access.scope != self.scope or self._index < access._index
|
||||
}
|
||||
self.__accesses |= later_accesses
|
||||
earlier_accesses = accesses - later_accesses
|
||||
if earlier_accesses and self.scope.parent != self.scope:
|
||||
# Accesses "earlier" than the relevant assignment should be attached
|
||||
# to assignments of the same name in the parent
|
||||
for shadowed_assignment in self.scope.parent[self.name]:
|
||||
shadowed_assignment.record_accesses(earlier_accesses)
|
||||
|
||||
@property
|
||||
def references(self) -> Collection[Access]:
|
||||
|
|
@ -123,6 +170,11 @@ class BaseAssignment(abc.ABC):
|
|||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
@property
|
||||
def _index(self) -> int:
|
||||
"""Return an integer that represents the order of assignments in `scope`"""
|
||||
return -1
|
||||
|
||||
|
||||
class Assignment(BaseAssignment):
|
||||
"""An assignment records the name, CSTNode and its accesses."""
|
||||
|
|
@ -130,11 +182,19 @@ class Assignment(BaseAssignment):
|
|||
#: The node of assignment, it could be a :class:`~libcst.Import`, :class:`~libcst.ImportFrom`,
|
||||
#: :class:`~libcst.Name`, :class:`~libcst.FunctionDef`, or :class:`~libcst.ClassDef`.
|
||||
node: cst.CSTNode
|
||||
__index: int
|
||||
|
||||
def __init__(self, name: str, scope: "Scope", node: cst.CSTNode) -> None:
|
||||
def __init__(
|
||||
self, name: str, scope: "Scope", node: cst.CSTNode, index: int
|
||||
) -> None:
|
||||
self.node = node
|
||||
self.__index = index
|
||||
super().__init__(name, scope)
|
||||
|
||||
@property
|
||||
def _index(self) -> int:
|
||||
return self.__index
|
||||
|
||||
|
||||
# even though we don't override the constructor.
|
||||
class BuiltinAssignment(BaseAssignment):
|
||||
|
|
@ -318,6 +378,7 @@ class Scope(abc.ABC):
|
|||
globals: "GlobalScope"
|
||||
_assignments: MutableMapping[str, Set[BaseAssignment]]
|
||||
_accesses: MutableMapping[str, Set[Access]]
|
||||
_assignment_count: int
|
||||
|
||||
def __init__(self, parent: "Scope") -> None:
|
||||
super().__init__()
|
||||
|
|
@ -325,9 +386,12 @@ class Scope(abc.ABC):
|
|||
self.globals = parent.globals
|
||||
self._assignments = defaultdict(set)
|
||||
self._accesses = defaultdict(set)
|
||||
self._assignment_count = 0
|
||||
|
||||
def record_assignment(self, name: str, node: cst.CSTNode) -> None:
|
||||
self._assignments[name].add(Assignment(name=name, scope=self, node=node))
|
||||
self._assignments[name].add(
|
||||
Assignment(name=name, scope=self, node=node, index=self._assignment_count)
|
||||
)
|
||||
|
||||
def record_access(self, name: str, access: Access) -> None:
|
||||
self._accesses[name].add(access)
|
||||
|
|
@ -934,7 +998,7 @@ class ScopeVisitor(cst.CSTVisitor):
|
|||
break
|
||||
|
||||
scope_name_accesses[(access.scope, name)].add(access)
|
||||
access.record_assignments(access.scope[name])
|
||||
access.record_assignments(name)
|
||||
access.scope.record_access(name, access)
|
||||
|
||||
for (scope, name), accesses in scope_name_accesses.items():
|
||||
|
|
@ -945,6 +1009,8 @@ class ScopeVisitor(cst.CSTVisitor):
|
|||
|
||||
def on_leave(self, original_node: cst.CSTNode) -> None:
|
||||
self.provider.set_metadata(original_node, self.scope)
|
||||
if isinstance(original_node, _ASSIGNMENT_LIKE_NODES):
|
||||
self.scope._assignment_count += 1
|
||||
super().on_leave(original_node)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1329,3 +1329,105 @@ class ScopeProviderTest(UnitTest):
|
|||
)
|
||||
}
|
||||
self.assertEqual(names, {"a.b.c", "a.b", "a"})
|
||||
|
||||
def test_ordering(self) -> None:
|
||||
m, scopes = get_scope_metadata_provider(
|
||||
"""
|
||||
from a import b
|
||||
class X:
|
||||
x = b
|
||||
b = b
|
||||
y = b
|
||||
"""
|
||||
)
|
||||
global_scope = scopes[m]
|
||||
import_stmt = ensure_type(
|
||||
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
|
||||
)
|
||||
first_assignment = list(global_scope.assignments)[0]
|
||||
assert isinstance(first_assignment, cst.metadata.Assignment)
|
||||
self.assertEqual(first_assignment.node, import_stmt)
|
||||
global_refs = list(first_assignment.references)
|
||||
self.assertEqual(len(global_refs), 2)
|
||||
class_def = ensure_type(m.body[1], cst.ClassDef)
|
||||
x = ensure_type(
|
||||
ensure_type(class_def.body.body[0], cst.SimpleStatementLine).body[0],
|
||||
cst.Assign,
|
||||
)
|
||||
self.assertEqual(x.value, global_refs[0].node)
|
||||
class_b = ensure_type(
|
||||
ensure_type(class_def.body.body[1], cst.SimpleStatementLine).body[0],
|
||||
cst.Assign,
|
||||
)
|
||||
self.assertEqual(class_b.value, global_refs[1].node)
|
||||
|
||||
class_accesses = list(scopes[x].accesses)
|
||||
self.assertEqual(len(class_accesses), 3)
|
||||
self.assertIn(
|
||||
class_b.targets[0].target,
|
||||
[
|
||||
ref.node
|
||||
for acc in class_accesses
|
||||
for ref in acc.referents
|
||||
if isinstance(ref, Assignment)
|
||||
],
|
||||
)
|
||||
y = ensure_type(
|
||||
ensure_type(class_def.body.body[2], cst.SimpleStatementLine).body[0],
|
||||
cst.Assign,
|
||||
)
|
||||
self.assertIn(y.value, [access.node for access in class_accesses])
|
||||
|
||||
def test_ordering_between_scopes(self) -> None:
|
||||
m, scopes = get_scope_metadata_provider(
|
||||
"""
|
||||
def f(a):
|
||||
print(a)
|
||||
print(b)
|
||||
a = 1
|
||||
b = 1
|
||||
"""
|
||||
)
|
||||
f = cst.ensure_type(m.body[0], cst.FunctionDef)
|
||||
a_param = f.params.params[0].name
|
||||
a_param_assignment = list(scopes[a_param]["a"])[0]
|
||||
a_param_refs = list(a_param_assignment.references)
|
||||
first_print = cst.ensure_type(
|
||||
cst.ensure_type(
|
||||
cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0],
|
||||
cst.Expr,
|
||||
).value,
|
||||
cst.Call,
|
||||
)
|
||||
second_print = cst.ensure_type(
|
||||
cst.ensure_type(
|
||||
cst.ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0],
|
||||
cst.Expr,
|
||||
).value,
|
||||
cst.Call,
|
||||
)
|
||||
self.assertEqual(
|
||||
first_print.args[0].value,
|
||||
a_param_refs[0].node,
|
||||
)
|
||||
a_global = (
|
||||
cst.ensure_type(
|
||||
cst.ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign
|
||||
)
|
||||
.targets[0]
|
||||
.target
|
||||
)
|
||||
a_global_assignment = list(scopes[a_global]["a"])[0]
|
||||
a_global_refs = list(a_global_assignment.references)
|
||||
self.assertEqual(a_global_refs, [])
|
||||
b_global = (
|
||||
cst.ensure_type(
|
||||
cst.ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Assign
|
||||
)
|
||||
.targets[0]
|
||||
.target
|
||||
)
|
||||
b_global_assignment = list(scopes[b_global]["b"])[0]
|
||||
b_global_refs = list(b_global_assignment.references)
|
||||
self.assertEqual(len(b_global_refs), 1)
|
||||
self.assertEqual(b_global_refs[0].node, second_print.args[0].value)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue