[ty] Check base classes when determining subtyping etc for generic aliases (#17927)

#17897 added variance handling for legacy typevars — but they were only
being considered when checking generic aliases of the same class:

```py
class A: ...
class B(A): ...

class C[T]: ...

static_assert(is_subtype_of(C[B], C[A]))
```

and not for generic subclasses:

```py
class D[U](C[U]): ...

static_assert(is_subtype_of(D[B], C[A]))
```

Now we check those too!

Closes https://github.com/astral-sh/ty/issues/101
This commit is contained in:
Douglas Creager 2025-05-07 15:21:11 -04:00 committed by GitHub
parent ce0800fccf
commit 2cf5cba7ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 363 additions and 98 deletions

View file

@ -1515,6 +1515,10 @@ impl<'db> Type<'db> {
false
}
(Type::SliceLiteral(_), _) => KnownClass::Slice
.to_instance(db)
.is_assignable_to(db, target),
(Type::FunctionLiteral(self_function_literal), Type::Callable(_)) => {
self_function_literal
.into_callable_type(db)

View file

@ -290,109 +290,82 @@ impl<'db> ClassType<'db> {
})
}
/// If `self` and `other` are generic aliases of the same generic class, returns their
/// corresponding specializations.
fn compatible_specializations(
self,
db: &'db dyn Db,
other: ClassType<'db>,
) -> Option<(Specialization<'db>, Specialization<'db>)> {
match (self, other) {
(ClassType::Generic(self_generic), ClassType::Generic(other_generic)) => {
if self_generic.origin(db) == other_generic.origin(db) {
Some((
self_generic.specialization(db),
other_generic.specialization(db),
))
} else {
None
}
}
_ => None,
}
}
/// Return `true` if `other` is present in this class's MRO.
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
// `is_subclass_of` is checking the subtype relation, in which gradual types do not
// participate, so we should not return `True` if we find `Any/Unknown` in the MRO.
if self.iter_mro(db).contains(&ClassBase::Class(other)) {
return true;
}
self.iter_mro(db).any(|base| {
match base {
// `is_subclass_of` is checking the subtype relation, in which gradual types do not
// participate.
ClassBase::Dynamic(_) => false,
// `self` is a subclass of `other` if they are both generic aliases of the same generic
// class, and their specializations are compatible, taking into account the variance of the
// class's typevars.
if let Some((self_specialization, other_specialization)) =
self.compatible_specializations(db, other)
{
if self_specialization.is_subtype_of(db, other_specialization) {
return true;
// Protocol and Generic are not represented by a ClassType.
ClassBase::Protocol | ClassBase::Generic(_) => false,
ClassBase::Class(base) => match (base, other) {
(ClassType::NonGeneric(base), ClassType::NonGeneric(other)) => base == other,
(ClassType::Generic(base), ClassType::Generic(other)) => {
base.origin(db) == other.origin(db)
&& base
.specialization(db)
.is_subtype_of(db, other.specialization(db))
}
(ClassType::Generic(_), ClassType::NonGeneric(_))
| (ClassType::NonGeneric(_), ClassType::Generic(_)) => false,
},
}
}
false
})
}
pub(super) fn is_equivalent_to(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
if self == other {
return true;
}
match (self, other) {
(ClassType::NonGeneric(this), ClassType::NonGeneric(other)) => this == other,
(ClassType::NonGeneric(_), _) | (_, ClassType::NonGeneric(_)) => false,
// `self` is equivalent to `other` if they are both generic aliases of the same generic
// class, and their specializations are compatible, taking into account the variance of the
// class's typevars.
if let Some((self_specialization, other_specialization)) =
self.compatible_specializations(db, other)
{
if self_specialization.is_equivalent_to(db, other_specialization) {
return true;
(ClassType::Generic(this), ClassType::Generic(other)) => {
this.origin(db) == other.origin(db)
&& this
.specialization(db)
.is_equivalent_to(db, other.specialization(db))
}
}
false
}
pub(super) fn is_assignable_to(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
if self.is_subclass_of(db, other) {
return true;
}
self.iter_mro(db).any(|base| {
match base {
ClassBase::Dynamic(DynamicType::Any | DynamicType::Unknown) => !other.is_final(db),
ClassBase::Dynamic(_) => false,
// `self` is assignable to `other` if they are both generic aliases of the same generic
// class, and their specializations are compatible, taking into account the variance of the
// class's typevars.
if let Some((self_specialization, other_specialization)) =
self.compatible_specializations(db, other)
{
if self_specialization.is_assignable_to(db, other_specialization) {
return true;
// Protocol and Generic are not represented by a ClassType.
ClassBase::Protocol | ClassBase::Generic(_) => false,
ClassBase::Class(base) => match (base, other) {
(ClassType::NonGeneric(base), ClassType::NonGeneric(other)) => base == other,
(ClassType::Generic(base), ClassType::Generic(other)) => {
base.origin(db) == other.origin(db)
&& base
.specialization(db)
.is_assignable_to(db, other.specialization(db))
}
(ClassType::Generic(_), ClassType::NonGeneric(_))
| (ClassType::NonGeneric(_), ClassType::Generic(_)) => false,
},
}
}
if self.is_subclass_of_any_or_unknown(db) && !other.is_final(db) {
return true;
}
false
})
}
pub(super) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
if self == other {
return true;
}
match (self, other) {
(ClassType::NonGeneric(this), ClassType::NonGeneric(other)) => this == other,
(ClassType::NonGeneric(_), _) | (_, ClassType::NonGeneric(_)) => false,
// `self` is equivalent to `other` if they are both generic aliases of the same generic
// class, and their specializations are compatible, taking into account the variance of the
// class's typevars.
if let Some((self_specialization, other_specialization)) =
self.compatible_specializations(db, other)
{
if self_specialization.is_gradual_equivalent_to(db, other_specialization) {
return true;
(ClassType::Generic(this), ClassType::Generic(other)) => {
this.origin(db) == other.origin(db)
&& this
.specialization(db)
.is_gradual_equivalent_to(db, other.specialization(db))
}
}
false
}
/// Return the metaclass of this class, or `type[Unknown]` if the metaclass cannot be inferred.