Fix enclosing attribute for attributes in call arguments

Fixes enclosed arguments like `c.d` in `x.y(c.d()).z()` were badly being
resolved as `x.y` instead.

This also clarifies the intent in `infer_accesses()` so it no longer shadows
variable `name` and also fixes the case where no node is actually found
in the scope.
This commit is contained in:
Germán Méndez Bravo 2020-08-05 16:26:53 -07:00 committed by Germán Méndez Bravo
parent 2e788a25dc
commit 3e66bdd957
2 changed files with 27 additions and 8 deletions

View file

@ -77,6 +77,16 @@ class RemoveUnusedImportsCommandTest(CodemodTest):
self.assertCodemod(before, after)
def test_enclosed_attributes(self) -> None:
before = """
from a.b import c
import x
def foo() -> None:
x.y(c.d()).z()
"""
self.assertCodemod(before, before)
def test_access_in_assignment(self) -> None:
before = """
from a import b

View file

@ -639,7 +639,7 @@ class ScopeVisitor(cst.CSTVisitor):
self.provider: ScopeProvider = provider
self.scope: Scope = GlobalScope()
self.__deferred_accesses: List[Tuple[Access, Optional[cst.Attribute]]] = []
self.__top_level_attribute: Optional[cst.Attribute] = None
self.__top_level_attribute_stack: List[Optional[cst.Attribute]] = [None]
@contextmanager
def _new_scope(
@ -686,13 +686,19 @@ class ScopeVisitor(cst.CSTVisitor):
return self._visit_import_alike(node)
def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:
if self.__top_level_attribute is None:
self.__top_level_attribute = node
if self.__top_level_attribute_stack[-1] is None:
self.__top_level_attribute_stack[-1] = node
node.value.visit(self) # explicitly not visiting attr
if self.__top_level_attribute is node:
self.__top_level_attribute = None
if self.__top_level_attribute_stack[-1] is node:
self.__top_level_attribute_stack[-1] = None
return False
def visit_Call(self, node: cst.Call) -> Optional[bool]:
self.__top_level_attribute_stack.append(None)
def leave_Call(self, original_node: cst.Call) -> None:
self.__top_level_attribute_stack.pop()
def visit_Name(self, node: cst.Name) -> Optional[bool]:
# not all Name have ExpressionContext
context = self.provider.get_metadata(ExpressionContextProvider, node, None)
@ -700,7 +706,9 @@ class ScopeVisitor(cst.CSTVisitor):
self.scope.record_assignment(node.value, node)
elif context in (ExpressionContext.LOAD, ExpressionContext.DEL):
access = Access(node, self.scope)
self.__deferred_accesses.append((access, self.__top_level_attribute))
self.__deferred_accesses.append(
(access, self.__top_level_attribute_stack[-1])
)
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
@ -842,9 +850,10 @@ class ScopeVisitor(cst.CSTVisitor):
if enclosing_attribute is not None:
# if _gen_dotted_names doesn't generate any values, fall back to
# the original name node above
for name, node in _gen_dotted_names(enclosing_attribute):
if name in access.scope:
for attr_name, node in _gen_dotted_names(enclosing_attribute):
if attr_name in access.scope:
access.node = node
name = attr_name
break
scope_name_accesses[(access.scope, name)].add(access)