From ed18112cfaf1d8b5a056ea5ecaf7f5ef327c8429 Mon Sep 17 00:00:00 2001 From: David Peter Date: Fri, 7 Nov 2025 17:46:55 +0100 Subject: [PATCH] [ty] Add support for `Literal`s in implicit type aliases (#21296) ## Summary Add support for `Literal` types in implicit type aliases. part of https://github.com/astral-sh/ty/issues/221 ## Ecosystem analysis This looks good to me, true positives and known problems. ## Test Plan New Markdown tests. --- .../resources/mdtest/annotations/literal.md | 60 ++++++---------- .../resources/mdtest/implicit_type_aliases.md | 72 ++++++++++++++++++- crates/ty_python_semantic/src/types.rs | 57 ++++++++++----- .../src/types/class_base.rs | 3 +- .../src/types/infer/builder.rs | 39 ++++++++-- .../types/infer/builder/type_expression.rs | 16 ++++- 6 files changed, 179 insertions(+), 68 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/literal.md b/crates/ty_python_semantic/resources/mdtest/annotations/literal.md index 897be97e77..0c6a443afa 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/literal.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/literal.md @@ -181,30 +181,20 @@ def _( bool2: Literal[Bool2], multiple: Literal[SingleInt, SingleStr, SingleEnum], ): - # TODO should be `Literal[1]` - reveal_type(single_int) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal["foo"]` - reveal_type(single_str) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[b"bar"]` - reveal_type(single_bytes) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[True]` - reveal_type(single_bool) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `None` - reveal_type(single_none) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[E.A]` - reveal_type(single_enum) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[1, "foo", b"bar", True, E.A] | None` - reveal_type(union_literals) # revealed: @Todo(Inference of subscript on special form) + reveal_type(single_int) # revealed: Literal[1] + reveal_type(single_str) # revealed: Literal["foo"] + reveal_type(single_bytes) # revealed: Literal[b"bar"] + reveal_type(single_bool) # revealed: Literal[True] + reveal_type(single_none) # revealed: None + reveal_type(single_enum) # revealed: Literal[E.A] + reveal_type(union_literals) # revealed: Literal[1, "foo", b"bar", True, E.A] | None # Could also be `E` reveal_type(an_enum1) # revealed: Unknown - # TODO should be `E` - reveal_type(an_enum2) # revealed: @Todo(Inference of subscript on special form) + reveal_type(an_enum2) # revealed: E # Could also be `bool` reveal_type(bool1) # revealed: Unknown - # TODO should be `bool` - reveal_type(bool2) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[1, "foo", E.A]` - reveal_type(multiple) # revealed: @Todo(Inference of subscript on special form) + reveal_type(bool2) # revealed: bool + reveal_type(multiple) # revealed: Literal[1, "foo", E.A] ``` ### Implicit type alias @@ -246,28 +236,18 @@ def _( bool2: Literal[Bool2], multiple: Literal[SingleInt, SingleStr, SingleEnum], ): - # TODO should be `Literal[1]` - reveal_type(single_int) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal["foo"]` - reveal_type(single_str) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[b"bar"]` - reveal_type(single_bytes) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[True]` - reveal_type(single_bool) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `None` - reveal_type(single_none) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[E.A]` - reveal_type(single_enum) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[1, "foo", b"bar", True, E.A] | None` - reveal_type(union_literals) # revealed: @Todo(Inference of subscript on special form) + reveal_type(single_int) # revealed: Literal[1] + reveal_type(single_str) # revealed: Literal["foo"] + reveal_type(single_bytes) # revealed: Literal[b"bar"] + reveal_type(single_bool) # revealed: Literal[True] + reveal_type(single_none) # revealed: None + reveal_type(single_enum) # revealed: Literal[E.A] + reveal_type(union_literals) # revealed: Literal[1, "foo", b"bar", True, E.A] | None reveal_type(an_enum1) # revealed: Unknown - # TODO should be `E` - reveal_type(an_enum2) # revealed: @Todo(Inference of subscript on special form) + reveal_type(an_enum2) # revealed: E reveal_type(bool1) # revealed: Unknown - # TODO should be `bool` - reveal_type(bool2) # revealed: @Todo(Inference of subscript on special form) - # TODO should be `Literal[1, "foo", E.A]` - reveal_type(multiple) # revealed: @Todo(Inference of subscript on special form) + reveal_type(bool2) # revealed: bool + reveal_type(multiple) # revealed: Literal[1, "foo", E.A] ``` ## Shortening unions of literals diff --git a/crates/ty_python_semantic/resources/mdtest/implicit_type_aliases.md b/crates/ty_python_semantic/resources/mdtest/implicit_type_aliases.md index b557a730f7..504921c317 100644 --- a/crates/ty_python_semantic/resources/mdtest/implicit_type_aliases.md +++ b/crates/ty_python_semantic/resources/mdtest/implicit_type_aliases.md @@ -33,7 +33,7 @@ g(None) We also support unions in type aliases: ```py -from typing_extensions import Any, Never +from typing_extensions import Any, Never, Literal from ty_extensions import Unknown IntOrStr = int | str @@ -54,6 +54,8 @@ NeverOrAny = Never | Any AnyOrNever = Any | Never UnknownOrInt = Unknown | int IntOrUnknown = int | Unknown +StrOrZero = str | Literal[0] +ZeroOrStr = Literal[0] | str reveal_type(IntOrStr) # revealed: types.UnionType reveal_type(IntOrStrOrBytes1) # revealed: types.UnionType @@ -73,6 +75,8 @@ reveal_type(NeverOrAny) # revealed: types.UnionType reveal_type(AnyOrNever) # revealed: types.UnionType reveal_type(UnknownOrInt) # revealed: types.UnionType reveal_type(IntOrUnknown) # revealed: types.UnionType +reveal_type(StrOrZero) # revealed: types.UnionType +reveal_type(ZeroOrStr) # revealed: types.UnionType def _( int_or_str: IntOrStr, @@ -93,6 +97,8 @@ def _( any_or_never: AnyOrNever, unknown_or_int: UnknownOrInt, int_or_unknown: IntOrUnknown, + str_or_zero: StrOrZero, + zero_or_str: ZeroOrStr, ): reveal_type(int_or_str) # revealed: int | str reveal_type(int_or_str_or_bytes1) # revealed: int | str | bytes @@ -112,6 +118,8 @@ def _( reveal_type(any_or_never) # revealed: Any reveal_type(unknown_or_int) # revealed: Unknown | int reveal_type(int_or_unknown) # revealed: int | Unknown + reveal_type(str_or_zero) # revealed: str | Literal[0] + reveal_type(zero_or_str) # revealed: Literal[0] | str ``` If a type is unioned with itself in a value expression, the result is just that type. No @@ -255,6 +263,68 @@ def _(list_or_tuple: ListOrTuple[int]): reveal_type(list_or_tuple) # revealed: @Todo(Generic specialization of types.UnionType) ``` +## `Literal`s + +We also support `typing.Literal` in implicit type aliases. + +```py +from typing import Literal +from enum import Enum + +IntLiteral1 = Literal[26] +IntLiteral2 = Literal[0x1A] +IntLiterals = Literal[-1, 0, 1] +NestedLiteral = Literal[Literal[1]] +StringLiteral = Literal["a"] +BytesLiteral = Literal[b"b"] +BoolLiteral = Literal[True] +MixedLiterals = Literal[1, "a", True, None] + +class Color(Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + +EnumLiteral = Literal[Color.RED] + +def _( + int_literal1: IntLiteral1, + int_literal2: IntLiteral2, + int_literals: IntLiterals, + nested_literal: NestedLiteral, + string_literal: StringLiteral, + bytes_literal: BytesLiteral, + bool_literal: BoolLiteral, + mixed_literals: MixedLiterals, + enum_literal: EnumLiteral, +): + reveal_type(int_literal1) # revealed: Literal[26] + reveal_type(int_literal2) # revealed: Literal[26] + reveal_type(int_literals) # revealed: Literal[-1, 0, 1] + reveal_type(nested_literal) # revealed: Literal[1] + reveal_type(string_literal) # revealed: Literal["a"] + reveal_type(bytes_literal) # revealed: Literal[b"b"] + reveal_type(bool_literal) # revealed: Literal[True] + reveal_type(mixed_literals) # revealed: Literal[1, "a", True] | None + reveal_type(enum_literal) # revealed: Literal[Color.RED] +``` + +We reject invalid uses: + +```py +# error: [invalid-type-form] "Type arguments for `Literal` must be `None`, a literal value (int, bool, str, or bytes), or an enum member" +LiteralInt = Literal[int] + +reveal_type(LiteralInt) # revealed: Unknown + +def _(weird: LiteralInt): + reveal_type(weird) # revealed: Unknown + +# error: [invalid-type-form] "`Literal[26]` is not a generic class" +def _(weird: IntLiteral1[int]): + reveal_type(weird) # revealed: Unknown +``` + ## Stringified annotations? From the [typing spec on type aliases](https://typing.python.org/en/latest/spec/aliases.html): diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index bc75895833..6c9cdefa20 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -6451,9 +6451,9 @@ impl<'db> Type<'db> { invalid_expressions: smallvec::smallvec_inline![InvalidTypeExpression::Generic], fallback_type: Type::unknown(), }), - KnownInstanceType::UnionType(union_type) => { + KnownInstanceType::UnionType(list) => { let mut builder = UnionBuilder::new(db); - for element in union_type.elements(db) { + for element in list.elements(db) { builder = builder.add(element.in_type_expression( db, scope_id, @@ -6462,6 +6462,7 @@ impl<'db> Type<'db> { } Ok(builder.build()) } + KnownInstanceType::Literal(list) => Ok(list.to_union(db)), }, Type::SpecialForm(special_form) => match special_form { @@ -7675,7 +7676,10 @@ pub enum KnownInstanceType<'db> { /// A single instance of `types.UnionType`, which stores the left- and /// right-hand sides of a PEP 604 union. - UnionType(UnionTypeInstance<'db>), + UnionType(TypeList<'db>), + + /// A single instance of `typing.Literal` + Literal(TypeList<'db>), } fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( @@ -7702,9 +7706,9 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( visitor.visit_type(db, default_ty); } } - KnownInstanceType::UnionType(union_type) => { - for element in union_type.elements(db) { - visitor.visit_type(db, element); + KnownInstanceType::UnionType(list) | KnownInstanceType::Literal(list) => { + for element in list.elements(db) { + visitor.visit_type(db, *element); } } } @@ -7743,7 +7747,8 @@ impl<'db> KnownInstanceType<'db> { // Nothing to normalize Self::ConstraintSet(set) } - Self::UnionType(union_type) => Self::UnionType(union_type.normalized_impl(db, visitor)), + Self::UnionType(list) => Self::UnionType(list.normalized_impl(db, visitor)), + Self::Literal(list) => Self::Literal(list.normalized_impl(db, visitor)), } } @@ -7762,6 +7767,7 @@ impl<'db> KnownInstanceType<'db> { Self::Field(_) => KnownClass::Field, Self::ConstraintSet(_) => KnownClass::ConstraintSet, Self::UnionType(_) => KnownClass::UnionType, + Self::Literal(_) => KnownClass::GenericAlias, } } @@ -7842,6 +7848,7 @@ impl<'db> KnownInstanceType<'db> { ) } KnownInstanceType::UnionType(_) => f.write_str("types.UnionType"), + KnownInstanceType::Literal(_) => f.write_str("typing.Literal"), } } } @@ -8972,32 +8979,46 @@ impl<'db> TypeVarBoundOrConstraints<'db> { } } -/// An instance of `types.UnionType`. +/// A salsa-interned list of types. /// /// # Ordering /// Ordering is based on the context's salsa-assigned id and not on its values. /// The id may change between runs, or when the context was garbage collected and recreated. #[salsa::interned(debug)] #[derive(PartialOrd, Ord)] -pub struct UnionTypeInstance<'db> { - left: Type<'db>, - right: Type<'db>, +pub struct TypeList<'db> { + #[returns(deref)] + elements: Box<[Type<'db>]>, } -impl get_size2::GetSize for UnionTypeInstance<'_> {} +impl get_size2::GetSize for TypeList<'_> {} -impl<'db> UnionTypeInstance<'db> { - pub(crate) fn elements(self, db: &'db dyn Db) -> [Type<'db>; 2] { - [self.left(db), self.right(db)] +impl<'db> TypeList<'db> { + pub(crate) fn from_elements( + db: &'db dyn Db, + elements: impl IntoIterator>, + ) -> TypeList<'db> { + TypeList::new(db, elements.into_iter().collect::>()) + } + + pub(crate) fn singleton(db: &'db dyn Db, element: Type<'db>) -> TypeList<'db> { + TypeList::from_elements(db, [element]) } pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { - UnionTypeInstance::new( + TypeList::new( db, - self.left(db).normalized_impl(db, visitor), - self.right(db).normalized_impl(db, visitor), + self.elements(db) + .iter() + .map(|ty| ty.normalized_impl(db, visitor)) + .collect::>(), ) } + + /// Turn this list of types `[T1, T2, ...]` into a union type `T1 | T2 | ...`. + pub(crate) fn to_union(self, db: &'db dyn Db) -> Type<'db> { + UnionType::from_elements(db, self.elements(db)) + } } /// Error returned if a type is not awaitable. diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 071d4b92b7..caddc88567 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -168,7 +168,8 @@ impl<'db> ClassBase<'db> { | KnownInstanceType::Deprecated(_) | KnownInstanceType::Field(_) | KnownInstanceType::ConstraintSet(_) - | KnownInstanceType::UnionType(_) => None, + | KnownInstanceType::UnionType(_) + | KnownInstanceType::Literal(_) => None, }, Type::SpecialForm(special_form) => match special_form { diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 8dc2b244cf..3b1142b89b 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -103,10 +103,10 @@ use crate::types::{ DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, - TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, + TypeAliasType, TypeAndQualifiers, TypeContext, TypeList, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance, TypedDictType, UnionBuilder, UnionType, - UnionTypeInstance, binding_type, todo_type, + binding_type, todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -8754,19 +8754,23 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { | Type::SubclassOf(..) | Type::GenericAlias(..) | Type::SpecialForm(_) - | Type::KnownInstance(KnownInstanceType::UnionType(_)), + | Type::KnownInstance( + KnownInstanceType::UnionType(_) | KnownInstanceType::Literal(_), + ), Type::ClassLiteral(..) | Type::SubclassOf(..) | Type::GenericAlias(..) | Type::SpecialForm(_) - | Type::KnownInstance(KnownInstanceType::UnionType(_)), + | Type::KnownInstance( + KnownInstanceType::UnionType(_) | KnownInstanceType::Literal(_), + ), ast::Operator::BitOr, ) if Program::get(self.db()).python_version(self.db()) >= PythonVersion::PY310 => { if left_ty.is_equivalent_to(self.db(), right_ty) { Some(left_ty) } else { Some(Type::KnownInstance(KnownInstanceType::UnionType( - UnionTypeInstance::new(self.db(), left_ty, right_ty), + TypeList::from_elements(self.db(), [left_ty, right_ty]), ))) } } @@ -8791,7 +8795,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { && instance.has_known_class(self.db(), KnownClass::NoneType) => { Some(Type::KnownInstance(KnownInstanceType::UnionType( - UnionTypeInstance::new(self.db(), left_ty, right_ty), + TypeList::from_elements(self.db(), [left_ty, right_ty]), ))) } @@ -9924,6 +9928,29 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); } } + if value_ty == Type::SpecialForm(SpecialFormType::Literal) { + match self.infer_literal_parameter_type(slice) { + Ok(result) => { + return Type::KnownInstance(KnownInstanceType::Literal(TypeList::singleton( + self.db(), + result, + ))); + } + Err(nodes) => { + for node in nodes { + let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, node) + else { + continue; + }; + builder.into_diagnostic( + "Type arguments for `Literal` must be `None`, \ + a literal value (int, bool, str, or bytes), or an enum member", + ); + } + return Type::unknown(); + } + } + } let slice_ty = self.infer_expression(slice, TypeContext::default()); let result_ty = self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx); diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index c6b2bbbef0..d091487ce7 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -814,6 +814,16 @@ impl<'db> TypeInferenceBuilder<'db, '_> { self.infer_type_expression(slice); todo_type!("Generic specialization of types.UnionType") } + KnownInstanceType::Literal(ty) => { + self.infer_type_expression(slice); + if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) { + builder.into_diagnostic(format_args!( + "`{ty}` is not a generic class", + ty = ty.to_union(self.db()).display(self.db()) + )); + } + Type::unknown() + } }, Type::Dynamic(DynamicType::Todo(_)) => { self.infer_type_expression(slice); @@ -1367,7 +1377,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } } - fn infer_literal_parameter_type<'param>( + pub(crate) fn infer_literal_parameter_type<'param>( &mut self, parameters: &'param ast::Expr, ) -> Result, Vec<&'param ast::Expr>> { @@ -1435,7 +1445,6 @@ impl<'db> TypeInferenceBuilder<'db, '_> { // enum members and aliases to literal types ast::Expr::Name(_) | ast::Expr::Attribute(_) => { let subscript_ty = self.infer_expression(parameters, TypeContext::default()); - // TODO handle implicit type aliases also match subscript_ty { // type aliases to literal types Type::KnownInstance(KnownInstanceType::TypeAliasType(type_alias)) => { @@ -1444,6 +1453,9 @@ impl<'db> TypeInferenceBuilder<'db, '_> { return Ok(value_ty); } } + Type::KnownInstance(KnownInstanceType::Literal(list)) => { + return Ok(list.to_union(self.db())); + } // `Literal[SomeEnum.Member]` Type::EnumLiteral(_) => { return Ok(subscript_ty);