tie accesses from string annotation to the string node (#483)

This commit is contained in:
Zsolt Dollenstein 2021-05-12 14:50:15 +01:00 committed by GitHub
parent d1606b7077
commit 4d2ccc54b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 105 additions and 7 deletions

View file

@ -74,9 +74,10 @@ class Access:
#: The node of the access. A name is an access when the expression context is
#: :attr:`ExpressionContext.LOAD`. This is usually the name node representing the
#: access, except for dotted imports, when it might be the attribute that
#: represents the most specific part of the imported symbol.
node: Union[cst.Name, cst.Attribute]
#: access, except for: 1) dotted imports, when it might be the attribute that
#: represents the most specific part of the imported symbol; and 2) string
#: annotations, when it is the entire string literal
node: Union[cst.Name, cst.Attribute, cst.BaseString]
#: The scope of the access. Note that a access could be in a child scope of its
#: assignment.
@ -422,7 +423,7 @@ class Scope(abc.ABC):
@abc.abstractmethod
def __contains__(self, name: str) -> bool:
""" Check if the name str exist in current scope by ``name in scope``. """
"""Check if the name str exist in current scope by ``name in scope``."""
...
@abc.abstractmethod
@ -775,18 +776,26 @@ def _is_assignment(node: cst.CSTNode, assignment_node: cst.CSTNode) -> bool:
return False
@dataclass(frozen=True)
class DeferredAccess:
access: Access
enclosing_attribute: Optional[cst.Attribute]
enclosing_string_annotation: Optional[cst.BaseString]
class ScopeVisitor(cst.CSTVisitor):
# since it's probably not useful. That can makes this visitor cleaner.
def __init__(self, provider: "ScopeProvider") -> None:
self.provider: ScopeProvider = provider
self.scope: Scope = GlobalScope()
self.__deferred_accesses: List[Tuple[Access, Optional[cst.Attribute]]] = []
self.__deferred_accesses: List[DeferredAccess] = []
self.__top_level_attribute_stack: List[Optional[cst.Attribute]] = [None]
self.__in_annotation: Set[
Union[cst.Call, cst.Annotation, cst.Subscript]
] = set()
self.__in_type_hint: Set[Union[cst.Call, cst.Annotation, cst.Subscript]] = set()
self.__in_ignored_subscript: Set[cst.Subscript] = set()
self.__last_string_annotation: Optional[cst.BaseString] = None
self.__ignore_annotation: int = 0
@contextmanager
@ -887,8 +896,13 @@ class ScopeVisitor(cst.CSTVisitor):
) and not self.__in_ignored_subscript:
value = node.evaluated_value
if value:
top_level_annotation = self.__last_string_annotation is None
if top_level_annotation:
self.__last_string_annotation = node
mod = cst.parse_module(value)
mod.visit(self)
if top_level_annotation:
self.__last_string_annotation = None
return True
return False
@ -920,7 +934,11 @@ class ScopeVisitor(cst.CSTVisitor):
is_type_hint=bool(self.__in_type_hint),
)
self.__deferred_accesses.append(
(access, self.__top_level_attribute_stack[-1])
DeferredAccess(
access=access,
enclosing_attribute=self.__top_level_attribute_stack[-1],
enclosing_string_annotation=self.__last_string_annotation,
)
)
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
@ -1074,7 +1092,12 @@ class ScopeVisitor(cst.CSTVisitor):
# In worst case, all accesses (m) and assignments (n) refer to the same name,
# the time complexity is O(m x n), this optimizes it as O(m + n).
scope_name_accesses = defaultdict(set)
for (access, enclosing_attribute) in self.__deferred_accesses:
for def_access in self.__deferred_accesses:
access, enclosing_attribute, enclosing_string_annotation = (
def_access.access,
def_access.enclosing_attribute,
def_access.enclosing_string_annotation,
)
name = ensure_type(access.node, cst.Name).value
if enclosing_attribute is not None:
# if _gen_dotted_names doesn't generate any values, fall back to
@ -1085,6 +1108,9 @@ class ScopeVisitor(cst.CSTVisitor):
name = attr_name
break
if enclosing_string_annotation is not None:
access.node = enclosing_string_annotation
scope_name_accesses[(access.scope, name)].add(access)
access.record_assignments(name)
access.scope.record_access(name, access)

View file

@ -1082,6 +1082,10 @@ class ScopeProviderTest(UnitTest):
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "B")
assignment = list(scope["C"])[0]
self.assertIsInstance(assignment, Assignment)
@ -1104,6 +1108,10 @@ class ScopeProviderTest(UnitTest):
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "E")
assignment = list(scope["E2"])[0]
self.assertIsInstance(assignment, Assignment)
@ -1119,6 +1127,10 @@ class ScopeProviderTest(UnitTest):
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "Optional[G]")
assignment = list(scope["G2"])[0]
self.assertIsInstance(assignment, Assignment)
@ -1130,6 +1142,10 @@ class ScopeProviderTest(UnitTest):
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "H")
assignment = list(scope["I"])[0]
self.assertIsInstance(assignment, Assignment)
@ -1148,6 +1164,10 @@ class ScopeProviderTest(UnitTest):
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "K")
assignment = list(scope["K2"])[0]
self.assertIsInstance(assignment, Assignment)
@ -1157,12 +1177,64 @@ class ScopeProviderTest(UnitTest):
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "L")
assignment = list(scope["M"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
def test_insane_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
r"""
from typing import TypeVar
from a import G
TypeVar("G2", bound="Optional[\"G\"]")
"""
)
imp = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
call = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
)
bound = call.args[1].value
scope = scopes[imp]
assignment = next(iter(scope["G"]))
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
self.assertEqual(list(assignment.references)[0].node, bound)
def test_dotted_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
r"""
from typing import TypeVar
import a.G
TypeVar("G2", bound="a.G")
"""
)
imp = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Import
)
call = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
)
bound = call.args[1].value
scope = scopes[imp]
assignment = next(iter(scope["a.G"]))
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
self.assertEqual(list(assignment.references)[0].node, bound)
def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""