diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md b/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md index 3239c9c2f1..fb3254f87a 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md @@ -305,9 +305,9 @@ range. ```py def _[T]() -> None: - # revealed: ty_extensions.ConstraintSet[((SubSub ≤ T@_ ≤ Base) ∧ ¬(Sub ≤ T@_ ≤ Base))] + # revealed: ty_extensions.ConstraintSet[(¬(Sub ≤ T@_ ≤ Base) ∧ (SubSub ≤ T@_ ≤ Base))] reveal_type(range_constraint(SubSub, T, Base) & negated_range_constraint(Sub, T, Super)) - # revealed: ty_extensions.ConstraintSet[((SubSub ≤ T@_ ≤ Super) ∧ ¬(Sub ≤ T@_ ≤ Base))] + # revealed: ty_extensions.ConstraintSet[(¬(Sub ≤ T@_ ≤ Base) ∧ (SubSub ≤ T@_ ≤ Super))] reveal_type(range_constraint(SubSub, T, Super) & negated_range_constraint(Sub, T, Base)) ``` @@ -339,9 +339,9 @@ Otherwise, the union cannot be simplified. ```py def _[T]() -> None: - # revealed: ty_extensions.ConstraintSet[(¬(Sub ≤ T@_ ≤ Base) ∧ ¬(Base ≤ T@_ ≤ Super))] + # revealed: ty_extensions.ConstraintSet[(¬(Base ≤ T@_ ≤ Super) ∧ ¬(Sub ≤ T@_ ≤ Base))] reveal_type(negated_range_constraint(Sub, T, Base) & negated_range_constraint(Base, T, Super)) - # revealed: ty_extensions.ConstraintSet[(¬(SubSub ≤ T@_ ≤ Sub) ∧ ¬(Base ≤ T@_ ≤ Super))] + # revealed: ty_extensions.ConstraintSet[(¬(Base ≤ T@_ ≤ Super) ∧ ¬(SubSub ≤ T@_ ≤ Sub))] reveal_type(negated_range_constraint(SubSub, T, Sub) & negated_range_constraint(Base, T, Super)) # revealed: ty_extensions.ConstraintSet[(¬(SubSub ≤ T@_ ≤ Sub) ∧ ¬(Unrelated ≤ T@_))] reveal_type(negated_range_constraint(SubSub, T, Sub) & negated_range_constraint(Unrelated, T, object)) @@ -385,7 +385,7 @@ We cannot simplify the union of constraints that refer to different typevars. def _[T, U]() -> None: # revealed: ty_extensions.ConstraintSet[(Sub ≤ T@_ ≤ Base) ∨ (Sub ≤ U@_ ≤ Base)] reveal_type(range_constraint(Sub, T, Base) | range_constraint(Sub, U, Base)) - # revealed: ty_extensions.ConstraintSet[¬(Sub ≤ T@_ ≤ Base) ∨ ¬(Sub ≤ U@_ ≤ Base)] + # revealed: ty_extensions.ConstraintSet[¬(Sub ≤ U@_ ≤ Base) ∨ ¬(Sub ≤ T@_ ≤ Base)] reveal_type(negated_range_constraint(Sub, T, Base) | negated_range_constraint(Sub, U, Base)) ``` @@ -417,9 +417,9 @@ Otherwise, the union cannot be simplified. ```py def _[T]() -> None: - # revealed: ty_extensions.ConstraintSet[(Sub ≤ T@_ ≤ Base) ∨ (Base ≤ T@_ ≤ Super)] + # revealed: ty_extensions.ConstraintSet[(Base ≤ T@_ ≤ Super) ∨ (Sub ≤ T@_ ≤ Base)] reveal_type(range_constraint(Sub, T, Base) | range_constraint(Base, T, Super)) - # revealed: ty_extensions.ConstraintSet[(SubSub ≤ T@_ ≤ Sub) ∨ (Base ≤ T@_ ≤ Super)] + # revealed: ty_extensions.ConstraintSet[(Base ≤ T@_ ≤ Super) ∨ (SubSub ≤ T@_ ≤ Sub)] reveal_type(range_constraint(SubSub, T, Sub) | range_constraint(Base, T, Super)) # revealed: ty_extensions.ConstraintSet[(SubSub ≤ T@_ ≤ Sub) ∨ (Unrelated ≤ T@_)] reveal_type(range_constraint(SubSub, T, Sub) | range_constraint(Unrelated, T, object)) @@ -488,9 +488,9 @@ range. ```py def _[T]() -> None: - # revealed: ty_extensions.ConstraintSet[¬(SubSub ≤ T@_ ≤ Base) ∨ (Sub ≤ T@_ ≤ Base)] + # revealed: ty_extensions.ConstraintSet[(Sub ≤ T@_ ≤ Base) ∨ ¬(SubSub ≤ T@_ ≤ Base)] reveal_type(negated_range_constraint(SubSub, T, Base) | range_constraint(Sub, T, Super)) - # revealed: ty_extensions.ConstraintSet[¬(SubSub ≤ T@_ ≤ Super) ∨ (Sub ≤ T@_ ≤ Base)] + # revealed: ty_extensions.ConstraintSet[(Sub ≤ T@_ ≤ Base) ∨ ¬(SubSub ≤ T@_ ≤ Super)] reveal_type(negated_range_constraint(SubSub, T, Super) | range_constraint(Sub, T, Base)) ``` @@ -562,3 +562,42 @@ def _[T]() -> None: # revealed: ty_extensions.ConstraintSet[always] reveal_type(constraint | ~constraint) ``` + +### Negation of constraints involving two variables + +```py +from typing import final, Never +from ty_extensions import range_constraint + +class Base: ... + +@final +class Unrelated: ... + +def _[T, U]() -> None: + # revealed: ty_extensions.ConstraintSet[¬(U@_ ≤ Base) ∨ ¬(T@_ ≤ Base)] + reveal_type(~(range_constraint(Never, T, Base) & range_constraint(Never, U, Base))) +``` + +The union of a constraint and its negation should always be satisfiable. + +```py +def _[T, U]() -> None: + c1 = range_constraint(Never, T, Base) & range_constraint(Never, U, Base) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(c1 | ~c1) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(~c1 | c1) + + c2 = range_constraint(Unrelated, T, object) & range_constraint(Unrelated, U, object) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(c2 | ~c2) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(~c2 | c2) + + union = c1 | c2 + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(union | ~union) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(~union | union) +``` diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 7087f01301..cb4e504196 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -14,26 +14,16 @@ //! This module provides the machinery for representing the "under what constraints" part of that //! question. //! -//! An individual constraint restricts the specialization of a single typevar. You can then build -//! up more complex constraint sets using union, intersection, and negation operations. We use a -//! disjunctive normal form (DNF) representation, just like we do for types: a [constraint -//! set][ConstraintSet] is the union of zero or more [clauses][ConstraintClause], each of which is -//! the intersection of zero or more [individual constraints][ConstrainedTypeVar]. Note that the -//! constraint set that contains no clauses is never satisfiable (`⋃ {} = 0`); and the constraint -//! set that contains a single clause, where that clause contains no constraints, is always -//! satisfiable (`⋃ {⋂ {}} = 1`). -//! -//! An individual constraint consists of a _positive range_ and zero or more _negative holes_. The -//! positive range and each negative hole consists of a lower and upper bound. A type is within a -//! lower and upper bound if it is a supertype of the lower bound and a subtype of the upper bound. -//! The typevar can specialize to any type that is within the positive range, and is not within any -//! of the negative holes. (You can think of the constraint as the set of types that are within the -//! positive range, with the negative holes "removed" from that set.) +//! An individual constraint restricts the specialization of a single typevar to be within a +//! particular lower and upper bound. (A type is within a lower and upper bound if it is a +//! supertype of the lower bound and a subtype of the upper bound.) You can then build up more +//! complex constraint sets using union, intersection, and negation operations. We use a [binary +//! decision diagram][bdd] (BDD) to represent a constraint set. //! //! Note that all lower and upper bounds in a constraint must be fully static. We take the bottom //! and top materializations of the types to remove any gradual forms if needed. //! -//! NOTE: This module is currently in a transitional state. We've added the DNF [`ConstraintSet`] +//! NOTE: This module is currently in a transitional state. We've added the BDD [`ConstraintSet`] //! representation, and updated all of our property checks to build up a constraint set and then //! check whether it is ever or always satisfiable, as appropriate. We are not yet inferring //! specializations from those constraints. @@ -60,23 +50,18 @@ //! constraint `(int ≤ T ≤ int) ∪ (str ≤ T ≤ str)`. When the lower and upper bounds are the same, //! the constraint says that the typevar must specialize to that _exact_ type, not to a subtype or //! supertype of it. +//! +//! [bdd]: https://en.wikipedia.org/wiki/Binary_decision_diagram +use std::cmp::Ordering; use std::fmt::Display; -use itertools::{EitherOrBoth, Itertools}; -use smallvec::{SmallVec, smallvec}; +use itertools::Itertools; +use rustc_hash::FxHashSet; use crate::Db; use crate::types::{BoundTypeVarInstance, IntersectionType, Type, UnionType}; -fn comparable<'db>(db: &'db dyn Db, left: Type<'db>, right: Type<'db>) -> bool { - left.is_subtype_of(db, right) || right.is_subtype_of(db, left) -} - -fn incomparable<'db>(db: &'db dyn Db, left: Type<'db>, right: Type<'db>) -> bool { - !comparable(db, left, right) -} - /// An extension trait for building constraint sets from [`Option`] values. pub(crate) trait OptionConstraintsExtension { /// Returns a constraint set that is always satisfiable if the option is `None`; otherwise @@ -154,7 +139,7 @@ where ) -> ConstraintSet<'db> { let mut result = ConstraintSet::always(); for child in self { - if result.intersect(db, &f(child)).is_never_satisfied() { + if result.intersect(db, f(child)).is_never_satisfied() { return result; } } @@ -164,66 +149,55 @@ where /// A set of constraints under which a type property holds. /// -/// We use a DNF representation, so a set contains a list of zero or more -/// [clauses][ConstraintClause], each of which is an intersection of zero or more -/// [constraints][ConstrainedTypeVar]. -/// /// This is called a "set of constraint sets", and denoted _𝒮_, in [[POPL2015][]]. /// -/// ### Invariants -/// -/// - The clauses are simplified as much as possible — there are no two clauses in the set that can -/// be simplified into a single clause. -/// /// [POPL2015]: https://doi.org/10.1145/2676726.2676991 -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] pub struct ConstraintSet<'db> { - // NOTE: We use 2 here because there are a couple of places where we create unions of 2 clauses - // as temporary values — in particular when negating a constraint — and this lets us avoid - // spilling the temporary value to the heap. - clauses: SmallVec<[ConstraintClause<'db>; 2]>, + /// The BDD representing this constraint set + node: Node<'db>, } impl<'db> ConstraintSet<'db> { fn never() -> Self { Self { - clauses: smallvec![], + node: Node::AlwaysFalse, } } fn always() -> Self { - Self::singleton(ConstraintClause::always()) + Self { + node: Node::AlwaysTrue, + } } /// Returns whether this constraint set never holds - pub(crate) fn is_never_satisfied(&self) -> bool { - self.clauses.is_empty() + pub(crate) fn is_never_satisfied(self) -> bool { + self.node.is_never_satisfied() } /// Returns whether this constraint set always holds - pub(crate) fn is_always_satisfied(&self) -> bool { - self.clauses.len() == 1 && self.clauses[0].is_always() + pub(crate) fn is_always_satisfied(self) -> bool { + self.node.is_always_satisfied() } /// Updates this constraint set to hold the union of itself and another constraint set. - pub(crate) fn union(&mut self, db: &'db dyn Db, other: Self) -> &Self { - self.union_set(db, other); - self + pub(crate) fn union(&mut self, db: &'db dyn Db, other: Self) -> Self { + self.node = self.node.or(db, other.node).simplify(db); + *self } /// Updates this constraint set to hold the intersection of itself and another constraint set. - pub(crate) fn intersect(&mut self, db: &'db dyn Db, other: &Self) -> &Self { - self.intersect_set(db, other); - self + pub(crate) fn intersect(&mut self, db: &'db dyn Db, other: Self) -> Self { + self.node = self.node.and(db, other.node).simplify(db); + *self } /// Returns the negation of this constraint set. - pub(crate) fn negate(&self, db: &'db dyn Db) -> Self { - let mut result = Self::always(); - for clause in &self.clauses { - result.intersect_set(db, &clause.negate(db)); + pub(crate) fn negate(self, db: &'db dyn Db) -> Self { + Self { + node: self.node.negate(db).simplify(db), } - result } /// Returns the intersection of this constraint set and another. The other constraint set is @@ -231,7 +205,7 @@ impl<'db> ConstraintSet<'db> { /// constraint set is already saturated. pub(crate) fn and(mut self, db: &'db dyn Db, other: impl FnOnce() -> Self) -> Self { if !self.is_never_satisfied() { - self.intersect(db, &other()); + self.intersect(db, other()); } self } @@ -246,13 +220,6 @@ impl<'db> ConstraintSet<'db> { self } - /// Returns a constraint set that contains a single clause. - fn singleton(clause: ConstraintClause<'db>) -> Self { - Self { - clauses: smallvec![clause], - } - } - pub(crate) fn range( db: &'db dyn Db, lower: Type<'db>, @@ -261,10 +228,9 @@ impl<'db> ConstraintSet<'db> { ) -> Self { let lower = lower.bottom_materialization(db); let upper = upper.top_materialization(db); - let constraint = Constraint::range(db, lower, upper).constrain(typevar); - let mut result = Self::never(); - result.union_constraint(db, constraint); - result + Self { + node: ConstrainedTypeVar::new_node(db, lower, typevar, upper), + } } pub(crate) fn negated_range( @@ -273,130 +239,11 @@ impl<'db> ConstraintSet<'db> { typevar: BoundTypeVarInstance<'db>, upper: Type<'db>, ) -> Self { - let lower = lower.bottom_materialization(db); - let upper = upper.top_materialization(db); - let constraint = Constraint::negated_range(db, lower, upper).constrain(typevar); - let mut result = Self::never(); - result.union_constraint(db, constraint); - result + Self::range(db, lower, typevar, upper).negate(db) } - /// Updates this set to be the union of itself and a constraint. - fn union_constraint( - &mut self, - db: &'db dyn Db, - constraint: Satisfiable>, - ) { - self.union_clause(db, constraint.map(ConstraintClause::singleton)); - } - - /// Updates this set to be the union of itself and a clause. To maintain the invariants of this - /// type, we must simplify this clause against all existing clauses, if possible. - fn union_clause(&mut self, db: &'db dyn Db, clause: Satisfiable>) { - let mut clause = match clause { - // If the new constraint can always be satisfied, that causes this whole set to be - // always satisfied too. - Satisfiable::Always => { - self.clauses.clear(); - self.clauses.push(ConstraintClause::always()); - return; - } - - // If the new clause can never satisfied, then the set does not change. - Satisfiable::Never => return, - - Satisfiable::Constrained(clause) => clause, - }; - - // Naively, we would just append the new clause to the set's list of clauses. But that - // doesn't ensure that the clauses are simplified with respect to each other. So instead, - // we iterate through the list of existing clauses, and try to simplify the new clause - // against each one in turn. (We can assume that the existing clauses are already - // simplified with respect to each other, since we can assume that the invariant holds upon - // entry to this method.) - let mut existing_clauses = std::mem::take(&mut self.clauses).into_iter(); - for existing in existing_clauses.by_ref() { - // Try to simplify the new clause against an existing clause. - match existing.simplify_clauses(db, clause) { - Simplifiable::NeverSatisfiable => { - // If two clauses cancel out to 0, that does NOT cause the entire set to become - // 0. We need to keep whatever clauses have already been added to the result, - // and also need to copy over any later clauses that we hadn't processed yet. - self.clauses.extend(existing_clauses); - return; - } - - Simplifiable::AlwaysSatisfiable => { - // If two clauses cancel out to 1, that makes the entire set 1, and all - // existing clauses are simplified away. - self.clauses.clear(); - self.clauses.push(ConstraintClause::always()); - return; - } - - Simplifiable::NotSimplified(existing, c) => { - // We couldn't simplify the new clause relative to this existing clause, so add - // the existing clause to the result. Continue trying to simplify the new - // clause against the later existing clauses. - self.clauses.push(existing); - clause = c; - } - - Simplifiable::Simplified(c) => { - // We were able to simplify the new clause relative to this existing clause. - // Don't add it to the result yet; instead, try to simplify the result further - // against later existing clauses. - clause = c; - } - } - } - - // If we fall through then we need to add the new clause to the clause list (either because - // we couldn't simplify it with anything, or because we did without it canceling out). - self.clauses.push(clause); - } - - /// Updates this set to be the union of itself and another set. - fn union_set(&mut self, db: &'db dyn Db, other: Self) { - for clause in other.clauses { - self.union_clause(db, Satisfiable::Constrained(clause)); - } - } - - /// Updates this set to be the intersection of itself and another set. - fn intersect_set(&mut self, db: &'db dyn Db, other: &Self) { - // This is the distributive law: - // (A ∪ B) ∩ (C ∪ D ∪ E) = (A ∩ C) ∪ (A ∩ D) ∪ (A ∩ E) ∪ (B ∩ C) ∪ (B ∩ D) ∪ (B ∩ E) - let self_clauses = std::mem::take(&mut self.clauses); - for self_clause in &self_clauses { - for other_clause in &other.clauses { - self.union_clause(db, self_clause.intersect_clause(db, other_clause)); - } - } - } - - pub(crate) fn display(&self, db: &'db dyn Db) -> impl Display { - struct DisplayConstraintSet<'a, 'db> { - set: &'a ConstraintSet<'db>, - db: &'db dyn Db, - } - - impl Display for DisplayConstraintSet<'_, '_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.set.clauses.is_empty() { - return f.write_str("0"); - } - for (i, clause) in self.set.clauses.iter().enumerate() { - if i > 0 { - f.write_str(" ∨ ")?; - } - clause.display(self.db).fmt(f)?; - } - Ok(()) - } - } - - DisplayConstraintSet { set: self, db } + pub(crate) fn display(self, db: &'db dyn Db) -> impl Display { + self.node.display(db) } } @@ -406,823 +253,68 @@ impl From for ConstraintSet<'_> { } } -/// The intersection of zero or more individual constraints. -/// -/// This is called a "constraint set", and denoted _C_, in [[POPL2015][]]. -/// -/// [POPL2015]: https://doi.org/10.1145/2676726.2676991 -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] -pub(crate) struct ConstraintClause<'db> { - // NOTE: We use 1 here because most clauses only mention a single typevar. - constraints: SmallVec<[ConstrainedTypeVar<'db>; 1]>, -} - -impl<'db> ConstraintClause<'db> { - fn new(constraints: SmallVec<[ConstrainedTypeVar<'db>; 1]>) -> Satisfiable { - if constraints.is_empty() { - Satisfiable::Always - } else { - Satisfiable::Constrained(Self { constraints }) - } - } - - /// Returns the clause that is always satisfiable. - fn always() -> Self { - Self { - constraints: smallvec![], - } - } - - /// Returns a clause containing a single constraint. - fn singleton(constraint: ConstrainedTypeVar<'db>) -> Self { - Self { - constraints: smallvec![constraint], - } - } - - /// Returns whether this constraint is always satisfiable. - fn is_always(&self) -> bool { - self.constraints.is_empty() - } - - fn is_satisfiable(&self) -> Satisfiable<()> { - if self.is_always() { - Satisfiable::Always - } else { - Satisfiable::Constrained(()) - } - } - - /// Updates this clause to be the intersection of itself and an individual constraint. Returns - /// a flag indicating whether the updated clause is never, always, or sometimes satisfied. - fn intersect_constraint( - &mut self, - db: &'db dyn Db, - constraint: Satisfiable>, - ) -> Satisfiable<()> { - let mut constraint = match constraint { - // If the new constraint cannot be satisfied, that causes this whole clause to be - // unsatisfiable too. - Satisfiable::Never => return Satisfiable::Never, - - // If the new constraint can always satisfied, then the clause does not change. It was - // not always satisfiable before, and so it still isn't. - Satisfiable::Always => return Satisfiable::Constrained(()), - - Satisfiable::Constrained(constraint) => constraint, - }; - - // Naively, we would just append the new constraint to the clauses's list of constraints. - // But that doesn't ensure that the constraints are simplified with respect to each other. - // So instead, we iterate through the list of existing constraints, and try to simplify the - // new constraint against each one in turn. (We can assume that the existing constraints - // are already simplified with respect to each other, since we can assume that the - // invariant holds upon entry to this method.) - let mut existing_constraints = std::mem::take(&mut self.constraints).into_iter(); - for existing in existing_constraints.by_ref() { - // Try to simplify the new constraint against an existing constraint. - match existing.intersect(db, &constraint) { - Some(Satisfiable::Never) => { - // If two constraints cancel out to 0, that makes the entire clause 0, and all - // existing constraints are simplified away. - return Satisfiable::Never; - } - - Some(Satisfiable::Always) => { - // If two constraints cancel out to 1, that does NOT cause the entire clause to - // become 1. We need to keep whatever constraints have already been added to - // the result, and also need to copy over any later constraints that we hadn't - // processed yet. - self.constraints.extend(existing_constraints); - return self.is_satisfiable(); - } - - None => { - // We couldn't simplify the new constraint relative to this existing - // constraint, so add the existing constraint to the result. Continue trying to - // simplify the new constraint against the later existing constraints. - self.constraints.push(existing.clone()); - } - - Some(Satisfiable::Constrained(simplified)) => { - // We were able to simplify the new constraint relative to this existing - // constraint. Don't add it to the result yet; instead, try to simplify the - // result further against later existing constraints. - constraint = simplified; - } - } - } - - // If we fall through then we need to add the new constraint to the constraint list (either - // because we couldn't simplify it with anything, or because we did without it canceling - // out). - self.constraints.push(constraint); - self.is_satisfiable() - } - - /// Returns the intersection of this clause with another. - fn intersect_clause(&self, db: &'db dyn Db, other: &Self) -> Satisfiable { - // Add each `other` constraint in turn. Short-circuit if the result ever becomes 0. - let mut result = self.clone(); - for constraint in &other.constraints { - match result.intersect_constraint(db, Satisfiable::Constrained(constraint.clone())) { - Satisfiable::Never => return Satisfiable::Never, - Satisfiable::Always | Satisfiable::Constrained(()) => {} - } - } - if result.is_always() { - Satisfiable::Always - } else { - Satisfiable::Constrained(result) - } - } - - /// Tries to simplify the union of two clauses into a single clause, if possible. - fn simplify_clauses(self, db: &'db dyn Db, other: Self) -> Simplifiable { - // Saturation - // - // If either clause is always satisfiable, the union is too. (`1 ∪ C₂ = 1`, `C₁ ∪ 1 = 1`) - // - // ```py - // class A[T]: ... - // - // class C1[U]: - // # T can specialize to any type, so this is "always satisfiable", or `1` - // x: A[U] - // - // class C2[V: int]: - // # `T ≤ int` - // x: A[V] - // - // class Saturation[U, V: int]: - // # `1 ∪ (T ≤ int)` - // # simplifies via saturation to - // # `T ≤ int` - // x: A[U] | A[V] - // ``` - if self.is_always() || other.is_always() { - return Simplifiable::Simplified(Self::always()); - } - - // Subsumption - // - // If either clause subsumes (is "smaller than") the other, then the union simplifies to - // the "bigger" clause (the one being subsumed): - // - // - `A ∩ B` must be at least as large as `A ∩ B ∩ C` - // - Therefore, `(A ∩ B) ∪ (A ∩ B ∩ C) = (A ∩ B)` - // - // (Note that possibly counterintuitively, "bigger" here means _fewer_ constraints in the - // intersection, since intersecting more things can only make the result smaller.) - // - // ```py - // class A[T, U, V]: ... - // - // class C1[X: int, Y: str, Z]: - // # `(T ≤ int ∩ U ≤ str)` - // x: A[X, Y, Z] - // - // class C2[X: int, Y: str, Z: bytes]: - // # `(T ≤ int ∩ U ≤ str ∩ V ≤ bytes)` - // x: A[X, Y, Z] - // - // class Subsumption[X1: int, Y1: str, Z2, X2: int, Y2: str, Z2: bytes]: - // # `(T ≤ int ∩ U ≤ str) ∪ (T ≤ int ∩ U ≤ str ∩ V ≤ bytes)` - // # simplifies via subsumption to - // # `(T ≤ int ∩ U ≤ str)` - // x: A[X1, Y1, Z2] | A[X2, Y2, Z2] - // ``` - // - // TODO: Consider checking both directions in one pass, possibly via a tri-valued return - // value. - if self.subsumes_via_intersection(db, &other) { - return Simplifiable::Simplified(other); - } - if other.subsumes_via_intersection(db, &self) { - return Simplifiable::Simplified(self); - } - - // Distribution - // - // If the two clauses constrain the same typevar in an "overlapping" way, we can factor - // that out: - // - // (A₁ ∩ B ∩ C) ∪ (A₂ ∩ B ∩ C) = (A₁ ∪ A₂) ∩ B ∩ C - // - // ```py - // class A[T, U, V]: ... - // - // class C1[X: int, Y: str, Z: bytes]: - // # `(T ≤ int ∩ U ≤ str ∩ V ≤ bytes)` - // x: A[X, Y, Z] - // - // class C2[X: bool, Y: str, Z: bytes]: - // # `(T ≤ bool ∩ U ≤ str ∩ V ≤ bytes)` - // x: A[X, Y, Z] - // - // class Distribution[X1: int, Y1: str, Z2: bytes, X2: bool, Y2: str, Z2: bytes]: - // # `(T ≤ int ∩ U ≤ str ∩ V ≤ bytes) ∪ (T ≤ bool ∩ U ≤ str ∩ V ≤ bytes)` - // # simplifies via distribution to - // # `(T ≤ int ∪ T ≤ bool) ∩ U ≤ str ∩ V ≤ bytes)` - // # which (because `bool ≤ int`) is equivalent to - // # `(T ≤ int ∩ U ≤ str ∩ V ≤ bytes)` - // x: A[X1, Y1, Z2] | A[X2, Y2, Z2] - // ``` - if let Some(simplified) = self.simplifies_via_distribution(db, &other) { - return simplified; - } - - // Can't be simplified - Simplifiable::NotSimplified(self, other) - } - - /// Returns whether this clause subsumes `other` via intersection — that is, if the - /// intersection of `self` and `other` is `self`. - fn subsumes_via_intersection(&self, db: &'db dyn Db, other: &Self) -> bool { - // See the notes in `simplify_clauses` for more details on subsumption, including Python - // examples that cause it to fire. - - if self.constraints.len() != other.constraints.len() { - return false; - } - - let pairwise = (self.constraints.iter()) - .merge_join_by(&other.constraints, |a, b| a.typevar.cmp(&b.typevar)); - for pair in pairwise { - match pair { - // `other` contains a constraint whose typevar doesn't appear in `self`, so `self` - // cannot be smaller. - EitherOrBoth::Right(_) => return false, - - // `self` contains a constraint whose typevar doesn't appear in `other`. `self` - // might be smaller, but we still have to check the remaining constraints. - EitherOrBoth::Left(_) => continue, - - // Both clauses contain a constraint with this typevar; verify that the constraint - // in `self` is smaller. - EitherOrBoth::Both(self_constraint, other_constraint) => { - if !self_constraint.subsumes(db, other_constraint) { - return false; - } - } - } - } - true - } - - /// If the union of two clauses is simpler than either of the individual clauses, returns the - /// union. This happens when they mention the same set of typevars and the constraints for all - /// but one typevar are identical. Moreover, for the other typevar, the union of the - /// constraints for that typevar simplifies to (a) a single constraint, or (b) two constraints - /// where one of them is smaller than before. That is, - /// - /// ```text - /// (A₁ ∩ B ∩ C) ∪ (A₂ ∩ B ∩ C) = A₁₂ ∩ B ∩ C - /// or (A₁' ∪ A₂) ∩ B ∩ C - /// or (A₁ ∪ A₂') ∩ B ∩ C - /// ``` - /// - /// where `B` and `C` are the constraints that are identical for all but one typevar, and `A₁` - /// and `A₂` are the constraints for the other typevar; and where `A₁ ∪ A₂` either simplifies - /// to a single constraint (`A₁₂`), or to a union where either `A₁` or `A₂` becomes smaller - /// (`A₁'` or `A₂'`, respectively). - /// - /// Otherwise returns `None`. - fn simplifies_via_distribution( - &self, - db: &'db dyn Db, - other: &Self, - ) -> Option> { - // See the notes in `simplify_clauses` for more details on distribution, including Python - // examples that cause it to fire. - - if self.constraints.len() != other.constraints.len() { - return None; - } - - // Verify that the constraints for precisely one typevar simplify, and the constraints for - // all other typevars are identical. Remember the index of the typevar whose constraints - // simplify. - let mut simplified_index = None; - let pairwise = (self.constraints.iter()) - .merge_join_by(&other.constraints, |a, b| a.typevar.cmp(&b.typevar)); - for (index, pair) in pairwise.enumerate() { - match pair { - // If either clause contains a constraint whose typevar doesn't appear in the - // other, the clauses don't simplify. - EitherOrBoth::Left(_) | EitherOrBoth::Right(_) => return None, - - EitherOrBoth::Both(self_constraint, other_constraint) => { - if self_constraint == other_constraint { - continue; - } - let Some(union_constraint) = - self_constraint.simplified_union(db, other_constraint) - else { - // The constraints for this typevar are not identical, nor do they - // simplify. - return None; - }; - if simplified_index - .replace((index, union_constraint)) - .is_some() - { - // More than one constraint simplify, which doesn't allow the clause as a - // whole to simplify. - return None; - } - } - } - } - - let Some((index, union_constraint)) = simplified_index else { - // We never found a typevar whose constraints simplify. - return None; - }; - let mut constraints = self.constraints.clone(); - match union_constraint { - Simplifiable::NeverSatisfiable => { - panic!("unioning two non-never constraints should not be never") - } - Simplifiable::AlwaysSatisfiable => { - // If the simplified union of constraints is Always, then we can remove this typevar - // from the constraint completely. - constraints.remove(index); - Some(Simplifiable::from_one(Self::new(constraints))) - } - Simplifiable::Simplified(union_constraint) => { - constraints[index] = union_constraint; - Some(Simplifiable::from_one(Self::new(constraints))) - } - Simplifiable::NotSimplified(left, right) => { - let mut left_constraints = constraints.clone(); - let mut right_constraints = constraints; - left_constraints[index] = left; - right_constraints[index] = right; - Some(Simplifiable::from_union( - Self::new(left_constraints), - Self::new(right_constraints), - )) - } - } - } - - /// Returns the negation of this clause. The result is a set since negating an intersection - /// produces a union. - fn negate(&self, db: &'db dyn Db) -> ConstraintSet<'db> { - let mut result = ConstraintSet::never(); - for constraint in &self.constraints { - constraint.negate_into(db, &mut result); - } - result - } - - fn display(&self, db: &'db dyn Db) -> impl Display { - struct DisplayConstraintClause<'a, 'db> { - clause: &'a ConstraintClause<'db>, - db: &'db dyn Db, - } - - impl Display for DisplayConstraintClause<'_, '_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.clause.constraints.is_empty() { - return f.write_str("1"); - } - - let clause_count: usize = (self.clause.constraints.iter()) - .map(ConstrainedTypeVar::clause_count) - .sum(); - if clause_count > 1 { - f.write_str("(")?; - } - for (i, constraint) in self.clause.constraints.iter().enumerate() { - if i > 0 { - f.write_str(" ∧ ")?; - } - constraint.display(self.db).fmt(f)?; - } - if clause_count > 1 { - f.write_str(")")?; - } - Ok(()) - } - } - - DisplayConstraintClause { clause: self, db } - } -} - -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] +/// An individual constraint in a constraint set. This restricts a single typevar to be within a +/// lower and upper bound. +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +#[derive(PartialOrd, Ord)] pub(crate) struct ConstrainedTypeVar<'db> { typevar: BoundTypeVarInstance<'db>, - constraint: Constraint<'db>, -} - -impl<'db> ConstrainedTypeVar<'db> { - fn clause_count(&self) -> usize { - self.constraint.clause_count() - } - - /// Returns the intersection of this individual constraint and another, or `None` if the two - /// constraints do not refer to the same typevar (and therefore cannot be simplified to a - /// single constraint). - fn intersect(&self, db: &'db dyn Db, other: &Self) -> Option> { - if self.typevar != other.typevar { - return None; - } - Some( - self.constraint - .intersect(db, &other.constraint) - .map(|constraint| constraint.constrain(self.typevar)), - ) - } - - /// Returns the union of this individual constraint and another, if it can be simplified to a - /// union of two constraints or fewer. Returns `None` if the union cannot be simplified that - /// much. - fn simplified_union(&self, db: &'db dyn Db, other: &Self) -> Option> { - if self.typevar != other.typevar { - return None; - } - self.constraint - .simplified_union(db, &other.constraint) - .map(|constraint| constraint.map(|constraint| constraint.constrain(self.typevar))) - } - - /// Adds the negation of this individual constraint to a constraint set. - fn negate_into(&self, db: &'db dyn Db, set: &mut ConstraintSet<'db>) { - self.constraint.negate_into(db, self.typevar, set); - } - - /// Returns whether `self` has tighter bounds than `other` — that is, if the intersection of - /// `self` and `other` is `self`. - fn subsumes(&self, db: &'db dyn Db, other: &Self) -> bool { - debug_assert_eq!(self.typevar, other.typevar); - match self.intersect(db, other) { - Some(Satisfiable::Constrained(intersection)) => intersection == *self, - _ => false, - } - } - - fn display(&self, db: &'db dyn Db) -> impl Display { - self.constraint.display(db, self.typevar.display(db)) - } -} - -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] -pub(crate) struct Constraint<'db> { - positive: RangeConstraint<'db>, - negative: SmallVec<[NegatedRangeConstraint<'db>; 1]>, -} - -impl<'db> Constraint<'db> { - fn constrain(self, typevar: BoundTypeVarInstance<'db>) -> ConstrainedTypeVar<'db> { - ConstrainedTypeVar { - typevar, - constraint: self, - } - } - - fn clause_count(&self) -> usize { - usize::from(!self.positive.is_always()) + self.negative.len() - } - - fn satisfiable(self, db: &'db dyn Db) -> Satisfiable { - if self.positive.is_always() && self.negative.is_empty() { - return Satisfiable::Always; - } - if (self.negative.iter()).any(|negative| negative.hole.contains(db, &self.positive)) { - return Satisfiable::Never; - } - Satisfiable::Constrained(self) - } - - fn intersect(&self, db: &'db dyn Db, other: &Constraint<'db>) -> Satisfiable> { - let Some(positive) = self.positive.intersect(db, &other.positive) else { - // If the positive intersection is empty, none of the negative holes matter, since - // there are no types for the holes to remove. - return Satisfiable::Never; - }; - - // The negative portion of the intersection is given by - // - // ¬(s₁ ≤ α ≤ t₁) ∧ ¬(s₂ ≤ α ≤ t₂) = ¬((s₁ ≤ α ≤ t₁) ∨ (s₂ ≤ α ≤ t₂)) - // - // That is, we union together the holes from `self` and `other`. If any of the holes - // entirely contain another, we can simplify those two down to the larger hole. We use the - // same trick as above in `union_clause` and `intersect_constraint` to look for pairs that - // we can simplify. - // - // We also want to clip each negative hole to the minimum range that overlaps with the - // positive range. We'll do that now to all of the holes from `self`, and we'll do that to - // holes from `other` below when we try to simplify them. - let mut previous: SmallVec<[NegatedRangeConstraint<'db>; 1]> = SmallVec::new(); - let mut current: SmallVec<_> = (self.negative.iter()) - .filter_map(|negative| negative.clip_to_positive(db, &positive)) - .collect(); - for other_negative in &other.negative { - let Some(mut other_negative) = other_negative.clip_to_positive(db, &positive) else { - continue; - }; - std::mem::swap(&mut previous, &mut current); - let mut previous_negative = previous.iter(); - for self_negative in previous_negative.by_ref() { - match self_negative.intersect_negative(db, &other_negative) { - None => { - // We couldn't simplify the new hole relative to this existing holes, so - // add the existing hole to the result. Continue trying to simplify the new - // hole against the later existing holes. - current.push(self_negative.clone()); - } - - Some(union) => { - // We were able to simplify the new hole relative to this existing hole. - // Don't add it to the result yet; instead, try to simplify the result - // further against later existing holes. - other_negative = union.clone(); - } - } - } - - // If we fall through then we need to add the new hole to the hole list (either because - // we couldn't simplify it with anything, or because we did without it canceling out). - current.push(other_negative); - } - - let result = Self { - positive, - negative: current, - }; - result.satisfiable(db) - } - - fn simplified_union( - &self, - db: &'db dyn Db, - other: &Constraint<'db>, - ) -> Option>> { - // (ap ∧ ¬an₁ ∧ ¬an₂ ∧ ...) ∨ (bp ∧ ¬bn₁ ∧ ¬bn₂ ∧ ...) - // = (ap ∨ bp) ∧ (ap ∨ ¬bn₁) ∧ (ap ∨ ¬bn₂) ∧ ... - // ∧ (¬an₁ ∨ bp) ∧ (¬an₁ ∨ ¬bn₁) ∧ (¬an₁ ∨ ¬bn₂) ∧ ... - // ∧ (¬an₂ ∨ bp) ∧ (¬an₂ ∨ ¬bn₁) ∧ (¬an₂ ∨ ¬bn₂) ∧ ... - // - // We use a helper type to build up the result of the union of two constraints, since we - // need to calculate the Cartesian product of the the positive and negative portions of the - // two inputs. We cannot use `ConstraintSet` for this, since it would try to invoke the - // `simplify_union` logic, which this method is part of the definition of! So we have to - // reproduce some of that logic here, in a simplified form since we know we're only ever - // looking at pairs of individual constraints at a time. - - struct Results<'db> { - next: Vec>, - results: Vec>, - } - - impl<'db> Results<'db> { - fn new(constraint: Constraint<'db>) -> Results<'db> { - Results { - next: vec![], - results: vec![constraint], - } - } - - fn flip(&mut self) { - std::mem::swap(&mut self.next, &mut self.results); - self.next.clear(); - } - - /// Adds a constraint by intersecting it with any currently pending results. - fn add_constraint(&mut self, db: &'db dyn Db, constraint: &Constraint<'db>) { - self.next.extend(self.results.iter().filter_map(|result| { - match result.intersect(db, constraint) { - Satisfiable::Never => None, - Satisfiable::Always => Some(Constraint { - positive: RangeConstraint::always(), - negative: smallvec![], - }), - Satisfiable::Constrained(constraint) => Some(constraint), - } - })); - } - - /// Adds a single negative range constraint to the pending results. - fn add_negated_range( - &mut self, - db: &'db dyn Db, - negative: Option>, - ) { - let negative = match negative { - Some(negative) => Constraint { - positive: RangeConstraint::always(), - negative: smallvec![negative], - }, - // If the intersection of these two holes is empty, then they don't remove - // anything from the final union. - None => return, - }; - self.add_constraint(db, &negative); - self.flip(); - } - - /// Adds a possibly simplified constraint to the pending results. If the parameter has - /// been simplified to a single constraint, it is intersected with each currently - /// pending result. If it could not be simplified (i.e., it is the union of two - /// constraints), then we duplicate any pending results, so that we can _separately_ - /// intersect each non-simplified constraint with the results. - fn add_simplified_constraint( - &mut self, - db: &'db dyn Db, - constraints: Simplifiable>, - ) { - match constraints { - Simplifiable::NeverSatisfiable => { - self.results.clear(); - return; - } - Simplifiable::AlwaysSatisfiable => { - return; - } - Simplifiable::Simplified(constraint) => { - self.add_constraint(db, &constraint); - } - Simplifiable::NotSimplified(first, second) => { - self.add_constraint(db, &first); - self.add_constraint(db, &second); - } - } - self.flip(); - } - - /// If there are two or fewer final results, translates them into a [`Simplifiable`] - /// result. Otherwise returns `None`, indicating that the union cannot be simplified - /// enough for our purposes. - fn into_result(self, db: &'db dyn Db) -> Option>> { - let mut results = self.results.into_iter(); - let Some(first) = results.next() else { - return Some(Simplifiable::NeverSatisfiable); - }; - let Some(second) = results.next() else { - return Some(Simplifiable::from_one(first.satisfiable(db))); - }; - if results.next().is_some() { - return None; - } - Some(Simplifiable::from_union( - first.satisfiable(db), - second.satisfiable(db), - )) - } - } - - let mut results = match self.positive.union(db, &other.positive) { - Some(positive) => Results::new(Constraint { - positive: positive.clone(), - negative: smallvec![], - }), - None => return None, - }; - for other_negative in &other.negative { - results.add_simplified_constraint( - db, - self.positive.union_negated_range(db, other_negative), - ); - } - for self_negative in &self.negative { - // Reverse the results here so that we always add items from `self` first. This ensures - // that the output we produce is ordered consistently with the input we receive. - results.add_simplified_constraint( - db, - other - .positive - .union_negated_range(db, self_negative) - .reverse(), - ); - } - for self_negative in &self.negative { - for other_negative in &other.negative { - results.add_negated_range(db, self_negative.union_negative(db, other_negative)); - } - } - - results.into_result(db) - } - - fn negate_into( - &self, - db: &'db dyn Db, - typevar: BoundTypeVarInstance<'db>, - set: &mut ConstraintSet<'db>, - ) { - for negative in &self.negative { - set.union_constraint( - db, - Constraint::range(db, negative.hole.lower, negative.hole.upper).constrain(typevar), - ); - } - set.union_constraint( - db, - Constraint::negated_range(db, self.positive.lower, self.positive.upper) - .constrain(typevar), - ); - } - - fn display(&self, db: &'db dyn Db, typevar: impl Display) -> impl Display { - struct DisplayConstraint<'a, 'db, D> { - constraint: &'a Constraint<'db>, - typevar: D, - db: &'db dyn Db, - } - - impl Display for DisplayConstraint<'_, '_, D> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut first = true; - if !self.constraint.positive.is_always() { - (self.constraint.positive) - .display(self.db, &self.typevar) - .fmt(f)?; - first = false; - } - for negative in &self.constraint.negative { - if first { - first = false; - } else { - f.write_str(" ∧ ")?; - } - negative.display(self.db, &self.typevar).fmt(f)?; - } - Ok(()) - } - } - - DisplayConstraint { - constraint: self, - typevar, - db, - } - } -} - -impl<'db> Satisfiable> { - fn constrain(self, typevar: BoundTypeVarInstance<'db>) -> Satisfiable> { - self.map(|constraint| constraint.constrain(typevar)) - } -} - -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] -pub(crate) struct RangeConstraint<'db> { lower: Type<'db>, upper: Type<'db>, } -impl<'db> Constraint<'db> { +// The Salsa heap is tracked separately. +impl get_size2::GetSize for ConstrainedTypeVar<'_> {} + +#[salsa::tracked] +impl<'db> ConstrainedTypeVar<'db> { /// Returns a new range constraint. /// /// Panics if `lower` and `upper` are not both fully static. - fn range(db: &'db dyn Db, lower: Type<'db>, upper: Type<'db>) -> Satisfiable> { + fn new_node( + db: &'db dyn Db, + lower: Type<'db>, + typevar: BoundTypeVarInstance<'db>, + upper: Type<'db>, + ) -> Node<'db> { debug_assert_eq!(lower, lower.bottom_materialization(db)); debug_assert_eq!(upper, upper.top_materialization(db)); // If `lower ≰ upper`, then the constraint cannot be satisfied, since there is no type that // is both greater than `lower`, and less than `upper`. if !lower.is_subtype_of(db, upper) { - return Satisfiable::Never; + return Node::AlwaysFalse; } // If the requested constraint is `Never ≤ T ≤ object`, then the typevar can be specialized // to _any_ type, and the constraint does nothing. - let positive = RangeConstraint { lower, upper }; - if positive.is_always() { - return Satisfiable::Always; + if lower.is_never() && upper.is_object() { + return Node::AlwaysTrue; } - Satisfiable::Constrained(Constraint { - positive, - negative: smallvec![], - }) + Node::new_constraint(db, ConstrainedTypeVar::new(db, typevar, lower, upper)) + } + fn when_true(self) -> ConstraintAssignment<'db> { + ConstraintAssignment::Positive(self) } -} -impl<'db> RangeConstraint<'db> { - fn always() -> Self { - Self { - lower: Type::Never, - upper: Type::object(), + fn when_false(self) -> ConstraintAssignment<'db> { + ConstraintAssignment::Negative(self) + } + + fn contains(self, db: &'db dyn Db, other: Self) -> bool { + if self.typevar(db) != other.typevar(db) { + return false; } - } - - fn contains(&self, db: &'db dyn Db, other: &RangeConstraint<'db>) -> bool { - self.lower.is_subtype_of(db, other.lower) && other.upper.is_subtype_of(db, self.upper) - } - - fn is_always(&self) -> bool { - self.lower.is_never() && self.upper.is_object() + self.lower(db).is_subtype_of(db, other.lower(db)) + && other.upper(db).is_subtype_of(db, self.upper(db)) } /// Returns the intersection of two range constraints, or `None` if the intersection is empty. - fn intersect(&self, db: &'db dyn Db, other: &RangeConstraint<'db>) -> Option { + fn intersect(self, db: &'db dyn Db, other: Self) -> Option { // (s₁ ≤ α ≤ t₁) ∧ (s₂ ≤ α ≤ t₂) = (s₁ ∪ s₂) ≤ α ≤ (t₁ ∩ t₂)) - let lower = UnionType::from_elements(db, [self.lower, other.lower]); - let upper = IntersectionType::from_elements(db, [self.upper, other.upper]); + let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]); + let upper = IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]); // If `lower ≰ upper`, then the intersection is empty, since there is no type that is both // greater than `lower`, and less than `upper`. @@ -1230,274 +322,1029 @@ impl<'db> RangeConstraint<'db> { return None; } - Some(Self { lower, upper }) + Some(Self::new(db, self.typevar(db), lower, upper)) } - /// Returns the union of two range constraints if it can be simplified to a single constraint. - /// Otherwise returns `None`. - fn union(&self, db: &'db dyn Db, other: &RangeConstraint<'db>) -> Option { - // When one of the bounds is entirely contained within the other, the union simplifies to - // the larger bounds. - if self.lower.is_subtype_of(db, other.lower) && other.upper.is_subtype_of(db, self.upper) { - return Some(self.clone()); - } - if other.lower.is_subtype_of(db, self.lower) && self.upper.is_subtype_of(db, other.upper) { - return Some(other.clone()); - } - - // Otherwise the result cannot be simplified. - None + fn display(self, db: &'db dyn Db) -> impl Display { + self.display_inner(db, false) } - /// Returns the union of a positive range with a negative hole. - fn union_negated_range( - &self, - db: &'db dyn Db, - negated: &NegatedRangeConstraint<'db>, - ) -> Simplifiable> { - // If the positive range completely contains the negative range, then the union is always - // satisfied. - if self.contains(db, &negated.hole) { - return Simplifiable::AlwaysSatisfiable; - } - - // If the positive range is disjoint from the negative range, the positive range doesn't - // add anything; the union is the negative range. - if incomparable(db, self.lower, negated.hole.upper) - || incomparable(db, negated.hole.lower, self.upper) - { - return Simplifiable::from_one(Constraint::negated_range( - db, - negated.hole.lower, - negated.hole.upper, - )); - } - - // Otherwise we clip the positive constraint to the minimum range that overlaps with the - // negative range. - Simplifiable::from_union( - Constraint::range( - db, - UnionType::from_elements(db, [self.lower, negated.hole.lower]), - IntersectionType::from_elements(db, [self.upper, negated.hole.upper]), - ), - Constraint::negated_range(db, negated.hole.lower, negated.hole.upper), - ) + fn display_negated(self, db: &'db dyn Db) -> impl Display { + self.display_inner(db, true) } - fn display(&self, db: &'db dyn Db, typevar: impl Display) -> impl Display { - struct DisplayRangeConstraint<'a, 'db, D> { - constraint: &'a RangeConstraint<'db>, - typevar: D, + fn display_inner(self, db: &'db dyn Db, negated: bool) -> impl Display { + struct DisplayConstrainedTypeVar<'db> { + constraint: ConstrainedTypeVar<'db>, + negated: bool, db: &'db dyn Db, } - impl Display for DisplayRangeConstraint<'_, '_, D> { + impl Display for DisplayConstrainedTypeVar<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if (self.constraint.lower).is_equivalent_to(self.db, self.constraint.upper) { + let lower = self.constraint.lower(self.db); + let upper = self.constraint.upper(self.db); + if lower.is_equivalent_to(self.db, upper) { return write!( f, - "({} = {})", - &self.typevar, - self.constraint.lower.display(self.db) + "({} {} {})", + self.constraint.typevar(self.db).display(self.db), + if self.negated { "≠" } else { "=" }, + lower.display(self.db) ); } - f.write_str("(")?; - if !self.constraint.lower.is_never() { - write!(f, "{} ≤ ", self.constraint.lower.display(self.db))?; + if self.negated { + f.write_str("¬")?; } - self.typevar.fmt(f)?; - if !self.constraint.upper.is_object() { - write!(f, " ≤ {}", self.constraint.upper.display(self.db))?; + f.write_str("(")?; + if !lower.is_never() { + write!(f, "{} ≤ ", lower.display(self.db))?; + } + self.constraint.typevar(self.db).display(self.db).fmt(f)?; + if !upper.is_object() { + write!(f, " ≤ {}", upper.display(self.db))?; } f.write_str(")") } } - DisplayRangeConstraint { + DisplayConstrainedTypeVar { constraint: self, - typevar, + negated, db, } } } -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] -pub(crate) struct NegatedRangeConstraint<'db> { - hole: RangeConstraint<'db>, +/// A BDD node. +/// +/// The "variables" of a constraint set BDD are individual constraints, represented by an interned +/// [`ConstrainedTypeVar`]. +/// +/// Terminal nodes (`false` and `true`) have their own dedicated enum variants. The +/// [`Interior`][InteriorNode] variant represents interior nodes. +/// +/// BDD nodes are _reduced_, which means that there are no duplicate nodes (which we handle via +/// Salsa interning), and that there are no redundant nodes, with `if_true` and `if_false` edges +/// that point at the same node. +/// +/// BDD nodes are also _ordered_, meaning that every path from the root of a BDD to a terminal node +/// visits variables in the same order. [`ConstrainedTypeVar`]s are interned, so we can use the IDs +/// that salsa assigns to define this order. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] +enum Node<'db> { + AlwaysFalse, + AlwaysTrue, + Interior(InteriorNode<'db>), } -impl<'db> Constraint<'db> { - /// Returns a new negated range constraint. +impl<'db> Node<'db> { + /// Creates a new BDD node, ensuring that it is fully reduced. + fn new( + db: &'db dyn Db, + constraint: ConstrainedTypeVar<'db>, + if_true: Node<'db>, + if_false: Node<'db>, + ) -> Self { + debug_assert!( + (if_true.root_constraint(db)) + .is_none_or(|root_constraint| root_constraint > constraint) + ); + debug_assert!( + (if_false.root_constraint(db)) + .is_none_or(|root_constraint| root_constraint > constraint) + ); + if if_true == if_false { + return if_true; + } + Self::Interior(InteriorNode::new(db, constraint, if_true, if_false)) + } + + /// Creates a new BDD node for an individual constraint. (The BDD will evaluate to `true` when + /// the constraint holds, and to `false` when it does not.) + fn new_constraint(db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) -> Self { + Self::Interior(InteriorNode::new( + db, + constraint, + Node::AlwaysTrue, + Node::AlwaysFalse, + )) + } + + /// Creates a new BDD node for a positive or negative individual constraint. (For a positive + /// constraint, this returns the same BDD node as [`new_constraint`][Self::new_constraint]. For + /// a negative constraint, it returns the negation of that BDD node.) + fn new_satisfied_constraint(db: &'db dyn Db, constraint: ConstraintAssignment<'db>) -> Self { + match constraint { + ConstraintAssignment::Positive(constraint) => Self::Interior(InteriorNode::new( + db, + constraint, + Node::AlwaysTrue, + Node::AlwaysFalse, + )), + ConstraintAssignment::Negative(constraint) => Self::Interior(InteriorNode::new( + db, + constraint, + Node::AlwaysFalse, + Node::AlwaysTrue, + )), + } + } + + /// Returns the BDD variable of the root node of this BDD, or `None` if this BDD is a terminal + /// node. + fn root_constraint(self, db: &'db dyn Db) -> Option> { + match self { + Node::Interior(interior) => Some(interior.constraint(db)), + _ => None, + } + } + + /// Returns whether this BDD represent the constant function `true`. + fn is_always_satisfied(self) -> bool { + matches!(self, Node::AlwaysTrue) + } + + /// Returns whether this BDD represent the constant function `false`. + fn is_never_satisfied(self) -> bool { + matches!(self, Node::AlwaysFalse) + } + + /// Returns the negation of this BDD. + fn negate(self, db: &'db dyn Db) -> Self { + match self { + Node::AlwaysTrue => Node::AlwaysFalse, + Node::AlwaysFalse => Node::AlwaysTrue, + Node::Interior(interior) => interior.negate(db), + } + } + + /// Returns the `or` or union of two BDDs. + fn or(self, db: &'db dyn Db, other: Self) -> Self { + match (self, other) { + (Node::AlwaysTrue, _) | (_, Node::AlwaysTrue) => Node::AlwaysTrue, + (Node::AlwaysFalse, other) | (other, Node::AlwaysFalse) => other, + (Node::Interior(a), Node::Interior(b)) => { + // OR is commutative, which lets us halve the cache requirements + let (a, b) = if b.0 < a.0 { (b, a) } else { (a, b) }; + a.or(db, b) + } + } + } + + /// Returns the `and` or intersection of two BDDs. + fn and(self, db: &'db dyn Db, other: Self) -> Self { + match (self, other) { + (Node::AlwaysFalse, _) | (_, Node::AlwaysFalse) => Node::AlwaysFalse, + (Node::AlwaysTrue, other) | (other, Node::AlwaysTrue) => other, + (Node::Interior(a), Node::Interior(b)) => { + // AND is commutative, which lets us halve the cache requirements + let (a, b) = if b.0 < a.0 { (b, a) } else { (a, b) }; + a.and(db, b) + } + } + } + + /// Returns a new BDD that evaluates to `true` when both input BDDs evaluate to the same + /// result. + fn iff(self, db: &'db dyn Db, other: Self) -> Self { + match (self, other) { + (Node::AlwaysFalse, Node::AlwaysFalse) | (Node::AlwaysTrue, Node::AlwaysTrue) => { + Node::AlwaysTrue + } + (Node::AlwaysTrue, Node::AlwaysFalse) | (Node::AlwaysFalse, Node::AlwaysTrue) => { + Node::AlwaysFalse + } + (Node::AlwaysTrue | Node::AlwaysFalse, Node::Interior(interior)) => Node::new( + db, + interior.constraint(db), + self.iff(db, interior.if_true(db)), + self.iff(db, interior.if_false(db)), + ), + (Node::Interior(interior), Node::AlwaysTrue | Node::AlwaysFalse) => Node::new( + db, + interior.constraint(db), + interior.if_true(db).iff(db, other), + interior.if_false(db).iff(db, other), + ), + (Node::Interior(a), Node::Interior(b)) => { + // IFF is commutative, which lets us halve the cache requirements + let (a, b) = if b.0 < a.0 { (b, a) } else { (a, b) }; + a.iff(db, b) + } + } + } + + /// Returns the `if-then-else` of three BDDs: when `self` evaluates to `true`, it returns what + /// `then_node` evaluates to; otherwise it returns what `else_node` evaluates to. + fn ite(self, db: &'db dyn Db, then_node: Self, else_node: Self) -> Self { + self.and(db, then_node) + .or(db, self.negate(db).and(db, else_node)) + } + + /// Returns a new BDD that returns the same results as `self`, but with some inputs fixed to + /// particular values. (Those variables will not be checked when evaluating the result, and + /// will not be present in the result.) /// - /// Panics if `lower` and `upper` are not both fully static. - fn negated_range( + /// Also returns whether _all_ of the restricted variables appeared in the BDD. + fn restrict( + self, db: &'db dyn Db, - lower: Type<'db>, - upper: Type<'db>, - ) -> Satisfiable> { - debug_assert_eq!(lower, lower.bottom_materialization(db)); - debug_assert_eq!(upper, upper.top_materialization(db)); + assignment: impl IntoIterator>, + ) -> (Self, bool) { + assignment + .into_iter() + .fold((self, true), |(restricted, found), assignment| { + let (restricted, found_this) = restricted.restrict_one(db, assignment); + (restricted, found && found_this) + }) + } - // If `lower ≰ upper`, then the negated constraint is always satisfied, since there is no - // type that is both greater than `lower`, and less than `upper`. - if !lower.is_subtype_of(db, upper) { - return Satisfiable::Always; + /// Returns a new BDD that returns the same results as `self`, but with one input fixed to a + /// particular value. (That variable will be not be checked when evaluating the result, and + /// will not be present in the result.) + /// + /// Also returns whether the restricted variable appeared in the BDD. + fn restrict_one(self, db: &'db dyn Db, assignment: ConstraintAssignment<'db>) -> (Self, bool) { + match self { + Node::AlwaysTrue => (Node::AlwaysTrue, false), + Node::AlwaysFalse => (Node::AlwaysFalse, false), + Node::Interior(interior) => interior.restrict_one(db, assignment), } + } - // If the requested constraint is `¬(Never ≤ T ≤ object)`, then the constraint cannot be - // satisfied. - let negative = NegatedRangeConstraint { - hole: RangeConstraint { lower, upper }, + /// Returns a new BDD with any occurrence of `left ∧ right` replaced with `replacement`. + fn substitute_intersection( + self, + db: &'db dyn Db, + left: ConstraintAssignment<'db>, + right: ConstraintAssignment<'db>, + replacement: Node<'db>, + ) -> Self { + // We perform a Shannon expansion to find out what the input BDD evaluates to when: + // - left and right are both true + // - left is false + // - left is true and right is false + // This covers the entire truth table of `left ∧ right`. + let (when_left_and_right, both_found) = self.restrict(db, [left, right]); + if !both_found { + // If left and right are not both present in the input BDD, we should not even attempt + // the substitution, since the Shannon expansion might introduce the missing variables! + // That confuses us below when we try to detect whether the substitution is consistent + // with the input. + return self; + } + let (when_not_left, _) = self.restrict(db, [left.negated()]); + let (when_left_but_not_right, _) = self.restrict(db, [left, right.negated()]); + + // The result should test `replacement`, and when it's true, it should produce the same + // output that input would when `left ∧ right` is true. When replacement is false, it + // should fall back on testing left and right individually to make sure we produce the + // correct outputs in the `¬(left ∧ right)` case. So the result is + // + // if replacement + // when_left_and_right + // else if not left + // when_not_left + // else if not right + // when_left_but_not_right + // else + // false + // + // (Note that the `else` branch shouldn't be reachable, but we have to provide something!) + let left_node = Node::new_satisfied_constraint(db, left); + let right_node = Node::new_satisfied_constraint(db, right); + let right_result = right_node.ite(db, Node::AlwaysFalse, when_left_but_not_right); + let left_result = left_node.ite(db, right_result, when_not_left); + let result = replacement.ite(db, when_left_and_right, left_result); + + // Lastly, verify that the result is consistent with the input. (It must produce the same + // results when `left ∧ right`.) If it doesn't, the substitution isn't valid, and we should + // return the original BDD unmodified. + let validity = replacement.iff(db, left_node.and(db, right_node)); + let constrained_original = self.and(db, validity); + let constrained_replacement = result.and(db, validity); + if constrained_original == constrained_replacement { + result + } else { + self + } + } + + /// Returns a new BDD with any occurrence of `left ∨ right` replaced with `replacement`. + fn substitute_union( + self, + db: &'db dyn Db, + left: ConstraintAssignment<'db>, + right: ConstraintAssignment<'db>, + replacement: Node<'db>, + ) -> Self { + // We perform a Shannon expansion to find out what the input BDD evaluates to when: + // - left and right are both true + // - left is true and right is false + // - left is false and right is true + // - left and right are both false + // This covers the entire truth table of `left ∨ right`. + let (when_l1_r1, both_found) = self.restrict(db, [left, right]); + if !both_found { + // If left and right are not both present in the input BDD, we should not even attempt + // the substitution, since the Shannon expansion might introduce the missing variables! + // That confuses us below when we try to detect whether the substitution is consistent + // with the input. + return self; + } + let (when_l0_r0, _) = self.restrict(db, [left.negated(), right.negated()]); + let (when_l1_r0, _) = self.restrict(db, [left, right.negated()]); + let (when_l0_r1, _) = self.restrict(db, [left.negated(), right]); + + // The result should test `replacement`, and when it's true, it should produce the same + // output that input would when `left ∨ right` is true. For OR, this is the union of what + // the input produces for the three cases that comprise `left ∨ right`. When `replacement` + // is false, the result should produce the same output that input would when + // `¬(left ∨ right)`, i.e. when `left ∧ right`. So the result is + // + // if replacement + // or(when_l1_r1, when_l1_r0, when_r0_l1) + // else + // when_l0_r0 + let result = replacement.ite( + db, + when_l1_r0.or(db, when_l0_r1.or(db, when_l1_r1)), + when_l0_r0, + ); + + // Lastly, verify that the result is consistent with the input. (It must produce the same + // results when `left ∨ right`.) If it doesn't, the substitution isn't valid, and we should + // return the original BDD unmodified. + let left_node = Node::new_satisfied_constraint(db, left); + let right_node = Node::new_satisfied_constraint(db, right); + let validity = replacement.iff(db, left_node.or(db, right_node)); + let constrained_original = self.and(db, validity); + let constrained_replacement = result.and(db, validity); + if constrained_original == constrained_replacement { + result + } else { + self + } + } + + /// Invokes a closure for each constraint variable that appears anywhere in a BDD. (Any given + /// constraint can appear multiple times in different paths from the root; we do not + /// deduplicate those constraints, and will instead invoke the callback each time we encounter + /// the constraint.) + fn for_each_constraint(self, db: &'db dyn Db, f: &mut dyn FnMut(ConstrainedTypeVar<'db>)) { + let Node::Interior(interior) = self else { + return; }; - if negative.hole.is_always() { - return Satisfiable::Never; + f(interior.constraint(db)); + interior.if_true(db).for_each_constraint(db, f); + interior.if_false(db).for_each_constraint(db, f); + } + + /// Simplifies a BDD, replacing constraints with simpler or smaller constraints where possible. + fn simplify(self, db: &'db dyn Db) -> Self { + match self { + Node::AlwaysTrue | Node::AlwaysFalse => self, + Node::Interior(interior) => interior.simplify(db), + } + } + + /// Returns clauses describing all of the variable assignments that cause this BDD to evaluate + /// to `true`. (This translates the boolean function that this BDD represents into DNF form.) + fn satisfied_clauses(self, db: &'db dyn Db) -> SatisfiedClauses<'db> { + struct Searcher<'db> { + clauses: SatisfiedClauses<'db>, + current_clause: SatisfiedClause<'db>, } - Satisfiable::Constrained(Constraint { - positive: RangeConstraint::always(), - negative: smallvec![negative], - }) - } -} + impl<'db> Searcher<'db> { + fn visit_node(&mut self, db: &'db dyn Db, node: Node<'db>) { + match node { + Node::AlwaysFalse => {} + Node::AlwaysTrue => self.clauses.push(self.current_clause.clone()), + Node::Interior(interior) => { + let interior_constraint = interior.constraint(db); + self.current_clause.push(interior_constraint.when_true()); + self.visit_node(db, interior.if_true(db)); + self.current_clause.pop(); + self.current_clause.push(interior_constraint.when_false()); + self.visit_node(db, interior.if_false(db)); + self.current_clause.pop(); + } + } + } + } -impl<'db> NegatedRangeConstraint<'db> { - /// Clips this negative hole to be the smallest hole that removes the same types from the given - /// positive range. - fn clip_to_positive(&self, db: &'db dyn Db, positive: &RangeConstraint<'db>) -> Option { - self.hole - .intersect(db, positive) - .map(|hole| NegatedRangeConstraint { hole }) + let mut searcher = Searcher { + clauses: SatisfiedClauses::default(), + current_clause: SatisfiedClause::default(), + }; + searcher.visit_node(db, self); + searcher.clauses } - /// Returns the union of two negative constraints. (This this is _intersection_ of the - /// constraints' holes.) - fn union_negative( - &self, - db: &'db dyn Db, - positive: &NegatedRangeConstraint<'db>, - ) -> Option { - self.hole - .intersect(db, &positive.hole) - .map(|hole| NegatedRangeConstraint { hole }) - } - - /// Returns the intersection of two negative constraints. (This this is _union_ of the - /// constraints' holes.) - fn intersect_negative( - &self, - db: &'db dyn Db, - other: &NegatedRangeConstraint<'db>, - ) -> Option { - self.hole - .union(db, &other.hole) - .map(|hole| NegatedRangeConstraint { hole }) - } - - fn display(&self, db: &'db dyn Db, typevar: impl Display) -> impl Display { - struct DisplayNegatedRangeConstraint<'a, 'db, D> { - constraint: &'a NegatedRangeConstraint<'db>, - typevar: D, + fn display(self, db: &'db dyn Db) -> impl Display { + // To render a BDD in DNF form, you perform a depth-first search of the BDD tree, looking + // for any path that leads to the AlwaysTrue terminal. Each such path represents one of the + // intersection clauses in the DNF form. The path traverses zero or more interior nodes, + // and takes either the true or false edge from each one. That gives you the positive or + // negative individual constraints in the path's clause. + struct DisplayNode<'db> { + node: Node<'db>, db: &'db dyn Db, } - impl Display for DisplayNegatedRangeConstraint<'_, '_, D> { + impl Display for DisplayNode<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if (self.constraint.hole.lower) - .is_equivalent_to(self.db, self.constraint.hole.upper) - { - return write!( - f, - "({} ≠ {})", - &self.typevar, - self.constraint.hole.lower.display(self.db) - ); + match self.node { + Node::AlwaysTrue => f.write_str("always"), + Node::AlwaysFalse => f.write_str("never"), + Node::Interior(_) => { + let mut clauses = self.node.satisfied_clauses(self.db); + clauses.simplify(); + clauses.display(self.db).fmt(f) + } } - - f.write_str("¬")?; - self.constraint.hole.display(self.db, &self.typevar).fmt(f) } } - DisplayNegatedRangeConstraint { + DisplayNode { node: self, db } + } +} + +/// An interior node of a BDD +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +struct InteriorNode<'db> { + constraint: ConstrainedTypeVar<'db>, + if_true: Node<'db>, + if_false: Node<'db>, +} + +// The Salsa heap is tracked separately. +impl get_size2::GetSize for InteriorNode<'_> {} + +#[salsa::tracked] +impl<'db> InteriorNode<'db> { + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn negate(self, db: &'db dyn Db) -> Node<'db> { + Node::new( + db, + self.constraint(db), + self.if_true(db).negate(db), + self.if_false(db).negate(db), + ) + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn or(self, db: &'db dyn Db, other: Self) -> Node<'db> { + let self_constraint = self.constraint(db); + let other_constraint = other.constraint(db); + match self_constraint.cmp(&other_constraint) { + Ordering::Equal => Node::new( + db, + self_constraint, + self.if_true(db).or(db, other.if_true(db)), + self.if_false(db).or(db, other.if_false(db)), + ), + Ordering::Less => Node::new( + db, + self_constraint, + self.if_true(db).or(db, Node::Interior(other)), + self.if_false(db).or(db, Node::Interior(other)), + ), + Ordering::Greater => Node::new( + db, + other_constraint, + Node::Interior(self).or(db, other.if_true(db)), + Node::Interior(self).or(db, other.if_false(db)), + ), + } + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn and(self, db: &'db dyn Db, other: Self) -> Node<'db> { + let self_constraint = self.constraint(db); + let other_constraint = other.constraint(db); + match self_constraint.cmp(&other_constraint) { + Ordering::Equal => Node::new( + db, + self_constraint, + self.if_true(db).and(db, other.if_true(db)), + self.if_false(db).and(db, other.if_false(db)), + ), + Ordering::Less => Node::new( + db, + self_constraint, + self.if_true(db).and(db, Node::Interior(other)), + self.if_false(db).and(db, Node::Interior(other)), + ), + Ordering::Greater => Node::new( + db, + other_constraint, + Node::Interior(self).and(db, other.if_true(db)), + Node::Interior(self).and(db, other.if_false(db)), + ), + } + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn iff(self, db: &'db dyn Db, other: Self) -> Node<'db> { + let self_constraint = self.constraint(db); + let other_constraint = other.constraint(db); + match self_constraint.cmp(&other_constraint) { + Ordering::Equal => Node::new( + db, + self_constraint, + self.if_true(db).iff(db, other.if_true(db)), + self.if_false(db).iff(db, other.if_false(db)), + ), + Ordering::Less => Node::new( + db, + self_constraint, + self.if_true(db).iff(db, Node::Interior(other)), + self.if_false(db).iff(db, Node::Interior(other)), + ), + Ordering::Greater => Node::new( + db, + other_constraint, + Node::Interior(self).iff(db, other.if_true(db)), + Node::Interior(self).iff(db, other.if_false(db)), + ), + } + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn restrict_one( + self, + db: &'db dyn Db, + assignment: ConstraintAssignment<'db>, + ) -> (Node<'db>, bool) { + // If this node's variable is larger than the assignment's variable, then we have reached a + // point in the BDD where the assignment can no longer affect the result, + // and we can return early. + let self_constraint = self.constraint(db); + if assignment.constraint() < self_constraint { + return (Node::Interior(self), false); + } + + // Otherwise, check if this node's variable is in the assignment. If so, substitute the + // variable by replacing this node with its if_false/if_true edge, accordingly. + if assignment == self_constraint.when_true() { + (self.if_true(db), true) + } else if assignment == self_constraint.when_false() { + (self.if_false(db), true) + } else { + let (if_true, found_in_true) = self.if_true(db).restrict_one(db, assignment); + let (if_false, found_in_false) = self.if_false(db).restrict_one(db, assignment); + ( + Node::new(db, self_constraint, if_true, if_false), + found_in_true || found_in_false, + ) + } + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn simplify(self, db: &'db dyn Db) -> Node<'db> { + // To simplify a non-terminal BDD, we find all pairs of constraints that are mentioned in + // the BDD. If any of those pairs can be simplified to some other BDD, we perform a + // substitution to replace the pair with the simplification. + // + // Some of the simplifications create _new_ constraints that weren't originally present in + // the BDD. If we encounter one of those cases, we need to check if we can simplify things + // further relative to that new constraint. + // + // To handle this, we keep track of the individual constraints that we have already + // discovered (`seen_constraints`), and a queue of constraint pairs that we still need to + // check (`to_visit`). + + // Seed the seen set with all of the constraints that are present in the input BDD, and the + // visit queue with all pairs of those constraints. (We use "combinations" because we don't + // need to compare a constraint against itself, and because ordering doesn't matter.) + let mut seen_constraints = FxHashSet::default(); + Node::Interior(self).for_each_constraint(db, &mut |constraint| { + seen_constraints.insert(constraint); + }); + let mut to_visit: Vec<(_, _)> = (seen_constraints.iter().copied()) + .tuple_combinations() + .collect(); + + // Repeatedly pop constraint pairs off of the visit queue, checking whether each pair can + // be simplified. + let mut simplified = Node::Interior(self); + while let Some((left_constraint, right_constraint)) = to_visit.pop() { + // If the constraints refer to different typevars, they trivially cannot be compared. + // TODO: We might need to consider when one constraint's upper or lower bound refers to + // the other constraint's typevar. + let typevar = left_constraint.typevar(db); + if typevar != right_constraint.typevar(db) { + continue; + } + + // Containment: The range of one constraint might completely contain the range of the + // other. If so, there are several potential simplifications. + let larger_smaller = if left_constraint.contains(db, right_constraint) { + Some((left_constraint, right_constraint)) + } else if right_constraint.contains(db, left_constraint) { + Some((right_constraint, left_constraint)) + } else { + None + }; + if let Some((larger_constraint, smaller_constraint)) = larger_smaller { + // larger ∨ smaller = larger + simplified = simplified.substitute_union( + db, + larger_constraint.when_true(), + smaller_constraint.when_true(), + Node::new_satisfied_constraint(db, larger_constraint.when_true()), + ); + + // ¬larger ∧ ¬smaller = ¬larger + simplified = simplified.substitute_intersection( + db, + larger_constraint.when_false(), + smaller_constraint.when_false(), + Node::new_satisfied_constraint(db, larger_constraint.when_false()), + ); + + // smaller ∧ ¬larger = false + // (¬larger removes everything that's present in smaller) + simplified = simplified.substitute_intersection( + db, + larger_constraint.when_false(), + smaller_constraint.when_true(), + Node::AlwaysFalse, + ); + + // larger ∨ ¬smaller = true + // (larger fills in everything that's missing in ¬smaller) + simplified = simplified.substitute_union( + db, + larger_constraint.when_true(), + smaller_constraint.when_false(), + Node::AlwaysTrue, + ); + } + + // There are some simplifications we can make when the intersection of the two + // constraints is empty, and others that we can make when the intersection is + // non-empty. + match left_constraint.intersect(db, right_constraint) { + Some(intersection_constraint) => { + // If the intersection is non-empty, we need to create a new constraint to + // represent that intersection. We also need to add the new constraint to our + // seen set and (if we haven't already seen it) to the to-visit queue. + if seen_constraints.insert(intersection_constraint) { + to_visit.extend( + (seen_constraints.iter().copied()) + .filter(|seen| *seen != intersection_constraint) + .map(|seen| (seen, intersection_constraint)), + ); + } + let positive_intersection_node = + Node::new_satisfied_constraint(db, intersection_constraint.when_true()); + let negative_intersection_node = + Node::new_satisfied_constraint(db, intersection_constraint.when_false()); + + // left ∧ right = intersection + simplified = simplified.substitute_intersection( + db, + left_constraint.when_true(), + right_constraint.when_true(), + positive_intersection_node, + ); + + // ¬left ∨ ¬right = ¬intersection + simplified = simplified.substitute_union( + db, + left_constraint.when_false(), + right_constraint.when_false(), + negative_intersection_node, + ); + + // left ∧ ¬right = left ∧ ¬intersection + // (clip the negative constraint to the smallest range that actually removes + // something from positive constraint) + simplified = simplified.substitute_intersection( + db, + left_constraint.when_true(), + right_constraint.when_false(), + Node::new_satisfied_constraint(db, left_constraint.when_true()) + .and(db, negative_intersection_node), + ); + + // ¬left ∧ right = ¬intersection ∧ right + // (save as above but reversed) + simplified = simplified.substitute_intersection( + db, + left_constraint.when_false(), + right_constraint.when_true(), + Node::new_satisfied_constraint(db, right_constraint.when_true()) + .and(db, negative_intersection_node), + ); + + // left ∨ ¬right = intersection ∨ ¬right + // (clip the positive constraint to the smallest range that actually adds + // something to the negative constraint) + simplified = simplified.substitute_union( + db, + left_constraint.when_true(), + right_constraint.when_false(), + Node::new_satisfied_constraint(db, right_constraint.when_false()) + .or(db, positive_intersection_node), + ); + + // ¬left ∨ right = ¬left ∨ intersection + // (save as above but reversed) + simplified = simplified.substitute_union( + db, + left_constraint.when_false(), + right_constraint.when_true(), + Node::new_satisfied_constraint(db, left_constraint.when_false()) + .or(db, positive_intersection_node), + ); + } + + None => { + // All of the below hold because we just proved that the intersection of left + // and right is empty. + + // left ∧ right = false + simplified = simplified.substitute_intersection( + db, + left_constraint.when_true(), + right_constraint.when_true(), + Node::AlwaysFalse, + ); + + // ¬left ∨ ¬right = true + simplified = simplified.substitute_union( + db, + left_constraint.when_false(), + right_constraint.when_false(), + Node::AlwaysTrue, + ); + + // left ∧ ¬right = left + // (there is nothing in the hole of ¬right that overlaps with left) + simplified = simplified.substitute_intersection( + db, + left_constraint.when_true(), + right_constraint.when_false(), + Node::new_constraint(db, left_constraint), + ); + + // ¬left ∧ right = right + // (save as above but reversed) + simplified = simplified.substitute_intersection( + db, + left_constraint.when_false(), + right_constraint.when_true(), + Node::new_constraint(db, right_constraint), + ); + } + } + } + + simplified + } +} + +/// An assignment of one BDD variable to either `true` or `false`. (When evaluating a BDD, we +/// must provide an assignment for each variable present in the BDD.) +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +enum ConstraintAssignment<'db> { + Positive(ConstrainedTypeVar<'db>), + Negative(ConstrainedTypeVar<'db>), +} + +impl<'db> ConstraintAssignment<'db> { + fn constraint(self) -> ConstrainedTypeVar<'db> { + match self { + ConstraintAssignment::Positive(constraint) => constraint, + ConstraintAssignment::Negative(constraint) => constraint, + } + } + + fn negated(self) -> Self { + match self { + ConstraintAssignment::Positive(constraint) => { + ConstraintAssignment::Negative(constraint) + } + ConstraintAssignment::Negative(constraint) => { + ConstraintAssignment::Positive(constraint) + } + } + } + + fn negate(&mut self) { + *self = self.negated(); + } + + // Keep this for future debugging needs, even though it's not currently used when rendering + // constraint sets. + #[expect(dead_code)] + fn display(self, db: &'db dyn Db) -> impl Display { + struct DisplayConstraintAssignment<'db> { + constraint: ConstraintAssignment<'db>, + db: &'db dyn Db, + } + + impl Display for DisplayConstraintAssignment<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.constraint { + ConstraintAssignment::Positive(constraint) => { + constraint.display(self.db).fmt(f) + } + ConstraintAssignment::Negative(constraint) => { + constraint.display_negated(self.db).fmt(f) + } + } + } + } + + DisplayConstraintAssignment { constraint: self, - typevar, db, } } } -/// Wraps a constraint (or clause, or set), while using distinct variants to represent when the -/// constraint is never satisfiable or always satisfiable. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum Satisfiable { - Never, - Always, - Constrained(T), +/// A single clause in the DNF representation of a BDD +#[derive(Clone, Debug, Default, Eq, PartialEq)] +struct SatisfiedClause<'db> { + constraints: Vec>, } -impl Satisfiable { - fn map(self, f: impl FnOnce(T) -> U) -> Satisfiable { - match self { - Satisfiable::Never => Satisfiable::Never, - Satisfiable::Always => Satisfiable::Always, - Satisfiable::Constrained(t) => Satisfiable::Constrained(f(t)), - } +impl<'db> SatisfiedClause<'db> { + fn push(&mut self, constraint: ConstraintAssignment<'db>) { + self.constraints.push(constraint); } -} -/// The result of trying to simplify two constraints (or clauses, or sets). Like [`Satisfiable`], -/// we use distinct variants to represent when the simplification is never satisfiable or always -/// satisfiable. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum Simplifiable { - NeverSatisfiable, - AlwaysSatisfiable, - Simplified(T), - NotSimplified(T, T), -} - -impl Simplifiable { - fn from_one(constraint: Satisfiable) -> Self { - match constraint { - Satisfiable::Never => Simplifiable::NeverSatisfiable, - Satisfiable::Always => Simplifiable::AlwaysSatisfiable, - Satisfiable::Constrained(constraint) => Simplifiable::Simplified(constraint), - } + fn pop(&mut self) { + self.constraints + .pop() + .expect("clause vector should not be empty"); } -} -impl Simplifiable { - fn from_union(first: Satisfiable, second: Satisfiable) -> Self { - match (first, second) { - (Satisfiable::Always, _) | (_, Satisfiable::Always) => Simplifiable::AlwaysSatisfiable, - (Satisfiable::Never, Satisfiable::Never) => Simplifiable::NeverSatisfiable, - (Satisfiable::Never, Satisfiable::Constrained(constraint)) - | (Satisfiable::Constrained(constraint), Satisfiable::Never) => { - Simplifiable::Simplified(constraint) - } - (Satisfiable::Constrained(first), Satisfiable::Constrained(second)) => { - Simplifiable::NotSimplified(first, second) + /// Invokes a closure with the last constraint in this clause negated. Returns the clause back + /// to its original state after invoking the closure. + fn with_negated_last_constraint(&mut self, f: impl for<'a> FnOnce(&'a Self)) { + if self.constraints.is_empty() { + return; + } + let last_index = self.constraints.len() - 1; + self.constraints[last_index].negate(); + f(self); + self.constraints[last_index].negate(); + } + + /// Removes another clause from this clause, if it appears as a prefix of this clause. Returns + /// whether the prefix was removed. + fn remove_prefix(&mut self, prefix: &SatisfiedClause<'db>) -> bool { + if self.constraints.starts_with(&prefix.constraints) { + self.constraints.drain(0..prefix.constraints.len()); + return true; + } + false + } + + fn display(&self, db: &'db dyn Db) -> impl Display { + struct DisplaySatisfiedClause<'a, 'db> { + clause: &'a SatisfiedClause<'db>, + db: &'db dyn Db, + } + + impl Display for DisplaySatisfiedClause<'_, '_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.clause.constraints.len() > 1 { + f.write_str("(")?; + } + for (i, constraint) in self.clause.constraints.iter().enumerate() { + if i > 0 { + f.write_str(" ∧ ")?; + } + match constraint { + ConstraintAssignment::Positive(constraint) => { + write!(f, "{}", constraint.display(self.db))?; + } + ConstraintAssignment::Negative(constraint) => { + write!(f, "{}", constraint.display_negated(self.db))?; + } + } + } + if self.clause.constraints.len() > 1 { + f.write_str(")")?; + } + Ok(()) } } - } - fn map(self, mut f: impl FnMut(T) -> U) -> Simplifiable { - match self { - Simplifiable::NeverSatisfiable => Simplifiable::NeverSatisfiable, - Simplifiable::AlwaysSatisfiable => Simplifiable::AlwaysSatisfiable, - Simplifiable::Simplified(t) => Simplifiable::Simplified(f(t)), - Simplifiable::NotSimplified(t1, t2) => Simplifiable::NotSimplified(f(t1), f(t2)), - } - } - - fn reverse(self) -> Self { - match self { - Simplifiable::NeverSatisfiable - | Simplifiable::AlwaysSatisfiable - | Simplifiable::Simplified(_) => self, - Simplifiable::NotSimplified(t1, t2) => Simplifiable::NotSimplified(t2, t1), - } + DisplaySatisfiedClause { clause: self, db } + } +} + +/// A list of the clauses that satisfy a BDD. This is a DNF representation of the boolean function +/// that the BDD represents. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +struct SatisfiedClauses<'db> { + clauses: Vec>, +} + +impl<'db> SatisfiedClauses<'db> { + fn push(&mut self, clause: SatisfiedClause<'db>) { + self.clauses.push(clause); + } + + /// Simplifies the DNF representation, removing redundancies that do not change the underlying + /// function. (This is used when displaying a BDD, to make sure that the representation that we + /// show is as simple as possible while still producing the same results.) + fn simplify(&mut self) { + while self.simplify_one_round() { + // Keep going + } + } + + fn simplify_one_round(&mut self) -> bool { + let mut changes_made = false; + + // First remove any duplicate clauses. (The clause list will start out with no duplicates + // in the first round of simplification, because of the guarantees provided by the BDD + // structure. But earlier rounds of simplification might have made some clauses redundant.) + // Note that we have to loop through the vector element indexes manually, since we might + // remove elements in each iteration. + let mut i = 0; + while i < self.clauses.len() { + let mut j = i + 1; + while j < self.clauses.len() { + if self.clauses[i] == self.clauses[j] { + self.clauses.swap_remove(j); + changes_made = true; + } else { + j += 1; + } + } + i += 1; + } + if changes_made { + return true; + } + + // Then look for "prefix simplifications". That is, looks for patterns + // + // (A ∧ B) ∨ (A ∧ ¬B ∧ ...) + // + // and replaces them with + // + // (A ∧ B) ∨ (...) + for i in 0..self.clauses.len() { + let (clause, rest) = self.clauses[..=i] + .split_last_mut() + .expect("index should be in range"); + clause.with_negated_last_constraint(|clause| { + for existing in rest { + changes_made |= existing.remove_prefix(clause); + } + }); + + let (clause, rest) = self.clauses[i..] + .split_first_mut() + .expect("index should be in range"); + clause.with_negated_last_constraint(|clause| { + for existing in rest { + changes_made |= existing.remove_prefix(clause); + } + }); + + if changes_made { + return true; + } + } + + false + } + + fn display(&self, db: &'db dyn Db) -> impl Display { + struct DisplaySatisfiedClauses<'a, 'db> { + clauses: &'a SatisfiedClauses<'db>, + db: &'db dyn Db, + } + + impl Display for DisplaySatisfiedClauses<'_, '_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.clauses.clauses.is_empty() { + return f.write_str("always"); + } + for (i, clause) in self.clauses.clauses.iter().enumerate() { + if i > 0 { + f.write_str(" ∨ ")?; + } + clause.display(self.db).fmt(f)?; + } + Ok(()) + } + } + + DisplaySatisfiedClauses { clauses: self, db } } } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 7650a4eac3..0edb463366 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -877,7 +877,7 @@ impl<'db> Specialization<'db> { } TypeVarVariance::Bivariant => ConstraintSet::from(true), }; - if result.intersect(db, &compatible).is_never_satisfied() { + if result.intersect(db, compatible).is_never_satisfied() { return result; } } @@ -918,7 +918,7 @@ impl<'db> Specialization<'db> { } TypeVarVariance::Bivariant => ConstraintSet::from(true), }; - if result.intersect(db, &compatible).is_never_satisfied() { + if result.intersect(db, compatible).is_never_satisfied() { return result; } } @@ -928,7 +928,7 @@ impl<'db> Specialization<'db> { (None, None) => {} (Some(self_tuple), Some(other_tuple)) => { let compatible = self_tuple.is_equivalent_to_impl(db, other_tuple, visitor); - if result.intersect(db, &compatible).is_never_satisfied() { + if result.intersect(db, compatible).is_never_satisfied() { return result; } } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 298c3b4644..79cd816b0a 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -6947,7 +6947,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::UnaryOp::Invert, Type::KnownInstance(KnownInstanceType::ConstraintSet(constraints)), ) => { - let constraints = constraints.constraints(self.db()).clone(); + let constraints = constraints.constraints(self.db()); let result = constraints.negate(self.db()); Type::KnownInstance(KnownInstanceType::ConstraintSet(TrackedConstraintSet::new( self.db(), @@ -7311,9 +7311,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::KnownInstance(KnownInstanceType::ConstraintSet(right)), ast::Operator::BitAnd, ) => { - let left = left.constraints(self.db()).clone(); - let right = right.constraints(self.db()).clone(); - let result = left.and(self.db(), || right); + let left = left.constraints(self.db()); + let right = right.constraints(self.db()); + let result = left.and(self.db(), || *right); Some(Type::KnownInstance(KnownInstanceType::ConstraintSet( TrackedConstraintSet::new(self.db(), result), ))) @@ -7324,9 +7324,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::KnownInstance(KnownInstanceType::ConstraintSet(right)), ast::Operator::BitOr, ) => { - let left = left.constraints(self.db()).clone(); - let right = right.constraints(self.db()).clone(); - let result = left.or(self.db(), || right); + let left = left.constraints(self.db()); + let right = right.constraints(self.db()); + let result = left.or(self.db(), || *right); Some(Type::KnownInstance(KnownInstanceType::ConstraintSet( TrackedConstraintSet::new(self.db(), result), ))) diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index ddc06951fc..d5ae64fdaf 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -551,10 +551,7 @@ impl<'db> Signature<'db> { let self_type = self_type.unwrap_or(Type::unknown()); let other_type = other_type.unwrap_or(Type::unknown()); !result - .intersect( - db, - &self_type.is_equivalent_to_impl(db, other_type, visitor), - ) + .intersect(db, self_type.is_equivalent_to_impl(db, other_type, visitor)) .is_never_satisfied() }; @@ -699,10 +696,7 @@ impl<'db> Signature<'db> { let type1 = type1.unwrap_or(Type::unknown()); let type2 = type2.unwrap_or(Type::unknown()); !result - .intersect( - db, - &type1.has_relation_to_impl(db, type2, relation, visitor), - ) + .intersect(db, type1.has_relation_to_impl(db, type2, relation, visitor)) .is_never_satisfied() }; diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index db7c4d26b7..c6d669612a 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -439,7 +439,7 @@ impl<'db> FixedLengthTuple> { let element_constraints = self_ty.has_relation_to_impl(db, *other_ty, relation, visitor); if result - .intersect(db, &element_constraints) + .intersect(db, element_constraints) .is_never_satisfied() { return result; @@ -452,7 +452,7 @@ impl<'db> FixedLengthTuple> { let element_constraints = self_ty.has_relation_to_impl(db, *other_ty, relation, visitor); if result - .intersect(db, &element_constraints) + .intersect(db, element_constraints) .is_never_satisfied() { return result; @@ -774,7 +774,7 @@ impl<'db> VariableLengthTuple> { let element_constraints = self_ty.has_relation_to_impl(db, other_ty, relation, visitor); if result - .intersect(db, &element_constraints) + .intersect(db, element_constraints) .is_never_satisfied() { return result; @@ -788,7 +788,7 @@ impl<'db> VariableLengthTuple> { let element_constraints = self_ty.has_relation_to_impl(db, other_ty, relation, visitor); if result - .intersect(db, &element_constraints) + .intersect(db, element_constraints) .is_never_satisfied() { return result; @@ -832,7 +832,7 @@ impl<'db> VariableLengthTuple> { return ConstraintSet::from(false); } }; - if result.intersect(db, &pair_constraints).is_never_satisfied() { + if result.intersect(db, pair_constraints).is_never_satisfied() { return result; } } @@ -858,7 +858,7 @@ impl<'db> VariableLengthTuple> { return ConstraintSet::from(false); } }; - if result.intersect(db, &pair_constraints).is_never_satisfied() { + if result.intersect(db, pair_constraints).is_never_satisfied() { return result; } }