[ty] Propagate specializations to ancestor base classes (#17892)

@AlexWaygood discovered that even though we've been propagating
specializations to _parent_ base classes correctly, we haven't been
passing them on to _grandparent_ base classes:
https://github.com/astral-sh/ruff/pull/17832#issuecomment-2854360969

```py
class Bar[T]:
    x: T

class Baz[T](Bar[T]): ...
class Spam[T](Baz[T]): ...

reveal_type(Spam[int]().x) # revealed: `T`, but should be `int`
```

This PR updates the MRO machinery to apply the current specialization
when starting to iterate the MRO of each base class.
This commit is contained in:
Douglas Creager 2025-05-06 14:25:21 -04:00 committed by GitHub
parent 8152ba7cb7
commit 9085f18353
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 156 additions and 52 deletions

View file

@ -4948,6 +4948,19 @@ impl<'db> Type<'db> {
}
}
#[must_use]
pub fn apply_optional_specialization(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
) -> Type<'db> {
if let Some(specialization) = specialization {
self.apply_specialization(db, specialization)
} else {
self
}
}
/// Applies a specialization to this type, replacing any typevars with the types that they are
/// specialized to.
///
@ -7979,7 +7992,9 @@ pub enum SuperOwnerKind<'db> {
impl<'db> SuperOwnerKind<'db> {
fn iter_mro(self, db: &'db dyn Db) -> impl Iterator<Item = ClassBase<'db>> {
match self {
SuperOwnerKind::Dynamic(dynamic) => Either::Left(ClassBase::Dynamic(dynamic).mro(db)),
SuperOwnerKind::Dynamic(dynamic) => {
Either::Left(ClassBase::Dynamic(dynamic).mro(db, None))
}
SuperOwnerKind::Class(class) => Either::Right(class.iter_mro(db)),
SuperOwnerKind::Instance(instance) => Either::Right(instance.class().iter_mro(db)),
}
@ -8106,7 +8121,7 @@ impl<'db> BoundSuperType<'db> {
mro_iter: impl Iterator<Item = ClassBase<'db>>,
) -> impl Iterator<Item = ClassBase<'db>> {
let Some(pivot_class) = self.pivot_class(db).into_class() else {
return Either::Left(ClassBase::Dynamic(DynamicType::Unknown).mro(db));
return Either::Left(ClassBase::Dynamic(DynamicType::Unknown).mro(db, None));
};
let mut pivot_found = false;

View file

@ -191,6 +191,26 @@ impl<'db> ClassType<'db> {
}
}
/// Returns the class literal and specialization for this class, with an additional
/// specialization applied if the class is generic.
pub(crate) fn class_literal_specialized(
self,
db: &'db dyn Db,
additional_specialization: Option<Specialization<'db>>,
) -> (ClassLiteral<'db>, Option<Specialization<'db>>) {
match self {
Self::NonGeneric(non_generic) => (non_generic, None),
Self::Generic(generic) => (
generic.origin(db),
Some(
generic
.specialization(db)
.apply_optional_specialization(db, additional_specialization),
),
),
}
}
pub(crate) fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {
let (class_literal, _) = self.class_literal(db);
class_literal.name(db)
@ -206,13 +226,6 @@ impl<'db> ClassType<'db> {
class_literal.definition(db)
}
fn specialize_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> {
match self {
Self::NonGeneric(_) => ty,
Self::Generic(generic) => ty.apply_specialization(db, generic.specialization(db)),
}
}
/// Return `true` if this class represents `known_class`
pub(crate) fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool {
self.known(db) == Some(known_class)
@ -249,6 +262,18 @@ impl<'db> ClassType<'db> {
class_literal.iter_mro(db, specialization)
}
/// Iterate over the method resolution order ("MRO") of the class, optionally applying an
/// additional specialization to it if the class is generic.
pub(super) fn iter_mro_specialized(
self,
db: &'db dyn Db,
additional_specialization: Option<Specialization<'db>>,
) -> MroIterator<'db> {
let (class_literal, specialization) =
self.class_literal_specialized(db, additional_specialization);
class_literal.iter_mro(db, specialization)
}
/// Is this class final?
pub(super) fn is_final(self, db: &'db dyn Db) -> bool {
let (class_literal, _) = self.class_literal(db);
@ -372,8 +397,10 @@ impl<'db> ClassType<'db> {
/// Return the metaclass of this class, or `type[Unknown]` if the metaclass cannot be inferred.
pub(super) fn metaclass(self, db: &'db dyn Db) -> Type<'db> {
let (class_literal, _) = self.class_literal(db);
self.specialize_type(db, class_literal.metaclass(db))
let (class_literal, specialization) = self.class_literal(db);
class_literal
.metaclass(db)
.apply_optional_specialization(db, specialization)
}
/// Return a type representing "the set of all instances of the metaclass of this class".
@ -396,9 +423,7 @@ impl<'db> ClassType<'db> {
policy: MemberLookupPolicy,
) -> SymbolAndQualifiers<'db> {
let (class_literal, specialization) = self.class_literal(db);
class_literal
.class_member_inner(db, specialization, name, policy)
.map_type(|ty| self.specialize_type(db, ty))
class_literal.class_member_inner(db, specialization, name, policy)
}
/// Returns the inferred type of the class member named `name`. Only bound members
@ -411,7 +436,7 @@ impl<'db> ClassType<'db> {
let (class_literal, specialization) = self.class_literal(db);
class_literal
.own_class_member(db, specialization, name)
.map_type(|ty| self.specialize_type(db, ty))
.map_type(|ty| ty.apply_optional_specialization(db, specialization))
}
/// Returns the `name` attribute of an instance of this class.
@ -424,16 +449,16 @@ impl<'db> ClassType<'db> {
let (class_literal, specialization) = self.class_literal(db);
class_literal
.instance_member(db, specialization, name)
.map_type(|ty| self.specialize_type(db, ty))
.map_type(|ty| ty.apply_optional_specialization(db, specialization))
}
/// A helper function for `instance_member` that looks up the `name` attribute only on
/// this class, not on its superclasses.
fn own_instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
let (class_literal, _) = self.class_literal(db);
let (class_literal, specialization) = self.class_literal(db);
class_literal
.own_instance_member(db, name)
.map_type(|ty| self.specialize_type(db, ty))
.map_type(|ty| ty.apply_optional_specialization(db, specialization))
}
}

View file

@ -1,4 +1,4 @@
use crate::types::generics::GenericContext;
use crate::types::generics::{GenericContext, Specialization};
use crate::types::{
todo_type, ClassType, DynamicType, KnownClass, KnownInstanceType, MroIterator, Type,
};
@ -202,8 +202,35 @@ impl<'db> ClassBase<'db> {
}
}
pub(crate) fn apply_specialization(
self,
db: &'db dyn Db,
specialization: Specialization<'db>,
) -> Self {
match self {
Self::Class(class) => Self::Class(class.apply_specialization(db, specialization)),
Self::Dynamic(_) | Self::Generic(_) | Self::Protocol => self,
}
}
pub(crate) fn apply_optional_specialization(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
) -> Self {
if let Some(specialization) = specialization {
self.apply_specialization(db, specialization)
} else {
self
}
}
/// Iterate over the MRO of this base
pub(super) fn mro(self, db: &'db dyn Db) -> impl Iterator<Item = ClassBase<'db>> {
pub(super) fn mro(
self,
db: &'db dyn Db,
additional_specialization: Option<Specialization<'db>>,
) -> impl Iterator<Item = ClassBase<'db>> {
match self {
ClassBase::Protocol => {
ClassBaseMroIterator::length_3(db, self, ClassBase::Generic(None))
@ -214,7 +241,9 @@ impl<'db> ClassBase<'db> {
ClassBase::Dynamic(_) | ClassBase::Generic(_) => {
ClassBaseMroIterator::length_2(db, self)
}
ClassBase::Class(class) => ClassBaseMroIterator::from_class(db, class),
ClassBase::Class(class) => {
ClassBaseMroIterator::from_class(db, class, additional_specialization)
}
}
}
}
@ -263,8 +292,12 @@ impl<'db> ClassBaseMroIterator<'db> {
}
/// Iterate over the MRO of an arbitrary class. The MRO may be of any length.
fn from_class(db: &'db dyn Db, class: ClassType<'db>) -> Self {
ClassBaseMroIterator::FromClass(class.iter_mro(db))
fn from_class(
db: &'db dyn Db,
class: ClassType<'db>,
additional_specialization: Option<Specialization<'db>>,
) -> Self {
ClassBaseMroIterator::FromClass(class.iter_mro_specialized(db, additional_specialization))
}
}

View file

@ -202,6 +202,19 @@ impl<'db> Specialization<'db> {
Specialization::new(db, self.generic_context(db), types)
}
/// Applies an optional specialization to this specialization.
pub(crate) fn apply_optional_specialization(
self,
db: &'db dyn Db,
other: Option<Specialization<'db>>,
) -> Self {
if let Some(other) = other {
self.apply_specialization(db, other)
} else {
self
}
}
/// Combines two specializations of the same generic context. If either specialization maps a
/// typevar to `Type::Unknown`, the other specialization's mapping is used. If both map the
/// typevar to a known type, those types are unioned together.

View file

@ -80,12 +80,14 @@ impl<'db> Mro<'db> {
));
}
let class_type = class.apply_optional_specialization(db, specialization);
match class_bases {
// `builtins.object` is the special case:
// the only class in Python that has an MRO with length <2
[] if class.is_object(db) => Ok(Self::from([
// object is not generic, so the default specialization should be a no-op
ClassBase::Class(class.apply_optional_specialization(db, specialization)),
ClassBase::Class(class_type),
])),
// All other classes in Python have an MRO with length >=2.
@ -102,7 +104,7 @@ impl<'db> Mro<'db> {
// (<class '__main__.Foo'>, <class 'object'>)
// ```
[] => Ok(Self::from([
ClassBase::Class(class.apply_optional_specialization(db, specialization)),
ClassBase::Class(class_type),
ClassBase::object(db),
])),
@ -114,11 +116,9 @@ impl<'db> Mro<'db> {
[single_base] => ClassBase::try_from_type(db, *single_base).map_or_else(
|| Err(MroErrorKind::InvalidBases(Box::from([(0, *single_base)]))),
|single_base| {
Ok(std::iter::once(ClassBase::Class(
class.apply_optional_specialization(db, specialization),
))
.chain(single_base.mro(db))
.collect())
Ok(std::iter::once(ClassBase::Class(class_type))
.chain(single_base.mro(db, specialization))
.collect())
},
),
@ -142,13 +142,16 @@ impl<'db> Mro<'db> {
return Err(MroErrorKind::InvalidBases(invalid_bases.into_boxed_slice()));
}
let mut seqs = vec![VecDeque::from([ClassBase::Class(
class.apply_optional_specialization(db, specialization),
)])];
let mut seqs = vec![VecDeque::from([ClassBase::Class(class_type)])];
for base in &valid_bases {
seqs.push(base.mro(db).collect());
seqs.push(base.mro(db, specialization).collect());
}
seqs.push(valid_bases.iter().copied().collect());
seqs.push(
valid_bases
.iter()
.map(|base| base.apply_optional_specialization(db, specialization))
.collect(),
);
c3_merge(seqs).ok_or_else(|| {
let mut seen_bases = FxHashSet::default();