diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index 291d0452..ba3cd700 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional, Sequence, Set, Tuple, Union import libcst as cst +from libcst import matchers as m from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareTransformer from libcst.codemod.visitors._add_imports import AddImportsVisitor @@ -133,6 +134,8 @@ class TypeCollector(cst.CSTVisitor): def _handle_Subscript(self, node: cst.Subscript) -> cst.Subscript: slice = node.slice + if m.matches(node.value, m.Name(value="Type")): + return node if isinstance(slice, list): new_slice = [] for item in slice: @@ -163,7 +166,7 @@ class TypeCollector(cst.CSTVisitor): return cst.Annotation(annotation=attr) if isinstance(annotation, cst.Subscript): value = annotation.value - if isinstance(value, cst.Name) and value.value == "Type": + if m.matches(value, m.Name(value="Type")): return returns return cst.Annotation(annotation=self._handle_Subscript(annotation)) else: diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py index 2e4027b0..90e7b58c 100644 --- a/libcst/codemod/visitors/tests/test_apply_type_annotations.py +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -624,6 +624,33 @@ class TestApplyAnnotationsVisitor(CodemodTest): return [] """, ), + ( + """ + from typing import Dict + + example: Dict[str, Type[foo.Example]] = ... + """, + """ + from typing import Type + + def foo() -> Type[foo.Example]: + class Example: + pass + return Example + + example = { "test": foo() } + """, + """ + from typing import Dict, Type + + def foo() -> Type[foo.Example]: + class Example: + pass + return Example + + example: Dict[str, Type[foo.Example]] = { "test": foo() } + """, + ), ) ) def test_annotate_functions(self, stub: str, before: str, after: str) -> None: