[ty] linear variance inference for PEP-695 type parameters (#18713)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

Implement linear-time variance inference for type variables
(https://github.com/astral-sh/ty/issues/488).

Inspired by Martin Huschenbett's [PyCon 2025
Talk](https://www.youtube.com/watch?v=7uixlNTOY4s&t=9705s).

## Test Plan

update tests, add new tests, including for mutually recursive classes

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Eric Mark Martin 2025-08-19 20:54:09 -04:00 committed by GitHub
parent 656fc335f2
commit 33030b34cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1088 additions and 95 deletions

View file

@ -60,6 +60,7 @@ use crate::types::mro::{Mro, MroError, MroIterator};
pub(crate) use crate::types::narrow::infer_narrowing_constraint;
use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature};
use crate::types::tuple::TupleSpec;
use crate::types::variance::{TypeVarVariance, VarianceInferable};
use crate::unpack::EvaluationMode;
pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic;
use crate::{Db, FxOrderMap, FxOrderSet, Module, Program};
@ -92,6 +93,7 @@ mod subclass_of;
mod tuple;
mod type_ordering;
mod unpacker;
mod variance;
mod visitor;
mod definition;
@ -322,6 +324,29 @@ fn class_lookup_cycle_initial<'db>(
Place::bound(Type::Never).into()
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn variance_cycle_recover<'db, T>(
_db: &'db dyn Db,
_value: &TypeVarVariance,
count: u32,
_self: T,
_typevar: BoundTypeVarInstance<'db>,
) -> salsa::CycleRecoveryAction<TypeVarVariance> {
assert!(
count <= 2,
"Should only be able to cycle at most twice: there are only three levels in the lattice, each cycle should move us one"
);
salsa::CycleRecoveryAction::Iterate
}
fn variance_cycle_initial<'db, T>(
_db: &'db dyn Db,
_self: T,
_typevar: BoundTypeVarInstance<'db>,
) -> TypeVarVariance {
TypeVarVariance::Bivariant
}
/// Meta data for `Type::Todo`, which represents a known limitation in ty.
#[cfg(debug_assertions)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
@ -755,7 +780,7 @@ impl<'db> Type<'db> {
Name::new_static("T_all"),
None,
None,
variance,
Some(variance),
None,
TypeVarKind::Pep695,
),
@ -963,7 +988,6 @@ impl<'db> Type<'db> {
.expect("Expected a Type::FunctionLiteral variant")
}
#[cfg(test)]
pub(crate) const fn is_function_literal(&self) -> bool {
matches!(self, Type::FunctionLiteral(..))
}
@ -5533,7 +5557,10 @@ impl<'db> Type<'db> {
ast::name::Name::new_static("Self"),
Some(class_definition),
Some(TypeVarBoundOrConstraints::UpperBound(instance).into()),
TypeVarVariance::Invariant,
// According to the [spec], we can consider `Self`
// equivalent to an invariant type variable
// [spec]: https://typing.python.org/en/latest/spec/generics.html#self
Some(TypeVarVariance::Invariant),
None,
TypeVarKind::TypingSelf,
);
@ -6315,6 +6342,99 @@ impl<'db> From<&Type<'db>> for Type<'db> {
}
}
impl<'db> VarianceInferable<'db> for Type<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
tracing::debug!(
"Checking variance of '{tvar}' in `{ty:?}`",
tvar = typevar.typevar(db).name(db),
ty = self.display(db),
);
let v = match self {
Type::ClassLiteral(class_literal) => class_literal.variance_of(db, typevar),
Type::FunctionLiteral(function_type) => {
// TODO: do we need to replace self?
function_type.signature(db).variance_of(db, typevar)
}
Type::BoundMethod(method_type) => {
// TODO: do we need to replace self?
method_type
.function(db)
.signature(db)
.variance_of(db, typevar)
}
Type::NominalInstance(nominal_instance_type) => {
nominal_instance_type.variance_of(db, typevar)
}
Type::GenericAlias(generic_alias) => generic_alias.variance_of(db, typevar),
Type::Callable(callable_type) => callable_type.signatures(db).variance_of(db, typevar),
Type::TypeVar(other_typevar) | Type::NonInferableTypeVar(other_typevar)
if other_typevar == typevar =>
{
// type variables are covariant in themselves
TypeVarVariance::Covariant
}
Type::ProtocolInstance(protocol_instance_type) => {
protocol_instance_type.variance_of(db, typevar)
}
Type::Union(union_type) => union_type
.elements(db)
.iter()
.map(|ty| ty.variance_of(db, typevar))
.collect(),
Type::Intersection(intersection_type) => intersection_type
.positive(db)
.iter()
.map(|ty| ty.variance_of(db, typevar))
.chain(intersection_type.negative(db).iter().map(|ty| {
ty.with_polarity(TypeVarVariance::Contravariant)
.variance_of(db, typevar)
}))
.collect(),
Type::PropertyInstance(property_instance_type) => property_instance_type
.getter(db)
.iter()
.chain(&property_instance_type.setter(db))
.map(|ty| ty.variance_of(db, typevar))
.collect(),
Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar),
Type::Dynamic(_)
| Type::Never
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::ModuleLiteral(_)
| Type::IntLiteral(_)
| Type::BooleanLiteral(_)
| Type::StringLiteral(_)
| Type::EnumLiteral(_)
| Type::LiteralString
| Type::BytesLiteral(_)
| Type::SpecialForm(_)
| Type::KnownInstance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy
| Type::BoundSuper(_)
| Type::TypeVar(_)
| Type::NonInferableTypeVar(_)
| Type::TypeIs(_)
| Type::TypedDict(_)
| Type::TypeAlias(_) => TypeVarVariance::Bivariant,
};
tracing::debug!(
"Result of variance of '{tvar}' in `{ty:?}` is `{v:?}`",
tvar = typevar.typevar(db).name(db),
ty = self.display(db),
);
v
}
}
/// A mapping that can be applied to a type, producing another type. This is applied inductively to
/// the components of complex types.
///
@ -6972,8 +7092,8 @@ pub struct TypeVarInstance<'db> {
/// instead (to evaluate any lazy bound or constraints).
_bound_or_constraints: Option<TypeVarBoundOrConstraintsEvaluation<'db>>,
/// The variance of the TypeVar
variance: TypeVarVariance,
/// The explicitly specified variance of the TypeVar
explicit_variance: Option<TypeVarVariance>,
/// The default type for this TypeVar, if any. Don't use this field directly, use the
/// `default_type` method instead (to evaluate any lazy default).
@ -7065,7 +7185,7 @@ impl<'db> TypeVarInstance<'db> {
.lazy_constraints(db)
.map(|constraints| constraints.normalized_impl(db, visitor).into()),
}),
self.variance(db),
self.explicit_variance(db),
self._default(db).and_then(|default| match default {
TypeVarDefaultEvaluation::Eager(ty) => Some(ty.normalized_impl(db, visitor).into()),
TypeVarDefaultEvaluation::Lazy => self
@ -7093,7 +7213,7 @@ impl<'db> TypeVarInstance<'db> {
.lazy_constraints(db)
.map(|constraints| constraints.materialize(db, variance).into()),
}),
self.variance(db),
self.explicit_variance(db),
self._default(db).and_then(|default| match default {
TypeVarDefaultEvaluation::Eager(ty) => Some(ty.materialize(db, variance).into()),
TypeVarDefaultEvaluation::Lazy => self
@ -7118,7 +7238,7 @@ impl<'db> TypeVarInstance<'db> {
Name::new(format!("{}'instance", self.name(db))),
None,
Some(bound_or_constraints.into()),
self.variance(db),
self.explicit_variance(db),
None,
self.kind(db),
))
@ -7187,6 +7307,33 @@ pub struct BoundTypeVarInstance<'db> {
// The Salsa heap is tracked separately.
impl get_size2::GetSize for BoundTypeVarInstance<'_> {}
impl<'db> BoundTypeVarInstance<'db> {
pub(crate) fn variance_with_polarity(
self,
db: &'db dyn Db,
polarity: TypeVarVariance,
) -> TypeVarVariance {
let _span = tracing::trace_span!("variance_with_polarity").entered();
match self.typevar(db).explicit_variance(db) {
Some(explicit_variance) => explicit_variance.compose(polarity),
None => match self.binding_context(db) {
BindingContext::Definition(definition) => {
let type_inference = infer_definition_types(db, definition);
type_inference
.binding_type(definition)
.with_polarity(polarity)
.variance_of(db, self)
}
BindingContext::Synthetic => TypeVarVariance::Invariant,
},
}
}
pub(crate) fn variance(self, db: &'db dyn Db) -> TypeVarVariance {
self.variance_with_polarity(db, TypeVarVariance::Covariant)
}
}
fn walk_bound_type_var_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
@ -7250,28 +7397,6 @@ impl<'db> BoundTypeVarInstance<'db> {
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub enum TypeVarVariance {
Invariant,
Covariant,
Contravariant,
Bivariant,
}
impl TypeVarVariance {
/// Flips the polarity of the variance.
///
/// Covariant becomes contravariant, contravariant becomes covariant, others remain unchanged.
pub(crate) const fn flip(self) -> Self {
match self {
TypeVarVariance::Invariant => TypeVarVariance::Invariant,
TypeVarVariance::Covariant => TypeVarVariance::Contravariant,
TypeVarVariance::Contravariant => TypeVarVariance::Covariant,
TypeVarVariance::Bivariant => TypeVarVariance::Bivariant,
}
}
}
/// Whether a typevar default is eagerly specified or lazily evaluated.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub enum TypeVarDefaultEvaluation<'db> {

View file

@ -13,8 +13,10 @@ use crate::module_resolver::KnownModule;
use crate::semantic_index::definition::{Definition, DefinitionState};
use crate::semantic_index::place::ScopedPlaceId;
use crate::semantic_index::scope::NodeWithScopeKind;
use crate::semantic_index::symbol::Symbol;
use crate::semantic_index::{
BindingWithConstraints, DeclarationWithConstraint, SemanticIndex, attribute_declarations,
attribute_scopes,
};
use crate::types::context::InferContext;
use crate::types::diagnostic::{INVALID_LEGACY_TYPE_VARIABLE, INVALID_TYPE_ALIAS_TYPE};
@ -28,8 +30,8 @@ use crate::types::{
ApplyTypeMappingVisitor, BareTypeAliasType, Binding, BoundSuperError, BoundSuperType,
CallableType, DataclassParams, DeprecatedInstance, HasRelationToVisitor, KnownInstanceType,
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping,
TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, declaration_type,
infer_definition_types, todo_type,
TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, VarianceInferable,
declaration_type, infer_definition_types, todo_type,
};
use crate::{
Db, FxIndexMap, FxOrderSet, Program,
@ -314,6 +316,51 @@ impl<'db> From<GenericAlias<'db>> for Type<'db> {
}
}
#[salsa::tracked]
impl<'db> VarianceInferable<'db> for GenericAlias<'db> {
#[salsa::tracked]
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
let origin = self.origin(db);
let specialization = self.specialization(db);
// if the class is the thing defining the variable, then it can
// reference it without it being applied to the specialization
std::iter::once(origin.variance_of(db, typevar))
.chain(
specialization
.generic_context(db)
.variables(db)
.iter()
.zip(specialization.types(db))
.map(|(generic_typevar, ty)| {
if let Some(explicit_variance) =
generic_typevar.typevar(db).explicit_variance(db)
{
ty.with_polarity(explicit_variance).variance_of(db, typevar)
} else {
// `with_polarity` composes the passed variance with the
// inferred one. The inference is done lazily, as we can
// sometimes determine the result just from the passed
// variance. This operation is commutative, so we could
// infer either first. We choose to make the `ClassLiteral`
// variance lazy, as it is known to be expensive, requiring
// that we traverse all members.
//
// If salsa let us look at the cache, we could check first
// to see if the class literal query was already run.
let typevar_variance_in_substituted_type = ty.variance_of(db, typevar);
origin
.with_polarity(typevar_variance_in_substituted_type)
.variance_of(db, *generic_typevar)
}
}),
)
.collect()
}
}
/// Represents a class type, which might be a non-generic class, or a specialization of a generic
/// class.
#[derive(
@ -1136,6 +1183,15 @@ impl<'db> From<ClassType<'db>> for Type<'db> {
}
}
impl<'db> VarianceInferable<'db> for ClassType<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
match self {
Self::NonGeneric(class) => class.variance_of(db, typevar),
Self::Generic(generic) => generic.variance_of(db, typevar),
}
}
}
/// A filter that describes which methods are considered when looking for implicit attribute assignments
/// in [`ClassLiteral::implicit_attribute`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@ -3060,6 +3116,126 @@ impl<'db> From<ClassLiteral<'db>> for ClassType<'db> {
}
}
#[salsa::tracked]
impl<'db> VarianceInferable<'db> for ClassLiteral<'db> {
#[salsa::tracked(cycle_fn=crate::types::variance_cycle_recover, cycle_initial=crate::types::variance_cycle_initial)]
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
let typevar_in_generic_context = self
.generic_context(db)
.is_some_and(|generic_context| generic_context.variables(db).contains(&typevar));
if !typevar_in_generic_context {
return TypeVarVariance::Bivariant;
}
let class_body_scope = self.body_scope(db);
let file = class_body_scope.file(db);
let index = semantic_index(db, file);
let explicit_bases_variances = self
.explicit_bases(db)
.iter()
.map(|class| class.variance_of(db, typevar));
let default_attribute_variance = {
let is_namedtuple = CodeGeneratorKind::NamedTuple.matches(db, self);
// Python 3.13 introduced a synthesized `__replace__` method on dataclasses which uses
// their field types in contravariant position, thus meaning a frozen dataclass must
// still be invariant in its field types. Other synthesized methods on dataclasses are
// not considered here, since they don't use field types in their signatures. TODO:
// ideally we'd have a single source of truth for information about synthesized
// methods, so we just look them up normally and don't hardcode this knowledge here.
let is_frozen_dataclass = Program::get(db).python_version(db) <= PythonVersion::PY312
&& self
.dataclass_params(db)
.is_some_and(|params| params.contains(DataclassParams::FROZEN));
if is_namedtuple || is_frozen_dataclass {
TypeVarVariance::Covariant
} else {
TypeVarVariance::Invariant
}
};
let init_name: &Name = &"__init__".into();
let new_name: &Name = &"__new__".into();
let use_def_map = index.use_def_map(class_body_scope.file_scope_id(db));
let table = place_table(db, class_body_scope);
let attribute_places_and_qualifiers =
use_def_map
.all_end_of_scope_symbol_declarations()
.map(|(symbol_id, declarations)| {
let place_and_qual =
place_from_declarations(db, declarations).ignore_conflicting_declarations();
(symbol_id, place_and_qual)
})
.chain(use_def_map.all_end_of_scope_symbol_bindings().map(
|(symbol_id, bindings)| (symbol_id, place_from_bindings(db, bindings).into()),
))
.filter_map(|(symbol_id, place_and_qual)| {
if let Some(name) = table.place(symbol_id).as_symbol().map(Symbol::name) {
(![init_name, new_name].contains(&name))
.then_some((name.to_string(), place_and_qual))
} else {
None
}
});
// Dataclasses can have some additional synthesized methods (`__eq__`, `__hash__`,
// `__lt__`, etc.) but none of these will have field types type variables in their signatures, so we
// don't need to consider them for variance.
let attribute_names = attribute_scopes(db, self.body_scope(db))
.flat_map(|function_scope_id| {
index
.place_table(function_scope_id)
.members()
.filter_map(|member| member.as_instance_attribute())
.filter(|name| *name != init_name && *name != new_name)
.map(std::string::ToString::to_string)
.collect::<Vec<_>>()
})
.dedup();
let attribute_variances = attribute_names
.map(|name| {
let place_and_quals = self.own_instance_member(db, &name);
(name, place_and_quals)
})
.chain(attribute_places_and_qualifiers)
.dedup()
.filter_map(|(name, place_and_qual)| {
place_and_qual.place.ignore_possibly_unbound().map(|ty| {
let variance = if place_and_qual
.qualifiers
// `CLASS_VAR || FINAL` is really `all()`, but
// we want to be robust against new qualifiers
.intersects(TypeQualifiers::CLASS_VAR | TypeQualifiers::FINAL)
// We don't allow mutation of methods or properties
|| ty.is_function_literal()
|| ty.is_property_instance()
// Underscore-prefixed attributes are assumed not to be externally mutated
|| name.starts_with('_')
{
// CLASS_VAR: class vars generally shouldn't contain the
// type variable, but they could if it's a
// callable type. They can't be mutated on instances.
//
// FINAL: final attributes are immutable, and thus covariant
TypeVarVariance::Covariant
} else {
default_attribute_variance
};
ty.with_polarity(variance).variance_of(db, typevar)
})
});
attribute_variances
.chain(explicit_bases_variances)
.collect()
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize)]
pub(super) enum InheritanceCycle {
/// The class is cyclically defined and is a participant in the cycle.
@ -4673,7 +4849,7 @@ impl KnownClass {
&target.id,
Some(containing_assignment),
bound_or_constraint,
variance,
Some(variance),
default.map(Into::into),
TypeVarKind::Legacy,
),

View file

@ -549,12 +549,7 @@ impl<'db> Specialization<'db> {
.into_iter()
.zip(self.types(db))
.map(|(bound_typevar, vartype)| {
let variance = match bound_typevar.typevar(db).variance(db) {
TypeVarVariance::Invariant => TypeVarVariance::Invariant,
TypeVarVariance::Covariant => variance,
TypeVarVariance::Contravariant => variance.flip(),
TypeVarVariance::Bivariant => unreachable!(),
};
let variance = bound_typevar.variance_with_polarity(db, variance);
vartype.materialize(db, variance)
})
.collect();
@ -599,7 +594,7 @@ impl<'db> Specialization<'db> {
// - contravariant: verify that other_type <: self_type
// - invariant: verify that self_type <: other_type AND other_type <: self_type
// - bivariant: skip, can't make subtyping/assignability false
let compatible = match bound_typevar.typevar(db).variance(db) {
let compatible = match bound_typevar.variance(db) {
TypeVarVariance::Invariant => match relation {
TypeRelation::Subtyping => self_type.is_equivalent_to(db, *other_type),
TypeRelation::Assignability => {
@ -639,7 +634,7 @@ impl<'db> Specialization<'db> {
// - contravariant: verify that other_type == self_type
// - invariant: verify that self_type == other_type
// - bivariant: skip, can't make equivalence false
let compatible = match bound_typevar.typevar(db).variance(db) {
let compatible = match bound_typevar.variance(db) {
TypeVarVariance::Invariant
| TypeVarVariance::Covariant
| TypeVarVariance::Contravariant => self_type.is_equivalent_to(db, *other_type),

View file

@ -124,8 +124,8 @@ use crate::types::{
MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm,
Parameters, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType,
TypeAndQualifiers, TypeIsType, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation,
TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, TypeVarVariance, UnionBuilder,
UnionType, binding_type, todo_type,
TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type,
todo_type,
};
use crate::unpack::{EvaluationMode, Unpack, UnpackPosition};
use crate::util::diagnostics::format_enumeration;
@ -3470,7 +3470,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&name.id,
Some(definition),
bound_or_constraint,
TypeVarVariance::Invariant, // TODO: infer this
None,
default.as_deref().map(|_| TypeVarDefaultEvaluation::Lazy),
TypeVarKind::Pep695,
)));

View file

@ -12,7 +12,7 @@ use crate::types::protocol_class::walk_protocol_interface;
use crate::types::tuple::{TupleSpec, TupleType};
use crate::types::{
ApplyTypeMappingVisitor, ClassBase, DynamicType, HasRelationToVisitor, IsDisjointVisitor,
NormalizedVisitor, TypeMapping, TypeRelation, TypeTransformer,
NormalizedVisitor, TypeMapping, TypeRelation, TypeTransformer, VarianceInferable,
};
use crate::{Db, FxOrderSet};
@ -406,6 +406,12 @@ pub(crate) struct SliceLiteral {
pub(crate) step: Option<i32>,
}
impl<'db> VarianceInferable<'db> for NominalInstanceType<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
self.class(db).variance_of(db, typevar)
}
}
/// A `ProtocolInstanceType` represents the set of all possible runtime objects
/// that conform to the interface described by a certain protocol.
#[derive(
@ -593,6 +599,12 @@ impl<'db> ProtocolInstanceType<'db> {
}
}
impl<'db> VarianceInferable<'db> for ProtocolInstanceType<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
self.inner.variance_of(db, typevar)
}
}
/// An enumeration of the two kinds of protocol types: those that originate from a class
/// definition in source code, and those that are synthesized from a set of members.
#[derive(
@ -618,12 +630,23 @@ impl<'db> Protocol<'db> {
}
}
impl<'db> VarianceInferable<'db> for Protocol<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
match self {
Protocol::FromClass(class_type) => class_type.variance_of(db, typevar),
Protocol::Synthesized(synthesized_protocol_type) => {
synthesized_protocol_type.variance_of(db, typevar)
}
}
}
}
mod synthesized_protocol {
use crate::semantic_index::definition::Definition;
use crate::types::protocol_class::ProtocolInterface;
use crate::types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, NormalizedVisitor, TypeMapping,
TypeVarVariance,
TypeVarVariance, VarianceInferable,
};
use crate::{Db, FxOrderSet};
@ -676,4 +699,14 @@ mod synthesized_protocol {
self.0
}
}
impl<'db> VarianceInferable<'db> for SynthesizedProtocolType<'db> {
fn variance_of(
self,
db: &'db dyn Db,
typevar: BoundTypeVarInstance<'db>,
) -> TypeVarVariance {
self.0.variance_of(db, typevar)
}
}
}

