[ty] ensure union normalization really normalizes (#20147)

## Summary

Now that we have `Type::TypeAlias`, which can wrap a union, and the
possibility of unions including non-unpacked type aliases (which is
necessary to support recursive type aliases), we can no longer assume in
`UnionType::normalized_impl` that normalizing each element of an
existing union will result in a set of elements that we can order and
then place raw into `UnionType` to create a normalized union. It's now
possible for those elements to themselves include union types (unpacked
from an alias). So instead, we need to feed those elements into the full
`UnionBuilder` (with alias-unpacking turned on) to flatten/normalize
them, and then order them.

## Test Plan

Added mdtest.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Carl Meyer 2025-08-29 09:02:35 -07:00 committed by GitHub
parent 5a608f7366
commit 8223fea062
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 58 additions and 15 deletions

View file

@ -120,6 +120,23 @@ def f(x: IntOrStr, y: str | bytes):
reveal_type(z) # revealed: (int & ~AlwaysFalsy) | str | bytes reveal_type(z) # revealed: (int & ~AlwaysFalsy) | str | bytes
``` ```
## Multiple layers of union aliases
```py
class A: ...
class B: ...
class C: ...
class D: ...
type W = A | B
type X = C | D
type Y = W | X
from ty_extensions import is_equivalent_to, static_assert
static_assert(is_equivalent_to(Y, A | B | C | D))
```
## `TypeAliasType` properties ## `TypeAliasType` properties
Two `TypeAliasType`s are distinct and disjoint, even if they refer to the same type Two `TypeAliasType`s are distinct and disjoint, even if they refer to the same type

View file

@ -1118,9 +1118,7 @@ impl<'db> Type<'db> {
#[must_use] #[must_use]
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
match self { match self {
Type::Union(union) => { Type::Union(union) => visitor.visit(self, || union.normalized_impl(db, visitor)),
visitor.visit(self, || Type::Union(union.normalized_impl(db, visitor)))
}
Type::Intersection(intersection) => visitor.visit(self, || { Type::Intersection(intersection) => visitor.visit(self, || {
Type::Intersection(intersection.normalized_impl(db, visitor)) Type::Intersection(intersection.normalized_impl(db, visitor))
}), }),
@ -1887,14 +1885,14 @@ impl<'db> Type<'db> {
} }
(Type::TypeAlias(self_alias), _) => { (Type::TypeAlias(self_alias), _) => {
let self_alias_ty = self_alias.value_type(db); let self_alias_ty = self_alias.value_type(db).normalized(db);
visitor.visit((self_alias_ty, other), || { visitor.visit((self_alias_ty, other), || {
self_alias_ty.is_equivalent_to_impl(db, other, visitor) self_alias_ty.is_equivalent_to_impl(db, other, visitor)
}) })
} }
(_, Type::TypeAlias(other_alias)) => { (_, Type::TypeAlias(other_alias)) => {
let other_alias_ty = other_alias.value_type(db); let other_alias_ty = other_alias.value_type(db).normalized(db);
visitor.visit((self, other_alias_ty), || { visitor.visit((self, other_alias_ty), || {
self.is_equivalent_to_impl(db, other_alias_ty, visitor) self.is_equivalent_to_impl(db, other_alias_ty, visitor)
}) })
@ -7697,7 +7695,17 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
TypeVarBoundOrConstraints::UpperBound(bound.normalized_impl(db, visitor)) TypeVarBoundOrConstraints::UpperBound(bound.normalized_impl(db, visitor))
} }
TypeVarBoundOrConstraints::Constraints(constraints) => { TypeVarBoundOrConstraints::Constraints(constraints) => {
TypeVarBoundOrConstraints::Constraints(constraints.normalized_impl(db, visitor)) // Constraints are a non-normalized union by design (it's not really a union at
// all, we are just using a union to store the types). Normalize the types but not
// the containing union.
TypeVarBoundOrConstraints::Constraints(UnionType::new(
db,
constraints
.elements(db)
.iter()
.map(|ty| ty.normalized_impl(db, visitor))
.collect::<Box<_>>(),
))
} }
} }
} }
@ -9654,18 +9662,25 @@ impl<'db> UnionType<'db> {
/// ///
/// See [`Type::normalized`] for more details. /// See [`Type::normalized`] for more details.
#[must_use] #[must_use]
pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { pub(crate) fn normalized(self, db: &'db dyn Db) -> Type<'db> {
self.normalized_impl(db, &NormalizedVisitor::default()) self.normalized_impl(db, &NormalizedVisitor::default())
} }
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { pub(crate) fn normalized_impl(
let mut new_elements: Vec<Type<'db>> = self self,
.elements(db) db: &'db dyn Db,
visitor: &NormalizedVisitor<'db>,
) -> Type<'db> {
self.elements(db)
.iter() .iter()
.map(|element| element.normalized_impl(db, visitor)) .map(|ty| ty.normalized_impl(db, visitor))
.collect(); .fold(
new_elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r)); UnionBuilder::new(db)
UnionType::new(db, new_elements.into_boxed_slice()) .order_elements(true)
.unpack_aliases(true),
UnionBuilder::add,
)
.build()
} }
pub(crate) fn is_equivalent_to_impl<C: Constraints<'db>>( pub(crate) fn is_equivalent_to_impl<C: Constraints<'db>>(
@ -9687,7 +9702,7 @@ impl<'db> UnionType<'db> {
let sorted_self = self.normalized(db); let sorted_self = self.normalized(db);
if sorted_self == other { if sorted_self == Type::Union(other) {
return C::always_satisfiable(db); return C::always_satisfiable(db);
} }

View file

@ -38,6 +38,7 @@
//! unnecessary `is_subtype_of` checks. //! unnecessary `is_subtype_of` checks.
use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::type_ordering::union_or_intersection_elements_ordering;
use crate::types::{ use crate::types::{
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type, BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
TypeVarBoundOrConstraints, UnionType, TypeVarBoundOrConstraints, UnionType,
@ -211,6 +212,7 @@ pub(crate) struct UnionBuilder<'db> {
elements: Vec<UnionElement<'db>>, elements: Vec<UnionElement<'db>>,
db: &'db dyn Db, db: &'db dyn Db,
unpack_aliases: bool, unpack_aliases: bool,
order_elements: bool,
} }
impl<'db> UnionBuilder<'db> { impl<'db> UnionBuilder<'db> {
@ -219,6 +221,7 @@ impl<'db> UnionBuilder<'db> {
db, db,
elements: vec![], elements: vec![],
unpack_aliases: true, unpack_aliases: true,
order_elements: false,
} }
} }
@ -227,6 +230,11 @@ impl<'db> UnionBuilder<'db> {
self self
} }
pub(crate) fn order_elements(mut self, val: bool) -> Self {
self.order_elements = val;
self
}
pub(crate) fn is_empty(&self) -> bool { pub(crate) fn is_empty(&self) -> bool {
self.elements.is_empty() self.elements.is_empty()
} }
@ -545,6 +553,9 @@ impl<'db> UnionBuilder<'db> {
UnionElement::Type(ty) => types.push(ty), UnionElement::Type(ty) => types.push(ty),
} }
} }
if self.order_elements {
types.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(self.db, l, r));
}
match types.len() { match types.len() {
0 => None, 0 => None,
1 => Some(types[0]), 1 => Some(types[0]),