From 8d66bbdbbcb48ec2885d04a2986f8777e31cff4b Mon Sep 17 00:00:00 2001 From: Jimmy Lai Date: Wed, 23 Oct 2019 11:56:08 -0700 Subject: [PATCH] [scope] add node to FunctionScope and ClassScope --- libcst/metadata/scope_provider.py | 33 ++++++++++++++++---- libcst/metadata/tests/test_scope_provider.py | 25 +++++++++++++++ 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index c5ea55b8..b84c1edf 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -526,7 +526,12 @@ class FunctionScope(LocalScope): When a function is defined, it creates a FunctionScope. """ - pass + #: The :class:`~libcst.FunctionDef` node defines the current scope. + node: cst.FunctionDef + + def __init__(self, parent: Scope, name: str, node: cst.FunctionDef) -> None: + super().__init__(parent, name) + self.node = node # even though we don't override the constructor. @@ -535,6 +540,13 @@ class ClassScope(LocalScope): When a class is defined, it creates a ClassScope. """ + #: The :class:`~libcst.ClassDef` node defines the current scope. + node: cst.ClassDef + + def __init__(self, parent: Scope, name: str, node: cst.ClassDef) -> None: + super().__init__(parent, name) + self.node = node + def _record_assignment_as_parent(self, name: str, node: cst.CSTNode) -> None: """ Forward the assignment to parent. @@ -584,10 +596,17 @@ class ScopeVisitor(cst.CSTVisitor): @contextmanager def _new_scope( - self, kind: Type[LocalScope], name: Optional[str] = None + self, + kind: Type[LocalScope], + name: Optional[str] = None, + node: Optional[cst.CSTNode] = None, ) -> Iterator[None]: parent_scope = self.scope - self.scope = kind(parent_scope, name) + if node and kind in (FunctionScope, ClassScope): + # pyre-ignore: pyre don't know FunctionScope and ClassScope take node arg. + self.scope = kind(parent_scope, name, node) + else: + self.scope = kind(parent_scope, name) try: yield finally: @@ -656,7 +675,9 @@ class ScopeVisitor(cst.CSTVisitor): self.scope.record_assignment(node.name.value, node) self.provider.set_metadata(node.name, self.scope) - with self._new_scope(FunctionScope, _NameUtil.get_full_name_for(node.name)): + with self._new_scope( + FunctionScope, _NameUtil.get_full_name_for(node.name), node + ): node.params.visit(self) node.body.visit(self) @@ -682,7 +703,7 @@ class ScopeVisitor(cst.CSTVisitor): return False def visit_Lambda(self, node: cst.Lambda) -> Optional[bool]: - with self._new_scope(FunctionScope): + with self._new_scope(FunctionScope, name=None, node=node): node.params.visit(self) node.body.visit(self) @@ -731,7 +752,7 @@ class ScopeVisitor(cst.CSTVisitor): for keyword in node.keywords: keyword.visit(self) - with self._new_scope(ClassScope, _NameUtil.get_full_name_for(node.name)): + with self._new_scope(ClassScope, _NameUtil.get_full_name_for(node.name), node): for statement in node.body.body: statement.visit(self) diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index e32e1c60..ef6c1c2c 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -880,3 +880,28 @@ class ScopeProviderTest(UnitTest): ) self.assertEqual(len(set(scopes.values())), 3) + + def test_node_of_scopes(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + def f1(): + target() + + class C: + attr = target() + """ + ) + f1 = ensure_type(m.body[0], cst.FunctionDef) + target_call = ensure_type( + ensure_type(f1.body.body[0], cst.SimpleStatementLine).body[0], cst.Expr + ).value + f1_scope = scopes[target_call] + self.assertIsInstance(f1_scope, FunctionScope) + self.assertEqual(cast(FunctionScope, f1_scope).node, f1) + c = ensure_type(m.body[1], cst.ClassDef) + target_call_2 = ensure_type( + ensure_type(c.body.body[0], cst.SimpleStatementLine).body[0], cst.Assign + ).value + c_scope = scopes[target_call_2] + self.assertIsInstance(c_scope, ClassScope) + self.assertEqual(cast(ClassScope, c_scope).node, c)