View file

@ -18,7 +18,7 @@ use crate::{
types::{
BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, IsDisjointVisitor,
KnownFunction, NormalizedVisitor, PropertyInstanceType, Signature, Type, TypeMapping,
TypeQualifiers, TypeRelation, TypeTransformer,
TypeQualifiers, TypeRelation, TypeTransformer, VarianceInferable,
signatures::{Parameter, Parameters},
},
};
@ -301,6 +301,15 @@ impl<'db> ProtocolInterface<'db> {
}
}
impl<'db> VarianceInferable<'db> for ProtocolInterface<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
self.members(db)
// TODO do we need to switch on member kind?
.map(|member| member.ty().variance_of(db, typevar))
.collect()
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update, get_size2::GetSize)]
pub(super) struct ProtocolMemberData<'db> {
kind: ProtocolMemberKind<'db>,

View file

@ -20,7 +20,7 @@ use crate::semantic_index::definition::Definition;
use crate::types::generics::{GenericContext, walk_generic_context};
use crate::types::{
BindingContext, BoundTypeVarInstance, KnownClass, NormalizedVisitor, TypeMapping, TypeRelation,
todo_type,
VarianceInferable, todo_type,
};
use crate::{Db, FxOrderSet};
use ruff_python_ast::{self as ast, name::Name};
@ -223,6 +223,16 @@ impl<'a, 'db> IntoIterator for &'a CallableSignature<'db> {
}
}
impl<'db> VarianceInferable<'db> for &CallableSignature<'db> {
// TODO: possibly need to replace self
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
self.overloads
.iter()
.map(|signature| signature.variance_of(db, typevar))
.collect()
}
}
/// The signature of one of the overloads of a callable.
#[derive(Clone, Debug, salsa::Update, get_size2::GetSize)]
pub struct Signature<'db> {
@ -982,6 +992,28 @@ impl std::hash::Hash for Signature<'_> {
}
}
impl<'db> VarianceInferable<'db> for &Signature<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
tracing::debug!(
"Checking variance of `{tvar}` in `{self:?}`",
tvar = typevar.typevar(db).name(db)
);
itertools::chain(
self.parameters
.iter()
.filter_map(|parameter| match parameter.form {
ParameterForm::Type => None,
ParameterForm::Value => parameter.annotated_type().map(|ty| {
ty.with_polarity(TypeVarVariance::Contravariant)
.variance_of(db, typevar)
}),
}),
self.return_ty.map(|ty| ty.variance_of(db, typevar)),
)
.collect()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)]
pub(crate) struct Parameters<'db> {
// TODO: use SmallVec here once invariance bug is fixed

