[red-knot] Corrections and improvements to intersection simplification (#15475)

This commit is contained in:
Alex Waygood 2025-01-14 18:15:38 +00:00 committed by GitHub
parent 5ed7b55b15
commit bcf0a715c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 248 additions and 35 deletions

View file

@ -635,6 +635,83 @@ def _(
reveal_type(i8) # revealed: Never reveal_type(i8) # revealed: Never
``` ```
### Simplifications of `bool`, `AlwaysTruthy` and `AlwaysFalsy`
In general, intersections with `AlwaysTruthy` and `AlwaysFalsy` cannot be simplified. Naively, you
might think that `int & AlwaysFalsy` could simplify to `Literal[0]`, but this is not the case: for
example, the `False` constant inhabits the type `int & AlwaysFalsy` (due to the fact that
`False.__class__` is `bool` at runtime, and `bool` subclasses `int`), but `False` does not inhabit
the type `Literal[0]`.
Nonetheless, intersections of `AlwaysFalsy` or `AlwaysTruthy` with `bool` _can_ be simplified, due
to the fact that `bool` is a `@final` class at runtime that cannot be subclassed.
```py
from knot_extensions import Intersection, Not, AlwaysTruthy, AlwaysFalsy
class P: ...
def f(
a: Intersection[bool, AlwaysTruthy],
b: Intersection[bool, AlwaysFalsy],
c: Intersection[bool, Not[AlwaysTruthy]],
d: Intersection[bool, Not[AlwaysFalsy]],
e: Intersection[bool, AlwaysTruthy, P],
f: Intersection[bool, AlwaysFalsy, P],
g: Intersection[bool, Not[AlwaysTruthy], P],
h: Intersection[bool, Not[AlwaysFalsy], P],
):
reveal_type(a) # revealed: Literal[True]
reveal_type(b) # revealed: Literal[False]
reveal_type(c) # revealed: Literal[False]
reveal_type(d) # revealed: Literal[True]
# `bool & AlwaysTruthy & P` -> `Literal[True] & P` -> `Never`
reveal_type(e) # revealed: Never
reveal_type(f) # revealed: Never
reveal_type(g) # revealed: Never
reveal_type(h) # revealed: Never
```
## Simplification of `LiteralString`, `AlwaysTruthy` and `AlwaysFalsy`
Similarly, intersections between `LiteralString`, `AlwaysTruthy` and `AlwaysFalsy` can be
simplified, due to the fact that a `LiteralString` inhabitant is known to have `__class__` set to
exactly `str` (and not a subclass of `str`):
```py
from knot_extensions import Intersection, Not, AlwaysTruthy, AlwaysFalsy
from typing_extensions import LiteralString
def f(
a: Intersection[LiteralString, AlwaysTruthy],
b: Intersection[LiteralString, AlwaysFalsy],
c: Intersection[LiteralString, Not[AlwaysTruthy]],
d: Intersection[LiteralString, Not[AlwaysFalsy]],
e: Intersection[AlwaysFalsy, LiteralString],
f: Intersection[Not[AlwaysTruthy], LiteralString],
):
reveal_type(a) # revealed: LiteralString & ~Literal[""]
reveal_type(b) # revealed: Literal[""]
reveal_type(c) # revealed: Literal[""]
reveal_type(d) # revealed: LiteralString & ~Literal[""]
reveal_type(e) # revealed: Literal[""]
reveal_type(f) # revealed: Literal[""]
```
## Addition of a type to an intersection with many non-disjoint types
This slightly strange-looking test is a regression test for a mistake that was nearly made in a PR:
<https://github.com/astral-sh/ruff/pull/15475#discussion_r1915041987>.
```py
from knot_extensions import AlwaysFalsy, Intersection, Unknown
from typing_extensions import Literal
def _(x: Intersection[str, Unknown, AlwaysFalsy, Literal[""]]):
reveal_type(x) # revealed: Unknown & Literal[""]
```
## Non fully-static types ## Non fully-static types
### Negation of dynamic types ### Negation of dynamic types

View file

@ -181,3 +181,43 @@ def _(x: object, y: type[int]):
if isinstance(x, y): if isinstance(x, y):
reveal_type(x) # revealed: int reveal_type(x) # revealed: int
``` ```
## Adding a disjoint element to an existing intersection
We used to incorrectly infer `Literal` booleans for some of these.
```py
from knot_extensions import Not, Intersection, AlwaysTruthy, AlwaysFalsy
class P: ...
def f(
a: Intersection[P, AlwaysTruthy],
b: Intersection[P, AlwaysFalsy],
c: Intersection[P, Not[AlwaysTruthy]],
d: Intersection[P, Not[AlwaysFalsy]],
):
if isinstance(a, bool):
reveal_type(a) # revealed: Never
else:
# TODO: `bool` is final, so `& ~bool` is redundant here
reveal_type(a) # revealed: P & AlwaysTruthy & ~bool
if isinstance(b, bool):
reveal_type(b) # revealed: Never
else:
# TODO: `bool` is final, so `& ~bool` is redundant here
reveal_type(b) # revealed: P & AlwaysFalsy & ~bool
if isinstance(c, bool):
reveal_type(c) # revealed: Never
else:
# TODO: `bool` is final, so `& ~bool` is redundant here
reveal_type(c) # revealed: P & ~AlwaysTruthy & ~bool
if isinstance(d, bool):
reveal_type(d) # revealed: Never
else:
# TODO: `bool` is final, so `& ~bool` is redundant here
reveal_type(d) # revealed: P & ~AlwaysFalsy & ~bool
```

View file

@ -199,7 +199,7 @@ def f(x: Literal[0, 1], y: Literal["", "hello"]):
reveal_type(y) # revealed: Literal["", "hello"] reveal_type(y) # revealed: Literal["", "hello"]
``` ```
## ControlFlow Merging ## Control Flow Merging
After merging control flows, when we take the union of all constraints applied in each branch, we After merging control flows, when we take the union of all constraints applied in each branch, we
should return to the original state. should return to the original state.
@ -312,3 +312,20 @@ def _(x: type[FalsyClass] | type[TruthyClass]):
reveal_type(x or A()) # revealed: type[TruthyClass] | A reveal_type(x or A()) # revealed: type[TruthyClass] | A
reveal_type(x and A()) # revealed: type[FalsyClass] | A reveal_type(x and A()) # revealed: type[FalsyClass] | A
``` ```
## Truthiness narrowing for `LiteralString`
```py
from typing_extensions import LiteralString
def _(x: LiteralString):
if x:
reveal_type(x) # revealed: LiteralString & ~Literal[""]
else:
reveal_type(x) # revealed: Literal[""]
if not x:
reveal_type(x) # revealed: Literal[""]
else:
reveal_type(x) # revealed: LiteralString & ~Literal[""]
```

View file

@ -671,6 +671,13 @@ impl<'db> Type<'db> {
.expect("Expected a Type::IntLiteral variant") .expect("Expected a Type::IntLiteral variant")
} }
pub const fn into_instance(self) -> Option<InstanceType<'db>> {
match self {
Type::Instance(instance_type) => Some(instance_type),
_ => None,
}
}
pub const fn into_known_instance(self) -> Option<KnownInstanceType<'db>> { pub const fn into_known_instance(self) -> Option<KnownInstanceType<'db>> {
match self { match self {
Type::KnownInstance(known_instance) => Some(known_instance), Type::KnownInstance(known_instance) => Some(known_instance),
@ -2557,6 +2564,10 @@ pub enum KnownClass {
} }
impl<'db> KnownClass { impl<'db> KnownClass {
pub const fn is_bool(self) -> bool {
matches!(self, Self::Bool)
}
pub const fn as_str(&self) -> &'static str { pub const fn as_str(&self) -> &'static str {
match self { match self {
Self::Bool => "bool", Self::Bool => "bool",

View file

@ -30,8 +30,6 @@ use crate::types::{InstanceType, IntersectionType, KnownClass, Type, UnionType};
use crate::{Db, FxOrderSet}; use crate::{Db, FxOrderSet};
use smallvec::SmallVec; use smallvec::SmallVec;
use super::Truthiness;
pub(crate) struct UnionBuilder<'db> { pub(crate) struct UnionBuilder<'db> {
elements: Vec<Type<'db>>, elements: Vec<Type<'db>>,
db: &'db dyn Db, db: &'db dyn Db,
@ -248,7 +246,12 @@ struct InnerIntersectionBuilder<'db> {
impl<'db> InnerIntersectionBuilder<'db> { impl<'db> InnerIntersectionBuilder<'db> {
/// Adds a positive type to this intersection. /// Adds a positive type to this intersection.
fn add_positive(&mut self, db: &'db dyn Db, new_positive: Type<'db>) { fn add_positive(&mut self, db: &'db dyn Db, mut new_positive: Type<'db>) {
if new_positive == Type::AlwaysTruthy && self.positive.contains(&Type::LiteralString) {
self.add_negative(db, Type::string_literal(db, ""));
return;
}
if let Type::Intersection(other) = new_positive { if let Type::Intersection(other) = new_positive {
for pos in other.positive(db) { for pos in other.positive(db) {
self.add_positive(db, *pos); self.add_positive(db, *pos);
@ -257,25 +260,74 @@ impl<'db> InnerIntersectionBuilder<'db> {
self.add_negative(db, *neg); self.add_negative(db, *neg);
} }
} else { } else {
// ~Literal[True] & bool = Literal[False] let addition_is_bool_instance = new_positive
// ~AlwaysTruthy & bool = Literal[False] .into_instance()
if let Type::Instance(InstanceType { class }) = new_positive { .and_then(|instance| instance.class.known(db))
if class.is_known(db, KnownClass::Bool) { .is_some_and(KnownClass::is_bool);
if let Some(new_type) = self
.negative for (index, existing_positive) in self.positive.iter().enumerate() {
.iter() match existing_positive {
.find(|element| { // `AlwaysTruthy & bool` -> `Literal[True]`
element.is_boolean_literal() Type::AlwaysTruthy if addition_is_bool_instance => {
| matches!(element, Type::AlwaysFalsy | Type::AlwaysTruthy) new_positive = Type::BooleanLiteral(true);
})
.map(|element| {
Type::BooleanLiteral(element.bool(db) != Truthiness::AlwaysTrue)
})
{
*self = Self::default();
self.positive.insert(new_type);
return;
} }
// `AlwaysFalsy & bool` -> `Literal[False]`
Type::AlwaysFalsy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(false);
}
// `AlwaysFalsy & LiteralString` -> `Literal[""]`
Type::AlwaysFalsy if new_positive.is_literal_string() => {
new_positive = Type::string_literal(db, "");
}
Type::Instance(InstanceType { class })
if class.is_known(db, KnownClass::Bool) =>
{
match new_positive {
// `bool & AlwaysTruthy` -> `Literal[True]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(true);
}
// `bool & AlwaysFalsy` -> `Literal[False]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(false);
}
_ => continue,
}
}
// `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::LiteralString if new_positive == Type::AlwaysFalsy => {
new_positive = Type::string_literal(db, "");
}
_ => continue,
}
self.positive.swap_remove_index(index);
break;
}
if addition_is_bool_instance {
for (index, existing_negative) in self.negative.iter().enumerate() {
match existing_negative {
// `bool & ~Literal[False]` -> `Literal[True]`
// `bool & ~Literal[True]` -> `Literal[False]`
Type::BooleanLiteral(bool_value) => {
new_positive = Type::BooleanLiteral(!bool_value);
}
// `bool & ~AlwaysTruthy` -> `Literal[False]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(false);
}
// `bool & ~AlwaysFalsy` -> `Literal[True]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(true);
}
_ => continue,
}
self.negative.swap_remove_index(index);
break;
}
} else if new_positive.is_literal_string() {
if self.negative.swap_remove(&Type::AlwaysTruthy) {
new_positive = Type::string_literal(db, "");
} }
} }
@ -298,8 +350,8 @@ impl<'db> InnerIntersectionBuilder<'db> {
return; return;
} }
} }
for index in to_remove.iter().rev() { for index in to_remove.into_iter().rev() {
self.positive.swap_remove_index(*index); self.positive.swap_remove_index(index);
} }
let mut to_remove = SmallVec::<[usize; 1]>::new(); let mut to_remove = SmallVec::<[usize; 1]>::new();
@ -315,8 +367,8 @@ impl<'db> InnerIntersectionBuilder<'db> {
to_remove.push(index); to_remove.push(index);
} }
} }
for index in to_remove.iter().rev() { for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(*index); self.negative.swap_remove_index(index);
} }
self.positive.insert(new_positive); self.positive.insert(new_positive);
@ -325,6 +377,14 @@ impl<'db> InnerIntersectionBuilder<'db> {
/// Adds a negative type to this intersection. /// Adds a negative type to this intersection.
fn add_negative(&mut self, db: &'db dyn Db, new_negative: Type<'db>) { fn add_negative(&mut self, db: &'db dyn Db, new_negative: Type<'db>) {
let contains_bool = || {
self.positive
.iter()
.filter_map(|ty| ty.into_instance())
.filter_map(|instance| instance.class.known(db))
.any(KnownClass::is_bool)
};
match new_negative { match new_negative {
Type::Intersection(inter) => { Type::Intersection(inter) => {
for pos in inter.positive(db) { for pos in inter.positive(db) {
@ -348,15 +408,23 @@ impl<'db> InnerIntersectionBuilder<'db> {
// simplify the representation. // simplify the representation.
self.add_positive(db, ty); self.add_positive(db, ty);
} }
// bool & ~Literal[True] = Literal[False] // `bool & ~AlwaysTruthy` -> `bool & Literal[False]`
// bool & ~AlwaysTruthy = Literal[False] // `bool & ~Literal[True]` -> `bool & Literal[False]`
Type::BooleanLiteral(_) | Type::AlwaysFalsy | Type::AlwaysTruthy Type::AlwaysTruthy | Type::BooleanLiteral(true) if contains_bool() => {
if self.positive.contains(&KnownClass::Bool.to_instance(db)) => self.add_positive(db, Type::BooleanLiteral(false));
{ }
*self = Self::default(); // `LiteralString & ~AlwaysTruthy` -> `LiteralString & Literal[""]`
self.positive.insert(Type::BooleanLiteral( Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => {
new_negative.bool(db) != Truthiness::AlwaysTrue, self.add_positive(db, Type::string_literal(db, ""));
)); }
// `bool & ~AlwaysFalsy` -> `bool & Literal[True]`
// `bool & ~Literal[False]` -> `bool & Literal[True]`
Type::AlwaysFalsy | Type::BooleanLiteral(false) if contains_bool() => {
self.add_positive(db, Type::BooleanLiteral(true));
}
// `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]`
Type::AlwaysFalsy if self.positive.contains(&Type::LiteralString) => {
self.add_negative(db, Type::string_literal(db, ""));
} }
_ => { _ => {
let mut to_remove = SmallVec::<[usize; 1]>::new(); let mut to_remove = SmallVec::<[usize; 1]>::new();