[red-knot] Ensure differently ordered unions and intersections are understood as equivalent even inside arbitrarily nested tuples (#15740)

## Summary

On `main`, red-knot:
- Considers `P | Q` equivalent to `Q | P`
- Considered `tuple[P | Q]` equivalent to `tuple[Q | P]`
- Considers `tuple[P | tuple[P | Q]]` equivalent to `tuple[tuple[Q | P]
| P]`
- ‼️ Does _not_ consider `tuple[tuple[P | Q]]` equivalent to
`tuple[tuple[Q | P]]`

The key difference for the last one of these is that the union appears
inside a tuple that is directly nested inside another tuple.

This PR fixes this so that differently ordered unions are considered
equivalent even when they appear inside arbitrarily nested tuple types.

## Test Plan

- Added mdtests that fails on `main`
- Checked that all property tests continue to pass with this PR
This commit is contained in:
Alex Waygood 2025-01-25 16:39:07 +00:00 committed by GitHub
parent a77a32b7d4
commit f85ea1bf46
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 85 additions and 18 deletions

View file

@ -84,4 +84,26 @@ static_assert(
)
```
## Unions containing tuples containing tuples containing unions (etc.)
```py
from knot_extensions import is_equivalent_to, static_assert, Intersection
class P: ...
class Q: ...
static_assert(
is_equivalent_to(
tuple[tuple[tuple[P | Q]]] | P,
tuple[tuple[tuple[Q | P]]] | P,
)
)
static_assert(
is_equivalent_to(
tuple[tuple[tuple[tuple[tuple[Intersection[P, Q]]]]]],
tuple[tuple[tuple[tuple[tuple[Intersection[Q, P]]]]]],
)
)
```
[the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent

View file

@ -811,6 +811,35 @@ impl<'db> Type<'db> {
}
}
/// Return a normalized version of `self` in which all unions and intersections are sorted
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
#[must_use]
pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self {
match self {
Type::Union(union) => Type::Union(union.to_sorted_union(db)),
Type::Intersection(intersection) => {
Type::Intersection(intersection.to_sorted_intersection(db))
}
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions(db)),
Type::LiteralString
| Type::Instance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy
| Type::BooleanLiteral(_)
| Type::SliceLiteral(_)
| Type::BytesLiteral(_)
| Type::StringLiteral(_)
| Type::Dynamic(_)
| Type::Never
| Type::FunctionLiteral(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::KnownInstance(_)
| Type::IntLiteral(_)
| Type::SubclassOf(_) => self,
}
}
/// Return true if this type is a [subtype of] type `target`.
///
/// This method returns `false` if either `self` or `other` is not fully static.
@ -1154,7 +1183,7 @@ impl<'db> Type<'db> {
left.is_equivalent_to(db, right)
}
(Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right),
_ => self.is_fully_static(db) && other.is_fully_static(db) && self == other,
_ => self == other && self.is_fully_static(db) && other.is_fully_static(db),
}
}
@ -4352,12 +4381,11 @@ impl<'db> UnionType<'db> {
/// Create a new union type with the elements sorted according to a canonical ordering.
#[must_use]
pub fn to_sorted_union(self, db: &'db dyn Db) -> Self {
let mut new_elements = self.elements(db).to_vec();
for element in &mut new_elements {
if let Type::Intersection(intersection) = element {
intersection.sort(db);
}
}
let mut new_elements: Vec<Type<'db>> = self
.elements(db)
.iter()
.map(|element| element.with_sorted_unions(db))
.collect();
new_elements.sort_unstable_by(union_elements_ordering);
UnionType::new(db, new_elements.into_boxed_slice())
}
@ -4453,19 +4481,24 @@ impl<'db> IntersectionType<'db> {
/// according to a canonical ordering.
#[must_use]
pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self {
let mut positive = self.positive(db).clone();
positive.sort_unstable_by(union_elements_ordering);
fn normalized_set<'db>(
db: &'db dyn Db,
elements: &FxOrderSet<Type<'db>>,
) -> FxOrderSet<Type<'db>> {
let mut elements: FxOrderSet<Type<'db>> = elements
.iter()
.map(|ty| ty.with_sorted_unions(db))
.collect();
let mut negative = self.negative(db).clone();
negative.sort_unstable_by(union_elements_ordering);
elements.sort_unstable_by(union_elements_ordering);
elements
}
IntersectionType::new(db, positive, negative)
}
/// Perform an in-place sort of this [`IntersectionType`] instance
/// according to a canonical ordering.
fn sort(&mut self, db: &'db dyn Db) {
*self = self.to_sorted_intersection(db);
IntersectionType::new(
db,
normalized_set(db, self.positive(db)),
normalized_set(db, self.negative(db)),
)
}
pub fn is_fully_static(self, db: &'db dyn Db) -> bool {
@ -4608,6 +4641,18 @@ impl<'db> TupleType<'db> {
Type::Tuple(Self::new(db, elements.into_boxed_slice()))
}
/// Return a normalized version of `self` in which all unions and intersections are sorted
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
#[must_use]
pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self {
let elements: Box<[Type<'db>]> = self
.elements(db)
.iter()
.map(|ty| ty.with_sorted_unions(db))
.collect();
TupleType::new(db, elements)
}
pub fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
let self_elements = self.elements(db);
let other_elements = other.elements(db);