From ba54bb0075a72f83b4de5bb8d252350672f80afe Mon Sep 17 00:00:00 2001 From: Jimmy Lai Date: Wed, 23 Oct 2019 12:37:30 -0700 Subject: [PATCH] add node to all LocalScope --- libcst/metadata/scope_provider.py | 41 +++++++++++-------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index b84c1edf..2c0023fa 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -494,9 +494,15 @@ class LocalScope(Scope, abc.ABC): #: Name of function. Used as qualified name. name: Optional[str] - def __init__(self, parent: Scope, name: Optional[str] = None) -> None: + #: The :class:`~libcst.CSTNode` node defines the current scope. + node: cst.CSTNode + + def __init__( + self, parent: Scope, node: cst.CSTNode, name: Optional[str] = None + ) -> None: super().__init__(parent) self.name = name + self.node = node self._scope_overwrites = {} def record_global_overwrite(self, name: str) -> None: @@ -526,12 +532,7 @@ class FunctionScope(LocalScope): When a function is defined, it creates a FunctionScope. """ - #: The :class:`~libcst.FunctionDef` node defines the current scope. - node: cst.FunctionDef - - def __init__(self, parent: Scope, name: str, node: cst.FunctionDef) -> None: - super().__init__(parent, name) - self.node = node + pass # even though we don't override the constructor. @@ -540,13 +541,6 @@ class ClassScope(LocalScope): When a class is defined, it creates a ClassScope. """ - #: The :class:`~libcst.ClassDef` node defines the current scope. - node: cst.ClassDef - - def __init__(self, parent: Scope, name: str, node: cst.ClassDef) -> None: - super().__init__(parent, name) - self.node = node - def _record_assignment_as_parent(self, name: str, node: cst.CSTNode) -> None: """ Forward the assignment to parent. @@ -596,17 +590,10 @@ class ScopeVisitor(cst.CSTVisitor): @contextmanager def _new_scope( - self, - kind: Type[LocalScope], - name: Optional[str] = None, - node: Optional[cst.CSTNode] = None, + self, kind: Type[LocalScope], node: cst.CSTNode, name: Optional[str] = None ) -> Iterator[None]: parent_scope = self.scope - if node and kind in (FunctionScope, ClassScope): - # pyre-ignore: pyre don't know FunctionScope and ClassScope take node arg. - self.scope = kind(parent_scope, name, node) - else: - self.scope = kind(parent_scope, name) + self.scope = kind(parent_scope, node, name) try: yield finally: @@ -676,7 +663,7 @@ class ScopeVisitor(cst.CSTVisitor): self.provider.set_metadata(node.name, self.scope) with self._new_scope( - FunctionScope, _NameUtil.get_full_name_for(node.name), node + FunctionScope, node, _NameUtil.get_full_name_for(node.name) ): node.params.visit(self) node.body.visit(self) @@ -703,7 +690,7 @@ class ScopeVisitor(cst.CSTVisitor): return False def visit_Lambda(self, node: cst.Lambda) -> Optional[bool]: - with self._new_scope(FunctionScope, name=None, node=node): + with self._new_scope(FunctionScope, node): node.params.visit(self) node.body.visit(self) @@ -752,7 +739,7 @@ class ScopeVisitor(cst.CSTVisitor): for keyword in node.keywords: keyword.visit(self) - with self._new_scope(ClassScope, _NameUtil.get_full_name_for(node.name), node): + with self._new_scope(ClassScope, node, _NameUtil.get_full_name_for(node.name)): for statement in node.body.body: statement.visit(self) @@ -842,7 +829,7 @@ class ScopeVisitor(cst.CSTVisitor): for_in = node.for_in for_in.iter.visit(self) self.provider.set_metadata(for_in, self.scope) - with self._new_scope(ComprehensionScope): + with self._new_scope(ComprehensionScope, node): for_in.target.visit(self) for condition in for_in.ifs: condition.visit(self)