[scope] add node to FunctionScope and ClassScope

This commit is contained in:
Jimmy Lai 2019-10-23 11:56:08 -07:00 committed by jimmylai
parent bb72a0556b
commit 8d66bbdbbc
2 changed files with 52 additions and 6 deletions

View file

@ -526,7 +526,12 @@ class FunctionScope(LocalScope):
When a function is defined, it creates a FunctionScope.
"""
pass
#: The :class:`~libcst.FunctionDef` node defines the current scope.
node: cst.FunctionDef
def __init__(self, parent: Scope, name: str, node: cst.FunctionDef) -> None:
super().__init__(parent, name)
self.node = node
# even though we don't override the constructor.
@ -535,6 +540,13 @@ class ClassScope(LocalScope):
When a class is defined, it creates a ClassScope.
"""
#: The :class:`~libcst.ClassDef` node defines the current scope.
node: cst.ClassDef
def __init__(self, parent: Scope, name: str, node: cst.ClassDef) -> None:
super().__init__(parent, name)
self.node = node
def _record_assignment_as_parent(self, name: str, node: cst.CSTNode) -> None:
"""
Forward the assignment to parent.
@ -584,10 +596,17 @@ class ScopeVisitor(cst.CSTVisitor):
@contextmanager
def _new_scope(
self, kind: Type[LocalScope], name: Optional[str] = None
self,
kind: Type[LocalScope],
name: Optional[str] = None,
node: Optional[cst.CSTNode] = None,
) -> Iterator[None]:
parent_scope = self.scope
self.scope = kind(parent_scope, name)
if node and kind in (FunctionScope, ClassScope):
# pyre-ignore: pyre don't know FunctionScope and ClassScope take node arg.
self.scope = kind(parent_scope, name, node)
else:
self.scope = kind(parent_scope, name)
try:
yield
finally:
@ -656,7 +675,9 @@ class ScopeVisitor(cst.CSTVisitor):
self.scope.record_assignment(node.name.value, node)
self.provider.set_metadata(node.name, self.scope)
with self._new_scope(FunctionScope, _NameUtil.get_full_name_for(node.name)):
with self._new_scope(
FunctionScope, _NameUtil.get_full_name_for(node.name), node
):
node.params.visit(self)
node.body.visit(self)
@ -682,7 +703,7 @@ class ScopeVisitor(cst.CSTVisitor):
return False
def visit_Lambda(self, node: cst.Lambda) -> Optional[bool]:
with self._new_scope(FunctionScope):
with self._new_scope(FunctionScope, name=None, node=node):
node.params.visit(self)
node.body.visit(self)
@ -731,7 +752,7 @@ class ScopeVisitor(cst.CSTVisitor):
for keyword in node.keywords:
keyword.visit(self)
with self._new_scope(ClassScope, _NameUtil.get_full_name_for(node.name)):
with self._new_scope(ClassScope, _NameUtil.get_full_name_for(node.name), node):
for statement in node.body.body:
statement.visit(self)

View file

@ -880,3 +880,28 @@ class ScopeProviderTest(UnitTest):
)
self.assertEqual(len(set(scopes.values())), 3)
def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f1():
target()
class C:
attr = target()
"""
)
f1 = ensure_type(m.body[0], cst.FunctionDef)
target_call = ensure_type(
ensure_type(f1.body.body[0], cst.SimpleStatementLine).body[0], cst.Expr
).value
f1_scope = scopes[target_call]
self.assertIsInstance(f1_scope, FunctionScope)
self.assertEqual(cast(FunctionScope, f1_scope).node, f1)
c = ensure_type(m.body[1], cst.ClassDef)
target_call_2 = ensure_type(
ensure_type(c.body.body[0], cst.SimpleStatementLine).body[0], cst.Assign
).value
c_scope = scopes[target_call_2]
self.assertIsInstance(c_scope, ClassScope)
self.assertEqual(cast(ClassScope, c_scope).node, c)