Fix dotted names (#358)

## Summary

When importing things like `import os.path` and using it as `os.path.join("a", "b").lower()`,
references ended up being in the `["os"]` assignment instead of `["os.path"]`.
This fixes the problem by updating the dotted names generator in the scope provider·

## Test Plan

```
tox -e py37
```

Co-authored-by: Germán Méndez Bravo <kronuz@fb.com>
This commit is contained in:
Germán Méndez Bravo 2020-08-04 17:33:22 -07:00 committed by GitHub
parent ffc4c93c82
commit 2e8a0c6df7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 12 deletions

View file

@ -610,18 +610,27 @@ def _gen_dotted_names(
yield node.value, node
else:
value = node.value
if not isinstance(value, (cst.Attribute, cst.Name)):
# this is not an import
return
name_values = _gen_dotted_names(value)
try:
next_name, next_node = next(name_values)
except StopIteration:
return
else:
yield f"{next_name}.{node.attr.value}", node
yield next_name, next_node
yield from name_values
if isinstance(value, cst.Call):
value = value.func
if isinstance(value, (cst.Attribute, cst.Name)):
name_values = _gen_dotted_names(value)
try:
next_name, next_node = next(name_values)
except StopIteration:
return
else:
yield next_name, next_node
yield from name_values
elif isinstance(value, (cst.Attribute, cst.Name)):
name_values = _gen_dotted_names(value)
try:
next_name, next_node = next(name_values)
except StopIteration:
return
else:
yield f"{next_name}.{node.attr.value}", node
yield next_name, next_node
yield from name_values
class ScopeVisitor(cst.CSTVisitor):

View file

@ -21,6 +21,7 @@ from libcst.metadata.scope_provider import (
QualifiedNameSource,
Scope,
ScopeProvider,
_gen_dotted_names,
)
from libcst.testing.utils import UnitTest, data_provider
@ -203,6 +204,39 @@ class ScopeProviderTest(UnitTest):
self.assertEqual(list(scope_of_module["x.y"])[0].references, set())
self.assertEqual(scope_of_module.accesses["x.y"], set())
def test_dotted_import_with_call_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import os.path
os.path.join("A", "B").lower()
"""
)
scope_of_module = scopes[m]
first_statement = ensure_type(m.body[1], cst.SimpleStatementLine)
attr = ensure_type(
ensure_type(
ensure_type(
ensure_type(
ensure_type(first_statement.body[0], cst.Expr).value, cst.Call
).func,
cst.Attribute,
).value,
cst.Call,
).func,
cst.Attribute,
).value
self.assertTrue("os.path" in scope_of_module)
self.assertTrue("os" in scope_of_module)
os_path_join_assignment = cast(Assignment, list(scope_of_module["os.path"])[0])
os_path_join_assignment_references = list(os_path_join_assignment.references)
self.assertNotEqual(len(os_path_join_assignment_references), 0)
os_path_join_access = os_path_join_assignment_references[0]
self.assertEqual(scope_of_module.accesses["os"], set())
self.assertEqual(scope_of_module.accesses["os.path"], {os_path_join_access})
self.assertEqual(scope_of_module.accesses["os.path.join"], set())
self.assertEqual(os_path_join_access.node, attr)
def test_import_from(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
@ -1153,3 +1187,34 @@ class ScopeProviderTest(UnitTest):
scope.get_qualified_names_for("doesnt_exist")
self.assertEqual(len(scope._assignments), assignments_len_before)
self.assertEqual(len(scope._accesses), accesses_len_before)
def test_gen_dotted_names(self) -> None:
names = {name for name, node in _gen_dotted_names(cst.Name(value="a"))}
self.assertEqual(names, {"a"})
names = {
name
for name, node in _gen_dotted_names(
cst.Attribute(value=cst.Name(value="a"), attr=cst.Name(value="b"))
)
}
self.assertEqual(names, {"a.b", "a"})
names = {
name
for name, node in _gen_dotted_names(
cst.Attribute(
value=cst.Call(
func=cst.Attribute(
value=cst.Attribute(
value=cst.Name(value="a"), attr=cst.Name(value="b")
),
attr=cst.Name(value="c"),
),
args=[],
),
attr=cst.Name(value="d"),
)
)
}
self.assertEqual(names, {"a.b.c", "a.b", "a"})