From 110095148fdb1fcb1d66e1d3e35da1cb4efc2903 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20M=C3=A9ndez=20Bravo?= Date: Tue, 17 Nov 2020 10:55:44 -0800 Subject: [PATCH] Handle string type references in cast() (#418) * Handle string type references in cast() * Directly visit the first argument of cast() Co-authored-by: Zsolt Dollenstein Co-authored-by: Zsolt Dollenstein --- libcst/metadata/scope_provider.py | 18 ++++++----- libcst/metadata/tests/test_scope_provider.py | 33 +++++++++++++++++--- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 7886f458..043f87e8 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -773,13 +773,19 @@ class ScopeVisitor(cst.CSTVisitor): def visit_Call(self, node: cst.Call) -> Optional[bool]: self.__top_level_attribute_stack.append(None) - qnames = self.scope.get_qualified_names_for(node) - if any(qn.name in {"typing.NewType", "typing.TypeVar"} for qn in qnames): + qnames = {qn.name for qn in self.scope.get_qualified_names_for(node)} + if "typing.NewType" in qnames or "typing.TypeVar" in qnames: node.func.visit(self) self.__in_type_hint.add(node) for arg in node.args[1:]: arg.visit(self) return False + if "typing.cast" in qnames: + node.func.visit(self) + self.__in_type_hint.add(node) + if len(node.args) > 0: + node.args[0].visit(self) + return False return True def leave_Call(self, original_node: cst.Call) -> None: @@ -814,12 +820,10 @@ class ScopeVisitor(cst.CSTVisitor): return False def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]: - qnames = self.scope.get_qualified_names_for(node.value) - if any(qn.name.startswith(("typing.", "typing_extensions.")) for qn in qnames): + qnames = {qn.name for qn in self.scope.get_qualified_names_for(node.value)} + if any(qn.startswith(("typing.", "typing_extensions.")) for qn in qnames): self.__in_type_hint.add(node) - if any( - qn.name in {"typing.Literal", "typing_extensions.Literal"} for qn in qnames - ): + if "typing.Literal" in qnames or "typing_extensions.Literal" in qnames: self.__in_ignored_subscript.add(node) return True diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index e54bbff9..36fd19e5 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -1037,23 +1037,24 @@ class ScopeProviderTest(UnitTest): def test_annotation_access(self) -> None: m, scopes = get_scope_metadata_provider( """ - from typing import Literal, NewType, Optional, TypeVar, Callable - from a import A, B, C, D, E, F, G, H, I, J + from typing import Literal, NewType, Optional, TypeVar, Callable, cast + from a import A, B, C, D, D2, E, E2, F, G, G2, H, I, J, K, K2 def x(a: A): pass def y(b: "B"): pass def z(c: Literal["C"]): pass - DType = TypeVar("DType", bound=D) - EType = TypeVar("EType", bound="E") + DType = TypeVar("D2", bound=D) + EType = TypeVar("E2", bound="E") FType = TypeVar("F") - GType = NewType("GType", "Optional[G]") + GType = NewType("G2", "Optional[G]") HType = Optional["H"] IType = Callable[..., I] class Test(Generic[J]): pass + casted = cast("K", "K2") """ ) imp = ensure_type( @@ -1084,6 +1085,10 @@ class ScopeProviderTest(UnitTest): self.assertFalse(references[0].is_annotation) self.assertTrue(references[0].is_type_hint) + assignment = list(scope["D2"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 0) + assignment = list(scope["E"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 1) @@ -1091,6 +1096,10 @@ class ScopeProviderTest(UnitTest): self.assertFalse(references[0].is_annotation) self.assertTrue(references[0].is_type_hint) + assignment = list(scope["E2"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 0) + assignment = list(scope["F"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 0) @@ -1102,6 +1111,10 @@ class ScopeProviderTest(UnitTest): self.assertFalse(references[0].is_annotation) self.assertTrue(references[0].is_type_hint) + assignment = list(scope["G2"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 0) + assignment = list(scope["H"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 1) @@ -1121,6 +1134,16 @@ class ScopeProviderTest(UnitTest): references = list(assignment.references) self.assertFalse(references[0].is_annotation) + assignment = list(scope["K"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 1) + references = list(assignment.references) + self.assertFalse(references[0].is_annotation) + + assignment = list(scope["K2"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 0) + def test_node_of_scopes(self) -> None: m, scopes = get_scope_metadata_provider( """