[ty] Fix normalization of unions containing instances parameterized with unions (#18112)

This commit is contained in:
Alex Waygood 2025-05-14 22:48:33 -04:00 committed by GitHub
parent 9aa6330bb1
commit c3a4992ae9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 49 additions and 17 deletions

View file

@ -73,5 +73,8 @@ def f(x: Any, y: Unknown, z: Any | str | int):
c = cast(Unknown, y)
reveal_type(c) # revealed: Unknown
d = cast(str | int | Any, z) # error: [redundant-cast]
d = cast(Unknown, x)
reveal_type(d) # revealed: Unknown
e = cast(str | int | Any, z) # error: [redundant-cast]
```

View file

@ -118,6 +118,23 @@ class R: ...
static_assert(is_equivalent_to(Intersection[tuple[P | Q], R], Intersection[tuple[Q | P], R]))
```
## Unions containing generic instances parameterized by unions
```toml
[environment]
python-version = "3.12"
```
```py
from ty_extensions import is_equivalent_to, static_assert
class A: ...
class B: ...
class Foo[T]: ...
static_assert(is_equivalent_to(A | Foo[A | B], Foo[B | A] | A))
```
## Callable
### Equivalent

View file

@ -985,15 +985,15 @@ impl<'db> Type<'db> {
Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)),
Type::Callable(callable) => Type::Callable(callable.normalized(db)),
Type::ProtocolInstance(protocol) => protocol.normalized(db),
Type::NominalInstance(instance) => Type::NominalInstance(instance.normalized(db)),
Type::Dynamic(_) => Type::any(),
Type::LiteralString
| Type::NominalInstance(_)
| Type::PropertyInstance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy
| Type::BooleanLiteral(_)
| Type::BytesLiteral(_)
| Type::StringLiteral(_)
| Type::Dynamic(_)
| Type::Never
| Type::FunctionLiteral(_)
| Type::MethodWrapper(_)
@ -1007,10 +1007,7 @@ impl<'db> Type<'db> {
| Type::IntLiteral(_)
| Type::BoundSuper(_)
| Type::SubclassOf(_) => self,
Type::GenericAlias(generic) => {
let specialization = generic.specialization(db).normalized(db);
Type::GenericAlias(GenericAlias::new(db, generic.origin(db), specialization))
}
Type::GenericAlias(generic) => Type::GenericAlias(generic.normalized(db)),
Type::TypeVar(typevar) => match typevar.bound_or_constraints(db) {
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
Type::TypeVar(TypeVarInstance::new(

View file

@ -164,6 +164,10 @@ pub struct GenericAlias<'db> {
}
impl<'db> GenericAlias<'db> {
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
Self::new(db, self.origin(db), self.specialization(db).normalized(db))
}
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
self.origin(db).definition(db)
}
@ -207,6 +211,13 @@ pub enum ClassType<'db> {
#[salsa::tracked]
impl<'db> ClassType<'db> {
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
match self {
Self::NonGeneric(_) => self,
Self::Generic(generic) => Self::Generic(generic.normalized(db)),
}
}
/// Returns the class literal and specialization for this class. For a non-generic class, this
/// is the class itself. For a generic alias, this is the alias's origin.
pub(crate) fn class_literal(

View file

@ -5048,16 +5048,16 @@ impl<'db> TypeInferenceBuilder<'db> {
overload.parameter_types()
{
let db = self.db();
if (source_type.is_equivalent_to(db, *casted_type)
|| source_type.normalized(db)
== casted_type.normalized(db))
&& !source_type.any_over_type(db, &|ty| {
matches!(
ty,
Type::Dynamic(dynamic)
if dynamic != DynamicType::Any
)
})
let contains_unknown_or_todo = |ty| matches!(ty, Type::Dynamic(dynamic) if dynamic != DynamicType::Any);
if source_type.is_equivalent_to(db, *casted_type)
|| (source_type.normalized(db)
== casted_type.normalized(db)
&& !casted_type.any_over_type(db, &|ty| {
contains_unknown_or_todo(ty)
})
&& !source_type.any_over_type(db, &|ty| {
contains_unknown_or_todo(ty)
}))
{
if let Some(builder) = self
.context

View file

@ -75,6 +75,10 @@ impl<'db> NominalInstanceType<'db> {
}
}
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
Self::from_class(self.class.normalized(db))
}
pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool {
// N.B. The subclass relation is fully static
self.class.is_subclass_of(db, other.class)