mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
Handle string type references in cast() (#418)
* Handle string type references in cast() * Directly visit the first argument of cast() Co-authored-by: Zsolt Dollenstein <zsol.zsol@gmail.com> Co-authored-by: Zsolt Dollenstein <zsol.zsol@gmail.com>
This commit is contained in:
parent
2ef730292b
commit
110095148f
2 changed files with 39 additions and 12 deletions
|
|
@ -773,13 +773,19 @@ class ScopeVisitor(cst.CSTVisitor):
|
|||
|
||||
def visit_Call(self, node: cst.Call) -> Optional[bool]:
|
||||
self.__top_level_attribute_stack.append(None)
|
||||
qnames = self.scope.get_qualified_names_for(node)
|
||||
if any(qn.name in {"typing.NewType", "typing.TypeVar"} for qn in qnames):
|
||||
qnames = {qn.name for qn in self.scope.get_qualified_names_for(node)}
|
||||
if "typing.NewType" in qnames or "typing.TypeVar" in qnames:
|
||||
node.func.visit(self)
|
||||
self.__in_type_hint.add(node)
|
||||
for arg in node.args[1:]:
|
||||
arg.visit(self)
|
||||
return False
|
||||
if "typing.cast" in qnames:
|
||||
node.func.visit(self)
|
||||
self.__in_type_hint.add(node)
|
||||
if len(node.args) > 0:
|
||||
node.args[0].visit(self)
|
||||
return False
|
||||
return True
|
||||
|
||||
def leave_Call(self, original_node: cst.Call) -> None:
|
||||
|
|
@ -814,12 +820,10 @@ class ScopeVisitor(cst.CSTVisitor):
|
|||
return False
|
||||
|
||||
def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]:
|
||||
qnames = self.scope.get_qualified_names_for(node.value)
|
||||
if any(qn.name.startswith(("typing.", "typing_extensions.")) for qn in qnames):
|
||||
qnames = {qn.name for qn in self.scope.get_qualified_names_for(node.value)}
|
||||
if any(qn.startswith(("typing.", "typing_extensions.")) for qn in qnames):
|
||||
self.__in_type_hint.add(node)
|
||||
if any(
|
||||
qn.name in {"typing.Literal", "typing_extensions.Literal"} for qn in qnames
|
||||
):
|
||||
if "typing.Literal" in qnames or "typing_extensions.Literal" in qnames:
|
||||
self.__in_ignored_subscript.add(node)
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -1037,23 +1037,24 @@ class ScopeProviderTest(UnitTest):
|
|||
def test_annotation_access(self) -> None:
|
||||
m, scopes = get_scope_metadata_provider(
|
||||
"""
|
||||
from typing import Literal, NewType, Optional, TypeVar, Callable
|
||||
from a import A, B, C, D, E, F, G, H, I, J
|
||||
from typing import Literal, NewType, Optional, TypeVar, Callable, cast
|
||||
from a import A, B, C, D, D2, E, E2, F, G, G2, H, I, J, K, K2
|
||||
def x(a: A):
|
||||
pass
|
||||
def y(b: "B"):
|
||||
pass
|
||||
def z(c: Literal["C"]):
|
||||
pass
|
||||
DType = TypeVar("DType", bound=D)
|
||||
EType = TypeVar("EType", bound="E")
|
||||
DType = TypeVar("D2", bound=D)
|
||||
EType = TypeVar("E2", bound="E")
|
||||
FType = TypeVar("F")
|
||||
GType = NewType("GType", "Optional[G]")
|
||||
GType = NewType("G2", "Optional[G]")
|
||||
HType = Optional["H"]
|
||||
IType = Callable[..., I]
|
||||
|
||||
class Test(Generic[J]):
|
||||
pass
|
||||
casted = cast("K", "K2")
|
||||
"""
|
||||
)
|
||||
imp = ensure_type(
|
||||
|
|
@ -1084,6 +1085,10 @@ class ScopeProviderTest(UnitTest):
|
|||
self.assertFalse(references[0].is_annotation)
|
||||
self.assertTrue(references[0].is_type_hint)
|
||||
|
||||
assignment = list(scope["D2"])[0]
|
||||
self.assertIsInstance(assignment, Assignment)
|
||||
self.assertEqual(len(assignment.references), 0)
|
||||
|
||||
assignment = list(scope["E"])[0]
|
||||
self.assertIsInstance(assignment, Assignment)
|
||||
self.assertEqual(len(assignment.references), 1)
|
||||
|
|
@ -1091,6 +1096,10 @@ class ScopeProviderTest(UnitTest):
|
|||
self.assertFalse(references[0].is_annotation)
|
||||
self.assertTrue(references[0].is_type_hint)
|
||||
|
||||
assignment = list(scope["E2"])[0]
|
||||
self.assertIsInstance(assignment, Assignment)
|
||||
self.assertEqual(len(assignment.references), 0)
|
||||
|
||||
assignment = list(scope["F"])[0]
|
||||
self.assertIsInstance(assignment, Assignment)
|
||||
self.assertEqual(len(assignment.references), 0)
|
||||
|
|
@ -1102,6 +1111,10 @@ class ScopeProviderTest(UnitTest):
|
|||
self.assertFalse(references[0].is_annotation)
|
||||
self.assertTrue(references[0].is_type_hint)
|
||||
|
||||
assignment = list(scope["G2"])[0]
|
||||
self.assertIsInstance(assignment, Assignment)
|
||||
self.assertEqual(len(assignment.references), 0)
|
||||
|
||||
assignment = list(scope["H"])[0]
|
||||
self.assertIsInstance(assignment, Assignment)
|
||||
self.assertEqual(len(assignment.references), 1)
|
||||
|
|
@ -1121,6 +1134,16 @@ class ScopeProviderTest(UnitTest):
|
|||
references = list(assignment.references)
|
||||
self.assertFalse(references[0].is_annotation)
|
||||
|
||||
assignment = list(scope["K"])[0]
|
||||
self.assertIsInstance(assignment, Assignment)
|
||||
self.assertEqual(len(assignment.references), 1)
|
||||
references = list(assignment.references)
|
||||
self.assertFalse(references[0].is_annotation)
|
||||
|
||||
assignment = list(scope["K2"])[0]
|
||||
self.assertIsInstance(assignment, Assignment)
|
||||
self.assertEqual(len(assignment.references), 0)
|
||||
|
||||
def test_node_of_scopes(self) -> None:
|
||||
m, scopes = get_scope_metadata_provider(
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue