[scope] keep track of assignment/access ordering (#413)

This commit is contained in:
Zsolt Dollenstein 2020-11-17 17:40:50 +00:00 committed by GitHub
parent 90df5a6a37
commit 2ef730292b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 177 additions and 9 deletions

View file

@ -36,6 +36,25 @@ from libcst.metadata.expression_context_provider import (
)
_ASSIGNMENT_LIKE_NODES = (
cst.AnnAssign,
cst.AsName,
cst.Assign,
cst.AugAssign,
cst.ClassDef,
cst.CompFor,
cst.For,
cst.FunctionDef,
cst.Global,
cst.Import,
cst.ImportFrom,
cst.NamedExpr,
cst.Nonlocal,
cst.Parameters,
cst.WithItem,
)
@add_slots
@dataclass(frozen=False)
class Access:
@ -68,6 +87,7 @@ class Access:
is_type_hint: bool
__assignments: Set["BaseAssignment"]
__index: int
def __init__(
self, node: cst.Name, scope: "Scope", is_annotation: bool, is_type_hint: bool
@ -77,6 +97,7 @@ class Access:
self.is_annotation = is_annotation
self.is_type_hint = is_type_hint
self.__assignments = set()
self.__index = scope._assignment_count
def __hash__(self) -> int:
return id(self)
@ -86,11 +107,25 @@ class Access:
"""Return all assignments of the access."""
return self.__assignments
def record_assignment(self, assignment: "BaseAssignment") -> None:
self.__assignments.add(assignment)
@property
def _index(self) -> int:
return self.__index
def record_assignments(self, assignments: Set["BaseAssignment"]) -> None:
self.__assignments |= assignments
def record_assignment(self, assignment: "BaseAssignment") -> None:
if assignment.scope != self.scope or assignment._index < self.__index:
self.__assignments.add(assignment)
def record_assignments(self, name: str) -> None:
assignments = self.scope[name]
# filter out assignments that happened later than this access
previous_assignments = {
assignment
for assignment in assignments
if assignment.scope != self.scope or assignment._index < self.__index
}
if not previous_assignments and assignments:
previous_assignments = self.scope.parent[name]
self.__assignments |= previous_assignments
class BaseAssignment(abc.ABC):
@ -109,10 +144,22 @@ class BaseAssignment(abc.ABC):
self.__accesses = set()
def record_access(self, access: Access) -> None:
self.__accesses.add(access)
if access.scope != self.scope or self._index < access._index:
self.__accesses.add(access)
def record_accesses(self, accesses: Set[Access]) -> None:
self.__accesses |= accesses
later_accesses = {
access
for access in accesses
if access.scope != self.scope or self._index < access._index
}
self.__accesses |= later_accesses
earlier_accesses = accesses - later_accesses
if earlier_accesses and self.scope.parent != self.scope:
# Accesses "earlier" than the relevant assignment should be attached
# to assignments of the same name in the parent
for shadowed_assignment in self.scope.parent[self.name]:
shadowed_assignment.record_accesses(earlier_accesses)
@property
def references(self) -> Collection[Access]:
@ -123,6 +170,11 @@ class BaseAssignment(abc.ABC):
def __hash__(self) -> int:
return id(self)
@property
def _index(self) -> int:
"""Return an integer that represents the order of assignments in `scope`"""
return -1
class Assignment(BaseAssignment):
"""An assignment records the name, CSTNode and its accesses."""
@ -130,11 +182,19 @@ class Assignment(BaseAssignment):
#: The node of assignment, it could be a :class:`~libcst.Import`, :class:`~libcst.ImportFrom`,
#: :class:`~libcst.Name`, :class:`~libcst.FunctionDef`, or :class:`~libcst.ClassDef`.
node: cst.CSTNode
__index: int
def __init__(self, name: str, scope: "Scope", node: cst.CSTNode) -> None:
def __init__(
self, name: str, scope: "Scope", node: cst.CSTNode, index: int
) -> None:
self.node = node
self.__index = index
super().__init__(name, scope)
@property
def _index(self) -> int:
return self.__index
# even though we don't override the constructor.
class BuiltinAssignment(BaseAssignment):
@ -318,6 +378,7 @@ class Scope(abc.ABC):
globals: "GlobalScope"
_assignments: MutableMapping[str, Set[BaseAssignment]]
_accesses: MutableMapping[str, Set[Access]]
_assignment_count: int
def __init__(self, parent: "Scope") -> None:
super().__init__()
@ -325,9 +386,12 @@ class Scope(abc.ABC):
self.globals = parent.globals
self._assignments = defaultdict(set)
self._accesses = defaultdict(set)
self._assignment_count = 0
def record_assignment(self, name: str, node: cst.CSTNode) -> None:
self._assignments[name].add(Assignment(name=name, scope=self, node=node))
self._assignments[name].add(
Assignment(name=name, scope=self, node=node, index=self._assignment_count)
)
def record_access(self, name: str, access: Access) -> None:
self._accesses[name].add(access)
@ -934,7 +998,7 @@ class ScopeVisitor(cst.CSTVisitor):
break
scope_name_accesses[(access.scope, name)].add(access)
access.record_assignments(access.scope[name])
access.record_assignments(name)
access.scope.record_access(name, access)
for (scope, name), accesses in scope_name_accesses.items():
@ -945,6 +1009,8 @@ class ScopeVisitor(cst.CSTVisitor):
def on_leave(self, original_node: cst.CSTNode) -> None:
self.provider.set_metadata(original_node, self.scope)
if isinstance(original_node, _ASSIGNMENT_LIKE_NODES):
self.scope._assignment_count += 1
super().on_leave(original_node)

View file

@ -1329,3 +1329,105 @@ class ScopeProviderTest(UnitTest):
)
}
self.assertEqual(names, {"a.b.c", "a.b", "a"})
def test_ordering(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from a import b
class X:
x = b
b = b
y = b
"""
)
global_scope = scopes[m]
import_stmt = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
first_assignment = list(global_scope.assignments)[0]
assert isinstance(first_assignment, cst.metadata.Assignment)
self.assertEqual(first_assignment.node, import_stmt)
global_refs = list(first_assignment.references)
self.assertEqual(len(global_refs), 2)
class_def = ensure_type(m.body[1], cst.ClassDef)
x = ensure_type(
ensure_type(class_def.body.body[0], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertEqual(x.value, global_refs[0].node)
class_b = ensure_type(
ensure_type(class_def.body.body[1], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertEqual(class_b.value, global_refs[1].node)
class_accesses = list(scopes[x].accesses)
self.assertEqual(len(class_accesses), 3)
self.assertIn(
class_b.targets[0].target,
[
ref.node
for acc in class_accesses
for ref in acc.referents
if isinstance(ref, Assignment)
],
)
y = ensure_type(
ensure_type(class_def.body.body[2], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertIn(y.value, [access.node for access in class_accesses])
def test_ordering_between_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f(a):
print(a)
print(b)
a = 1
b = 1
"""
)
f = cst.ensure_type(m.body[0], cst.FunctionDef)
a_param = f.params.params[0].name
a_param_assignment = list(scopes[a_param]["a"])[0]
a_param_refs = list(a_param_assignment.references)
first_print = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.Call,
)
second_print = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.Call,
)
self.assertEqual(
first_print.args[0].value,
a_param_refs[0].node,
)
a_global = (
cst.ensure_type(
cst.ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
a_global_assignment = list(scopes[a_global]["a"])[0]
a_global_refs = list(a_global_assignment.references)
self.assertEqual(a_global_refs, [])
b_global = (
cst.ensure_type(
cst.ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
b_global_assignment = list(scopes[b_global]["b"])[0]
b_global_refs = list(b_global_assignment.references)
self.assertEqual(len(b_global_refs), 1)
self.assertEqual(b_global_refs[0].node, second_print.args[0].value)