[ty] Fix normalization of unions containing instances parameterized with unions (#18112)

This commit is contained in:
Alex Waygood 2025-05-14 22:48:33 -04:00 committed by GitHub
parent 9aa6330bb1
commit c3a4992ae9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 49 additions and 17 deletions

View file

@ -73,5 +73,8 @@ def f(x: Any, y: Unknown, z: Any | str | int):
c = cast(Unknown, y) c = cast(Unknown, y)
reveal_type(c) # revealed: Unknown reveal_type(c) # revealed: Unknown
d = cast(str | int | Any, z) # error: [redundant-cast] d = cast(Unknown, x)
reveal_type(d) # revealed: Unknown
e = cast(str | int | Any, z) # error: [redundant-cast]
``` ```

View file

@ -118,6 +118,23 @@ class R: ...
static_assert(is_equivalent_to(Intersection[tuple[P | Q], R], Intersection[tuple[Q | P], R])) static_assert(is_equivalent_to(Intersection[tuple[P | Q], R], Intersection[tuple[Q | P], R]))
``` ```
## Unions containing generic instances parameterized by unions
```toml
[environment]
python-version = "3.12"
```
```py
from ty_extensions import is_equivalent_to, static_assert
class A: ...
class B: ...
class Foo[T]: ...
static_assert(is_equivalent_to(A | Foo[A | B], Foo[B | A] | A))
```
## Callable ## Callable
### Equivalent ### Equivalent

View file

@ -985,15 +985,15 @@ impl<'db> Type<'db> {
Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)), Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)),
Type::Callable(callable) => Type::Callable(callable.normalized(db)), Type::Callable(callable) => Type::Callable(callable.normalized(db)),
Type::ProtocolInstance(protocol) => protocol.normalized(db), Type::ProtocolInstance(protocol) => protocol.normalized(db),
Type::NominalInstance(instance) => Type::NominalInstance(instance.normalized(db)),
Type::Dynamic(_) => Type::any(),
Type::LiteralString Type::LiteralString
| Type::NominalInstance(_)
| Type::PropertyInstance(_) | Type::PropertyInstance(_)
| Type::AlwaysFalsy | Type::AlwaysFalsy
| Type::AlwaysTruthy | Type::AlwaysTruthy
| Type::BooleanLiteral(_) | Type::BooleanLiteral(_)
| Type::BytesLiteral(_) | Type::BytesLiteral(_)
| Type::StringLiteral(_) | Type::StringLiteral(_)
| Type::Dynamic(_)
| Type::Never | Type::Never
| Type::FunctionLiteral(_) | Type::FunctionLiteral(_)
| Type::MethodWrapper(_) | Type::MethodWrapper(_)
@ -1007,10 +1007,7 @@ impl<'db> Type<'db> {
| Type::IntLiteral(_) | Type::IntLiteral(_)
| Type::BoundSuper(_) | Type::BoundSuper(_)
| Type::SubclassOf(_) => self, | Type::SubclassOf(_) => self,
Type::GenericAlias(generic) => { Type::GenericAlias(generic) => Type::GenericAlias(generic.normalized(db)),
let specialization = generic.specialization(db).normalized(db);
Type::GenericAlias(GenericAlias::new(db, generic.origin(db), specialization))
}
Type::TypeVar(typevar) => match typevar.bound_or_constraints(db) { Type::TypeVar(typevar) => match typevar.bound_or_constraints(db) {
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => { Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
Type::TypeVar(TypeVarInstance::new( Type::TypeVar(TypeVarInstance::new(

View file

@ -164,6 +164,10 @@ pub struct GenericAlias<'db> {
} }
impl<'db> GenericAlias<'db> { impl<'db> GenericAlias<'db> {
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
Self::new(db, self.origin(db), self.specialization(db).normalized(db))
}
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
self.origin(db).definition(db) self.origin(db).definition(db)
} }
@ -207,6 +211,13 @@ pub enum ClassType<'db> {
#[salsa::tracked] #[salsa::tracked]
impl<'db> ClassType<'db> { impl<'db> ClassType<'db> {
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
match self {
Self::NonGeneric(_) => self,
Self::Generic(generic) => Self::Generic(generic.normalized(db)),
}
}
/// Returns the class literal and specialization for this class. For a non-generic class, this /// Returns the class literal and specialization for this class. For a non-generic class, this
/// is the class itself. For a generic alias, this is the alias's origin. /// is the class itself. For a generic alias, this is the alias's origin.
pub(crate) fn class_literal( pub(crate) fn class_literal(

View file

@ -5048,16 +5048,16 @@ impl<'db> TypeInferenceBuilder<'db> {
overload.parameter_types() overload.parameter_types()
{ {
let db = self.db(); let db = self.db();
if (source_type.is_equivalent_to(db, *casted_type) let contains_unknown_or_todo = |ty| matches!(ty, Type::Dynamic(dynamic) if dynamic != DynamicType::Any);
|| source_type.normalized(db) if source_type.is_equivalent_to(db, *casted_type)
== casted_type.normalized(db)) || (source_type.normalized(db)
&& !source_type.any_over_type(db, &|ty| { == casted_type.normalized(db)
matches!( && !casted_type.any_over_type(db, &|ty| {
ty, contains_unknown_or_todo(ty)
Type::Dynamic(dynamic) })
if dynamic != DynamicType::Any && !source_type.any_over_type(db, &|ty| {
) contains_unknown_or_todo(ty)
}) }))
{ {
if let Some(builder) = self if let Some(builder) = self
.context .context

View file

@ -75,6 +75,10 @@ impl<'db> NominalInstanceType<'db> {
} }
} }
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
Self::from_class(self.class.normalized(db))
}
pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool { pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool {
// N.B. The subclass relation is fully static // N.B. The subclass relation is fully static
self.class.is_subclass_of(db, other.class) self.class.is_subclass_of(db, other.class)