Handle string type references in cast() (#418)

* Handle string type references in cast()

* Directly visit the first argument of cast()

Co-authored-by: Zsolt Dollenstein <zsol.zsol@gmail.com>

Co-authored-by: Zsolt Dollenstein <zsol.zsol@gmail.com>
This commit is contained in:
Germán Méndez Bravo 2020-11-17 10:55:44 -08:00 committed by GitHub
parent 2ef730292b
commit 110095148f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 12 deletions

View file

@ -773,13 +773,19 @@ class ScopeVisitor(cst.CSTVisitor):
def visit_Call(self, node: cst.Call) -> Optional[bool]:
self.__top_level_attribute_stack.append(None)
qnames = self.scope.get_qualified_names_for(node)
if any(qn.name in {"typing.NewType", "typing.TypeVar"} for qn in qnames):
qnames = {qn.name for qn in self.scope.get_qualified_names_for(node)}
if "typing.NewType" in qnames or "typing.TypeVar" in qnames:
node.func.visit(self)
self.__in_type_hint.add(node)
for arg in node.args[1:]:
arg.visit(self)
return False
if "typing.cast" in qnames:
node.func.visit(self)
self.__in_type_hint.add(node)
if len(node.args) > 0:
node.args[0].visit(self)
return False
return True
def leave_Call(self, original_node: cst.Call) -> None:
@ -814,12 +820,10 @@ class ScopeVisitor(cst.CSTVisitor):
return False
def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]:
qnames = self.scope.get_qualified_names_for(node.value)
if any(qn.name.startswith(("typing.", "typing_extensions.")) for qn in qnames):
qnames = {qn.name for qn in self.scope.get_qualified_names_for(node.value)}
if any(qn.startswith(("typing.", "typing_extensions.")) for qn in qnames):
self.__in_type_hint.add(node)
if any(
qn.name in {"typing.Literal", "typing_extensions.Literal"} for qn in qnames
):
if "typing.Literal" in qnames or "typing_extensions.Literal" in qnames:
self.__in_ignored_subscript.add(node)
return True

View file

@ -1037,23 +1037,24 @@ class ScopeProviderTest(UnitTest):
def test_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from typing import Literal, NewType, Optional, TypeVar, Callable
from a import A, B, C, D, E, F, G, H, I, J
from typing import Literal, NewType, Optional, TypeVar, Callable, cast
from a import A, B, C, D, D2, E, E2, F, G, G2, H, I, J, K, K2
def x(a: A):
pass
def y(b: "B"):
pass
def z(c: Literal["C"]):
pass
DType = TypeVar("DType", bound=D)
EType = TypeVar("EType", bound="E")
DType = TypeVar("D2", bound=D)
EType = TypeVar("E2", bound="E")
FType = TypeVar("F")
GType = NewType("GType", "Optional[G]")
GType = NewType("G2", "Optional[G]")
HType = Optional["H"]
IType = Callable[..., I]
class Test(Generic[J]):
pass
casted = cast("K", "K2")
"""
)
imp = ensure_type(
@ -1084,6 +1085,10 @@ class ScopeProviderTest(UnitTest):
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
assignment = list(scope["D2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
assignment = list(scope["E"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
@ -1091,6 +1096,10 @@ class ScopeProviderTest(UnitTest):
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
assignment = list(scope["E2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
assignment = list(scope["F"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
@ -1102,6 +1111,10 @@ class ScopeProviderTest(UnitTest):
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
assignment = list(scope["G2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
assignment = list(scope["H"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
@ -1121,6 +1134,16 @@ class ScopeProviderTest(UnitTest):
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
assignment = list(scope["K"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
assignment = list(scope["K2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""