[ty] don't eagerly unpack aliases in user-authored unions (#20055)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

Add a subtly different test case for recursive PEP 695 type aliases,
which does require that we relax our union simplification, so we don't
eagerly unpack aliases from user-provided union annotations.

## Test Plan

Added mdtest.
This commit is contained in:
Carl Meyer 2025-08-26 16:29:45 -07:00 committed by GitHub
parent a60fb3f2c8
commit 9ab276b345
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 197 additions and 57 deletions

View file

@ -9389,6 +9389,21 @@ impl<'db> UnionType<'db> {
.build()
}
/// Create a union from a list of elements without unpacking type aliases.
pub(crate) fn from_elements_leave_aliases<I, T>(db: &'db dyn Db, elements: I) -> Type<'db>
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
{
elements
.into_iter()
.fold(
UnionBuilder::new(db).unpack_aliases(false),
|builder, element| builder.add(element.into()),
)
.build()
}
/// A fallible version of [`UnionType::from_elements`].
///
/// If all items in `elements` are `Some()`, the result of unioning all elements is returned.

View file

@ -207,6 +207,7 @@ const MAX_UNION_LITERALS: usize = 200;
pub(crate) struct UnionBuilder<'db> {
elements: Vec<UnionElement<'db>>,
db: &'db dyn Db,
unpack_aliases: bool,
}
impl<'db> UnionBuilder<'db> {
@ -214,9 +215,15 @@ impl<'db> UnionBuilder<'db> {
Self {
db,
elements: vec![],
unpack_aliases: true,
}
}
pub(crate) fn unpack_aliases(mut self, val: bool) -> Self {
self.unpack_aliases = val;
self
}
pub(crate) fn is_empty(&self) -> bool {
self.elements.is_empty()
}
@ -236,17 +243,29 @@ impl<'db> UnionBuilder<'db> {
/// Adds a type to this union.
pub(crate) fn add_in_place(&mut self, ty: Type<'db>) {
self.add_in_place_impl(ty, &mut vec![]);
}
pub(crate) fn add_in_place_impl(&mut self, ty: Type<'db>, seen_aliases: &mut Vec<Type<'db>>) {
match ty {
Type::Union(union) => {
let new_elements = union.elements(self.db);
self.elements.reserve(new_elements.len());
for element in new_elements {
self.add_in_place(*element);
self.add_in_place_impl(*element, seen_aliases);
}
}
// Adding `Never` to a union is a no-op.
Type::Never => {}
Type::TypeAlias(alias) => self.add_in_place(alias.value_type(self.db)),
Type::TypeAlias(alias) if self.unpack_aliases => {
if seen_aliases.contains(&ty) {
// Union contains itself recursively via a type alias. This is an error, just
// leave out the recursive alias. TODO surface this error.
} else {
seen_aliases.push(ty);
self.add_in_place_impl(alias.value_type(self.db), seen_aliases);
}
}
// If adding a string literal, look for an existing `UnionElement::StringLiterals` to
// add it to, or an existing element that is a super-type of string literals, which
// means we shouldn't add it. Otherwise, add a new `UnionElement::StringLiterals`
@ -260,7 +279,7 @@ impl<'db> UnionBuilder<'db> {
UnionElement::StringLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
let replace_with = KnownClass::Str.to_instance(self.db);
self.add_in_place(replace_with);
self.add_in_place_impl(replace_with, seen_aliases);
return;
}
found = Some(literals);
@ -305,7 +324,7 @@ impl<'db> UnionBuilder<'db> {
UnionElement::BytesLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
let replace_with = KnownClass::Bytes.to_instance(self.db);
self.add_in_place(replace_with);
self.add_in_place_impl(replace_with, seen_aliases);
return;
}
found = Some(literals);
@ -350,7 +369,7 @@ impl<'db> UnionBuilder<'db> {
UnionElement::IntLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
let replace_with = KnownClass::Int.to_instance(self.db);
self.add_in_place(replace_with);
self.add_in_place_impl(replace_with, seen_aliases);
return;
}
found = Some(literals);
@ -404,7 +423,10 @@ impl<'db> UnionBuilder<'db> {
.is_none();
if all_members_are_in_union {
self.add_in_place(enum_member_to_add.enum_class_instance(self.db));
self.add_in_place_impl(
enum_member_to_add.enum_class_instance(self.db),
seen_aliases,
);
} else if !self
.elements
.iter()
@ -426,8 +448,17 @@ impl<'db> UnionBuilder<'db> {
None
};
// If an alias gets here, it means we aren't unpacking aliases, and we also
// shouldn't try to simplify aliases out of the union, because that will require
// unpacking them.
let should_simplify_full = !matches!(ty, Type::TypeAlias(_));
let mut to_remove = SmallVec::<[usize; 2]>::new();
let ty_negated = ty.negate(self.db);
let ty_negated = if should_simplify_full {
ty.negate(self.db)
} else {
Type::Never // won't be used
};
for (index, element) in self.elements.iter_mut().enumerate() {
let element_type = match element.try_reduce(self.db, ty) {
@ -446,30 +477,37 @@ impl<'db> UnionBuilder<'db> {
return;
}
};
if Some(element_type) == bool_pair {
self.add_in_place(KnownClass::Bool.to_instance(self.db));
if ty == element_type {
return;
}
if ty.is_equivalent_to(self.db, element_type)
|| ty.is_subtype_of(self.db, element_type)
{
return;
} else if element_type.is_subtype_of(self.db, ty) {
to_remove.push(index);
} else if ty_negated.is_subtype_of(self.db, element_type) {
// We add `ty` to the union. We just checked that `~ty` is a subtype of an
// existing `element`. This also means that `~ty | ty` is a subtype of
// `element | ty`, because both elements in the first union are subtypes of
// the corresponding elements in the second union. But `~ty | ty` is just
// `object`. Since `object` is a subtype of `element | ty`, we can only
// conclude that `element | ty` must be `object` (object has no other
// supertypes). This means we can simplify the whole union to just
// `object`, since all other potential elements would also be subtypes of
// `object`.
self.collapse_to_object();
if Some(element_type) == bool_pair {
self.add_in_place_impl(KnownClass::Bool.to_instance(self.db), seen_aliases);
return;
}
if should_simplify_full && !matches!(element_type, Type::TypeAlias(_)) {
if ty.is_equivalent_to(self.db, element_type)
|| ty.is_subtype_of(self.db, element_type)
{
return;
} else if element_type.is_subtype_of(self.db, ty) {
to_remove.push(index);
} else if ty_negated.is_subtype_of(self.db, element_type) {
// We add `ty` to the union. We just checked that `~ty` is a subtype of an
// existing `element`. This also means that `~ty | ty` is a subtype of
// `element | ty`, because both elements in the first union are subtypes of
// the corresponding elements in the second union. But `~ty | ty` is just
// `object`. Since `object` is a subtype of `element | ty`, we can only
// conclude that `element | ty` must be `object` (object has no other
// supertypes). This means we can simplify the whole union to just
// `object`, since all other potential elements would also be subtypes of
// `object`.
self.collapse_to_object();
return;
}
}
}
if let Some((&first, rest)) = to_remove.split_first() {
self.elements[first] = UnionElement::Type(ty);
@ -541,11 +579,27 @@ impl<'db> IntersectionBuilder<'db> {
}
}
pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self {
pub(crate) fn add_positive(self, ty: Type<'db>) -> Self {
self.add_positive_impl(ty, &mut vec![])
}
pub(crate) fn add_positive_impl(
mut self,
ty: Type<'db>,
seen_aliases: &mut Vec<Type<'db>>,
) -> Self {
match ty {
Type::TypeAlias(alias) => {
if seen_aliases.contains(&ty) {
// Recursive alias, add it without expanding to avoid infinite recursion.
for inner in &mut self.intersections {
inner.positive.insert(ty);
}
return self;
}
seen_aliases.push(ty);
let value_type = alias.value_type(self.db);
self.add_positive(value_type)
self.add_positive_impl(value_type, seen_aliases)
}
Type::Union(union) => {
// Distribute ourself over this union: for each union element, clone ourself and
@ -559,7 +613,7 @@ impl<'db> IntersectionBuilder<'db> {
union
.elements(self.db)
.iter()
.map(|elem| self.clone().add_positive(*elem))
.map(|elem| self.clone().add_positive_impl(*elem, seen_aliases))
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
builder.intersections.extend(sub.intersections);
builder
@ -569,10 +623,10 @@ impl<'db> IntersectionBuilder<'db> {
Type::Intersection(other) => {
let db = self.db;
for pos in other.positive(db) {
self = self.add_positive(*pos);
self = self.add_positive_impl(*pos, seen_aliases);
}
for neg in other.negative(db) {
self = self.add_negative(*neg);
self = self.add_negative_impl(*neg, seen_aliases);
}
self
}
@ -600,12 +654,15 @@ impl<'db> IntersectionBuilder<'db> {
// `UnionBuilder` because we would simplify the union to just the enum instance
// and end up in this branch again.
let db = self.db;
self.add_positive(Type::Union(UnionType::new(
db,
enum_member_literals(db, instance.class(db).class_literal(db).0, None)
.expect("Calling `enum_member_literals` on an enum class")
.collect::<Box<[_]>>(),
)))
self.add_positive_impl(
Type::Union(UnionType::new(
db,
enum_member_literals(db, instance.class(db).class_literal(db).0, None)
.expect("Calling `enum_member_literals` on an enum class")
.collect::<Box<[_]>>(),
)),
seen_aliases,
)
} else {
for inner in &mut self.intersections {
inner.add_positive(self.db, ty);
@ -624,7 +681,15 @@ impl<'db> IntersectionBuilder<'db> {
}
}
pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self {
pub(crate) fn add_negative(self, ty: Type<'db>) -> Self {
self.add_negative_impl(ty, &mut vec![])
}
pub(crate) fn add_negative_impl(
mut self,
ty: Type<'db>,
seen_aliases: &mut Vec<Type<'db>>,
) -> Self {
let contains_enum = |enum_instance| {
self.intersections
.iter()
@ -635,12 +700,20 @@ impl<'db> IntersectionBuilder<'db> {
// See comments above in `add_positive`; this is just the negated version.
match ty {
Type::TypeAlias(alias) => {
if seen_aliases.contains(&ty) {
// Recursive alias, add it without expanding to avoid infinite recursion.
for inner in &mut self.intersections {
inner.negative.insert(ty);
}
return self;
}
seen_aliases.push(ty);
let value_type = alias.value_type(self.db);
self.add_negative(value_type)
self.add_negative_impl(value_type, seen_aliases)
}
Type::Union(union) => {
for elem in union.elements(self.db) {
self = self.add_negative(*elem);
self = self.add_negative_impl(*elem, seen_aliases);
}
self
}
@ -656,13 +729,19 @@ impl<'db> IntersectionBuilder<'db> {
.positive(self.db)
.iter()
// we negate all the positive constraints while distributing
.map(|elem| self.clone().add_negative(*elem));
.map(|elem| {
self.clone()
.add_negative_impl(*elem, &mut seen_aliases.clone())
});
let negative_side = intersection
.negative(self.db)
.iter()
// all negative constraints end up becoming positive constraints
.map(|elem| self.clone().add_positive(*elem));
.map(|elem| {
self.clone()
.add_positive_impl(*elem, &mut seen_aliases.clone())
});
positive_side.chain(negative_side).fold(
IntersectionBuilder::empty(self.db),
@ -676,15 +755,18 @@ impl<'db> IntersectionBuilder<'db> {
if contains_enum(enum_literal.enum_class_instance(self.db)) =>
{
let db = self.db;
self.add_positive(UnionType::from_elements(
db,
enum_member_literals(
self.add_positive_impl(
UnionType::from_elements(
db,
enum_literal.enum_class(db),
Some(enum_literal.name(db)),
)
.expect("Calling `enum_member_literals` on an enum class"),
))
enum_member_literals(
db,
enum_literal.enum_class(db),
Some(enum_literal.name(db)),
)
.expect("Calling `enum_member_literals` on an enum class"),
),
seen_aliases,
)
}
_ => {
for inner in &mut self.intersections {

View file

@ -9776,7 +9776,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
ast::Operator::BitOr => {
let left_ty = self.infer_type_expression(&binary.left);
let right_ty = self.infer_type_expression(&binary.right);
UnionType::from_elements(self.db(), [left_ty, right_ty])
UnionType::from_elements_leave_aliases(self.db(), [left_ty, right_ty])
}
// anything else is an invalid annotation:
op => {
@ -10288,7 +10288,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
}
ast::Expr::BinOp(binary) if binary.op == ast::Operator::BitOr => {
let union_ty = UnionType::from_elements(
let union_ty = UnionType::from_elements_leave_aliases(
self.db(),
[
self.infer_subclass_of_type_expression(&binary.left),
@ -10314,7 +10314,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
let parameters_ty = match self.infer_expression(value) {
Type::SpecialForm(SpecialFormType::Union) => match &**parameters {
ast::Expr::Tuple(tuple) => {
let ty = UnionType::from_elements(
let ty = UnionType::from_elements_leave_aliases(
self.db(),
tuple
.iter()
@ -10548,11 +10548,11 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
},
SpecialFormType::Optional => {
let param_type = self.infer_type_expression(arguments_slice);
UnionType::from_elements(db, [param_type, Type::none(db)])
UnionType::from_elements_leave_aliases(db, [param_type, Type::none(db)])
}
SpecialFormType::Union => match arguments_slice {
ast::Expr::Tuple(t) => {
let union_ty = UnionType::from_elements(
let union_ty = UnionType::from_elements_leave_aliases(
db,
t.iter().map(|elt| self.infer_type_expression(elt)),
);