[red-knot] Improve is_disjoint for two intersections (#16636)

## Summary

Background - as a follow up to #16611 I noticed that there's a lot of
code duplicated between the `is_assignable_to` and `is_subtype_of`
functions and considered trying to merge them.

[A subtype and an assignable type are pretty much the
same](https://typing.python.org/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation),
except that subtypes are by definition fully static, so I think we can
replace the whole of `is_subtype_of` with:

```
if !self.is_fully_static(db) || !target.is_fully_static(db) {
    return false;
}
return self.is_assignable_to(target)
```

if we move all of the logic to is_assignable_to and delete duplicate
code. Then we can discuss if it even makes sense to have a separate
is_subtype_of function (I think the answer is yes since it's used by a
bunch of other places, but we may be able to basically rip out the
concept).

Anyways while playing with combining the functions I noticed is that the
handling of Intersections in `is_subtype_of` has a special case for two
intersections, which I didn't include in the last PR - rather I first
handled right hand intersections before left hand, which should properly
handle double intersections (hand-wavy explanation I can justify if
needed - (A & B & C) is assignable to (A & B) because the left is
assignable to both A and B, but none of A, B, or C is assignable to (A &
B)).

I took a look at what breaks if I remove the handling for double
intersections, and the reason it is needed is because is_disjoint does
not properly handle intersections with negative conditions (so instead
`is_subtype_of` basically implements the check correctly).

This PR adds support to is_disjoint for properly checking negative
branches, which also lets us simplify `is_subtype_of`, bringing it in
line with `is_assignable_to`

## Test Plan

Added a bunch of tests, most of which failed before this fix

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Joey Bar 2025-03-12 14:13:04 +02:00 committed by GitHub
parent 11b5cbcd2f
commit b250304ad3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 67 additions and 45 deletions

View file

@ -266,6 +266,10 @@ static_assert(is_assignable_to(Intersection[int, Parent], Intersection[int, Not[
static_assert(not is_assignable_to(int, Not[int]))
static_assert(not is_assignable_to(int, Not[Literal[1]]))
static_assert(is_assignable_to(Not[Parent], Not[Child1]))
static_assert(not is_assignable_to(Not[Parent], Parent))
static_assert(not is_assignable_to(Intersection[Unrelated, Not[Parent]], Parent))
# Intersection with `Any` dominates the left hand side of intersections
static_assert(is_assignable_to(Intersection[Any, Parent], Parent))
static_assert(is_assignable_to(Intersection[Any, Child1], Parent))
@ -277,6 +281,7 @@ static_assert(is_assignable_to(Intersection[Any, Parent, Unrelated], Intersectio
# Even Any & Not[Parent] is assignable to Parent, since it could be Never
static_assert(is_assignable_to(Intersection[Any, Not[Parent]], Parent))
static_assert(is_assignable_to(Intersection[Any, Not[Parent]], Not[Parent]))
# Intersection with `Any` is effectively ignored on the right hand side for the sake of assignment
static_assert(is_assignable_to(Parent, Intersection[Any, Parent]))

View file

@ -16,6 +16,7 @@ static_assert(not is_disjoint_from(bool, object))
static_assert(not is_disjoint_from(Any, bool))
static_assert(not is_disjoint_from(Any, Any))
static_assert(not is_disjoint_from(Any, Not[Any]))
static_assert(not is_disjoint_from(LiteralString, LiteralString))
static_assert(not is_disjoint_from(str, LiteralString))
@ -95,8 +96,8 @@ static_assert(not is_disjoint_from(Literal[1, 2], Literal[2, 3]))
## Intersections
```py
from typing_extensions import Literal, final
from knot_extensions import Intersection, is_disjoint_from, static_assert
from typing_extensions import Literal, final, Any
from knot_extensions import Intersection, is_disjoint_from, static_assert, Not
@final
class P: ...
@ -130,6 +131,27 @@ static_assert(not is_disjoint_from(Y, Z))
static_assert(not is_disjoint_from(Intersection[X, Y], Z))
static_assert(not is_disjoint_from(Intersection[X, Z], Y))
static_assert(not is_disjoint_from(Intersection[Y, Z], X))
# If one side has a positive fully-static element and the other side has a negative of that element, they are disjoint
static_assert(is_disjoint_from(int, Not[int]))
static_assert(is_disjoint_from(Intersection[X, Y, Not[Z]], Intersection[X, Z]))
static_assert(is_disjoint_from(Intersection[X, Not[Literal[1]]], Literal[1]))
class Parent: ...
class Child(Parent): ...
static_assert(not is_disjoint_from(Parent, Child))
static_assert(not is_disjoint_from(Parent, Not[Child]))
static_assert(not is_disjoint_from(Not[Parent], Not[Child]))
static_assert(is_disjoint_from(Not[Parent], Child))
static_assert(is_disjoint_from(Intersection[X, Not[Parent]], Child))
static_assert(is_disjoint_from(Intersection[X, Not[Parent]], Intersection[X, Child]))
static_assert(not is_disjoint_from(Intersection[Any, X], Intersection[Any, Not[Y]]))
static_assert(not is_disjoint_from(Intersection[Any, Not[Y]], Intersection[Any, X]))
static_assert(is_disjoint_from(Intersection[int, Any], Not[int]))
static_assert(is_disjoint_from(Not[int], Intersection[int, Any]))
```
## Special types
@ -152,7 +174,7 @@ static_assert(is_disjoint_from(Never, object))
```py
from typing_extensions import Literal, LiteralString
from knot_extensions import is_disjoint_from, static_assert
from knot_extensions import is_disjoint_from, static_assert, Intersection, Not
static_assert(is_disjoint_from(None, Literal[True]))
static_assert(is_disjoint_from(None, Literal[1]))
@ -165,6 +187,9 @@ static_assert(is_disjoint_from(None, type[object]))
static_assert(not is_disjoint_from(None, None))
static_assert(not is_disjoint_from(None, int | None))
static_assert(not is_disjoint_from(None, object))
static_assert(is_disjoint_from(Intersection[int, Not[str]], None))
static_assert(is_disjoint_from(None, Intersection[int, Not[str]]))
```
### Literals

View file

@ -580,38 +580,9 @@ impl<'db> Type<'db> {
true
}
(Type::Intersection(self_intersection), Type::Intersection(target_intersection)) => {
// Check that all target positive values are covered in self positive values
target_intersection
.positive(db)
.iter()
.all(|&target_pos_elem| {
self_intersection
.positive(db)
.iter()
.any(|&self_pos_elem| self_pos_elem.is_subtype_of(db, target_pos_elem))
})
// Check that all target negative values are excluded in self, either by being
// subtypes of a self negative value or being disjoint from a self positive value.
&& target_intersection
.negative(db)
.iter()
.all(|&target_neg_elem| {
// Is target negative value is subtype of a self negative value
self_intersection.negative(db).iter().any(|&self_neg_elem| {
target_neg_elem.is_subtype_of(db, self_neg_elem)
// Is target negative value is disjoint from a self positive value?
}) || self_intersection.positive(db).iter().any(|&self_pos_elem| {
self_pos_elem.is_disjoint_from(db, target_neg_elem)
})
})
}
(Type::Intersection(intersection), _) => intersection
.positive(db)
.iter()
.any(|&elem_ty| elem_ty.is_subtype_of(db, target)),
// If both sides are intersections we need to handle the right side first
// (A & B & C) is a subtype of (A & B) because the left is a subtype of both A and B,
// but none of A, B, or C is a subtype of (A & B).
(_, Type::Intersection(intersection)) => {
intersection
.positive(db)
@ -623,6 +594,11 @@ impl<'db> Type<'db> {
.all(|&neg_ty| self.is_disjoint_from(db, neg_ty))
}
(Type::Intersection(intersection), _) => intersection
.positive(db)
.iter()
.any(|&elem_ty| elem_ty.is_subtype_of(db, target)),
// Note that the definition of `Type::AlwaysFalsy` depends on the return value of `__bool__`.
// If `__bool__` always returns True or False, it can be treated as a subtype of `AlwaysTruthy` or `AlwaysFalsy`, respectively.
(left, Type::AlwaysFalsy) => left.bool(db).is_always_false(),
@ -799,6 +775,10 @@ impl<'db> Type<'db> {
.iter()
.any(|&elem_ty| ty.is_assignable_to(db, elem_ty)),
// If both sides are intersections we need to handle the right side first
// (A & B & C) is assignable to (A & B) because the left is assignable to both A and B,
// but none of A, B, or C is assignable to (A & B).
//
// A type S is assignable to an intersection type T if
// S is assignable to all positive elements of T (e.g. `str & int` is assignable to `str & Any`), and
// S is disjoint from all negative elements of T (e.g. `int` is not assignable to Intersection[int, Not[Literal[1]]]).
@ -995,19 +975,31 @@ impl<'db> Type<'db> {
.iter()
.all(|e| e.is_disjoint_from(db, other)),
(Type::Intersection(intersection), other)
| (other, Type::Intersection(intersection)) => {
if intersection
// If we have two intersections, we test the positive elements of each one against the other intersection
// Negative elements need a positive element on the other side in order to be disjoint.
// This is similar to what would happen if we tried to build a new intersection that combines the two
(Type::Intersection(self_intersection), Type::Intersection(other_intersection)) => {
self_intersection
.positive(db)
.iter()
.any(|p| p.is_disjoint_from(db, other))
{
true
} else {
// TODO we can do better here. For example:
// X & ~Literal[1] is disjoint from Literal[1]
false
|| other_intersection
.positive(db)
.iter()
.any(|p: &Type<'_>| p.is_disjoint_from(db, self))
}
(Type::Intersection(intersection), other)
| (other, Type::Intersection(intersection)) => {
intersection
.positive(db)
.iter()
.any(|p| p.is_disjoint_from(db, other))
// A & B & Not[C] is disjoint from C
|| intersection
.negative(db)
.iter()
.any(|&neg_ty| other.is_subtype_of(db, neg_ty))
}
// any single-valued type is disjoint from another single-valued type