[scope] add qualified name support for imported attribute

This commit is contained in:
Jimmy Lai 2019-10-30 22:30:46 -07:00 committed by jimmylai
parent 17fa6add8a
commit e6b84fc11a
2 changed files with 46 additions and 12 deletions

View file

@ -238,7 +238,7 @@ class _NameUtil:
@staticmethod
def find_qualified_name_for_import_alike(
assignment_node: Union[cst.Import, cst.ImportFrom], name_parts: List[str]
assignment_node: Union[cst.Import, cst.ImportFrom], full_name: str
) -> Set[QualifiedName]:
module = ""
results = set()
@ -256,13 +256,16 @@ class _NameUtil:
name_asname = name.asname
if name_asname:
as_name = cst.ensure_type(name_asname.name, cst.Name).value
if as_name == name_parts[0]:
if as_name and full_name.startswith(as_name):
if module:
real_name = f"{module}.{real_name}"
if real_name:
remaining_name = full_name.split(as_name)[1].lstrip(".")
results.add(
QualifiedName(
".".join([real_name, *name_parts[1:]]),
f"{real_name}.{remaining_name}"
if remaining_name
else real_name,
QualifiedNameSource.IMPORT,
)
)
@ -270,7 +273,7 @@ class _NameUtil:
@staticmethod
def find_qualified_name_for_non_import(
assignment: Assignment, name_parts: List[str]
assignment: Assignment, remaining_name: str
) -> Set[QualifiedName]:
scope = assignment.scope
name_prefixes = []
@ -287,12 +290,10 @@ class _NameUtil:
raise Exception(f"Unexpected Scope: {scope}")
scope = scope.parent
return {
QualifiedName(
".".join([*reversed(name_prefixes), *name_parts]),
QualifiedNameSource.LOCAL,
)
}
parts = [*reversed(name_prefixes)]
if remaining_name:
parts.append(remaining_name)
return {QualifiedName(".".join(parts), QualifiedNameSource.LOCAL)}
class Scope(abc.ABC):
@ -439,11 +440,11 @@ class Scope(abc.ABC):
assignment_node = assignment.node
if isinstance(assignment_node, (cst.Import, cst.ImportFrom)):
results |= _NameUtil.find_qualified_name_for_import_alike(
assignment_node, parts
assignment_node, full_name
)
else:
results |= _NameUtil.find_qualified_name_for_non_import(
assignment, parts
assignment, full_name
)
elif isinstance(assignment, BuiltinAssignment):
results.add(

View file

@ -905,3 +905,36 @@ class ScopeProviderTest(UnitTest):
c_scope = scopes[target_call_2]
self.assertIsInstance(c_scope, ClassScope)
self.assertEqual(cast(ClassScope, c_scope).node, c)
def test_with_statement(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import unittest.mock
with unittest.mock.patch("something") as obj:
obj.f1()
unittest.mock
"""
)
import_ = ensure_type(m.body[0], cst.SimpleStatementLine).body[0]
assignments = scopes[import_]["unittest"]
self.assertEqual(len(assignments), 1)
self.assertEqual(cast(Assignment, list(assignments)[0]).node, import_)
with_ = ensure_type(m.body[1], cst.With)
fn_call = with_.items[0].item
self.assertEqual(
scopes[fn_call].get_qualified_names_for(fn_call),
{
QualifiedName(
name="unittest.mock.patch", source=QualifiedNameSource.IMPORT
)
},
)
mock = ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value
self.assertEqual(
scopes[fn_call].get_qualified_names_for(mock),
{QualifiedName(name="unittest.mock", source=QualifiedNameSource.IMPORT)},
)