[red-knot] Distribute intersections on negation (#13962)

## Summary

This does two things:
- distribute negated intersections when building up intersections (i.e.
going from `A & ~(B & C)` to `(A & ~B) | (A & ~C)`) (fixing #13931)

## Test Plan

`cargo test`
This commit is contained in:
Raphael Gaschignard 2024-10-29 12:56:04 +10:00 committed by GitHub
parent b6847b371e
commit 2fe203292a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 103 additions and 5 deletions

View file

@ -10,6 +10,9 @@ use crate::Db;
enum CoreStdlibModule {
Builtins,
Types,
// the Typing enum is currently only used in tests
#[allow(dead_code)]
Typing,
Typeshed,
TypingExtensions,
}
@ -19,6 +22,7 @@ impl CoreStdlibModule {
let module_name = match self {
Self::Builtins => "builtins",
Self::Types => "types",
Self::Typing => "typing",
Self::Typeshed => "_typeshed",
Self::TypingExtensions => "typing_extensions",
};
@ -63,6 +67,14 @@ pub(crate) fn types_symbol_ty<'db>(db: &'db dyn Db, symbol: &str) -> Type<'db> {
core_module_symbol_ty(db, CoreStdlibModule::Types, symbol)
}
/// Lookup the type of `symbol` in the `typing` module namespace.
///
/// Returns `Unbound` if the `typing` module isn't available for some reason.
#[inline]
#[allow(dead_code)] // currently only used in tests
pub(crate) fn typing_symbol_ty<'db>(db: &'db dyn Db, symbol: &str) -> Type<'db> {
core_module_symbol_ty(db, CoreStdlibModule::Typing, symbol)
}
/// Lookup the type of `symbol` in the `_typeshed` module namespace.
///
/// Returns `Unbound` if the `_typeshed` module isn't available for some reason.

View file

@ -1177,6 +1177,7 @@ impl<'db> KnownClass {
Self::GenericAlias | Self::ModuleType | Self::FunctionType => {
types_symbol_ty(db, self.as_str())
}
Self::NoneType => typeshed_symbol_ty(db, self.as_str()),
}
}

View file

@ -177,6 +177,33 @@ impl<'db> IntersectionBuilder<'db> {
self = self.add_negative(*elem);
}
self
} else if let Type::Intersection(intersection) = ty {
// (A | B) & ~(C & ~D)
// -> (A | B) & (~C | D)
// -> ((A | B) & ~C) | ((A | B) & D)
// i.e. if we have an intersection of positive constraints C
// and negative constraints D, then our new intersection
// is (existing & ~C) | (existing & D)
let positive_side = intersection
.positive(self.db)
.iter()
// we negate all the positive constraints while distributing
.map(|elem| self.clone().add_negative(*elem));
let negative_side = intersection
.negative(self.db)
.iter()
// all negative constraints end up becoming positive constraints
.map(|elem| self.clone().add_positive(*elem));
positive_side.chain(negative_side).fold(
IntersectionBuilder::empty(self.db),
|mut builder, sub| {
builder.intersections.extend(sub.intersections);
builder
},
)
} else {
for inner in &mut self.intersections {
inner.add_negative(self.db, ty);
@ -367,6 +394,7 @@ mod tests {
use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings};
use crate::python_version::PythonVersion;
use crate::stdlib::typing_symbol_ty;
use crate::types::{KnownClass, StringLiteralType, UnionBuilder};
use crate::ProgramSettings;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
@ -557,18 +585,22 @@ mod tests {
let ta = Type::Any;
let t1 = Type::IntLiteral(1);
let t2 = KnownClass::Int.to_instance(&db);
// i0 = Any & ~Literal[1]
let i0 = IntersectionBuilder::new(&db)
.add_positive(ta)
.add_negative(t1)
.build();
let intersection = IntersectionBuilder::new(&db)
// ta_not_i0 = int & ~(Any & ~Literal[1])
// -> int & (~Any | Literal[1])
// (~Any is equivalent to Any)
// -> (int & Any) | (int & Literal[1])
// -> (int & Any) | Literal[1]
let ta_not_i0 = IntersectionBuilder::new(&db)
.add_positive(t2)
.add_negative(i0)
.build()
.expect_intersection();
.build();
assert_eq!(intersection.pos_vec(&db), &[ta, t1]);
assert_eq!(intersection.neg_vec(&db), &[]);
assert_eq!(ta_not_i0.display(&db).to_string(), "int & Any | Literal[1]");
}
#[test]
@ -591,6 +623,59 @@ mod tests {
assert_eq!(i1.pos_vec(&db), &[ta, t1]);
}
#[test]
fn intersection_negation_distributes_over_union() {
let db = setup_db();
let st = typing_symbol_ty(&db, "Sized").to_instance(&db);
let ht = typing_symbol_ty(&db, "Hashable").to_instance(&db);
// sh_t: Sized & Hashable
let sh_t = IntersectionBuilder::new(&db)
.add_positive(st)
.add_positive(ht)
.build()
.expect_intersection();
assert_eq!(sh_t.pos_vec(&db), &[st, ht]);
assert_eq!(sh_t.neg_vec(&db), &[]);
// ~sh_t => ~Sized | ~Hashable
let not_s_h_t = IntersectionBuilder::new(&db)
.add_negative(Type::Intersection(sh_t))
.build()
.expect_union();
// should have as elements: (~Sized),(~Hashable)
let not_st = st.negate(&db);
let not_ht = ht.negate(&db);
assert_eq!(not_s_h_t.elements(&db), &[not_st, not_ht]);
}
#[test]
fn mixed_intersection_negation_distributes_over_union() {
let db = setup_db();
let it = KnownClass::Int.to_instance(&db);
let st = typing_symbol_ty(&db, "Sized").to_instance(&db);
let ht = typing_symbol_ty(&db, "Hashable").to_instance(&db);
// s_not_h_t: Sized & ~Hashable
let s_not_h_t = IntersectionBuilder::new(&db)
.add_positive(st)
.add_negative(ht)
.build()
.expect_intersection();
assert_eq!(s_not_h_t.pos_vec(&db), &[st]);
assert_eq!(s_not_h_t.neg_vec(&db), &[ht]);
// let's build int & ~(Sized & ~Hashable)
let tt = IntersectionBuilder::new(&db)
.add_positive(it)
.add_negative(Type::Intersection(s_not_h_t))
.build();
// int & ~(Sized & ~Hashable)
// -> int & (~Sized | Hashable)
// -> (int & ~Sized) | (int & Hashable)
assert_eq!(tt.display(&db).to_string(), "int & ~Sized | int & Hashable");
}
#[test]
fn build_intersection_self_negation() {
let db = setup_db();