[ty] Normalize tuples of unions as unions of tuples

This commit is contained in:
David Peter 2025-05-26 17:40:17 +02:00
parent 4e68dd96a6
commit 536d8fb000
2 changed files with 52 additions and 7 deletions

View file

@ -106,6 +106,17 @@ static_assert(
)
```
## Tuples containing unions, unions containing tuples
```py
from ty_extensions import is_equivalent_to, static_assert
class A: ...
class B: ...
static_assert(is_equivalent_to(tuple[A | B], tuple[A] | tuple[B]))
```
## Intersections containing tuples containing unions
```py

View file

@ -1013,7 +1013,7 @@ impl<'db> Type<'db> {
match self {
Type::Union(union) => Type::Union(union.normalized(db)),
Type::Intersection(intersection) => Type::Intersection(intersection.normalized(db)),
Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)),
Type::Tuple(tuple) => tuple.normalized(db),
Type::Callable(callable) => Type::Callable(callable.normalized(db)),
Type::ProtocolInstance(protocol) => protocol.normalized(db),
Type::NominalInstance(instance) => Type::NominalInstance(instance.normalized(db)),
@ -1709,7 +1709,7 @@ impl<'db> Type<'db> {
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
// TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc.
match (self, other) {
match (self.normalized(db), other.normalized(db)) {
(Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right),
(Type::Intersection(left), Type::Intersection(right)) => {
left.is_equivalent_to(db, right)
@ -1756,7 +1756,7 @@ impl<'db> Type<'db> {
return true;
}
match (self, other) {
match (self.normalized(db), other.normalized(db)) {
(Type::Dynamic(_), Type::Dynamic(_)) => true,
(Type::SubclassOf(first), Type::SubclassOf(second)) => {
@ -8712,13 +8712,47 @@ impl<'db> TupleType<'db> {
///
/// See [`Type::normalized`] for more details.
#[must_use]
pub(crate) fn normalized(self, db: &'db dyn Db) -> Self {
let elements: Box<[Type<'db>]> = self
pub(crate) fn normalized(self, db: &'db dyn Db) -> Type<'db> {
// Collect the normalized elements for each tuple slot.
let normalized_elements: Vec<Vec<Type<'db>>> = self
.elements(db)
.iter()
.map(|ty| ty.normalized(db))
.map(|ty| {
let norm = ty.normalized(db);
if let Type::Union(union) = norm {
union.elements(db).to_vec()
} else {
vec![norm]
}
})
.collect();
TupleType::new(db, elements)
// Compute the cartesian product of all element choices.
let mut product: Vec<Vec<Type<'db>>> = vec![vec![]];
for slot in &normalized_elements {
let mut next = Vec::with_capacity(product.len() * slot.len());
for prefix in &product {
for elem in slot {
let mut new_tuple = prefix.clone();
new_tuple.push(*elem);
next.push(new_tuple);
}
}
product = next;
}
// If only one combination, return a single tuple type.
if product.len() == 1 {
return TupleType::from_elements(db, product.pop().unwrap().into_boxed_slice());
}
// Otherwise, return a union of all possible tuple combinations.
UnionType::from_elements(
db,
product
.into_iter()
.map(|elems| TupleType::from_elements(db, elems.into_boxed_slice())),
)
}
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {