[red-knot] Fix more edge cases for intersection simplification with LiteralString and AlwaysTruthy/AlwaysFalsy (#15496)

This commit is contained in:
Alex Waygood 2025-01-15 15:02:41 +00:00 committed by GitHub
parent 8712438aec
commit 55a7f72035
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 130 additions and 109 deletions

View file

@ -680,7 +680,7 @@ simplified, due to the fact that a `LiteralString` inhabitant is known to have `
exactly `str` (and not a subclass of `str`): exactly `str` (and not a subclass of `str`):
```py ```py
from knot_extensions import Intersection, Not, AlwaysTruthy, AlwaysFalsy from knot_extensions import Intersection, Not, AlwaysTruthy, AlwaysFalsy, Unknown
from typing_extensions import LiteralString from typing_extensions import LiteralString
def f( def f(
@ -690,6 +690,10 @@ def f(
d: Intersection[LiteralString, Not[AlwaysFalsy]], d: Intersection[LiteralString, Not[AlwaysFalsy]],
e: Intersection[AlwaysFalsy, LiteralString], e: Intersection[AlwaysFalsy, LiteralString],
f: Intersection[Not[AlwaysTruthy], LiteralString], f: Intersection[Not[AlwaysTruthy], LiteralString],
g: Intersection[AlwaysTruthy, LiteralString],
h: Intersection[Not[AlwaysFalsy], LiteralString],
i: Intersection[Unknown, LiteralString, AlwaysFalsy],
j: Intersection[Not[AlwaysTruthy], Unknown, LiteralString],
): ):
reveal_type(a) # revealed: LiteralString & ~Literal[""] reveal_type(a) # revealed: LiteralString & ~Literal[""]
reveal_type(b) # revealed: Literal[""] reveal_type(b) # revealed: Literal[""]
@ -697,6 +701,10 @@ def f(
reveal_type(d) # revealed: LiteralString & ~Literal[""] reveal_type(d) # revealed: LiteralString & ~Literal[""]
reveal_type(e) # revealed: Literal[""] reveal_type(e) # revealed: Literal[""]
reveal_type(f) # revealed: Literal[""] reveal_type(f) # revealed: Literal[""]
reveal_type(g) # revealed: LiteralString & ~Literal[""]
reveal_type(h) # revealed: LiteralString & ~Literal[""]
reveal_type(i) # revealed: Unknown & Literal[""]
reveal_type(j) # revealed: Unknown & Literal[""]
``` ```
## Addition of a type to an intersection with many non-disjoint types ## Addition of a type to an intersection with many non-disjoint types

View file

@ -247,131 +247,144 @@ struct InnerIntersectionBuilder<'db> {
impl<'db> InnerIntersectionBuilder<'db> { impl<'db> InnerIntersectionBuilder<'db> {
/// Adds a positive type to this intersection. /// Adds a positive type to this intersection.
fn add_positive(&mut self, db: &'db dyn Db, mut new_positive: Type<'db>) { fn add_positive(&mut self, db: &'db dyn Db, mut new_positive: Type<'db>) {
if new_positive == Type::AlwaysTruthy && self.positive.contains(&Type::LiteralString) { match new_positive {
self.add_negative(db, Type::string_literal(db, "")); // `LiteralString & AlwaysTruthy` -> `LiteralString & ~Literal[""]`
return; Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => {
} self.add_negative(db, Type::string_literal(db, ""));
if let Type::Intersection(other) = new_positive {
for pos in other.positive(db) {
self.add_positive(db, *pos);
} }
for neg in other.negative(db) { // `LiteralString & AlwaysFalsy` -> `Literal[""]`
self.add_negative(db, *neg); Type::AlwaysFalsy if self.positive.swap_remove(&Type::LiteralString) => {
self.add_positive(db, Type::string_literal(db, ""));
} }
} else { // `AlwaysTruthy & LiteralString` -> `LiteralString & ~Literal[""]`
let addition_is_bool_instance = new_positive Type::LiteralString if self.positive.swap_remove(&Type::AlwaysTruthy) => {
.into_instance() self.add_positive(db, Type::LiteralString);
.and_then(|instance| instance.class.known(db)) self.add_negative(db, Type::string_literal(db, ""));
.is_some_and(KnownClass::is_bool); }
// `AlwaysFalsy & LiteralString` -> `Literal[""]`
for (index, existing_positive) in self.positive.iter().enumerate() { Type::LiteralString if self.positive.swap_remove(&Type::AlwaysFalsy) => {
match existing_positive { self.add_positive(db, Type::string_literal(db, ""));
// `AlwaysTruthy & bool` -> `Literal[True]` }
Type::AlwaysTruthy if addition_is_bool_instance => { // `LiteralString & ~AlwaysTruthy` -> `LiteralString & AlwaysFalsy` -> `Literal[""]`
new_positive = Type::BooleanLiteral(true); Type::LiteralString if self.negative.swap_remove(&Type::AlwaysTruthy) => {
} self.add_positive(db, Type::string_literal(db, ""));
// `AlwaysFalsy & bool` -> `Literal[False]` }
Type::AlwaysFalsy if addition_is_bool_instance => { // `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]`
new_positive = Type::BooleanLiteral(false); Type::LiteralString if self.negative.swap_remove(&Type::AlwaysFalsy) => {
} self.add_positive(db, Type::LiteralString);
// `AlwaysFalsy & LiteralString` -> `Literal[""]` self.add_negative(db, Type::string_literal(db, ""));
Type::AlwaysFalsy if new_positive.is_literal_string() => { }
new_positive = Type::string_literal(db, ""); // `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F`
} Type::Intersection(other) => {
Type::Instance(InstanceType { class }) for pos in other.positive(db) {
if class.is_known(db, KnownClass::Bool) => self.add_positive(db, *pos);
{ }
match new_positive { for neg in other.negative(db) {
// `bool & AlwaysTruthy` -> `Literal[True]` self.add_negative(db, *neg);
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(true);
}
// `bool & AlwaysFalsy` -> `Literal[False]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(false);
}
_ => continue,
}
}
// `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::LiteralString if new_positive == Type::AlwaysFalsy => {
new_positive = Type::string_literal(db, "");
}
_ => continue,
} }
self.positive.swap_remove_index(index);
break;
} }
_ => {
let addition_is_bool_instance = new_positive
.into_instance()
.and_then(|instance| instance.class.known(db))
.is_some_and(KnownClass::is_bool);
if addition_is_bool_instance { for (index, existing_positive) in self.positive.iter().enumerate() {
for (index, existing_negative) in self.negative.iter().enumerate() { match existing_positive {
match existing_negative { // `AlwaysTruthy & bool` -> `Literal[True]`
// `bool & ~Literal[False]` -> `Literal[True]` Type::AlwaysTruthy if addition_is_bool_instance => {
// `bool & ~Literal[True]` -> `Literal[False]` new_positive = Type::BooleanLiteral(true);
Type::BooleanLiteral(bool_value) => {
new_positive = Type::BooleanLiteral(!bool_value);
} }
// `bool & ~AlwaysTruthy` -> `Literal[False]` // `AlwaysFalsy & bool` -> `Literal[False]`
Type::AlwaysTruthy => { Type::AlwaysFalsy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(false); new_positive = Type::BooleanLiteral(false);
} }
// `bool & ~AlwaysFalsy` -> `Literal[True]` Type::Instance(InstanceType { class })
Type::AlwaysFalsy => { if class.is_known(db, KnownClass::Bool) =>
new_positive = Type::BooleanLiteral(true); {
match new_positive {
// `bool & AlwaysTruthy` -> `Literal[True]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(true);
}
// `bool & AlwaysFalsy` -> `Literal[False]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(false);
}
_ => continue,
}
} }
_ => continue, _ => continue,
} }
self.negative.swap_remove_index(index); self.positive.swap_remove_index(index);
break; break;
} }
} else if new_positive.is_literal_string() {
if self.negative.swap_remove(&Type::AlwaysTruthy) {
new_positive = Type::string_literal(db, "");
}
}
let mut to_remove = SmallVec::<[usize; 1]>::new(); if addition_is_bool_instance {
for (index, existing_positive) in self.positive.iter().enumerate() { for (index, existing_negative) in self.negative.iter().enumerate() {
// S & T = S if S <: T match existing_negative {
if existing_positive.is_subtype_of(db, new_positive) // `bool & ~Literal[False]` -> `Literal[True]`
|| existing_positive.is_same_gradual_form(new_positive) // `bool & ~Literal[True]` -> `Literal[False]`
{ Type::BooleanLiteral(bool_value) => {
return; new_positive = Type::BooleanLiteral(!bool_value);
}
// `bool & ~AlwaysTruthy` -> `Literal[False]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(false);
}
// `bool & ~AlwaysFalsy` -> `Literal[True]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(true);
}
_ => continue,
}
self.negative.swap_remove_index(index);
break;
}
} }
// same rule, reverse order
if new_positive.is_subtype_of(db, *existing_positive) {
to_remove.push(index);
}
// A & B = Never if A and B are disjoint
if new_positive.is_disjoint_from(db, *existing_positive) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
}
}
for index in to_remove.into_iter().rev() {
self.positive.swap_remove_index(index);
}
let mut to_remove = SmallVec::<[usize; 1]>::new(); let mut to_remove = SmallVec::<[usize; 1]>::new();
for (index, existing_negative) in self.negative.iter().enumerate() { for (index, existing_positive) in self.positive.iter().enumerate() {
// S & ~T = Never if S <: T // S & T = S if S <: T
if new_positive.is_subtype_of(db, *existing_negative) { if existing_positive.is_subtype_of(db, new_positive)
*self = Self::default(); || existing_positive.is_same_gradual_form(new_positive)
self.positive.insert(Type::Never); {
return; return;
}
// same rule, reverse order
if new_positive.is_subtype_of(db, *existing_positive) {
to_remove.push(index);
}
// A & B = Never if A and B are disjoint
if new_positive.is_disjoint_from(db, *existing_positive) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
}
} }
// A & ~B = A if A and B are disjoint for index in to_remove.into_iter().rev() {
if existing_negative.is_disjoint_from(db, new_positive) { self.positive.swap_remove_index(index);
to_remove.push(index);
} }
}
for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(index);
}
self.positive.insert(new_positive); let mut to_remove = SmallVec::<[usize; 1]>::new();
for (index, existing_negative) in self.negative.iter().enumerate() {
// S & ~T = Never if S <: T
if new_positive.is_subtype_of(db, *existing_negative) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
}
// A & ~B = A if A and B are disjoint
if existing_negative.is_disjoint_from(db, new_positive) {
to_remove.push(index);
}
}
for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(index);
}
self.positive.insert(new_positive);
}
} }
} }
@ -438,8 +451,8 @@ impl<'db> InnerIntersectionBuilder<'db> {
return; return;
} }
} }
for index in to_remove.iter().rev() { for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(*index); self.negative.swap_remove_index(index);
} }
for existing_positive in &self.positive { for existing_positive in &self.positive {