mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 13:24:57 +00:00
[ty] don't eagerly unpack aliases in user-authored unions (#20055)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run
## Summary Add a subtly different test case for recursive PEP 695 type aliases, which does require that we relax our union simplification, so we don't eagerly unpack aliases from user-provided union annotations. ## Test Plan Added mdtest.
This commit is contained in:
parent
a60fb3f2c8
commit
9ab276b345
4 changed files with 197 additions and 57 deletions
|
@ -204,13 +204,17 @@ def f(x: OptNestedInt) -> None:
|
|||
### Invalid self-referential
|
||||
|
||||
```py
|
||||
# TODO emit a diagnostic here
|
||||
# TODO emit a diagnostic on these two lines
|
||||
type IntOr = int | IntOr
|
||||
type OrInt = OrInt | int
|
||||
|
||||
def f(x: IntOr):
|
||||
def f(x: IntOr, y: OrInt):
|
||||
reveal_type(x) # revealed: int
|
||||
reveal_type(y) # revealed: int
|
||||
if not isinstance(x, int):
|
||||
reveal_type(x) # revealed: Never
|
||||
if not isinstance(y, int):
|
||||
reveal_type(y) # revealed: Never
|
||||
```
|
||||
|
||||
### Mutually recursive
|
||||
|
@ -234,3 +238,42 @@ from ty_extensions import Intersection
|
|||
def h(x: Intersection[A, B]):
|
||||
reveal_type(x) # revealed: tuple[B] | None
|
||||
```
|
||||
|
||||
### Union inside generic
|
||||
|
||||
#### With old-style union
|
||||
|
||||
```py
|
||||
from typing import Union
|
||||
|
||||
type A = list[Union["A", str]]
|
||||
|
||||
def f(x: A):
|
||||
reveal_type(x) # revealed: list[A | str]
|
||||
for item in x:
|
||||
reveal_type(item) # revealed: list[A | str] | str
|
||||
```
|
||||
|
||||
#### With new-style union
|
||||
|
||||
```py
|
||||
type A = list["A" | str]
|
||||
|
||||
def f(x: A):
|
||||
reveal_type(x) # revealed: list[A | str]
|
||||
for item in x:
|
||||
reveal_type(item) # revealed: list[A | str] | str
|
||||
```
|
||||
|
||||
#### With Optional
|
||||
|
||||
```py
|
||||
from typing import Optional, Union
|
||||
|
||||
type A = list[Optional[Union["A", str]]]
|
||||
|
||||
def f(x: A):
|
||||
reveal_type(x) # revealed: list[A | str | None]
|
||||
for item in x:
|
||||
reveal_type(item) # revealed: list[A | str | None] | str | None
|
||||
```
|
||||
|
|
|
@ -9389,6 +9389,21 @@ impl<'db> UnionType<'db> {
|
|||
.build()
|
||||
}
|
||||
|
||||
/// Create a union from a list of elements without unpacking type aliases.
|
||||
pub(crate) fn from_elements_leave_aliases<I, T>(db: &'db dyn Db, elements: I) -> Type<'db>
|
||||
where
|
||||
I: IntoIterator<Item = T>,
|
||||
T: Into<Type<'db>>,
|
||||
{
|
||||
elements
|
||||
.into_iter()
|
||||
.fold(
|
||||
UnionBuilder::new(db).unpack_aliases(false),
|
||||
|builder, element| builder.add(element.into()),
|
||||
)
|
||||
.build()
|
||||
}
|
||||
|
||||
/// A fallible version of [`UnionType::from_elements`].
|
||||
///
|
||||
/// If all items in `elements` are `Some()`, the result of unioning all elements is returned.
|
||||
|
|
|
@ -207,6 +207,7 @@ const MAX_UNION_LITERALS: usize = 200;
|
|||
pub(crate) struct UnionBuilder<'db> {
|
||||
elements: Vec<UnionElement<'db>>,
|
||||
db: &'db dyn Db,
|
||||
unpack_aliases: bool,
|
||||
}
|
||||
|
||||
impl<'db> UnionBuilder<'db> {
|
||||
|
@ -214,9 +215,15 @@ impl<'db> UnionBuilder<'db> {
|
|||
Self {
|
||||
db,
|
||||
elements: vec![],
|
||||
unpack_aliases: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unpack_aliases(mut self, val: bool) -> Self {
|
||||
self.unpack_aliases = val;
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.elements.is_empty()
|
||||
}
|
||||
|
@ -236,17 +243,29 @@ impl<'db> UnionBuilder<'db> {
|
|||
|
||||
/// Adds a type to this union.
|
||||
pub(crate) fn add_in_place(&mut self, ty: Type<'db>) {
|
||||
self.add_in_place_impl(ty, &mut vec![]);
|
||||
}
|
||||
|
||||
pub(crate) fn add_in_place_impl(&mut self, ty: Type<'db>, seen_aliases: &mut Vec<Type<'db>>) {
|
||||
match ty {
|
||||
Type::Union(union) => {
|
||||
let new_elements = union.elements(self.db);
|
||||
self.elements.reserve(new_elements.len());
|
||||
for element in new_elements {
|
||||
self.add_in_place(*element);
|
||||
self.add_in_place_impl(*element, seen_aliases);
|
||||
}
|
||||
}
|
||||
// Adding `Never` to a union is a no-op.
|
||||
Type::Never => {}
|
||||
Type::TypeAlias(alias) => self.add_in_place(alias.value_type(self.db)),
|
||||
Type::TypeAlias(alias) if self.unpack_aliases => {
|
||||
if seen_aliases.contains(&ty) {
|
||||
// Union contains itself recursively via a type alias. This is an error, just
|
||||
// leave out the recursive alias. TODO surface this error.
|
||||
} else {
|
||||
seen_aliases.push(ty);
|
||||
self.add_in_place_impl(alias.value_type(self.db), seen_aliases);
|
||||
}
|
||||
}
|
||||
// If adding a string literal, look for an existing `UnionElement::StringLiterals` to
|
||||
// add it to, or an existing element that is a super-type of string literals, which
|
||||
// means we shouldn't add it. Otherwise, add a new `UnionElement::StringLiterals`
|
||||
|
@ -260,7 +279,7 @@ impl<'db> UnionBuilder<'db> {
|
|||
UnionElement::StringLiterals(literals) => {
|
||||
if literals.len() >= MAX_UNION_LITERALS {
|
||||
let replace_with = KnownClass::Str.to_instance(self.db);
|
||||
self.add_in_place(replace_with);
|
||||
self.add_in_place_impl(replace_with, seen_aliases);
|
||||
return;
|
||||
}
|
||||
found = Some(literals);
|
||||
|
@ -305,7 +324,7 @@ impl<'db> UnionBuilder<'db> {
|
|||
UnionElement::BytesLiterals(literals) => {
|
||||
if literals.len() >= MAX_UNION_LITERALS {
|
||||
let replace_with = KnownClass::Bytes.to_instance(self.db);
|
||||
self.add_in_place(replace_with);
|
||||
self.add_in_place_impl(replace_with, seen_aliases);
|
||||
return;
|
||||
}
|
||||
found = Some(literals);
|
||||
|
@ -350,7 +369,7 @@ impl<'db> UnionBuilder<'db> {
|
|||
UnionElement::IntLiterals(literals) => {
|
||||
if literals.len() >= MAX_UNION_LITERALS {
|
||||
let replace_with = KnownClass::Int.to_instance(self.db);
|
||||
self.add_in_place(replace_with);
|
||||
self.add_in_place_impl(replace_with, seen_aliases);
|
||||
return;
|
||||
}
|
||||
found = Some(literals);
|
||||
|
@ -404,7 +423,10 @@ impl<'db> UnionBuilder<'db> {
|
|||
.is_none();
|
||||
|
||||
if all_members_are_in_union {
|
||||
self.add_in_place(enum_member_to_add.enum_class_instance(self.db));
|
||||
self.add_in_place_impl(
|
||||
enum_member_to_add.enum_class_instance(self.db),
|
||||
seen_aliases,
|
||||
);
|
||||
} else if !self
|
||||
.elements
|
||||
.iter()
|
||||
|
@ -426,8 +448,17 @@ impl<'db> UnionBuilder<'db> {
|
|||
None
|
||||
};
|
||||
|
||||
// If an alias gets here, it means we aren't unpacking aliases, and we also
|
||||
// shouldn't try to simplify aliases out of the union, because that will require
|
||||
// unpacking them.
|
||||
let should_simplify_full = !matches!(ty, Type::TypeAlias(_));
|
||||
|
||||
let mut to_remove = SmallVec::<[usize; 2]>::new();
|
||||
let ty_negated = ty.negate(self.db);
|
||||
let ty_negated = if should_simplify_full {
|
||||
ty.negate(self.db)
|
||||
} else {
|
||||
Type::Never // won't be used
|
||||
};
|
||||
|
||||
for (index, element) in self.elements.iter_mut().enumerate() {
|
||||
let element_type = match element.try_reduce(self.db, ty) {
|
||||
|
@ -446,30 +477,37 @@ impl<'db> UnionBuilder<'db> {
|
|||
return;
|
||||
}
|
||||
};
|
||||
if Some(element_type) == bool_pair {
|
||||
self.add_in_place(KnownClass::Bool.to_instance(self.db));
|
||||
|
||||
if ty == element_type {
|
||||
return;
|
||||
}
|
||||
|
||||
if ty.is_equivalent_to(self.db, element_type)
|
||||
|| ty.is_subtype_of(self.db, element_type)
|
||||
{
|
||||
return;
|
||||
} else if element_type.is_subtype_of(self.db, ty) {
|
||||
to_remove.push(index);
|
||||
} else if ty_negated.is_subtype_of(self.db, element_type) {
|
||||
// We add `ty` to the union. We just checked that `~ty` is a subtype of an
|
||||
// existing `element`. This also means that `~ty | ty` is a subtype of
|
||||
// `element | ty`, because both elements in the first union are subtypes of
|
||||
// the corresponding elements in the second union. But `~ty | ty` is just
|
||||
// `object`. Since `object` is a subtype of `element | ty`, we can only
|
||||
// conclude that `element | ty` must be `object` (object has no other
|
||||
// supertypes). This means we can simplify the whole union to just
|
||||
// `object`, since all other potential elements would also be subtypes of
|
||||
// `object`.
|
||||
self.collapse_to_object();
|
||||
if Some(element_type) == bool_pair {
|
||||
self.add_in_place_impl(KnownClass::Bool.to_instance(self.db), seen_aliases);
|
||||
return;
|
||||
}
|
||||
|
||||
if should_simplify_full && !matches!(element_type, Type::TypeAlias(_)) {
|
||||
if ty.is_equivalent_to(self.db, element_type)
|
||||
|| ty.is_subtype_of(self.db, element_type)
|
||||
{
|
||||
return;
|
||||
} else if element_type.is_subtype_of(self.db, ty) {
|
||||
to_remove.push(index);
|
||||
} else if ty_negated.is_subtype_of(self.db, element_type) {
|
||||
// We add `ty` to the union. We just checked that `~ty` is a subtype of an
|
||||
// existing `element`. This also means that `~ty | ty` is a subtype of
|
||||
// `element | ty`, because both elements in the first union are subtypes of
|
||||
// the corresponding elements in the second union. But `~ty | ty` is just
|
||||
// `object`. Since `object` is a subtype of `element | ty`, we can only
|
||||
// conclude that `element | ty` must be `object` (object has no other
|
||||
// supertypes). This means we can simplify the whole union to just
|
||||
// `object`, since all other potential elements would also be subtypes of
|
||||
// `object`.
|
||||
self.collapse_to_object();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((&first, rest)) = to_remove.split_first() {
|
||||
self.elements[first] = UnionElement::Type(ty);
|
||||
|
@ -541,11 +579,27 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self {
|
||||
pub(crate) fn add_positive(self, ty: Type<'db>) -> Self {
|
||||
self.add_positive_impl(ty, &mut vec![])
|
||||
}
|
||||
|
||||
pub(crate) fn add_positive_impl(
|
||||
mut self,
|
||||
ty: Type<'db>,
|
||||
seen_aliases: &mut Vec<Type<'db>>,
|
||||
) -> Self {
|
||||
match ty {
|
||||
Type::TypeAlias(alias) => {
|
||||
if seen_aliases.contains(&ty) {
|
||||
// Recursive alias, add it without expanding to avoid infinite recursion.
|
||||
for inner in &mut self.intersections {
|
||||
inner.positive.insert(ty);
|
||||
}
|
||||
return self;
|
||||
}
|
||||
seen_aliases.push(ty);
|
||||
let value_type = alias.value_type(self.db);
|
||||
self.add_positive(value_type)
|
||||
self.add_positive_impl(value_type, seen_aliases)
|
||||
}
|
||||
Type::Union(union) => {
|
||||
// Distribute ourself over this union: for each union element, clone ourself and
|
||||
|
@ -559,7 +613,7 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
union
|
||||
.elements(self.db)
|
||||
.iter()
|
||||
.map(|elem| self.clone().add_positive(*elem))
|
||||
.map(|elem| self.clone().add_positive_impl(*elem, seen_aliases))
|
||||
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
|
||||
builder.intersections.extend(sub.intersections);
|
||||
builder
|
||||
|
@ -569,10 +623,10 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
Type::Intersection(other) => {
|
||||
let db = self.db;
|
||||
for pos in other.positive(db) {
|
||||
self = self.add_positive(*pos);
|
||||
self = self.add_positive_impl(*pos, seen_aliases);
|
||||
}
|
||||
for neg in other.negative(db) {
|
||||
self = self.add_negative(*neg);
|
||||
self = self.add_negative_impl(*neg, seen_aliases);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
@ -600,12 +654,15 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
// `UnionBuilder` because we would simplify the union to just the enum instance
|
||||
// and end up in this branch again.
|
||||
let db = self.db;
|
||||
self.add_positive(Type::Union(UnionType::new(
|
||||
db,
|
||||
enum_member_literals(db, instance.class(db).class_literal(db).0, None)
|
||||
.expect("Calling `enum_member_literals` on an enum class")
|
||||
.collect::<Box<[_]>>(),
|
||||
)))
|
||||
self.add_positive_impl(
|
||||
Type::Union(UnionType::new(
|
||||
db,
|
||||
enum_member_literals(db, instance.class(db).class_literal(db).0, None)
|
||||
.expect("Calling `enum_member_literals` on an enum class")
|
||||
.collect::<Box<[_]>>(),
|
||||
)),
|
||||
seen_aliases,
|
||||
)
|
||||
} else {
|
||||
for inner in &mut self.intersections {
|
||||
inner.add_positive(self.db, ty);
|
||||
|
@ -624,7 +681,15 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self {
|
||||
pub(crate) fn add_negative(self, ty: Type<'db>) -> Self {
|
||||
self.add_negative_impl(ty, &mut vec![])
|
||||
}
|
||||
|
||||
pub(crate) fn add_negative_impl(
|
||||
mut self,
|
||||
ty: Type<'db>,
|
||||
seen_aliases: &mut Vec<Type<'db>>,
|
||||
) -> Self {
|
||||
let contains_enum = |enum_instance| {
|
||||
self.intersections
|
||||
.iter()
|
||||
|
@ -635,12 +700,20 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
// See comments above in `add_positive`; this is just the negated version.
|
||||
match ty {
|
||||
Type::TypeAlias(alias) => {
|
||||
if seen_aliases.contains(&ty) {
|
||||
// Recursive alias, add it without expanding to avoid infinite recursion.
|
||||
for inner in &mut self.intersections {
|
||||
inner.negative.insert(ty);
|
||||
}
|
||||
return self;
|
||||
}
|
||||
seen_aliases.push(ty);
|
||||
let value_type = alias.value_type(self.db);
|
||||
self.add_negative(value_type)
|
||||
self.add_negative_impl(value_type, seen_aliases)
|
||||
}
|
||||
Type::Union(union) => {
|
||||
for elem in union.elements(self.db) {
|
||||
self = self.add_negative(*elem);
|
||||
self = self.add_negative_impl(*elem, seen_aliases);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
@ -656,13 +729,19 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
.positive(self.db)
|
||||
.iter()
|
||||
// we negate all the positive constraints while distributing
|
||||
.map(|elem| self.clone().add_negative(*elem));
|
||||
.map(|elem| {
|
||||
self.clone()
|
||||
.add_negative_impl(*elem, &mut seen_aliases.clone())
|
||||
});
|
||||
|
||||
let negative_side = intersection
|
||||
.negative(self.db)
|
||||
.iter()
|
||||
// all negative constraints end up becoming positive constraints
|
||||
.map(|elem| self.clone().add_positive(*elem));
|
||||
.map(|elem| {
|
||||
self.clone()
|
||||
.add_positive_impl(*elem, &mut seen_aliases.clone())
|
||||
});
|
||||
|
||||
positive_side.chain(negative_side).fold(
|
||||
IntersectionBuilder::empty(self.db),
|
||||
|
@ -676,15 +755,18 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
if contains_enum(enum_literal.enum_class_instance(self.db)) =>
|
||||
{
|
||||
let db = self.db;
|
||||
self.add_positive(UnionType::from_elements(
|
||||
db,
|
||||
enum_member_literals(
|
||||
self.add_positive_impl(
|
||||
UnionType::from_elements(
|
||||
db,
|
||||
enum_literal.enum_class(db),
|
||||
Some(enum_literal.name(db)),
|
||||
)
|
||||
.expect("Calling `enum_member_literals` on an enum class"),
|
||||
))
|
||||
enum_member_literals(
|
||||
db,
|
||||
enum_literal.enum_class(db),
|
||||
Some(enum_literal.name(db)),
|
||||
)
|
||||
.expect("Calling `enum_member_literals` on an enum class"),
|
||||
),
|
||||
seen_aliases,
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
for inner in &mut self.intersections {
|
||||
|
|
|
@ -9776,7 +9776,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
|
|||
ast::Operator::BitOr => {
|
||||
let left_ty = self.infer_type_expression(&binary.left);
|
||||
let right_ty = self.infer_type_expression(&binary.right);
|
||||
UnionType::from_elements(self.db(), [left_ty, right_ty])
|
||||
UnionType::from_elements_leave_aliases(self.db(), [left_ty, right_ty])
|
||||
}
|
||||
// anything else is an invalid annotation:
|
||||
op => {
|
||||
|
@ -10288,7 +10288,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
|
|||
}
|
||||
}
|
||||
ast::Expr::BinOp(binary) if binary.op == ast::Operator::BitOr => {
|
||||
let union_ty = UnionType::from_elements(
|
||||
let union_ty = UnionType::from_elements_leave_aliases(
|
||||
self.db(),
|
||||
[
|
||||
self.infer_subclass_of_type_expression(&binary.left),
|
||||
|
@ -10314,7 +10314,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
|
|||
let parameters_ty = match self.infer_expression(value) {
|
||||
Type::SpecialForm(SpecialFormType::Union) => match &**parameters {
|
||||
ast::Expr::Tuple(tuple) => {
|
||||
let ty = UnionType::from_elements(
|
||||
let ty = UnionType::from_elements_leave_aliases(
|
||||
self.db(),
|
||||
tuple
|
||||
.iter()
|
||||
|
@ -10548,11 +10548,11 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
|
|||
},
|
||||
SpecialFormType::Optional => {
|
||||
let param_type = self.infer_type_expression(arguments_slice);
|
||||
UnionType::from_elements(db, [param_type, Type::none(db)])
|
||||
UnionType::from_elements_leave_aliases(db, [param_type, Type::none(db)])
|
||||
}
|
||||
SpecialFormType::Union => match arguments_slice {
|
||||
ast::Expr::Tuple(t) => {
|
||||
let union_ty = UnionType::from_elements(
|
||||
let union_ty = UnionType::from_elements_leave_aliases(
|
||||
db,
|
||||
t.iter().map(|elt| self.infer_type_expression(elt)),
|
||||
);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue