From 9085f18353d5426451dca00c1d9d4375d266bd1a Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Tue, 6 May 2025 14:25:21 -0400 Subject: [PATCH] [ty] Propagate specializations to ancestor base classes (#17892) @AlexWaygood discovered that even though we've been propagating specializations to _parent_ base classes correctly, we haven't been passing them on to _grandparent_ base classes: https://github.com/astral-sh/ruff/pull/17832#issuecomment-2854360969 ```py class Bar[T]: x: T class Baz[T](Bar[T]): ... class Spam[T](Baz[T]): ... reveal_type(Spam[int]().x) # revealed: `T`, but should be `int` ``` This PR updates the MRO machinery to apply the current specialization when starting to iterate the MRO of each base class. --- .../annotations/stdlib_typing_aliases.md | 10 ++-- .../mdtest/generics/legacy/classes.md | 25 +++++--- .../mdtest/generics/pep695/classes.md | 14 +++-- crates/ty_python_semantic/src/types.rs | 19 ++++++- crates/ty_python_semantic/src/types/class.rs | 57 +++++++++++++------ .../src/types/class_base.rs | 43 ++++++++++++-- .../ty_python_semantic/src/types/generics.rs | 13 +++++ crates/ty_python_semantic/src/types/mro.rs | 27 +++++---- 8 files changed, 156 insertions(+), 52 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/stdlib_typing_aliases.md b/crates/ty_python_semantic/resources/mdtest/annotations/stdlib_typing_aliases.md index fe9c54ddd0..5765db4b4d 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/stdlib_typing_aliases.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/stdlib_typing_aliases.md @@ -91,7 +91,7 @@ reveal_type(ListSubclass.__mro__) class DictSubclass(typing.Dict): ... # TODO: generic protocols -# revealed: tuple[Literal[DictSubclass], Literal[dict[Unknown, Unknown]], Literal[MutableMapping[_KT, _VT]], Literal[Mapping[_KT, _VT]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], Literal[object]] +# revealed: tuple[Literal[DictSubclass], Literal[dict[Unknown, Unknown]], Literal[MutableMapping[Unknown, Unknown]], Literal[Mapping[Unknown, Unknown]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], Literal[object]] reveal_type(DictSubclass.__mro__) class SetSubclass(typing.Set): ... @@ -113,19 +113,19 @@ reveal_type(FrozenSetSubclass.__mro__) class ChainMapSubclass(typing.ChainMap): ... # TODO: generic protocols -# revealed: tuple[Literal[ChainMapSubclass], Literal[ChainMap[Unknown, Unknown]], Literal[MutableMapping[_KT, _VT]], Literal[Mapping[_KT, _VT]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], Literal[object]] +# revealed: tuple[Literal[ChainMapSubclass], Literal[ChainMap[Unknown, Unknown]], Literal[MutableMapping[Unknown, Unknown]], Literal[Mapping[Unknown, Unknown]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], Literal[object]] reveal_type(ChainMapSubclass.__mro__) class CounterSubclass(typing.Counter): ... # TODO: Should be (CounterSubclass, Counter, dict, MutableMapping, Mapping, Collection, Sized, Iterable, Container, Generic, object) -# revealed: tuple[Literal[CounterSubclass], Literal[Counter[Unknown]], Literal[dict[_T, int]], Literal[MutableMapping[_KT, _VT]], Literal[Mapping[_KT, _VT]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], typing.Generic[_T], Literal[object]] +# revealed: tuple[Literal[CounterSubclass], Literal[Counter[Unknown]], Literal[dict[Unknown, int]], Literal[MutableMapping[Unknown, int]], Literal[Mapping[Unknown, int]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], typing.Generic[_T], Literal[object]] reveal_type(CounterSubclass.__mro__) class DefaultDictSubclass(typing.DefaultDict): ... # TODO: Should be (DefaultDictSubclass, defaultdict, dict, MutableMapping, Mapping, Collection, Sized, Iterable, Container, Generic, object) -# revealed: tuple[Literal[DefaultDictSubclass], Literal[defaultdict[Unknown, Unknown]], Literal[dict[_KT, _VT]], Literal[MutableMapping[_KT, _VT]], Literal[Mapping[_KT, _VT]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], Literal[object]] +# revealed: tuple[Literal[DefaultDictSubclass], Literal[defaultdict[Unknown, Unknown]], Literal[dict[Unknown, Unknown]], Literal[MutableMapping[Unknown, Unknown]], Literal[Mapping[Unknown, Unknown]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], Literal[object]] reveal_type(DefaultDictSubclass.__mro__) class DequeSubclass(typing.Deque): ... @@ -137,6 +137,6 @@ reveal_type(DequeSubclass.__mro__) class OrderedDictSubclass(typing.OrderedDict): ... # TODO: Should be (OrderedDictSubclass, OrderedDict, dict, MutableMapping, Mapping, Collection, Sized, Iterable, Container, Generic, object) -# revealed: tuple[Literal[OrderedDictSubclass], Literal[OrderedDict[Unknown, Unknown]], Literal[dict[_KT, _VT]], Literal[MutableMapping[_KT, _VT]], Literal[Mapping[_KT, _VT]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], Literal[object]] +# revealed: tuple[Literal[OrderedDictSubclass], Literal[OrderedDict[Unknown, Unknown]], Literal[dict[Unknown, Unknown]], Literal[MutableMapping[Unknown, Unknown]], Literal[Mapping[Unknown, Unknown]], Literal[Collection], Literal[Iterable], Literal[Container], @Todo(`Protocol[]` subscript), typing.Generic, typing.Generic[_KT, _VT_co], Literal[object]] reveal_type(OrderedDictSubclass.__mro__) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md index b078d38d0e..0d2fd6ae02 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md @@ -340,16 +340,27 @@ propagate through: from typing import Generic, TypeVar T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") +W = TypeVar("W") -class Base(Generic[T]): - x: T | None = None +class Parent(Generic[T]): + x: T -class ExplicitlyGenericSub(Base[T], Generic[T]): ... -class ImplicitlyGenericSub(Base[T]): ... +class ExplicitlyGenericChild(Parent[U], Generic[U]): ... +class ExplicitlyGenericGrandchild(ExplicitlyGenericChild[V], Generic[V]): ... +class ExplicitlyGenericGreatgrandchild(ExplicitlyGenericGrandchild[W], Generic[W]): ... +class ImplicitlyGenericChild(Parent[U]): ... +class ImplicitlyGenericGrandchild(ImplicitlyGenericChild[V]): ... +class ImplicitlyGenericGreatgrandchild(ImplicitlyGenericGrandchild[W]): ... -reveal_type(Base[int].x) # revealed: int | None -reveal_type(ExplicitlyGenericSub[int].x) # revealed: int | None -reveal_type(ImplicitlyGenericSub[int].x) # revealed: int | None +reveal_type(Parent[int]().x) # revealed: int +reveal_type(ExplicitlyGenericChild[int]().x) # revealed: int +reveal_type(ImplicitlyGenericChild[int]().x) # revealed: int +reveal_type(ExplicitlyGenericGrandchild[int]().x) # revealed: int +reveal_type(ImplicitlyGenericGrandchild[int]().x) # revealed: int +reveal_type(ExplicitlyGenericGreatgrandchild[int]().x) # revealed: int +reveal_type(ImplicitlyGenericGreatgrandchild[int]().x) # revealed: int ``` ## Generic methods diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md index 39bb9cfb62..89ebf00124 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md @@ -285,13 +285,17 @@ When a generic subclass fills its superclass's type parameter with one of its ow propagate through: ```py -class Base[T]: - x: T | None = None +class Parent[T]: + x: T -class Sub[U](Base[U]): ... +class Child[U](Parent[U]): ... +class Grandchild[V](Child[V]): ... +class Greatgrandchild[W](Child[W]): ... -reveal_type(Base[int].x) # revealed: int | None -reveal_type(Sub[int].x) # revealed: int | None +reveal_type(Parent[int]().x) # revealed: int +reveal_type(Child[int]().x) # revealed: int +reveal_type(Grandchild[int]().x) # revealed: int +reveal_type(Greatgrandchild[int]().x) # revealed: int ``` ## Generic methods diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index e1ce107159..2165f72026 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4948,6 +4948,19 @@ impl<'db> Type<'db> { } } + #[must_use] + pub fn apply_optional_specialization( + self, + db: &'db dyn Db, + specialization: Option>, + ) -> Type<'db> { + if let Some(specialization) = specialization { + self.apply_specialization(db, specialization) + } else { + self + } + } + /// Applies a specialization to this type, replacing any typevars with the types that they are /// specialized to. /// @@ -7979,7 +7992,9 @@ pub enum SuperOwnerKind<'db> { impl<'db> SuperOwnerKind<'db> { fn iter_mro(self, db: &'db dyn Db) -> impl Iterator> { match self { - SuperOwnerKind::Dynamic(dynamic) => Either::Left(ClassBase::Dynamic(dynamic).mro(db)), + SuperOwnerKind::Dynamic(dynamic) => { + Either::Left(ClassBase::Dynamic(dynamic).mro(db, None)) + } SuperOwnerKind::Class(class) => Either::Right(class.iter_mro(db)), SuperOwnerKind::Instance(instance) => Either::Right(instance.class().iter_mro(db)), } @@ -8106,7 +8121,7 @@ impl<'db> BoundSuperType<'db> { mro_iter: impl Iterator>, ) -> impl Iterator> { let Some(pivot_class) = self.pivot_class(db).into_class() else { - return Either::Left(ClassBase::Dynamic(DynamicType::Unknown).mro(db)); + return Either::Left(ClassBase::Dynamic(DynamicType::Unknown).mro(db, None)); }; let mut pivot_found = false; diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 30f24abe95..3a6038a415 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -191,6 +191,26 @@ impl<'db> ClassType<'db> { } } + /// Returns the class literal and specialization for this class, with an additional + /// specialization applied if the class is generic. + pub(crate) fn class_literal_specialized( + self, + db: &'db dyn Db, + additional_specialization: Option>, + ) -> (ClassLiteral<'db>, Option>) { + match self { + Self::NonGeneric(non_generic) => (non_generic, None), + Self::Generic(generic) => ( + generic.origin(db), + Some( + generic + .specialization(db) + .apply_optional_specialization(db, additional_specialization), + ), + ), + } + } + pub(crate) fn name(self, db: &'db dyn Db) -> &'db ast::name::Name { let (class_literal, _) = self.class_literal(db); class_literal.name(db) @@ -206,13 +226,6 @@ impl<'db> ClassType<'db> { class_literal.definition(db) } - fn specialize_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> { - match self { - Self::NonGeneric(_) => ty, - Self::Generic(generic) => ty.apply_specialization(db, generic.specialization(db)), - } - } - /// Return `true` if this class represents `known_class` pub(crate) fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool { self.known(db) == Some(known_class) @@ -249,6 +262,18 @@ impl<'db> ClassType<'db> { class_literal.iter_mro(db, specialization) } + /// Iterate over the method resolution order ("MRO") of the class, optionally applying an + /// additional specialization to it if the class is generic. + pub(super) fn iter_mro_specialized( + self, + db: &'db dyn Db, + additional_specialization: Option>, + ) -> MroIterator<'db> { + let (class_literal, specialization) = + self.class_literal_specialized(db, additional_specialization); + class_literal.iter_mro(db, specialization) + } + /// Is this class final? pub(super) fn is_final(self, db: &'db dyn Db) -> bool { let (class_literal, _) = self.class_literal(db); @@ -372,8 +397,10 @@ impl<'db> ClassType<'db> { /// Return the metaclass of this class, or `type[Unknown]` if the metaclass cannot be inferred. pub(super) fn metaclass(self, db: &'db dyn Db) -> Type<'db> { - let (class_literal, _) = self.class_literal(db); - self.specialize_type(db, class_literal.metaclass(db)) + let (class_literal, specialization) = self.class_literal(db); + class_literal + .metaclass(db) + .apply_optional_specialization(db, specialization) } /// Return a type representing "the set of all instances of the metaclass of this class". @@ -396,9 +423,7 @@ impl<'db> ClassType<'db> { policy: MemberLookupPolicy, ) -> SymbolAndQualifiers<'db> { let (class_literal, specialization) = self.class_literal(db); - class_literal - .class_member_inner(db, specialization, name, policy) - .map_type(|ty| self.specialize_type(db, ty)) + class_literal.class_member_inner(db, specialization, name, policy) } /// Returns the inferred type of the class member named `name`. Only bound members @@ -411,7 +436,7 @@ impl<'db> ClassType<'db> { let (class_literal, specialization) = self.class_literal(db); class_literal .own_class_member(db, specialization, name) - .map_type(|ty| self.specialize_type(db, ty)) + .map_type(|ty| ty.apply_optional_specialization(db, specialization)) } /// Returns the `name` attribute of an instance of this class. @@ -424,16 +449,16 @@ impl<'db> ClassType<'db> { let (class_literal, specialization) = self.class_literal(db); class_literal .instance_member(db, specialization, name) - .map_type(|ty| self.specialize_type(db, ty)) + .map_type(|ty| ty.apply_optional_specialization(db, specialization)) } /// A helper function for `instance_member` that looks up the `name` attribute only on /// this class, not on its superclasses. fn own_instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { - let (class_literal, _) = self.class_literal(db); + let (class_literal, specialization) = self.class_literal(db); class_literal .own_instance_member(db, name) - .map_type(|ty| self.specialize_type(db, ty)) + .map_type(|ty| ty.apply_optional_specialization(db, specialization)) } } diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 28f1484d31..f127410441 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -1,4 +1,4 @@ -use crate::types::generics::GenericContext; +use crate::types::generics::{GenericContext, Specialization}; use crate::types::{ todo_type, ClassType, DynamicType, KnownClass, KnownInstanceType, MroIterator, Type, }; @@ -202,8 +202,35 @@ impl<'db> ClassBase<'db> { } } + pub(crate) fn apply_specialization( + self, + db: &'db dyn Db, + specialization: Specialization<'db>, + ) -> Self { + match self { + Self::Class(class) => Self::Class(class.apply_specialization(db, specialization)), + Self::Dynamic(_) | Self::Generic(_) | Self::Protocol => self, + } + } + + pub(crate) fn apply_optional_specialization( + self, + db: &'db dyn Db, + specialization: Option>, + ) -> Self { + if let Some(specialization) = specialization { + self.apply_specialization(db, specialization) + } else { + self + } + } + /// Iterate over the MRO of this base - pub(super) fn mro(self, db: &'db dyn Db) -> impl Iterator> { + pub(super) fn mro( + self, + db: &'db dyn Db, + additional_specialization: Option>, + ) -> impl Iterator> { match self { ClassBase::Protocol => { ClassBaseMroIterator::length_3(db, self, ClassBase::Generic(None)) @@ -214,7 +241,9 @@ impl<'db> ClassBase<'db> { ClassBase::Dynamic(_) | ClassBase::Generic(_) => { ClassBaseMroIterator::length_2(db, self) } - ClassBase::Class(class) => ClassBaseMroIterator::from_class(db, class), + ClassBase::Class(class) => { + ClassBaseMroIterator::from_class(db, class, additional_specialization) + } } } } @@ -263,8 +292,12 @@ impl<'db> ClassBaseMroIterator<'db> { } /// Iterate over the MRO of an arbitrary class. The MRO may be of any length. - fn from_class(db: &'db dyn Db, class: ClassType<'db>) -> Self { - ClassBaseMroIterator::FromClass(class.iter_mro(db)) + fn from_class( + db: &'db dyn Db, + class: ClassType<'db>, + additional_specialization: Option>, + ) -> Self { + ClassBaseMroIterator::FromClass(class.iter_mro_specialized(db, additional_specialization)) } } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index a436f880af..47ba9c5ea2 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -202,6 +202,19 @@ impl<'db> Specialization<'db> { Specialization::new(db, self.generic_context(db), types) } + /// Applies an optional specialization to this specialization. + pub(crate) fn apply_optional_specialization( + self, + db: &'db dyn Db, + other: Option>, + ) -> Self { + if let Some(other) = other { + self.apply_specialization(db, other) + } else { + self + } + } + /// Combines two specializations of the same generic context. If either specialization maps a /// typevar to `Type::Unknown`, the other specialization's mapping is used. If both map the /// typevar to a known type, those types are unioned together. diff --git a/crates/ty_python_semantic/src/types/mro.rs b/crates/ty_python_semantic/src/types/mro.rs index 4cd27d6b31..4b883bbf0d 100644 --- a/crates/ty_python_semantic/src/types/mro.rs +++ b/crates/ty_python_semantic/src/types/mro.rs @@ -80,12 +80,14 @@ impl<'db> Mro<'db> { )); } + let class_type = class.apply_optional_specialization(db, specialization); + match class_bases { // `builtins.object` is the special case: // the only class in Python that has an MRO with length <2 [] if class.is_object(db) => Ok(Self::from([ // object is not generic, so the default specialization should be a no-op - ClassBase::Class(class.apply_optional_specialization(db, specialization)), + ClassBase::Class(class_type), ])), // All other classes in Python have an MRO with length >=2. @@ -102,7 +104,7 @@ impl<'db> Mro<'db> { // (, ) // ``` [] => Ok(Self::from([ - ClassBase::Class(class.apply_optional_specialization(db, specialization)), + ClassBase::Class(class_type), ClassBase::object(db), ])), @@ -114,11 +116,9 @@ impl<'db> Mro<'db> { [single_base] => ClassBase::try_from_type(db, *single_base).map_or_else( || Err(MroErrorKind::InvalidBases(Box::from([(0, *single_base)]))), |single_base| { - Ok(std::iter::once(ClassBase::Class( - class.apply_optional_specialization(db, specialization), - )) - .chain(single_base.mro(db)) - .collect()) + Ok(std::iter::once(ClassBase::Class(class_type)) + .chain(single_base.mro(db, specialization)) + .collect()) }, ), @@ -142,13 +142,16 @@ impl<'db> Mro<'db> { return Err(MroErrorKind::InvalidBases(invalid_bases.into_boxed_slice())); } - let mut seqs = vec![VecDeque::from([ClassBase::Class( - class.apply_optional_specialization(db, specialization), - )])]; + let mut seqs = vec![VecDeque::from([ClassBase::Class(class_type)])]; for base in &valid_bases { - seqs.push(base.mro(db).collect()); + seqs.push(base.mro(db, specialization).collect()); } - seqs.push(valid_bases.iter().copied().collect()); + seqs.push( + valid_bases + .iter() + .map(|base| base.apply_optional_specialization(db, specialization)) + .collect(), + ); c3_merge(seqs).ok_or_else(|| { let mut seen_bases = FxHashSet::default();