mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-22 08:12:17 +00:00
[ty] impl VarianceInferable
for KnownInstanceType
(#20924)
## Summary Derived from #20900 Implement `VarianceInferable` for `KnownInstanceType` (especially for `KnownInstanceType::TypeAliasType`). The variance of a type alias matches its value type. In normal usage, type aliases are expanded to value types, so the variance of a type alias can be obtained without implementing this. However, for example, if we want to display the variance when hovering over a type alias, we need to be able to obtain the variance of the type alias itself (cf. #20900). ## Test Plan I couldn't come up with a way to test this in mdtest, so I'm testing it in a test submodule at the end of `types.rs`. I also added a test to `mdtest/generics/pep695/variance.md`, but it passes without the changes in this PR.
This commit is contained in:
parent
6e7ff07065
commit
e4384fc212
2 changed files with 183 additions and 9 deletions
|
@ -790,6 +790,65 @@ static_assert(not is_assignable_to(C[B], C[A]))
|
||||||
static_assert(not is_assignable_to(C[A], C[B]))
|
static_assert(not is_assignable_to(C[A], C[B]))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Type aliases
|
||||||
|
|
||||||
|
The variance of the type alias matches the variance of the value type (RHS type).
|
||||||
|
|
||||||
|
```py
|
||||||
|
from ty_extensions import static_assert, is_subtype_of
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
class Covariant[T]:
|
||||||
|
def get(self) -> T:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
type CovariantLiteral1 = Covariant[Literal[1]]
|
||||||
|
type CovariantInt = Covariant[int]
|
||||||
|
type MyCovariant[T] = Covariant[T]
|
||||||
|
|
||||||
|
static_assert(is_subtype_of(CovariantLiteral1, CovariantInt))
|
||||||
|
static_assert(is_subtype_of(MyCovariant[Literal[1]], MyCovariant[int]))
|
||||||
|
|
||||||
|
class Contravariant[T]:
|
||||||
|
def set(self, value: T):
|
||||||
|
pass
|
||||||
|
|
||||||
|
type ContravariantLiteral1 = Contravariant[Literal[1]]
|
||||||
|
type ContravariantInt = Contravariant[int]
|
||||||
|
type MyContravariant[T] = Contravariant[T]
|
||||||
|
|
||||||
|
static_assert(is_subtype_of(ContravariantInt, ContravariantLiteral1))
|
||||||
|
static_assert(is_subtype_of(MyContravariant[int], MyContravariant[Literal[1]]))
|
||||||
|
|
||||||
|
class Invariant[T]:
|
||||||
|
def get(self) -> T:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def set(self, value: T):
|
||||||
|
pass
|
||||||
|
|
||||||
|
type InvariantLiteral1 = Invariant[Literal[1]]
|
||||||
|
type InvariantInt = Invariant[int]
|
||||||
|
type MyInvariant[T] = Invariant[T]
|
||||||
|
|
||||||
|
static_assert(not is_subtype_of(InvariantInt, InvariantLiteral1))
|
||||||
|
static_assert(not is_subtype_of(InvariantLiteral1, InvariantInt))
|
||||||
|
static_assert(not is_subtype_of(MyInvariant[Literal[1]], MyInvariant[int]))
|
||||||
|
static_assert(not is_subtype_of(MyInvariant[int], MyInvariant[Literal[1]]))
|
||||||
|
|
||||||
|
class Bivariant[T]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
type BivariantLiteral1 = Bivariant[Literal[1]]
|
||||||
|
type BivariantInt = Bivariant[int]
|
||||||
|
type MyBivariant[T] = Bivariant[T]
|
||||||
|
|
||||||
|
static_assert(is_subtype_of(BivariantInt, BivariantLiteral1))
|
||||||
|
static_assert(is_subtype_of(BivariantLiteral1, BivariantInt))
|
||||||
|
static_assert(is_subtype_of(MyBivariant[Literal[1]], MyBivariant[int]))
|
||||||
|
static_assert(is_subtype_of(MyBivariant[int], MyBivariant[Literal[1]]))
|
||||||
|
```
|
||||||
|
|
||||||
## Inheriting from generic classes with inferred variance
|
## Inheriting from generic classes with inferred variance
|
||||||
|
|
||||||
When inheriting from a generic class with our type variable substituted in, we count its occurrences
|
When inheriting from a generic class with our type variable substituted in, we count its occurrences
|
||||||
|
|
|
@ -55,9 +55,10 @@ use crate::types::function::{
|
||||||
DataclassTransformerFlags, DataclassTransformerParams, FunctionSpans, FunctionType,
|
DataclassTransformerFlags, DataclassTransformerParams, FunctionSpans, FunctionType,
|
||||||
KnownFunction,
|
KnownFunction,
|
||||||
};
|
};
|
||||||
|
pub(crate) use crate::types::generics::GenericContext;
|
||||||
use crate::types::generics::{
|
use crate::types::generics::{
|
||||||
GenericContext, InferableTypeVars, PartialSpecialization, Specialization, bind_typevar,
|
InferableTypeVars, PartialSpecialization, Specialization, bind_typevar, typing_self,
|
||||||
typing_self, walk_generic_context,
|
walk_generic_context,
|
||||||
};
|
};
|
||||||
use crate::types::infer::infer_unpack_types;
|
use crate::types::infer::infer_unpack_types;
|
||||||
use crate::types::mro::{Mro, MroError, MroIterator};
|
use crate::types::mro::{Mro, MroError, MroIterator};
|
||||||
|
@ -7274,6 +7275,7 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
|
||||||
.collect(),
|
.collect(),
|
||||||
Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar),
|
Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar),
|
||||||
Type::TypeIs(type_is_type) => type_is_type.variance_of(db, typevar),
|
Type::TypeIs(type_is_type) => type_is_type.variance_of(db, typevar),
|
||||||
|
Type::KnownInstance(known_instance) => known_instance.variance_of(db, typevar),
|
||||||
Type::Dynamic(_)
|
Type::Dynamic(_)
|
||||||
| Type::Never
|
| Type::Never
|
||||||
| Type::WrapperDescriptor(_)
|
| Type::WrapperDescriptor(_)
|
||||||
|
@ -7288,7 +7290,6 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
|
||||||
| Type::LiteralString
|
| Type::LiteralString
|
||||||
| Type::BytesLiteral(_)
|
| Type::BytesLiteral(_)
|
||||||
| Type::SpecialForm(_)
|
| Type::SpecialForm(_)
|
||||||
| Type::KnownInstance(_)
|
|
||||||
| Type::AlwaysFalsy
|
| Type::AlwaysFalsy
|
||||||
| Type::AlwaysTruthy
|
| Type::AlwaysTruthy
|
||||||
| Type::BoundSuper(_)
|
| Type::BoundSuper(_)
|
||||||
|
@ -7495,6 +7496,17 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'db> VarianceInferable<'db> for KnownInstanceType<'db> {
|
||||||
|
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
|
||||||
|
match self {
|
||||||
|
KnownInstanceType::TypeAliasType(type_alias) => {
|
||||||
|
type_alias.raw_value_type(db).variance_of(db, typevar)
|
||||||
|
}
|
||||||
|
_ => TypeVarVariance::Bivariant,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'db> KnownInstanceType<'db> {
|
impl<'db> KnownInstanceType<'db> {
|
||||||
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
|
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
|
||||||
match self {
|
match self {
|
||||||
|
@ -10693,14 +10705,10 @@ impl<'db> PEP695TypeAliasType<'db> {
|
||||||
semantic_index(db, scope.file(db)).expect_single_definition(type_alias_stmt_node)
|
semantic_index(db, scope.file(db)).expect_single_definition(type_alias_stmt_node)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The RHS type of a PEP-695 style type alias with specialization applied.
|
||||||
#[salsa::tracked(cycle_fn=value_type_cycle_recover, cycle_initial=value_type_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
|
#[salsa::tracked(cycle_fn=value_type_cycle_recover, cycle_initial=value_type_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
|
||||||
pub(crate) fn value_type(self, db: &'db dyn Db) -> Type<'db> {
|
pub(crate) fn value_type(self, db: &'db dyn Db) -> Type<'db> {
|
||||||
let scope = self.rhs_scope(db);
|
let value_type = self.raw_value_type(db);
|
||||||
let module = parsed_module(db, scope.file(db)).load(db);
|
|
||||||
let type_alias_stmt_node = scope.node(db).expect_type_alias();
|
|
||||||
let definition = self.definition(db);
|
|
||||||
let value_type =
|
|
||||||
definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value);
|
|
||||||
|
|
||||||
if let Some(generic_context) = self.generic_context(db) {
|
if let Some(generic_context) = self.generic_context(db) {
|
||||||
let specialization = self
|
let specialization = self
|
||||||
|
@ -10713,6 +10721,25 @@ impl<'db> PEP695TypeAliasType<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The RHS type of a PEP-695 style type alias with *no* specialization applied.
|
||||||
|
///
|
||||||
|
/// ## Warning
|
||||||
|
///
|
||||||
|
/// This uses the semantic index to find the definition of the type alias. This means that if the
|
||||||
|
/// calling query is not in the same file as this type alias is defined in, then this will create
|
||||||
|
/// a cross-module dependency directly on the full AST which will lead to cache
|
||||||
|
/// over-invalidation.
|
||||||
|
/// This method also calls the type inference functions, and since type aliases can have recursive structures,
|
||||||
|
/// we should be careful not to create infinite recursions in this method (or make it tracked if necessary).
|
||||||
|
pub(crate) fn raw_value_type(self, db: &'db dyn Db) -> Type<'db> {
|
||||||
|
let scope = self.rhs_scope(db);
|
||||||
|
let module = parsed_module(db, scope.file(db)).load(db);
|
||||||
|
let type_alias_stmt_node = scope.node(db).expect_type_alias();
|
||||||
|
let definition = self.definition(db);
|
||||||
|
|
||||||
|
definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn apply_specialization(
|
pub(crate) fn apply_specialization(
|
||||||
self,
|
self,
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
|
@ -10892,6 +10919,13 @@ impl<'db> TypeAliasType<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn raw_value_type(self, db: &'db dyn Db) -> Type<'db> {
|
||||||
|
match self {
|
||||||
|
TypeAliasType::PEP695(type_alias) => type_alias.raw_value_type(db),
|
||||||
|
TypeAliasType::ManualPEP695(type_alias) => type_alias.value(db),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn as_pep_695_type_alias(self) -> Option<PEP695TypeAliasType<'db>> {
|
pub(crate) fn as_pep_695_type_alias(self) -> Option<PEP695TypeAliasType<'db>> {
|
||||||
match self {
|
match self {
|
||||||
TypeAliasType::PEP695(type_alias) => Some(type_alias),
|
TypeAliasType::PEP695(type_alias) => Some(type_alias),
|
||||||
|
@ -11724,4 +11758,85 @@ pub(crate) mod tests {
|
||||||
.build();
|
.build();
|
||||||
assert_eq!(intersection.display(&db).to_string(), "Never");
|
assert_eq!(intersection.display(&db).to_string(), "Never");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn type_alias_variance() {
|
||||||
|
use crate::db::tests::TestDb;
|
||||||
|
use crate::place::global_symbol;
|
||||||
|
|
||||||
|
fn get_type_alias<'db>(db: &'db TestDb, name: &str) -> PEP695TypeAliasType<'db> {
|
||||||
|
let module = ruff_db::files::system_path_to_file(db, "/src/a.py").unwrap();
|
||||||
|
let ty = global_symbol(db, module, name).place.expect_type();
|
||||||
|
let Type::KnownInstance(KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(
|
||||||
|
type_alias,
|
||||||
|
))) = ty
|
||||||
|
else {
|
||||||
|
panic!("Expected `{name}` to be a type alias");
|
||||||
|
};
|
||||||
|
type_alias
|
||||||
|
}
|
||||||
|
fn get_bound_typevar<'db>(
|
||||||
|
db: &'db TestDb,
|
||||||
|
type_alias: PEP695TypeAliasType<'db>,
|
||||||
|
) -> BoundTypeVarInstance<'db> {
|
||||||
|
let generic_context = type_alias.generic_context(db).unwrap();
|
||||||
|
generic_context.variables(db).next().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut db = setup_db();
|
||||||
|
db.write_dedented(
|
||||||
|
"/src/a.py",
|
||||||
|
r#"
|
||||||
|
class Covariant[T]:
|
||||||
|
def get(self) -> T:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
class Contravariant[T]:
|
||||||
|
def set(self, value: T):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Invariant[T]:
|
||||||
|
def get(self) -> T:
|
||||||
|
raise ValueError
|
||||||
|
def set(self, value: T):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Bivariant[T]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
type CovariantAlias[T] = Covariant[T]
|
||||||
|
type ContravariantAlias[T] = Contravariant[T]
|
||||||
|
type InvariantAlias[T] = Invariant[T]
|
||||||
|
type BivariantAlias[T] = Bivariant[T]
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let covariant = get_type_alias(&db, "CovariantAlias");
|
||||||
|
assert_eq!(
|
||||||
|
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(covariant))
|
||||||
|
.variance_of(&db, get_bound_typevar(&db, covariant)),
|
||||||
|
TypeVarVariance::Covariant
|
||||||
|
);
|
||||||
|
|
||||||
|
let contravariant = get_type_alias(&db, "ContravariantAlias");
|
||||||
|
assert_eq!(
|
||||||
|
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(contravariant))
|
||||||
|
.variance_of(&db, get_bound_typevar(&db, contravariant)),
|
||||||
|
TypeVarVariance::Contravariant
|
||||||
|
);
|
||||||
|
|
||||||
|
let invariant = get_type_alias(&db, "InvariantAlias");
|
||||||
|
assert_eq!(
|
||||||
|
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(invariant))
|
||||||
|
.variance_of(&db, get_bound_typevar(&db, invariant)),
|
||||||
|
TypeVarVariance::Invariant
|
||||||
|
);
|
||||||
|
|
||||||
|
let bivariant = get_type_alias(&db, "BivariantAlias");
|
||||||
|
assert_eq!(
|
||||||
|
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(bivariant))
|
||||||
|
.variance_of(&db, get_bound_typevar(&db, bivariant)),
|
||||||
|
TypeVarVariance::Bivariant
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue