[red-knot] Corrections and improvements to intersection simplification (#15475)

This commit is contained in:
Alex Waygood 2025-01-14 18:15:38 +00:00 committed by GitHub
parent 5ed7b55b15
commit bcf0a715c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 248 additions and 35 deletions

View file

@ -671,6 +671,13 @@ impl<'db> Type<'db> {
.expect("Expected a Type::IntLiteral variant")
}
pub const fn into_instance(self) -> Option<InstanceType<'db>> {
match self {
Type::Instance(instance_type) => Some(instance_type),
_ => None,
}
}
pub const fn into_known_instance(self) -> Option<KnownInstanceType<'db>> {
match self {
Type::KnownInstance(known_instance) => Some(known_instance),
@ -2557,6 +2564,10 @@ pub enum KnownClass {
}
impl<'db> KnownClass {
pub const fn is_bool(self) -> bool {
matches!(self, Self::Bool)
}
pub const fn as_str(&self) -> &'static str {
match self {
Self::Bool => "bool",

View file

@ -30,8 +30,6 @@ use crate::types::{InstanceType, IntersectionType, KnownClass, Type, UnionType};
use crate::{Db, FxOrderSet};
use smallvec::SmallVec;
use super::Truthiness;
pub(crate) struct UnionBuilder<'db> {
elements: Vec<Type<'db>>,
db: &'db dyn Db,
@ -248,7 +246,12 @@ struct InnerIntersectionBuilder<'db> {
impl<'db> InnerIntersectionBuilder<'db> {
/// Adds a positive type to this intersection.
fn add_positive(&mut self, db: &'db dyn Db, new_positive: Type<'db>) {
fn add_positive(&mut self, db: &'db dyn Db, mut new_positive: Type<'db>) {
if new_positive == Type::AlwaysTruthy && self.positive.contains(&Type::LiteralString) {
self.add_negative(db, Type::string_literal(db, ""));
return;
}
if let Type::Intersection(other) = new_positive {
for pos in other.positive(db) {
self.add_positive(db, *pos);
@ -257,25 +260,74 @@ impl<'db> InnerIntersectionBuilder<'db> {
self.add_negative(db, *neg);
}
} else {
// ~Literal[True] & bool = Literal[False]
// ~AlwaysTruthy & bool = Literal[False]
if let Type::Instance(InstanceType { class }) = new_positive {
if class.is_known(db, KnownClass::Bool) {
if let Some(new_type) = self
.negative
.iter()
.find(|element| {
element.is_boolean_literal()
| matches!(element, Type::AlwaysFalsy | Type::AlwaysTruthy)
})
.map(|element| {
Type::BooleanLiteral(element.bool(db) != Truthiness::AlwaysTrue)
})
{
*self = Self::default();
self.positive.insert(new_type);
return;
let addition_is_bool_instance = new_positive
.into_instance()
.and_then(|instance| instance.class.known(db))
.is_some_and(KnownClass::is_bool);
for (index, existing_positive) in self.positive.iter().enumerate() {
match existing_positive {
// `AlwaysTruthy & bool` -> `Literal[True]`
Type::AlwaysTruthy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(true);
}
// `AlwaysFalsy & bool` -> `Literal[False]`
Type::AlwaysFalsy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(false);
}
// `AlwaysFalsy & LiteralString` -> `Literal[""]`
Type::AlwaysFalsy if new_positive.is_literal_string() => {
new_positive = Type::string_literal(db, "");
}
Type::Instance(InstanceType { class })
if class.is_known(db, KnownClass::Bool) =>
{
match new_positive {
// `bool & AlwaysTruthy` -> `Literal[True]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(true);
}
// `bool & AlwaysFalsy` -> `Literal[False]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(false);
}
_ => continue,
}
}
// `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::LiteralString if new_positive == Type::AlwaysFalsy => {
new_positive = Type::string_literal(db, "");
}
_ => continue,
}
self.positive.swap_remove_index(index);
break;
}
if addition_is_bool_instance {
for (index, existing_negative) in self.negative.iter().enumerate() {
match existing_negative {
// `bool & ~Literal[False]` -> `Literal[True]`
// `bool & ~Literal[True]` -> `Literal[False]`
Type::BooleanLiteral(bool_value) => {
new_positive = Type::BooleanLiteral(!bool_value);
}
// `bool & ~AlwaysTruthy` -> `Literal[False]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(false);
}
// `bool & ~AlwaysFalsy` -> `Literal[True]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(true);
}
_ => continue,
}
self.negative.swap_remove_index(index);
break;
}
} else if new_positive.is_literal_string() {
if self.negative.swap_remove(&Type::AlwaysTruthy) {
new_positive = Type::string_literal(db, "");
}
}
@ -298,8 +350,8 @@ impl<'db> InnerIntersectionBuilder<'db> {
return;
}
}
for index in to_remove.iter().rev() {
self.positive.swap_remove_index(*index);
for index in to_remove.into_iter().rev() {
self.positive.swap_remove_index(index);
}
let mut to_remove = SmallVec::<[usize; 1]>::new();
@ -315,8 +367,8 @@ impl<'db> InnerIntersectionBuilder<'db> {
to_remove.push(index);
}
}
for index in to_remove.iter().rev() {
self.negative.swap_remove_index(*index);
for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(index);
}
self.positive.insert(new_positive);
@ -325,6 +377,14 @@ impl<'db> InnerIntersectionBuilder<'db> {
/// Adds a negative type to this intersection.
fn add_negative(&mut self, db: &'db dyn Db, new_negative: Type<'db>) {
let contains_bool = || {
self.positive
.iter()
.filter_map(|ty| ty.into_instance())
.filter_map(|instance| instance.class.known(db))
.any(KnownClass::is_bool)
};
match new_negative {
Type::Intersection(inter) => {
for pos in inter.positive(db) {
@ -348,15 +408,23 @@ impl<'db> InnerIntersectionBuilder<'db> {
// simplify the representation.
self.add_positive(db, ty);
}
// bool & ~Literal[True] = Literal[False]
// bool & ~AlwaysTruthy = Literal[False]
Type::BooleanLiteral(_) | Type::AlwaysFalsy | Type::AlwaysTruthy
if self.positive.contains(&KnownClass::Bool.to_instance(db)) =>
{
*self = Self::default();
self.positive.insert(Type::BooleanLiteral(
new_negative.bool(db) != Truthiness::AlwaysTrue,
));
// `bool & ~AlwaysTruthy` -> `bool & Literal[False]`
// `bool & ~Literal[True]` -> `bool & Literal[False]`
Type::AlwaysTruthy | Type::BooleanLiteral(true) if contains_bool() => {
self.add_positive(db, Type::BooleanLiteral(false));
}
// `LiteralString & ~AlwaysTruthy` -> `LiteralString & Literal[""]`
Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => {
self.add_positive(db, Type::string_literal(db, ""));
}
// `bool & ~AlwaysFalsy` -> `bool & Literal[True]`
// `bool & ~Literal[False]` -> `bool & Literal[True]`
Type::AlwaysFalsy | Type::BooleanLiteral(false) if contains_bool() => {
self.add_positive(db, Type::BooleanLiteral(true));
}
// `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]`
Type::AlwaysFalsy if self.positive.contains(&Type::LiteralString) => {
self.add_negative(db, Type::string_literal(db, ""));
}
_ => {
let mut to_remove = SmallVec::<[usize; 1]>::new();