View file

@ -2,6 +2,7 @@ use ruff_python_ast::name::Name;
use crate::place::PlaceAndQualifiers;
use crate::semantic_index::definition::Definition;
use crate::types::variance::VarianceInferable;
use crate::types::{
ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, ClassType, DynamicType,
HasRelationToVisitor, KnownClass, MemberLookupPolicy, NormalizedVisitor, Type, TypeMapping,
@ -103,7 +104,7 @@ impl<'db> SubclassOfType<'db> {
)
.into(),
),
variance,
Some(variance),
None,
TypeVarKind::Pep695,
),
@ -215,6 +216,15 @@ impl<'db> SubclassOfType<'db> {
}
}
impl<'db> VarianceInferable<'db> for SubclassOfType<'db> {
fn variance_of(self, db: &dyn Db, typevar: BoundTypeVarInstance<'_>) -> TypeVarVariance {
match self.subclass_of {
SubclassOfInner::Dynamic(_) => TypeVarVariance::Bivariant,
SubclassOfInner::Class(class) => class.variance_of(db, typevar),
}
}
}
/// An enumeration of the different kinds of `type[]` types that a [`SubclassOfType`] can represent:
///
/// 1. A "subclass of a class": `type[C]` for any class object `C`

View file

@ -0,0 +1,138 @@
use crate::{Db, types::BoundTypeVarInstance};
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub enum TypeVarVariance {
Invariant,
Covariant,
Contravariant,
Bivariant,
}
impl TypeVarVariance {
pub const fn bottom() -> Self {
TypeVarVariance::Bivariant
}
pub const fn top() -> Self {
TypeVarVariance::Invariant
}
// supremum
#[must_use]
pub(crate) const fn join(self, other: Self) -> Self {
use TypeVarVariance::{Bivariant, Contravariant, Covariant, Invariant};
match (self, other) {
(Invariant, _) | (_, Invariant) => Invariant,
(Covariant, Covariant) => Covariant,
(Contravariant, Contravariant) => Contravariant,
(Covariant, Contravariant) | (Contravariant, Covariant) => Invariant,
(Bivariant, other) | (other, Bivariant) => other,
}
}
/// Compose two variances: useful for combining use-site and definition-site variances, e.g.
/// `C[D[T]]` or function argument/return position variances.
///
/// `other` is a thunk to avoid unnecessary computation when `self` is `Bivariant`.
///
/// Based on the variance composition/transformation operator in
/// <https://people.cs.umass.edu/~yannis/variance-extended2011.pdf>, page 5
///
/// While their operation would have `compose(Invariant, Bivariant) ==
/// Invariant`, we instead have it evaluate to `Bivariant`. This is a valid
/// choice, as discussed on that same page, where type equality is semantic
/// rather than syntactic. To see that this holds for our setting consider
/// the type
/// ```python
/// type ConstantInt[T] = int
/// ```
/// We would say `ConstantInt[str]` = `ConstantInt[float]`, so we qualify as
/// using semantic equivalence.
#[must_use]
pub(crate) fn compose(self, other: Self) -> Self {
self.compose_thunk(|| other)
}
/// Like `compose`, but takes `other` as a thunk to avoid unnecessary
/// computation when `self` is `Bivariant`.
#[must_use]
pub(crate) fn compose_thunk<F>(self, other: F) -> Self
where
F: FnOnce() -> Self,
{
match self {
TypeVarVariance::Covariant => other(),
TypeVarVariance::Contravariant => other().flip(),
TypeVarVariance::Bivariant => TypeVarVariance::Bivariant,
TypeVarVariance::Invariant => {
if TypeVarVariance::Bivariant == other() {
TypeVarVariance::Bivariant
} else {
TypeVarVariance::Invariant
}
}
}
}
/// Flips the polarity of the variance.
///
/// Covariant becomes contravariant, contravariant becomes covariant, others remain unchanged.
pub(crate) const fn flip(self) -> Self {
match self {
TypeVarVariance::Invariant => TypeVarVariance::Invariant,
TypeVarVariance::Covariant => TypeVarVariance::Contravariant,
TypeVarVariance::Contravariant => TypeVarVariance::Covariant,
TypeVarVariance::Bivariant => TypeVarVariance::Bivariant,
}
}
}
impl std::iter::FromIterator<Self> for TypeVarVariance {
fn from_iter<T: IntoIterator<Item = Self>>(iter: T) -> Self {
use std::ops::ControlFlow;
// TODO: use `into_value` when control_flow_into_value is stable
let (ControlFlow::Break(variance) | ControlFlow::Continue(variance)) = iter
.into_iter()
.try_fold(TypeVarVariance::Bivariant, |acc, variance| {
let supremum = acc.join(variance);
match supremum {
// short circuit at top
TypeVarVariance::Invariant => ControlFlow::Break(supremum),
TypeVarVariance::Bivariant
| TypeVarVariance::Covariant
| TypeVarVariance::Contravariant => ControlFlow::Continue(supremum),
}
});
variance
}
}
pub(crate) trait VarianceInferable<'db>: Sized {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance;
fn with_polarity(self, polarity: TypeVarVariance) -> WithPolarity<Self> {
WithPolarity {
variance_inferable: self,
polarity,
}
}
}
pub(crate) struct WithPolarity<T> {
variance_inferable: T,
polarity: TypeVarVariance,
}
impl<'db, T> VarianceInferable<'db> for WithPolarity<T>
where
T: VarianceInferable<'db>,
{
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
let WithPolarity {
variance_inferable,
polarity,
} = self;
polarity.compose_thunk(|| variance_inferable.variance_of(db, typevar))
}
}