diff --git a/libcst/codemod/commands/tests/test_remove_unused_imports.py b/libcst/codemod/commands/tests/test_remove_unused_imports.py index c685f5c8..23b1c727 100644 --- a/libcst/codemod/commands/tests/test_remove_unused_imports.py +++ b/libcst/codemod/commands/tests/test_remove_unused_imports.py @@ -77,6 +77,16 @@ class RemoveUnusedImportsCommandTest(CodemodTest): self.assertCodemod(before, after) + def test_enclosed_attributes(self) -> None: + before = """ + from a.b import c + import x + + def foo() -> None: + x.y(c.d()).z() + """ + self.assertCodemod(before, before) + def test_access_in_assignment(self) -> None: before = """ from a import b diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index fdda406a..3802c3c9 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -639,7 +639,7 @@ class ScopeVisitor(cst.CSTVisitor): self.provider: ScopeProvider = provider self.scope: Scope = GlobalScope() self.__deferred_accesses: List[Tuple[Access, Optional[cst.Attribute]]] = [] - self.__top_level_attribute: Optional[cst.Attribute] = None + self.__top_level_attribute_stack: List[Optional[cst.Attribute]] = [None] @contextmanager def _new_scope( @@ -686,13 +686,19 @@ class ScopeVisitor(cst.CSTVisitor): return self._visit_import_alike(node) def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: - if self.__top_level_attribute is None: - self.__top_level_attribute = node + if self.__top_level_attribute_stack[-1] is None: + self.__top_level_attribute_stack[-1] = node node.value.visit(self) # explicitly not visiting attr - if self.__top_level_attribute is node: - self.__top_level_attribute = None + if self.__top_level_attribute_stack[-1] is node: + self.__top_level_attribute_stack[-1] = None return False + def visit_Call(self, node: cst.Call) -> Optional[bool]: + self.__top_level_attribute_stack.append(None) + + def leave_Call(self, original_node: cst.Call) -> None: + self.__top_level_attribute_stack.pop() + def visit_Name(self, node: cst.Name) -> Optional[bool]: # not all Name have ExpressionContext context = self.provider.get_metadata(ExpressionContextProvider, node, None) @@ -700,7 +706,9 @@ class ScopeVisitor(cst.CSTVisitor): self.scope.record_assignment(node.value, node) elif context in (ExpressionContext.LOAD, ExpressionContext.DEL): access = Access(node, self.scope) - self.__deferred_accesses.append((access, self.__top_level_attribute)) + self.__deferred_accesses.append( + (access, self.__top_level_attribute_stack[-1]) + ) def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: self.scope.record_assignment(node.name.value, node) @@ -842,9 +850,10 @@ class ScopeVisitor(cst.CSTVisitor): if enclosing_attribute is not None: # if _gen_dotted_names doesn't generate any values, fall back to # the original name node above - for name, node in _gen_dotted_names(enclosing_attribute): - if name in access.scope: + for attr_name, node in _gen_dotted_names(enclosing_attribute): + if attr_name in access.scope: access.node = node + name = attr_name break scope_name_accesses[(access.scope, name)].add(access)