From e4384fc212df5ea5e70ad939dc18ba47454e5bbf Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Sat, 18 Oct 2025 04:12:19 +0900 Subject: [PATCH] [ty] impl `VarianceInferable` for `KnownInstanceType` (#20924) ## Summary Derived from #20900 Implement `VarianceInferable` for `KnownInstanceType` (especially for `KnownInstanceType::TypeAliasType`). The variance of a type alias matches its value type. In normal usage, type aliases are expanded to value types, so the variance of a type alias can be obtained without implementing this. However, for example, if we want to display the variance when hovering over a type alias, we need to be able to obtain the variance of the type alias itself (cf. #20900). ## Test Plan I couldn't come up with a way to test this in mdtest, so I'm testing it in a test submodule at the end of `types.rs`. I also added a test to `mdtest/generics/pep695/variance.md`, but it passes without the changes in this PR. --- .../mdtest/generics/pep695/variance.md | 59 ++++++++ crates/ty_python_semantic/src/types.rs | 133 ++++++++++++++++-- 2 files changed, 183 insertions(+), 9 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md index 4c96a3c4f4..7dc9392b21 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md @@ -790,6 +790,65 @@ static_assert(not is_assignable_to(C[B], C[A])) static_assert(not is_assignable_to(C[A], C[B])) ``` +## Type aliases + +The variance of the type alias matches the variance of the value type (RHS type). + +```py +from ty_extensions import static_assert, is_subtype_of +from typing import Literal + +class Covariant[T]: + def get(self) -> T: + raise ValueError + +type CovariantLiteral1 = Covariant[Literal[1]] +type CovariantInt = Covariant[int] +type MyCovariant[T] = Covariant[T] + +static_assert(is_subtype_of(CovariantLiteral1, CovariantInt)) +static_assert(is_subtype_of(MyCovariant[Literal[1]], MyCovariant[int])) + +class Contravariant[T]: + def set(self, value: T): + pass + +type ContravariantLiteral1 = Contravariant[Literal[1]] +type ContravariantInt = Contravariant[int] +type MyContravariant[T] = Contravariant[T] + +static_assert(is_subtype_of(ContravariantInt, ContravariantLiteral1)) +static_assert(is_subtype_of(MyContravariant[int], MyContravariant[Literal[1]])) + +class Invariant[T]: + def get(self) -> T: + raise ValueError + + def set(self, value: T): + pass + +type InvariantLiteral1 = Invariant[Literal[1]] +type InvariantInt = Invariant[int] +type MyInvariant[T] = Invariant[T] + +static_assert(not is_subtype_of(InvariantInt, InvariantLiteral1)) +static_assert(not is_subtype_of(InvariantLiteral1, InvariantInt)) +static_assert(not is_subtype_of(MyInvariant[Literal[1]], MyInvariant[int])) +static_assert(not is_subtype_of(MyInvariant[int], MyInvariant[Literal[1]])) + +class Bivariant[T]: + pass + +type BivariantLiteral1 = Bivariant[Literal[1]] +type BivariantInt = Bivariant[int] +type MyBivariant[T] = Bivariant[T] + +static_assert(is_subtype_of(BivariantInt, BivariantLiteral1)) +static_assert(is_subtype_of(BivariantLiteral1, BivariantInt)) +static_assert(is_subtype_of(MyBivariant[Literal[1]], MyBivariant[int])) +static_assert(is_subtype_of(MyBivariant[int], MyBivariant[Literal[1]])) +``` + ## Inheriting from generic classes with inferred variance When inheriting from a generic class with our type variable substituted in, we count its occurrences diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 53c4f427aa..477bf9aa11 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -55,9 +55,10 @@ use crate::types::function::{ DataclassTransformerFlags, DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction, }; +pub(crate) use crate::types::generics::GenericContext; use crate::types::generics::{ - GenericContext, InferableTypeVars, PartialSpecialization, Specialization, bind_typevar, - typing_self, walk_generic_context, + InferableTypeVars, PartialSpecialization, Specialization, bind_typevar, typing_self, + walk_generic_context, }; use crate::types::infer::infer_unpack_types; use crate::types::mro::{Mro, MroError, MroIterator}; @@ -7274,6 +7275,7 @@ impl<'db> VarianceInferable<'db> for Type<'db> { .collect(), Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar), Type::TypeIs(type_is_type) => type_is_type.variance_of(db, typevar), + Type::KnownInstance(known_instance) => known_instance.variance_of(db, typevar), Type::Dynamic(_) | Type::Never | Type::WrapperDescriptor(_) @@ -7288,7 +7290,6 @@ impl<'db> VarianceInferable<'db> for Type<'db> { | Type::LiteralString | Type::BytesLiteral(_) | Type::SpecialForm(_) - | Type::KnownInstance(_) | Type::AlwaysFalsy | Type::AlwaysTruthy | Type::BoundSuper(_) @@ -7495,6 +7496,17 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( } } +impl<'db> VarianceInferable<'db> for KnownInstanceType<'db> { + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + match self { + KnownInstanceType::TypeAliasType(type_alias) => { + type_alias.raw_value_type(db).variance_of(db, typevar) + } + _ => TypeVarVariance::Bivariant, + } + } +} + impl<'db> KnownInstanceType<'db> { fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { @@ -10693,14 +10705,10 @@ impl<'db> PEP695TypeAliasType<'db> { semantic_index(db, scope.file(db)).expect_single_definition(type_alias_stmt_node) } + /// The RHS type of a PEP-695 style type alias with specialization applied. #[salsa::tracked(cycle_fn=value_type_cycle_recover, cycle_initial=value_type_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(crate) fn value_type(self, db: &'db dyn Db) -> Type<'db> { - let scope = self.rhs_scope(db); - let module = parsed_module(db, scope.file(db)).load(db); - let type_alias_stmt_node = scope.node(db).expect_type_alias(); - let definition = self.definition(db); - let value_type = - definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value); + let value_type = self.raw_value_type(db); if let Some(generic_context) = self.generic_context(db) { let specialization = self @@ -10713,6 +10721,25 @@ impl<'db> PEP695TypeAliasType<'db> { } } + /// The RHS type of a PEP-695 style type alias with *no* specialization applied. + /// + /// ## Warning + /// + /// This uses the semantic index to find the definition of the type alias. This means that if the + /// calling query is not in the same file as this type alias is defined in, then this will create + /// a cross-module dependency directly on the full AST which will lead to cache + /// over-invalidation. + /// This method also calls the type inference functions, and since type aliases can have recursive structures, + /// we should be careful not to create infinite recursions in this method (or make it tracked if necessary). + pub(crate) fn raw_value_type(self, db: &'db dyn Db) -> Type<'db> { + let scope = self.rhs_scope(db); + let module = parsed_module(db, scope.file(db)).load(db); + let type_alias_stmt_node = scope.node(db).expect_type_alias(); + let definition = self.definition(db); + + definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value) + } + pub(crate) fn apply_specialization( self, db: &'db dyn Db, @@ -10892,6 +10919,13 @@ impl<'db> TypeAliasType<'db> { } } + pub(crate) fn raw_value_type(self, db: &'db dyn Db) -> Type<'db> { + match self { + TypeAliasType::PEP695(type_alias) => type_alias.raw_value_type(db), + TypeAliasType::ManualPEP695(type_alias) => type_alias.value(db), + } + } + pub(crate) fn as_pep_695_type_alias(self) -> Option> { match self { TypeAliasType::PEP695(type_alias) => Some(type_alias), @@ -11724,4 +11758,85 @@ pub(crate) mod tests { .build(); assert_eq!(intersection.display(&db).to_string(), "Never"); } + + #[test] + fn type_alias_variance() { + use crate::db::tests::TestDb; + use crate::place::global_symbol; + + fn get_type_alias<'db>(db: &'db TestDb, name: &str) -> PEP695TypeAliasType<'db> { + let module = ruff_db::files::system_path_to_file(db, "/src/a.py").unwrap(); + let ty = global_symbol(db, module, name).place.expect_type(); + let Type::KnownInstance(KnownInstanceType::TypeAliasType(TypeAliasType::PEP695( + type_alias, + ))) = ty + else { + panic!("Expected `{name}` to be a type alias"); + }; + type_alias + } + fn get_bound_typevar<'db>( + db: &'db TestDb, + type_alias: PEP695TypeAliasType<'db>, + ) -> BoundTypeVarInstance<'db> { + let generic_context = type_alias.generic_context(db).unwrap(); + generic_context.variables(db).next().unwrap() + } + + let mut db = setup_db(); + db.write_dedented( + "/src/a.py", + r#" +class Covariant[T]: + def get(self) -> T: + raise ValueError + +class Contravariant[T]: + def set(self, value: T): + pass + +class Invariant[T]: + def get(self) -> T: + raise ValueError + def set(self, value: T): + pass + +class Bivariant[T]: + pass + +type CovariantAlias[T] = Covariant[T] +type ContravariantAlias[T] = Contravariant[T] +type InvariantAlias[T] = Invariant[T] +type BivariantAlias[T] = Bivariant[T] +"#, + ) + .unwrap(); + let covariant = get_type_alias(&db, "CovariantAlias"); + assert_eq!( + KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(covariant)) + .variance_of(&db, get_bound_typevar(&db, covariant)), + TypeVarVariance::Covariant + ); + + let contravariant = get_type_alias(&db, "ContravariantAlias"); + assert_eq!( + KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(contravariant)) + .variance_of(&db, get_bound_typevar(&db, contravariant)), + TypeVarVariance::Contravariant + ); + + let invariant = get_type_alias(&db, "InvariantAlias"); + assert_eq!( + KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(invariant)) + .variance_of(&db, get_bound_typevar(&db, invariant)), + TypeVarVariance::Invariant + ); + + let bivariant = get_type_alias(&db, "BivariantAlias"); + assert_eq!( + KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(bivariant)) + .variance_of(&db, get_bound_typevar(&db, bivariant)), + TypeVarVariance::Bivariant + ); + } }