From 02ebb2ee6189e5f07c57c14ac47eec1df27bb4d0 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Thu, 25 Sep 2025 21:55:35 -0400 Subject: [PATCH] [ty] Change to BDD representation for constraint sets (#20533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit While working on #20093, I kept running into test failures due to constraint sets not simplifying as much as they could, and therefore not being easily testable against "always true" and "always false". This PR updates our constraint set representation to use BDDs. Because BDDs are reduced and ordered, they are canonical — equivalent boolean formulas are represented by the same interned BDD node. That said, there is a wrinkle, in that the "variables" that we use in these BDDs — the individual constraints like `Lower ≤ T ≤ Upper` are not always independent of each other. As an example, given types `A ≤ B ≤ C ≤ D` and a typevar `T`, the constraints `A ≤ T ≤ C` and `B ≤ T ≤ D` "overlap" — their intersection is non-empty. So we should be able to simplify ``` (A ≤ T ≤ C) ∧ (B ≤ T ≤ D) == (B ≤ T ≤ C) ``` That's not a simplification that the BDD structure can perform itself, since those three constraints are modeled as separate BDD variables, and are therefore "opaque" to the BDD algorithms. That means we need to perform this kind of simplification ourselves. We look at pairs of constraints that appear in a BDD and see if they can be simplified relative to each other, and if so, replace the pair with the simplification. A large part of the toil of getting this PR to work was identifying all of those patterns and getting that substitution logic correct. With this new representation, all existing tests pass, as well as some new ones that represent test failures that were occuring on #20093. --------- Co-authored-by: Carl Meyer --- .../mdtest/type_properties/constraints.md | 57 +- .../src/types/constraints.rs | 2221 ++++++++--------- .../ty_python_semantic/src/types/generics.rs | 6 +- .../src/types/infer/builder.rs | 14 +- .../src/types/signatures.rs | 10 +- crates/ty_python_semantic/src/types/tuple.rs | 12 +- 6 files changed, 1100 insertions(+), 1220 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md b/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md index 3239c9c2f1..fb3254f87a 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md @@ -305,9 +305,9 @@ range. ```py def _[T]() -> None: - # revealed: ty_extensions.ConstraintSet[((SubSub ≤ T@_ ≤ Base) ∧ ¬(Sub ≤ T@_ ≤ Base))] + # revealed: ty_extensions.ConstraintSet[(¬(Sub ≤ T@_ ≤ Base) ∧ (SubSub ≤ T@_ ≤ Base))] reveal_type(range_constraint(SubSub, T, Base) & negated_range_constraint(Sub, T, Super)) - # revealed: ty_extensions.ConstraintSet[((SubSub ≤ T@_ ≤ Super) ∧ ¬(Sub ≤ T@_ ≤ Base))] + # revealed: ty_extensions.ConstraintSet[(¬(Sub ≤ T@_ ≤ Base) ∧ (SubSub ≤ T@_ ≤ Super))] reveal_type(range_constraint(SubSub, T, Super) & negated_range_constraint(Sub, T, Base)) ``` @@ -339,9 +339,9 @@ Otherwise, the union cannot be simplified. ```py def _[T]() -> None: - # revealed: ty_extensions.ConstraintSet[(¬(Sub ≤ T@_ ≤ Base) ∧ ¬(Base ≤ T@_ ≤ Super))] + # revealed: ty_extensions.ConstraintSet[(¬(Base ≤ T@_ ≤ Super) ∧ ¬(Sub ≤ T@_ ≤ Base))] reveal_type(negated_range_constraint(Sub, T, Base) & negated_range_constraint(Base, T, Super)) - # revealed: ty_extensions.ConstraintSet[(¬(SubSub ≤ T@_ ≤ Sub) ∧ ¬(Base ≤ T@_ ≤ Super))] + # revealed: ty_extensions.ConstraintSet[(¬(Base ≤ T@_ ≤ Super) ∧ ¬(SubSub ≤ T@_ ≤ Sub))] reveal_type(negated_range_constraint(SubSub, T, Sub) & negated_range_constraint(Base, T, Super)) # revealed: ty_extensions.ConstraintSet[(¬(SubSub ≤ T@_ ≤ Sub) ∧ ¬(Unrelated ≤ T@_))] reveal_type(negated_range_constraint(SubSub, T, Sub) & negated_range_constraint(Unrelated, T, object)) @@ -385,7 +385,7 @@ We cannot simplify the union of constraints that refer to different typevars. def _[T, U]() -> None: # revealed: ty_extensions.ConstraintSet[(Sub ≤ T@_ ≤ Base) ∨ (Sub ≤ U@_ ≤ Base)] reveal_type(range_constraint(Sub, T, Base) | range_constraint(Sub, U, Base)) - # revealed: ty_extensions.ConstraintSet[¬(Sub ≤ T@_ ≤ Base) ∨ ¬(Sub ≤ U@_ ≤ Base)] + # revealed: ty_extensions.ConstraintSet[¬(Sub ≤ U@_ ≤ Base) ∨ ¬(Sub ≤ T@_ ≤ Base)] reveal_type(negated_range_constraint(Sub, T, Base) | negated_range_constraint(Sub, U, Base)) ``` @@ -417,9 +417,9 @@ Otherwise, the union cannot be simplified. ```py def _[T]() -> None: - # revealed: ty_extensions.ConstraintSet[(Sub ≤ T@_ ≤ Base) ∨ (Base ≤ T@_ ≤ Super)] + # revealed: ty_extensions.ConstraintSet[(Base ≤ T@_ ≤ Super) ∨ (Sub ≤ T@_ ≤ Base)] reveal_type(range_constraint(Sub, T, Base) | range_constraint(Base, T, Super)) - # revealed: ty_extensions.ConstraintSet[(SubSub ≤ T@_ ≤ Sub) ∨ (Base ≤ T@_ ≤ Super)] + # revealed: ty_extensions.ConstraintSet[(Base ≤ T@_ ≤ Super) ∨ (SubSub ≤ T@_ ≤ Sub)] reveal_type(range_constraint(SubSub, T, Sub) | range_constraint(Base, T, Super)) # revealed: ty_extensions.ConstraintSet[(SubSub ≤ T@_ ≤ Sub) ∨ (Unrelated ≤ T@_)] reveal_type(range_constraint(SubSub, T, Sub) | range_constraint(Unrelated, T, object)) @@ -488,9 +488,9 @@ range. ```py def _[T]() -> None: - # revealed: ty_extensions.ConstraintSet[¬(SubSub ≤ T@_ ≤ Base) ∨ (Sub ≤ T@_ ≤ Base)] + # revealed: ty_extensions.ConstraintSet[(Sub ≤ T@_ ≤ Base) ∨ ¬(SubSub ≤ T@_ ≤ Base)] reveal_type(negated_range_constraint(SubSub, T, Base) | range_constraint(Sub, T, Super)) - # revealed: ty_extensions.ConstraintSet[¬(SubSub ≤ T@_ ≤ Super) ∨ (Sub ≤ T@_ ≤ Base)] + # revealed: ty_extensions.ConstraintSet[(Sub ≤ T@_ ≤ Base) ∨ ¬(SubSub ≤ T@_ ≤ Super)] reveal_type(negated_range_constraint(SubSub, T, Super) | range_constraint(Sub, T, Base)) ``` @@ -562,3 +562,42 @@ def _[T]() -> None: # revealed: ty_extensions.ConstraintSet[always] reveal_type(constraint | ~constraint) ``` + +### Negation of constraints involving two variables + +```py +from typing import final, Never +from ty_extensions import range_constraint + +class Base: ... + +@final +class Unrelated: ... + +def _[T, U]() -> None: + # revealed: ty_extensions.ConstraintSet[¬(U@_ ≤ Base) ∨ ¬(T@_ ≤ Base)] + reveal_type(~(range_constraint(Never, T, Base) & range_constraint(Never, U, Base))) +``` + +The union of a constraint and its negation should always be satisfiable. + +```py +def _[T, U]() -> None: + c1 = range_constraint(Never, T, Base) & range_constraint(Never, U, Base) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(c1 | ~c1) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(~c1 | c1) + + c2 = range_constraint(Unrelated, T, object) & range_constraint(Unrelated, U, object) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(c2 | ~c2) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(~c2 | c2) + + union = c1 | c2 + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(union | ~union) + # revealed: ty_extensions.ConstraintSet[always] + reveal_type(~union | union) +``` diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 7087f01301..cb4e504196 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -14,26 +14,16 @@ //! This module provides the machinery for representing the "under what constraints" part of that //! question. //! -//! An individual constraint restricts the specialization of a single typevar. You can then build -//! up more complex constraint sets using union, intersection, and negation operations. We use a -//! disjunctive normal form (DNF) representation, just like we do for types: a [constraint -//! set][ConstraintSet] is the union of zero or more [clauses][ConstraintClause], each of which is -//! the intersection of zero or more [individual constraints][ConstrainedTypeVar]. Note that the -//! constraint set that contains no clauses is never satisfiable (`⋃ {} = 0`); and the constraint -//! set that contains a single clause, where that clause contains no constraints, is always -//! satisfiable (`⋃ {⋂ {}} = 1`). -//! -//! An individual constraint consists of a _positive range_ and zero or more _negative holes_. The -//! positive range and each negative hole consists of a lower and upper bound. A type is within a -//! lower and upper bound if it is a supertype of the lower bound and a subtype of the upper bound. -//! The typevar can specialize to any type that is within the positive range, and is not within any -//! of the negative holes. (You can think of the constraint as the set of types that are within the -//! positive range, with the negative holes "removed" from that set.) +//! An individual constraint restricts the specialization of a single typevar to be within a +//! particular lower and upper bound. (A type is within a lower and upper bound if it is a +//! supertype of the lower bound and a subtype of the upper bound.) You can then build up more +//! complex constraint sets using union, intersection, and negation operations. We use a [binary +//! decision diagram][bdd] (BDD) to represent a constraint set. //! //! Note that all lower and upper bounds in a constraint must be fully static. We take the bottom //! and top materializations of the types to remove any gradual forms if needed. //! -//! NOTE: This module is currently in a transitional state. We've added the DNF [`ConstraintSet`] +//! NOTE: This module is currently in a transitional state. We've added the BDD [`ConstraintSet`] //! representation, and updated all of our property checks to build up a constraint set and then //! check whether it is ever or always satisfiable, as appropriate. We are not yet inferring //! specializations from those constraints. @@ -60,23 +50,18 @@ //! constraint `(int ≤ T ≤ int) ∪ (str ≤ T ≤ str)`. When the lower and upper bounds are the same, //! the constraint says that the typevar must specialize to that _exact_ type, not to a subtype or //! supertype of it. +//! +//! [bdd]: https://en.wikipedia.org/wiki/Binary_decision_diagram +use std::cmp::Ordering; use std::fmt::Display; -use itertools::{EitherOrBoth, Itertools}; -use smallvec::{SmallVec, smallvec}; +use itertools::Itertools; +use rustc_hash::FxHashSet; use crate::Db; use crate::types::{BoundTypeVarInstance, IntersectionType, Type, UnionType}; -fn comparable<'db>(db: &'db dyn Db, left: Type<'db>, right: Type<'db>) -> bool { - left.is_subtype_of(db, right) || right.is_subtype_of(db, left) -} - -fn incomparable<'db>(db: &'db dyn Db, left: Type<'db>, right: Type<'db>) -> bool { - !comparable(db, left, right) -} - /// An extension trait for building constraint sets from [`Option`] values. pub(crate) trait OptionConstraintsExtension { /// Returns a constraint set that is always satisfiable if the option is `None`; otherwise @@ -154,7 +139,7 @@ where ) -> ConstraintSet<'db> { let mut result = ConstraintSet::always(); for child in self { - if result.intersect(db, &f(child)).is_never_satisfied() { + if result.intersect(db, f(child)).is_never_satisfied() { return result; } } @@ -164,66 +149,55 @@ where /// A set of constraints under which a type property holds. /// -/// We use a DNF representation, so a set contains a list of zero or more -/// [clauses][ConstraintClause], each of which is an intersection of zero or more -/// [constraints][ConstrainedTypeVar]. -/// /// This is called a "set of constraint sets", and denoted _𝒮_, in [[POPL2015][]]. /// -/// ### Invariants -/// -/// - The clauses are simplified as much as possible — there are no two clauses in the set that can -/// be simplified into a single clause. -/// /// [POPL2015]: https://doi.org/10.1145/2676726.2676991 -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] pub struct ConstraintSet<'db> { - // NOTE: We use 2 here because there are a couple of places where we create unions of 2 clauses - // as temporary values — in particular when negating a constraint — and this lets us avoid - // spilling the temporary value to the heap. - clauses: SmallVec<[ConstraintClause<'db>; 2]>, + /// The BDD representing this constraint set + node: Node<'db>, } impl<'db> ConstraintSet<'db> { fn never() -> Self { Self { - clauses: smallvec![], + node: Node::AlwaysFalse, } } fn always() -> Self { - Self::singleton(ConstraintClause::always()) + Self { + node: Node::AlwaysTrue, + } } /// Returns whether this constraint set never holds - pub(crate) fn is_never_satisfied(&self) -> bool { - self.clauses.is_empty() + pub(crate) fn is_never_satisfied(self) -> bool { + self.node.is_never_satisfied() } /// Returns whether this constraint set always holds - pub(crate) fn is_always_satisfied(&self) -> bool { - self.clauses.len() == 1 && self.clauses[0].is_always() + pub(crate) fn is_always_satisfied(self) -> bool { + self.node.is_always_satisfied() } /// Updates this constraint set to hold the union of itself and another constraint set. - pub(crate) fn union(&mut self, db: &'db dyn Db, other: Self) -> &Self { - self.union_set(db, other); - self + pub(crate) fn union(&mut self, db: &'db dyn Db, other: Self) -> Self { + self.node = self.node.or(db, other.node).simplify(db); + *self } /// Updates this constraint set to hold the intersection of itself and another constraint set. - pub(crate) fn intersect(&mut self, db: &'db dyn Db, other: &Self) -> &Self { - self.intersect_set(db, other); - self + pub(crate) fn intersect(&mut self, db: &'db dyn Db, other: Self) -> Self { + self.node = self.node.and(db, other.node).simplify(db); + *self } /// Returns the negation of this constraint set. - pub(crate) fn negate(&self, db: &'db dyn Db) -> Self { - let mut result = Self::always(); - for clause in &self.clauses { - result.intersect_set(db, &clause.negate(db)); + pub(crate) fn negate(self, db: &'db dyn Db) -> Self { + Self { + node: self.node.negate(db).simplify(db), } - result } /// Returns the intersection of this constraint set and another. The other constraint set is @@ -231,7 +205,7 @@ impl<'db> ConstraintSet<'db> { /// constraint set is already saturated. pub(crate) fn and(mut self, db: &'db dyn Db, other: impl FnOnce() -> Self) -> Self { if !self.is_never_satisfied() { - self.intersect(db, &other()); + self.intersect(db, other()); } self } @@ -246,13 +220,6 @@ impl<'db> ConstraintSet<'db> { self } - /// Returns a constraint set that contains a single clause. - fn singleton(clause: ConstraintClause<'db>) -> Self { - Self { - clauses: smallvec![clause], - } - } - pub(crate) fn range( db: &'db dyn Db, lower: Type<'db>, @@ -261,10 +228,9 @@ impl<'db> ConstraintSet<'db> { ) -> Self { let lower = lower.bottom_materialization(db); let upper = upper.top_materialization(db); - let constraint = Constraint::range(db, lower, upper).constrain(typevar); - let mut result = Self::never(); - result.union_constraint(db, constraint); - result + Self { + node: ConstrainedTypeVar::new_node(db, lower, typevar, upper), + } } pub(crate) fn negated_range( @@ -273,130 +239,11 @@ impl<'db> ConstraintSet<'db> { typevar: BoundTypeVarInstance<'db>, upper: Type<'db>, ) -> Self { - let lower = lower.bottom_materialization(db); - let upper = upper.top_materialization(db); - let constraint = Constraint::negated_range(db, lower, upper).constrain(typevar); - let mut result = Self::never(); - result.union_constraint(db, constraint); - result + Self::range(db, lower, typevar, upper).negate(db) } - /// Updates this set to be the union of itself and a constraint. - fn union_constraint( - &mut self, - db: &'db dyn Db, - constraint: Satisfiable>, - ) { - self.union_clause(db, constraint.map(ConstraintClause::singleton)); - } - - /// Updates this set to be the union of itself and a clause. To maintain the invariants of this - /// type, we must simplify this clause against all existing clauses, if possible. - fn union_clause(&mut self, db: &'db dyn Db, clause: Satisfiable>) { - let mut clause = match clause { - // If the new constraint can always be satisfied, that causes this whole set to be - // always satisfied too. - Satisfiable::Always => { - self.clauses.clear(); - self.clauses.push(ConstraintClause::always()); - return; - } - - // If the new clause can never satisfied, then the set does not change. - Satisfiable::Never => return, - - Satisfiable::Constrained(clause) => clause, - }; - - // Naively, we would just append the new clause to the set's list of clauses. But that - // doesn't ensure that the clauses are simplified with respect to each other. So instead, - // we iterate through the list of existing clauses, and try to simplify the new clause - // against each one in turn. (We can assume that the existing clauses are already - // simplified with respect to each other, since we can assume that the invariant holds upon - // entry to this method.) - let mut existing_clauses = std::mem::take(&mut self.clauses).into_iter(); - for existing in existing_clauses.by_ref() { - // Try to simplify the new clause against an existing clause. - match existing.simplify_clauses(db, clause) { - Simplifiable::NeverSatisfiable => { - // If two clauses cancel out to 0, that does NOT cause the entire set to become - // 0. We need to keep whatever clauses have already been added to the result, - // and also need to copy over any later clauses that we hadn't processed yet. - self.clauses.extend(existing_clauses); - return; - } - - Simplifiable::AlwaysSatisfiable => { - // If two clauses cancel out to 1, that makes the entire set 1, and all - // existing clauses are simplified away. - self.clauses.clear(); - self.clauses.push(ConstraintClause::always()); - return; - } - - Simplifiable::NotSimplified(existing, c) => { - // We couldn't simplify the new clause relative to this existing clause, so add - // the existing clause to the result. Continue trying to simplify the new - // clause against the later existing clauses. - self.clauses.push(existing); - clause = c; - } - - Simplifiable::Simplified(c) => { - // We were able to simplify the new clause relative to this existing clause. - // Don't add it to the result yet; instead, try to simplify the result further - // against later existing clauses. - clause = c; - } - } - } - - // If we fall through then we need to add the new clause to the clause list (either because - // we couldn't simplify it with anything, or because we did without it canceling out). - self.clauses.push(clause); - } - - /// Updates this set to be the union of itself and another set. - fn union_set(&mut self, db: &'db dyn Db, other: Self) { - for clause in other.clauses { - self.union_clause(db, Satisfiable::Constrained(clause)); - } - } - - /// Updates this set to be the intersection of itself and another set. - fn intersect_set(&mut self, db: &'db dyn Db, other: &Self) { - // This is the distributive law: - // (A ∪ B) ∩ (C ∪ D ∪ E) = (A ∩ C) ∪ (A ∩ D) ∪ (A ∩ E) ∪ (B ∩ C) ∪ (B ∩ D) ∪ (B ∩ E) - let self_clauses = std::mem::take(&mut self.clauses); - for self_clause in &self_clauses { - for other_clause in &other.clauses { - self.union_clause(db, self_clause.intersect_clause(db, other_clause)); - } - } - } - - pub(crate) fn display(&self, db: &'db dyn Db) -> impl Display { - struct DisplayConstraintSet<'a, 'db> { - set: &'a ConstraintSet<'db>, - db: &'db dyn Db, - } - - impl Display for DisplayConstraintSet<'_, '_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.set.clauses.is_empty() { - return f.write_str("0"); - } - for (i, clause) in self.set.clauses.iter().enumerate() { - if i > 0 { - f.write_str(" ∨ ")?; - } - clause.display(self.db).fmt(f)?; - } - Ok(()) - } - } - - DisplayConstraintSet { set: self, db } + pub(crate) fn display(self, db: &'db dyn Db) -> impl Display { + self.node.display(db) } } @@ -406,823 +253,68 @@ impl From for ConstraintSet<'_> { } } -/// The intersection of zero or more individual constraints. -/// -/// This is called a "constraint set", and denoted _C_, in [[POPL2015][]]. -/// -/// [POPL2015]: https://doi.org/10.1145/2676726.2676991 -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] -pub(crate) struct ConstraintClause<'db> { - // NOTE: We use 1 here because most clauses only mention a single typevar. - constraints: SmallVec<[ConstrainedTypeVar<'db>; 1]>, -} - -impl<'db> ConstraintClause<'db> { - fn new(constraints: SmallVec<[ConstrainedTypeVar<'db>; 1]>) -> Satisfiable { - if constraints.is_empty() { - Satisfiable::Always - } else { - Satisfiable::Constrained(Self { constraints }) - } - } - - /// Returns the clause that is always satisfiable. - fn always() -> Self { - Self { - constraints: smallvec![], - } - } - - /// Returns a clause containing a single constraint. - fn singleton(constraint: ConstrainedTypeVar<'db>) -> Self { - Self { - constraints: smallvec![constraint], - } - } - - /// Returns whether this constraint is always satisfiable. - fn is_always(&self) -> bool { - self.constraints.is_empty() - } - - fn is_satisfiable(&self) -> Satisfiable<()> { - if self.is_always() { - Satisfiable::Always - } else { - Satisfiable::Constrained(()) - } - } - - /// Updates this clause to be the intersection of itself and an individual constraint. Returns - /// a flag indicating whether the updated clause is never, always, or sometimes satisfied. - fn intersect_constraint( - &mut self, - db: &'db dyn Db, - constraint: Satisfiable>, - ) -> Satisfiable<()> { - let mut constraint = match constraint { - // If the new constraint cannot be satisfied, that causes this whole clause to be - // unsatisfiable too. - Satisfiable::Never => return Satisfiable::Never, - - // If the new constraint can always satisfied, then the clause does not change. It was - // not always satisfiable before, and so it still isn't. - Satisfiable::Always => return Satisfiable::Constrained(()), - - Satisfiable::Constrained(constraint) => constraint, - }; - - // Naively, we would just append the new constraint to the clauses's list of constraints. - // But that doesn't ensure that the constraints are simplified with respect to each other. - // So instead, we iterate through the list of existing constraints, and try to simplify the - // new constraint against each one in turn. (We can assume that the existing constraints - // are already simplified with respect to each other, since we can assume that the - // invariant holds upon entry to this method.) - let mut existing_constraints = std::mem::take(&mut self.constraints).into_iter(); - for existing in existing_constraints.by_ref() { - // Try to simplify the new constraint against an existing constraint. - match existing.intersect(db, &constraint) { - Some(Satisfiable::Never) => { - // If two constraints cancel out to 0, that makes the entire clause 0, and all - // existing constraints are simplified away. - return Satisfiable::Never; - } - - Some(Satisfiable::Always) => { - // If two constraints cancel out to 1, that does NOT cause the entire clause to - // become 1. We need to keep whatever constraints have already been added to - // the result, and also need to copy over any later constraints that we hadn't - // processed yet. - self.constraints.extend(existing_constraints); - return self.is_satisfiable(); - } - - None => { - // We couldn't simplify the new constraint relative to this existing - // constraint, so add the existing constraint to the result. Continue trying to - // simplify the new constraint against the later existing constraints. - self.constraints.push(existing.clone()); - } - - Some(Satisfiable::Constrained(simplified)) => { - // We were able to simplify the new constraint relative to this existing - // constraint. Don't add it to the result yet; instead, try to simplify the - // result further against later existing constraints. - constraint = simplified; - } - } - } - - // If we fall through then we need to add the new constraint to the constraint list (either - // because we couldn't simplify it with anything, or because we did without it canceling - // out). - self.constraints.push(constraint); - self.is_satisfiable() - } - - /// Returns the intersection of this clause with another. - fn intersect_clause(&self, db: &'db dyn Db, other: &Self) -> Satisfiable { - // Add each `other` constraint in turn. Short-circuit if the result ever becomes 0. - let mut result = self.clone(); - for constraint in &other.constraints { - match result.intersect_constraint(db, Satisfiable::Constrained(constraint.clone())) { - Satisfiable::Never => return Satisfiable::Never, - Satisfiable::Always | Satisfiable::Constrained(()) => {} - } - } - if result.is_always() { - Satisfiable::Always - } else { - Satisfiable::Constrained(result) - } - } - - /// Tries to simplify the union of two clauses into a single clause, if possible. - fn simplify_clauses(self, db: &'db dyn Db, other: Self) -> Simplifiable { - // Saturation - // - // If either clause is always satisfiable, the union is too. (`1 ∪ C₂ = 1`, `C₁ ∪ 1 = 1`) - // - // ```py - // class A[T]: ... - // - // class C1[U]: - // # T can specialize to any type, so this is "always satisfiable", or `1` - // x: A[U] - // - // class C2[V: int]: - // # `T ≤ int` - // x: A[V] - // - // class Saturation[U, V: int]: - // # `1 ∪ (T ≤ int)` - // # simplifies via saturation to - // # `T ≤ int` - // x: A[U] | A[V] - // ``` - if self.is_always() || other.is_always() { - return Simplifiable::Simplified(Self::always()); - } - - // Subsumption - // - // If either clause subsumes (is "smaller than") the other, then the union simplifies to - // the "bigger" clause (the one being subsumed): - // - // - `A ∩ B` must be at least as large as `A ∩ B ∩ C` - // - Therefore, `(A ∩ B) ∪ (A ∩ B ∩ C) = (A ∩ B)` - // - // (Note that possibly counterintuitively, "bigger" here means _fewer_ constraints in the - // intersection, since intersecting more things can only make the result smaller.) - // - // ```py - // class A[T, U, V]: ... - // - // class C1[X: int, Y: str, Z]: - // # `(T ≤ int ∩ U ≤ str)` - // x: A[X, Y, Z] - // - // class C2[X: int, Y: str, Z: bytes]: - // # `(T ≤ int ∩ U ≤ str ∩ V ≤ bytes)` - // x: A[X, Y, Z] - // - // class Subsumption[X1: int, Y1: str, Z2, X2: int, Y2: str, Z2: bytes]: - // # `(T ≤ int ∩ U ≤ str) ∪ (T ≤ int ∩ U ≤ str ∩ V ≤ bytes)` - // # simplifies via subsumption to - // # `(T ≤ int ∩ U ≤ str)` - // x: A[X1, Y1, Z2] | A[X2, Y2, Z2] - // ``` - // - // TODO: Consider checking both directions in one pass, possibly via a tri-valued return - // value. - if self.subsumes_via_intersection(db, &other) { - return Simplifiable::Simplified(other); - } - if other.subsumes_via_intersection(db, &self) { - return Simplifiable::Simplified(self); - } - - // Distribution - // - // If the two clauses constrain the same typevar in an "overlapping" way, we can factor - // that out: - // - // (A₁ ∩ B ∩ C) ∪ (A₂ ∩ B ∩ C) = (A₁ ∪ A₂) ∩ B ∩ C - // - // ```py - // class A[T, U, V]: ... - // - // class C1[X: int, Y: str, Z: bytes]: - // # `(T ≤ int ∩ U ≤ str ∩ V ≤ bytes)` - // x: A[X, Y, Z] - // - // class C2[X: bool, Y: str, Z: bytes]: - // # `(T ≤ bool ∩ U ≤ str ∩ V ≤ bytes)` - // x: A[X, Y, Z] - // - // class Distribution[X1: int, Y1: str, Z2: bytes, X2: bool, Y2: str, Z2: bytes]: - // # `(T ≤ int ∩ U ≤ str ∩ V ≤ bytes) ∪ (T ≤ bool ∩ U ≤ str ∩ V ≤ bytes)` - // # simplifies via distribution to - // # `(T ≤ int ∪ T ≤ bool) ∩ U ≤ str ∩ V ≤ bytes)` - // # which (because `bool ≤ int`) is equivalent to - // # `(T ≤ int ∩ U ≤ str ∩ V ≤ bytes)` - // x: A[X1, Y1, Z2] | A[X2, Y2, Z2] - // ``` - if let Some(simplified) = self.simplifies_via_distribution(db, &other) { - return simplified; - } - - // Can't be simplified - Simplifiable::NotSimplified(self, other) - } - - /// Returns whether this clause subsumes `other` via intersection — that is, if the - /// intersection of `self` and `other` is `self`. - fn subsumes_via_intersection(&self, db: &'db dyn Db, other: &Self) -> bool { - // See the notes in `simplify_clauses` for more details on subsumption, including Python - // examples that cause it to fire. - - if self.constraints.len() != other.constraints.len() { - return false; - } - - let pairwise = (self.constraints.iter()) - .merge_join_by(&other.constraints, |a, b| a.typevar.cmp(&b.typevar)); - for pair in pairwise { - match pair { - // `other` contains a constraint whose typevar doesn't appear in `self`, so `self` - // cannot be smaller. - EitherOrBoth::Right(_) => return false, - - // `self` contains a constraint whose typevar doesn't appear in `other`. `self` - // might be smaller, but we still have to check the remaining constraints. - EitherOrBoth::Left(_) => continue, - - // Both clauses contain a constraint with this typevar; verify that the constraint - // in `self` is smaller. - EitherOrBoth::Both(self_constraint, other_constraint) => { - if !self_constraint.subsumes(db, other_constraint) { - return false; - } - } - } - } - true - } - - /// If the union of two clauses is simpler than either of the individual clauses, returns the - /// union. This happens when they mention the same set of typevars and the constraints for all - /// but one typevar are identical. Moreover, for the other typevar, the union of the - /// constraints for that typevar simplifies to (a) a single constraint, or (b) two constraints - /// where one of them is smaller than before. That is, - /// - /// ```text - /// (A₁ ∩ B ∩ C) ∪ (A₂ ∩ B ∩ C) = A₁₂ ∩ B ∩ C - /// or (A₁' ∪ A₂) ∩ B ∩ C - /// or (A₁ ∪ A₂') ∩ B ∩ C - /// ``` - /// - /// where `B` and `C` are the constraints that are identical for all but one typevar, and `A₁` - /// and `A₂` are the constraints for the other typevar; and where `A₁ ∪ A₂` either simplifies - /// to a single constraint (`A₁₂`), or to a union where either `A₁` or `A₂` becomes smaller - /// (`A₁'` or `A₂'`, respectively). - /// - /// Otherwise returns `None`. - fn simplifies_via_distribution( - &self, - db: &'db dyn Db, - other: &Self, - ) -> Option> { - // See the notes in `simplify_clauses` for more details on distribution, including Python - // examples that cause it to fire. - - if self.constraints.len() != other.constraints.len() { - return None; - } - - // Verify that the constraints for precisely one typevar simplify, and the constraints for - // all other typevars are identical. Remember the index of the typevar whose constraints - // simplify. - let mut simplified_index = None; - let pairwise = (self.constraints.iter()) - .merge_join_by(&other.constraints, |a, b| a.typevar.cmp(&b.typevar)); - for (index, pair) in pairwise.enumerate() { - match pair { - // If either clause contains a constraint whose typevar doesn't appear in the - // other, the clauses don't simplify. - EitherOrBoth::Left(_) | EitherOrBoth::Right(_) => return None, - - EitherOrBoth::Both(self_constraint, other_constraint) => { - if self_constraint == other_constraint { - continue; - } - let Some(union_constraint) = - self_constraint.simplified_union(db, other_constraint) - else { - // The constraints for this typevar are not identical, nor do they - // simplify. - return None; - }; - if simplified_index - .replace((index, union_constraint)) - .is_some() - { - // More than one constraint simplify, which doesn't allow the clause as a - // whole to simplify. - return None; - } - } - } - } - - let Some((index, union_constraint)) = simplified_index else { - // We never found a typevar whose constraints simplify. - return None; - }; - let mut constraints = self.constraints.clone(); - match union_constraint { - Simplifiable::NeverSatisfiable => { - panic!("unioning two non-never constraints should not be never") - } - Simplifiable::AlwaysSatisfiable => { - // If the simplified union of constraints is Always, then we can remove this typevar - // from the constraint completely. - constraints.remove(index); - Some(Simplifiable::from_one(Self::new(constraints))) - } - Simplifiable::Simplified(union_constraint) => { - constraints[index] = union_constraint; - Some(Simplifiable::from_one(Self::new(constraints))) - } - Simplifiable::NotSimplified(left, right) => { - let mut left_constraints = constraints.clone(); - let mut right_constraints = constraints; - left_constraints[index] = left; - right_constraints[index] = right; - Some(Simplifiable::from_union( - Self::new(left_constraints), - Self::new(right_constraints), - )) - } - } - } - - /// Returns the negation of this clause. The result is a set since negating an intersection - /// produces a union. - fn negate(&self, db: &'db dyn Db) -> ConstraintSet<'db> { - let mut result = ConstraintSet::never(); - for constraint in &self.constraints { - constraint.negate_into(db, &mut result); - } - result - } - - fn display(&self, db: &'db dyn Db) -> impl Display { - struct DisplayConstraintClause<'a, 'db> { - clause: &'a ConstraintClause<'db>, - db: &'db dyn Db, - } - - impl Display for DisplayConstraintClause<'_, '_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.clause.constraints.is_empty() { - return f.write_str("1"); - } - - let clause_count: usize = (self.clause.constraints.iter()) - .map(ConstrainedTypeVar::clause_count) - .sum(); - if clause_count > 1 { - f.write_str("(")?; - } - for (i, constraint) in self.clause.constraints.iter().enumerate() { - if i > 0 { - f.write_str(" ∧ ")?; - } - constraint.display(self.db).fmt(f)?; - } - if clause_count > 1 { - f.write_str(")")?; - } - Ok(()) - } - } - - DisplayConstraintClause { clause: self, db } - } -} - -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] +/// An individual constraint in a constraint set. This restricts a single typevar to be within a +/// lower and upper bound. +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +#[derive(PartialOrd, Ord)] pub(crate) struct ConstrainedTypeVar<'db> { typevar: BoundTypeVarInstance<'db>, - constraint: Constraint<'db>, -} - -impl<'db> ConstrainedTypeVar<'db> { - fn clause_count(&self) -> usize { - self.constraint.clause_count() - } - - /// Returns the intersection of this individual constraint and another, or `None` if the two - /// constraints do not refer to the same typevar (and therefore cannot be simplified to a - /// single constraint). - fn intersect(&self, db: &'db dyn Db, other: &Self) -> Option> { - if self.typevar != other.typevar { - return None; - } - Some( - self.constraint - .intersect(db, &other.constraint) - .map(|constraint| constraint.constrain(self.typevar)), - ) - } - - /// Returns the union of this individual constraint and another, if it can be simplified to a - /// union of two constraints or fewer. Returns `None` if the union cannot be simplified that - /// much. - fn simplified_union(&self, db: &'db dyn Db, other: &Self) -> Option> { - if self.typevar != other.typevar { - return None; - } - self.constraint - .simplified_union(db, &other.constraint) - .map(|constraint| constraint.map(|constraint| constraint.constrain(self.typevar))) - } - - /// Adds the negation of this individual constraint to a constraint set. - fn negate_into(&self, db: &'db dyn Db, set: &mut ConstraintSet<'db>) { - self.constraint.negate_into(db, self.typevar, set); - } - - /// Returns whether `self` has tighter bounds than `other` — that is, if the intersection of - /// `self` and `other` is `self`. - fn subsumes(&self, db: &'db dyn Db, other: &Self) -> bool { - debug_assert_eq!(self.typevar, other.typevar); - match self.intersect(db, other) { - Some(Satisfiable::Constrained(intersection)) => intersection == *self, - _ => false, - } - } - - fn display(&self, db: &'db dyn Db) -> impl Display { - self.constraint.display(db, self.typevar.display(db)) - } -} - -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] -pub(crate) struct Constraint<'db> { - positive: RangeConstraint<'db>, - negative: SmallVec<[NegatedRangeConstraint<'db>; 1]>, -} - -impl<'db> Constraint<'db> { - fn constrain(self, typevar: BoundTypeVarInstance<'db>) -> ConstrainedTypeVar<'db> { - ConstrainedTypeVar { - typevar, - constraint: self, - } - } - - fn clause_count(&self) -> usize { - usize::from(!self.positive.is_always()) + self.negative.len() - } - - fn satisfiable(self, db: &'db dyn Db) -> Satisfiable { - if self.positive.is_always() && self.negative.is_empty() { - return Satisfiable::Always; - } - if (self.negative.iter()).any(|negative| negative.hole.contains(db, &self.positive)) { - return Satisfiable::Never; - } - Satisfiable::Constrained(self) - } - - fn intersect(&self, db: &'db dyn Db, other: &Constraint<'db>) -> Satisfiable> { - let Some(positive) = self.positive.intersect(db, &other.positive) else { - // If the positive intersection is empty, none of the negative holes matter, since - // there are no types for the holes to remove. - return Satisfiable::Never; - }; - - // The negative portion of the intersection is given by - // - // ¬(s₁ ≤ α ≤ t₁) ∧ ¬(s₂ ≤ α ≤ t₂) = ¬((s₁ ≤ α ≤ t₁) ∨ (s₂ ≤ α ≤ t₂)) - // - // That is, we union together the holes from `self` and `other`. If any of the holes - // entirely contain another, we can simplify those two down to the larger hole. We use the - // same trick as above in `union_clause` and `intersect_constraint` to look for pairs that - // we can simplify. - // - // We also want to clip each negative hole to the minimum range that overlaps with the - // positive range. We'll do that now to all of the holes from `self`, and we'll do that to - // holes from `other` below when we try to simplify them. - let mut previous: SmallVec<[NegatedRangeConstraint<'db>; 1]> = SmallVec::new(); - let mut current: SmallVec<_> = (self.negative.iter()) - .filter_map(|negative| negative.clip_to_positive(db, &positive)) - .collect(); - for other_negative in &other.negative { - let Some(mut other_negative) = other_negative.clip_to_positive(db, &positive) else { - continue; - }; - std::mem::swap(&mut previous, &mut current); - let mut previous_negative = previous.iter(); - for self_negative in previous_negative.by_ref() { - match self_negative.intersect_negative(db, &other_negative) { - None => { - // We couldn't simplify the new hole relative to this existing holes, so - // add the existing hole to the result. Continue trying to simplify the new - // hole against the later existing holes. - current.push(self_negative.clone()); - } - - Some(union) => { - // We were able to simplify the new hole relative to this existing hole. - // Don't add it to the result yet; instead, try to simplify the result - // further against later existing holes. - other_negative = union.clone(); - } - } - } - - // If we fall through then we need to add the new hole to the hole list (either because - // we couldn't simplify it with anything, or because we did without it canceling out). - current.push(other_negative); - } - - let result = Self { - positive, - negative: current, - }; - result.satisfiable(db) - } - - fn simplified_union( - &self, - db: &'db dyn Db, - other: &Constraint<'db>, - ) -> Option>> { - // (ap ∧ ¬an₁ ∧ ¬an₂ ∧ ...) ∨ (bp ∧ ¬bn₁ ∧ ¬bn₂ ∧ ...) - // = (ap ∨ bp) ∧ (ap ∨ ¬bn₁) ∧ (ap ∨ ¬bn₂) ∧ ... - // ∧ (¬an₁ ∨ bp) ∧ (¬an₁ ∨ ¬bn₁) ∧ (¬an₁ ∨ ¬bn₂) ∧ ... - // ∧ (¬an₂ ∨ bp) ∧ (¬an₂ ∨ ¬bn₁) ∧ (¬an₂ ∨ ¬bn₂) ∧ ... - // - // We use a helper type to build up the result of the union of two constraints, since we - // need to calculate the Cartesian product of the the positive and negative portions of the - // two inputs. We cannot use `ConstraintSet` for this, since it would try to invoke the - // `simplify_union` logic, which this method is part of the definition of! So we have to - // reproduce some of that logic here, in a simplified form since we know we're only ever - // looking at pairs of individual constraints at a time. - - struct Results<'db> { - next: Vec>, - results: Vec>, - } - - impl<'db> Results<'db> { - fn new(constraint: Constraint<'db>) -> Results<'db> { - Results { - next: vec![], - results: vec![constraint], - } - } - - fn flip(&mut self) { - std::mem::swap(&mut self.next, &mut self.results); - self.next.clear(); - } - - /// Adds a constraint by intersecting it with any currently pending results. - fn add_constraint(&mut self, db: &'db dyn Db, constraint: &Constraint<'db>) { - self.next.extend(self.results.iter().filter_map(|result| { - match result.intersect(db, constraint) { - Satisfiable::Never => None, - Satisfiable::Always => Some(Constraint { - positive: RangeConstraint::always(), - negative: smallvec![], - }), - Satisfiable::Constrained(constraint) => Some(constraint), - } - })); - } - - /// Adds a single negative range constraint to the pending results. - fn add_negated_range( - &mut self, - db: &'db dyn Db, - negative: Option>, - ) { - let negative = match negative { - Some(negative) => Constraint { - positive: RangeConstraint::always(), - negative: smallvec![negative], - }, - // If the intersection of these two holes is empty, then they don't remove - // anything from the final union. - None => return, - }; - self.add_constraint(db, &negative); - self.flip(); - } - - /// Adds a possibly simplified constraint to the pending results. If the parameter has - /// been simplified to a single constraint, it is intersected with each currently - /// pending result. If it could not be simplified (i.e., it is the union of two - /// constraints), then we duplicate any pending results, so that we can _separately_ - /// intersect each non-simplified constraint with the results. - fn add_simplified_constraint( - &mut self, - db: &'db dyn Db, - constraints: Simplifiable>, - ) { - match constraints { - Simplifiable::NeverSatisfiable => { - self.results.clear(); - return; - } - Simplifiable::AlwaysSatisfiable => { - return; - } - Simplifiable::Simplified(constraint) => { - self.add_constraint(db, &constraint); - } - Simplifiable::NotSimplified(first, second) => { - self.add_constraint(db, &first); - self.add_constraint(db, &second); - } - } - self.flip(); - } - - /// If there are two or fewer final results, translates them into a [`Simplifiable`] - /// result. Otherwise returns `None`, indicating that the union cannot be simplified - /// enough for our purposes. - fn into_result(self, db: &'db dyn Db) -> Option>> { - let mut results = self.results.into_iter(); - let Some(first) = results.next() else { - return Some(Simplifiable::NeverSatisfiable); - }; - let Some(second) = results.next() else { - return Some(Simplifiable::from_one(first.satisfiable(db))); - }; - if results.next().is_some() { - return None; - } - Some(Simplifiable::from_union( - first.satisfiable(db), - second.satisfiable(db), - )) - } - } - - let mut results = match self.positive.union(db, &other.positive) { - Some(positive) => Results::new(Constraint { - positive: positive.clone(), - negative: smallvec![], - }), - None => return None, - }; - for other_negative in &other.negative { - results.add_simplified_constraint( - db, - self.positive.union_negated_range(db, other_negative), - ); - } - for self_negative in &self.negative { - // Reverse the results here so that we always add items from `self` first. This ensures - // that the output we produce is ordered consistently with the input we receive. - results.add_simplified_constraint( - db, - other - .positive - .union_negated_range(db, self_negative) - .reverse(), - ); - } - for self_negative in &self.negative { - for other_negative in &other.negative { - results.add_negated_range(db, self_negative.union_negative(db, other_negative)); - } - } - - results.into_result(db) - } - - fn negate_into( - &self, - db: &'db dyn Db, - typevar: BoundTypeVarInstance<'db>, - set: &mut ConstraintSet<'db>, - ) { - for negative in &self.negative { - set.union_constraint( - db, - Constraint::range(db, negative.hole.lower, negative.hole.upper).constrain(typevar), - ); - } - set.union_constraint( - db, - Constraint::negated_range(db, self.positive.lower, self.positive.upper) - .constrain(typevar), - ); - } - - fn display(&self, db: &'db dyn Db, typevar: impl Display) -> impl Display { - struct DisplayConstraint<'a, 'db, D> { - constraint: &'a Constraint<'db>, - typevar: D, - db: &'db dyn Db, - } - - impl Display for DisplayConstraint<'_, '_, D> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut first = true; - if !self.constraint.positive.is_always() { - (self.constraint.positive) - .display(self.db, &self.typevar) - .fmt(f)?; - first = false; - } - for negative in &self.constraint.negative { - if first { - first = false; - } else { - f.write_str(" ∧ ")?; - } - negative.display(self.db, &self.typevar).fmt(f)?; - } - Ok(()) - } - } - - DisplayConstraint { - constraint: self, - typevar, - db, - } - } -} - -impl<'db> Satisfiable> { - fn constrain(self, typevar: BoundTypeVarInstance<'db>) -> Satisfiable> { - self.map(|constraint| constraint.constrain(typevar)) - } -} - -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] -pub(crate) struct RangeConstraint<'db> { lower: Type<'db>, upper: Type<'db>, } -impl<'db> Constraint<'db> { +// The Salsa heap is tracked separately. +impl get_size2::GetSize for ConstrainedTypeVar<'_> {} + +#[salsa::tracked] +impl<'db> ConstrainedTypeVar<'db> { /// Returns a new range constraint. /// /// Panics if `lower` and `upper` are not both fully static. - fn range(db: &'db dyn Db, lower: Type<'db>, upper: Type<'db>) -> Satisfiable> { + fn new_node( + db: &'db dyn Db, + lower: Type<'db>, + typevar: BoundTypeVarInstance<'db>, + upper: Type<'db>, + ) -> Node<'db> { debug_assert_eq!(lower, lower.bottom_materialization(db)); debug_assert_eq!(upper, upper.top_materialization(db)); // If `lower ≰ upper`, then the constraint cannot be satisfied, since there is no type that // is both greater than `lower`, and less than `upper`. if !lower.is_subtype_of(db, upper) { - return Satisfiable::Never; + return Node::AlwaysFalse; } // If the requested constraint is `Never ≤ T ≤ object`, then the typevar can be specialized // to _any_ type, and the constraint does nothing. - let positive = RangeConstraint { lower, upper }; - if positive.is_always() { - return Satisfiable::Always; + if lower.is_never() && upper.is_object() { + return Node::AlwaysTrue; } - Satisfiable::Constrained(Constraint { - positive, - negative: smallvec![], - }) + Node::new_constraint(db, ConstrainedTypeVar::new(db, typevar, lower, upper)) + } + fn when_true(self) -> ConstraintAssignment<'db> { + ConstraintAssignment::Positive(self) } -} -impl<'db> RangeConstraint<'db> { - fn always() -> Self { - Self { - lower: Type::Never, - upper: Type::object(), + fn when_false(self) -> ConstraintAssignment<'db> { + ConstraintAssignment::Negative(self) + } + + fn contains(self, db: &'db dyn Db, other: Self) -> bool { + if self.typevar(db) != other.typevar(db) { + return false; } - } - - fn contains(&self, db: &'db dyn Db, other: &RangeConstraint<'db>) -> bool { - self.lower.is_subtype_of(db, other.lower) && other.upper.is_subtype_of(db, self.upper) - } - - fn is_always(&self) -> bool { - self.lower.is_never() && self.upper.is_object() + self.lower(db).is_subtype_of(db, other.lower(db)) + && other.upper(db).is_subtype_of(db, self.upper(db)) } /// Returns the intersection of two range constraints, or `None` if the intersection is empty. - fn intersect(&self, db: &'db dyn Db, other: &RangeConstraint<'db>) -> Option { + fn intersect(self, db: &'db dyn Db, other: Self) -> Option { // (s₁ ≤ α ≤ t₁) ∧ (s₂ ≤ α ≤ t₂) = (s₁ ∪ s₂) ≤ α ≤ (t₁ ∩ t₂)) - let lower = UnionType::from_elements(db, [self.lower, other.lower]); - let upper = IntersectionType::from_elements(db, [self.upper, other.upper]); + let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]); + let upper = IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]); // If `lower ≰ upper`, then the intersection is empty, since there is no type that is both // greater than `lower`, and less than `upper`. @@ -1230,274 +322,1029 @@ impl<'db> RangeConstraint<'db> { return None; } - Some(Self { lower, upper }) + Some(Self::new(db, self.typevar(db), lower, upper)) } - /// Returns the union of two range constraints if it can be simplified to a single constraint. - /// Otherwise returns `None`. - fn union(&self, db: &'db dyn Db, other: &RangeConstraint<'db>) -> Option { - // When one of the bounds is entirely contained within the other, the union simplifies to - // the larger bounds. - if self.lower.is_subtype_of(db, other.lower) && other.upper.is_subtype_of(db, self.upper) { - return Some(self.clone()); - } - if other.lower.is_subtype_of(db, self.lower) && self.upper.is_subtype_of(db, other.upper) { - return Some(other.clone()); - } - - // Otherwise the result cannot be simplified. - None + fn display(self, db: &'db dyn Db) -> impl Display { + self.display_inner(db, false) } - /// Returns the union of a positive range with a negative hole. - fn union_negated_range( - &self, - db: &'db dyn Db, - negated: &NegatedRangeConstraint<'db>, - ) -> Simplifiable> { - // If the positive range completely contains the negative range, then the union is always - // satisfied. - if self.contains(db, &negated.hole) { - return Simplifiable::AlwaysSatisfiable; - } - - // If the positive range is disjoint from the negative range, the positive range doesn't - // add anything; the union is the negative range. - if incomparable(db, self.lower, negated.hole.upper) - || incomparable(db, negated.hole.lower, self.upper) - { - return Simplifiable::from_one(Constraint::negated_range( - db, - negated.hole.lower, - negated.hole.upper, - )); - } - - // Otherwise we clip the positive constraint to the minimum range that overlaps with the - // negative range. - Simplifiable::from_union( - Constraint::range( - db, - UnionType::from_elements(db, [self.lower, negated.hole.lower]), - IntersectionType::from_elements(db, [self.upper, negated.hole.upper]), - ), - Constraint::negated_range(db, negated.hole.lower, negated.hole.upper), - ) + fn display_negated(self, db: &'db dyn Db) -> impl Display { + self.display_inner(db, true) } - fn display(&self, db: &'db dyn Db, typevar: impl Display) -> impl Display { - struct DisplayRangeConstraint<'a, 'db, D> { - constraint: &'a RangeConstraint<'db>, - typevar: D, + fn display_inner(self, db: &'db dyn Db, negated: bool) -> impl Display { + struct DisplayConstrainedTypeVar<'db> { + constraint: ConstrainedTypeVar<'db>, + negated: bool, db: &'db dyn Db, } - impl Display for DisplayRangeConstraint<'_, '_, D> { + impl Display for DisplayConstrainedTypeVar<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if (self.constraint.lower).is_equivalent_to(self.db, self.constraint.upper) { + let lower = self.constraint.lower(self.db); + let upper = self.constraint.upper(self.db); + if lower.is_equivalent_to(self.db, upper) { return write!( f, - "({} = {})", - &self.typevar, - self.constraint.lower.display(self.db) + "({} {} {})", + self.constraint.typevar(self.db).display(self.db), + if self.negated { "≠" } else { "=" }, + lower.display(self.db) ); } - f.write_str("(")?; - if !self.constraint.lower.is_never() { - write!(f, "{} ≤ ", self.constraint.lower.display(self.db))?; + if self.negated { + f.write_str("¬")?; } - self.typevar.fmt(f)?; - if !self.constraint.upper.is_object() { - write!(f, " ≤ {}", self.constraint.upper.display(self.db))?; + f.write_str("(")?; + if !lower.is_never() { + write!(f, "{} ≤ ", lower.display(self.db))?; + } + self.constraint.typevar(self.db).display(self.db).fmt(f)?; + if !upper.is_object() { + write!(f, " ≤ {}", upper.display(self.db))?; } f.write_str(")") } } - DisplayRangeConstraint { + DisplayConstrainedTypeVar { constraint: self, - typevar, + negated, db, } } } -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] -pub(crate) struct NegatedRangeConstraint<'db> { - hole: RangeConstraint<'db>, +/// A BDD node. +/// +/// The "variables" of a constraint set BDD are individual constraints, represented by an interned +/// [`ConstrainedTypeVar`]. +/// +/// Terminal nodes (`false` and `true`) have their own dedicated enum variants. The +/// [`Interior`][InteriorNode] variant represents interior nodes. +/// +/// BDD nodes are _reduced_, which means that there are no duplicate nodes (which we handle via +/// Salsa interning), and that there are no redundant nodes, with `if_true` and `if_false` edges +/// that point at the same node. +/// +/// BDD nodes are also _ordered_, meaning that every path from the root of a BDD to a terminal node +/// visits variables in the same order. [`ConstrainedTypeVar`]s are interned, so we can use the IDs +/// that salsa assigns to define this order. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] +enum Node<'db> { + AlwaysFalse, + AlwaysTrue, + Interior(InteriorNode<'db>), } -impl<'db> Constraint<'db> { - /// Returns a new negated range constraint. +impl<'db> Node<'db> { + /// Creates a new BDD node, ensuring that it is fully reduced. + fn new( + db: &'db dyn Db, + constraint: ConstrainedTypeVar<'db>, + if_true: Node<'db>, + if_false: Node<'db>, + ) -> Self { + debug_assert!( + (if_true.root_constraint(db)) + .is_none_or(|root_constraint| root_constraint > constraint) + ); + debug_assert!( + (if_false.root_constraint(db)) + .is_none_or(|root_constraint| root_constraint > constraint) + ); + if if_true == if_false { + return if_true; + } + Self::Interior(InteriorNode::new(db, constraint, if_true, if_false)) + } + + /// Creates a new BDD node for an individual constraint. (The BDD will evaluate to `true` when + /// the constraint holds, and to `false` when it does not.) + fn new_constraint(db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) -> Self { + Self::Interior(InteriorNode::new( + db, + constraint, + Node::AlwaysTrue, + Node::AlwaysFalse, + )) + } + + /// Creates a new BDD node for a positive or negative individual constraint. (For a positive + /// constraint, this returns the same BDD node as [`new_constraint`][Self::new_constraint]. For + /// a negative constraint, it returns the negation of that BDD node.) + fn new_satisfied_constraint(db: &'db dyn Db, constraint: ConstraintAssignment<'db>) -> Self { + match constraint { + ConstraintAssignment::Positive(constraint) => Self::Interior(InteriorNode::new( + db, + constraint, + Node::AlwaysTrue, + Node::AlwaysFalse, + )), + ConstraintAssignment::Negative(constraint) => Self::Interior(InteriorNode::new( + db, + constraint, + Node::AlwaysFalse, + Node::AlwaysTrue, + )), + } + } + + /// Returns the BDD variable of the root node of this BDD, or `None` if this BDD is a terminal + /// node. + fn root_constraint(self, db: &'db dyn Db) -> Option> { + match self { + Node::Interior(interior) => Some(interior.constraint(db)), + _ => None, + } + } + + /// Returns whether this BDD represent the constant function `true`. + fn is_always_satisfied(self) -> bool { + matches!(self, Node::AlwaysTrue) + } + + /// Returns whether this BDD represent the constant function `false`. + fn is_never_satisfied(self) -> bool { + matches!(self, Node::AlwaysFalse) + } + + /// Returns the negation of this BDD. + fn negate(self, db: &'db dyn Db) -> Self { + match self { + Node::AlwaysTrue => Node::AlwaysFalse, + Node::AlwaysFalse => Node::AlwaysTrue, + Node::Interior(interior) => interior.negate(db), + } + } + + /// Returns the `or` or union of two BDDs. + fn or(self, db: &'db dyn Db, other: Self) -> Self { + match (self, other) { + (Node::AlwaysTrue, _) | (_, Node::AlwaysTrue) => Node::AlwaysTrue, + (Node::AlwaysFalse, other) | (other, Node::AlwaysFalse) => other, + (Node::Interior(a), Node::Interior(b)) => { + // OR is commutative, which lets us halve the cache requirements + let (a, b) = if b.0 < a.0 { (b, a) } else { (a, b) }; + a.or(db, b) + } + } + } + + /// Returns the `and` or intersection of two BDDs. + fn and(self, db: &'db dyn Db, other: Self) -> Self { + match (self, other) { + (Node::AlwaysFalse, _) | (_, Node::AlwaysFalse) => Node::AlwaysFalse, + (Node::AlwaysTrue, other) | (other, Node::AlwaysTrue) => other, + (Node::Interior(a), Node::Interior(b)) => { + // AND is commutative, which lets us halve the cache requirements + let (a, b) = if b.0 < a.0 { (b, a) } else { (a, b) }; + a.and(db, b) + } + } + } + + /// Returns a new BDD that evaluates to `true` when both input BDDs evaluate to the same + /// result. + fn iff(self, db: &'db dyn Db, other: Self) -> Self { + match (self, other) { + (Node::AlwaysFalse, Node::AlwaysFalse) | (Node::AlwaysTrue, Node::AlwaysTrue) => { + Node::AlwaysTrue + } + (Node::AlwaysTrue, Node::AlwaysFalse) | (Node::AlwaysFalse, Node::AlwaysTrue) => { + Node::AlwaysFalse + } + (Node::AlwaysTrue | Node::AlwaysFalse, Node::Interior(interior)) => Node::new( + db, + interior.constraint(db), + self.iff(db, interior.if_true(db)), + self.iff(db, interior.if_false(db)), + ), + (Node::Interior(interior), Node::AlwaysTrue | Node::AlwaysFalse) => Node::new( + db, + interior.constraint(db), + interior.if_true(db).iff(db, other), + interior.if_false(db).iff(db, other), + ), + (Node::Interior(a), Node::Interior(b)) => { + // IFF is commutative, which lets us halve the cache requirements + let (a, b) = if b.0 < a.0 { (b, a) } else { (a, b) }; + a.iff(db, b) + } + } + } + + /// Returns the `if-then-else` of three BDDs: when `self` evaluates to `true`, it returns what + /// `then_node` evaluates to; otherwise it returns what `else_node` evaluates to. + fn ite(self, db: &'db dyn Db, then_node: Self, else_node: Self) -> Self { + self.and(db, then_node) + .or(db, self.negate(db).and(db, else_node)) + } + + /// Returns a new BDD that returns the same results as `self`, but with some inputs fixed to + /// particular values. (Those variables will not be checked when evaluating the result, and + /// will not be present in the result.) /// - /// Panics if `lower` and `upper` are not both fully static. - fn negated_range( + /// Also returns whether _all_ of the restricted variables appeared in the BDD. + fn restrict( + self, db: &'db dyn Db, - lower: Type<'db>, - upper: Type<'db>, - ) -> Satisfiable> { - debug_assert_eq!(lower, lower.bottom_materialization(db)); - debug_assert_eq!(upper, upper.top_materialization(db)); + assignment: impl IntoIterator>, + ) -> (Self, bool) { + assignment + .into_iter() + .fold((self, true), |(restricted, found), assignment| { + let (restricted, found_this) = restricted.restrict_one(db, assignment); + (restricted, found && found_this) + }) + } - // If `lower ≰ upper`, then the negated constraint is always satisfied, since there is no - // type that is both greater than `lower`, and less than `upper`. - if !lower.is_subtype_of(db, upper) { - return Satisfiable::Always; + /// Returns a new BDD that returns the same results as `self`, but with one input fixed to a + /// particular value. (That variable will be not be checked when evaluating the result, and + /// will not be present in the result.) + /// + /// Also returns whether the restricted variable appeared in the BDD. + fn restrict_one(self, db: &'db dyn Db, assignment: ConstraintAssignment<'db>) -> (Self, bool) { + match self { + Node::AlwaysTrue => (Node::AlwaysTrue, false), + Node::AlwaysFalse => (Node::AlwaysFalse, false), + Node::Interior(interior) => interior.restrict_one(db, assignment), } + } - // If the requested constraint is `¬(Never ≤ T ≤ object)`, then the constraint cannot be - // satisfied. - let negative = NegatedRangeConstraint { - hole: RangeConstraint { lower, upper }, + /// Returns a new BDD with any occurrence of `left ∧ right` replaced with `replacement`. + fn substitute_intersection( + self, + db: &'db dyn Db, + left: ConstraintAssignment<'db>, + right: ConstraintAssignment<'db>, + replacement: Node<'db>, + ) -> Self { + // We perform a Shannon expansion to find out what the input BDD evaluates to when: + // - left and right are both true + // - left is false + // - left is true and right is false + // This covers the entire truth table of `left ∧ right`. + let (when_left_and_right, both_found) = self.restrict(db, [left, right]); + if !both_found { + // If left and right are not both present in the input BDD, we should not even attempt + // the substitution, since the Shannon expansion might introduce the missing variables! + // That confuses us below when we try to detect whether the substitution is consistent + // with the input. + return self; + } + let (when_not_left, _) = self.restrict(db, [left.negated()]); + let (when_left_but_not_right, _) = self.restrict(db, [left, right.negated()]); + + // The result should test `replacement`, and when it's true, it should produce the same + // output that input would when `left ∧ right` is true. When replacement is false, it + // should fall back on testing left and right individually to make sure we produce the + // correct outputs in the `¬(left ∧ right)` case. So the result is + // + // if replacement + // when_left_and_right + // else if not left + // when_not_left + // else if not right + // when_left_but_not_right + // else + // false + // + // (Note that the `else` branch shouldn't be reachable, but we have to provide something!) + let left_node = Node::new_satisfied_constraint(db, left); + let right_node = Node::new_satisfied_constraint(db, right); + let right_result = right_node.ite(db, Node::AlwaysFalse, when_left_but_not_right); + let left_result = left_node.ite(db, right_result, when_not_left); + let result = replacement.ite(db, when_left_and_right, left_result); + + // Lastly, verify that the result is consistent with the input. (It must produce the same + // results when `left ∧ right`.) If it doesn't, the substitution isn't valid, and we should + // return the original BDD unmodified. + let validity = replacement.iff(db, left_node.and(db, right_node)); + let constrained_original = self.and(db, validity); + let constrained_replacement = result.and(db, validity); + if constrained_original == constrained_replacement { + result + } else { + self + } + } + + /// Returns a new BDD with any occurrence of `left ∨ right` replaced with `replacement`. + fn substitute_union( + self, + db: &'db dyn Db, + left: ConstraintAssignment<'db>, + right: ConstraintAssignment<'db>, + replacement: Node<'db>, + ) -> Self { + // We perform a Shannon expansion to find out what the input BDD evaluates to when: + // - left and right are both true + // - left is true and right is false + // - left is false and right is true + // - left and right are both false + // This covers the entire truth table of `left ∨ right`. + let (when_l1_r1, both_found) = self.restrict(db, [left, right]); + if !both_found { + // If left and right are not both present in the input BDD, we should not even attempt + // the substitution, since the Shannon expansion might introduce the missing variables! + // That confuses us below when we try to detect whether the substitution is consistent + // with the input. + return self; + } + let (when_l0_r0, _) = self.restrict(db, [left.negated(), right.negated()]); + let (when_l1_r0, _) = self.restrict(db, [left, right.negated()]); + let (when_l0_r1, _) = self.restrict(db, [left.negated(), right]); + + // The result should test `replacement`, and when it's true, it should produce the same + // output that input would when `left ∨ right` is true. For OR, this is the union of what + // the input produces for the three cases that comprise `left ∨ right`. When `replacement` + // is false, the result should produce the same output that input would when + // `¬(left ∨ right)`, i.e. when `left ∧ right`. So the result is + // + // if replacement + // or(when_l1_r1, when_l1_r0, when_r0_l1) + // else + // when_l0_r0 + let result = replacement.ite( + db, + when_l1_r0.or(db, when_l0_r1.or(db, when_l1_r1)), + when_l0_r0, + ); + + // Lastly, verify that the result is consistent with the input. (It must produce the same + // results when `left ∨ right`.) If it doesn't, the substitution isn't valid, and we should + // return the original BDD unmodified. + let left_node = Node::new_satisfied_constraint(db, left); + let right_node = Node::new_satisfied_constraint(db, right); + let validity = replacement.iff(db, left_node.or(db, right_node)); + let constrained_original = self.and(db, validity); + let constrained_replacement = result.and(db, validity); + if constrained_original == constrained_replacement { + result + } else { + self + } + } + + /// Invokes a closure for each constraint variable that appears anywhere in a BDD. (Any given + /// constraint can appear multiple times in different paths from the root; we do not + /// deduplicate those constraints, and will instead invoke the callback each time we encounter + /// the constraint.) + fn for_each_constraint(self, db: &'db dyn Db, f: &mut dyn FnMut(ConstrainedTypeVar<'db>)) { + let Node::Interior(interior) = self else { + return; }; - if negative.hole.is_always() { - return Satisfiable::Never; + f(interior.constraint(db)); + interior.if_true(db).for_each_constraint(db, f); + interior.if_false(db).for_each_constraint(db, f); + } + + /// Simplifies a BDD, replacing constraints with simpler or smaller constraints where possible. + fn simplify(self, db: &'db dyn Db) -> Self { + match self { + Node::AlwaysTrue | Node::AlwaysFalse => self, + Node::Interior(interior) => interior.simplify(db), + } + } + + /// Returns clauses describing all of the variable assignments that cause this BDD to evaluate + /// to `true`. (This translates the boolean function that this BDD represents into DNF form.) + fn satisfied_clauses(self, db: &'db dyn Db) -> SatisfiedClauses<'db> { + struct Searcher<'db> { + clauses: SatisfiedClauses<'db>, + current_clause: SatisfiedClause<'db>, } - Satisfiable::Constrained(Constraint { - positive: RangeConstraint::always(), - negative: smallvec![negative], - }) - } -} + impl<'db> Searcher<'db> { + fn visit_node(&mut self, db: &'db dyn Db, node: Node<'db>) { + match node { + Node::AlwaysFalse => {} + Node::AlwaysTrue => self.clauses.push(self.current_clause.clone()), + Node::Interior(interior) => { + let interior_constraint = interior.constraint(db); + self.current_clause.push(interior_constraint.when_true()); + self.visit_node(db, interior.if_true(db)); + self.current_clause.pop(); + self.current_clause.push(interior_constraint.when_false()); + self.visit_node(db, interior.if_false(db)); + self.current_clause.pop(); + } + } + } + } -impl<'db> NegatedRangeConstraint<'db> { - /// Clips this negative hole to be the smallest hole that removes the same types from the given - /// positive range. - fn clip_to_positive(&self, db: &'db dyn Db, positive: &RangeConstraint<'db>) -> Option { - self.hole - .intersect(db, positive) - .map(|hole| NegatedRangeConstraint { hole }) + let mut searcher = Searcher { + clauses: SatisfiedClauses::default(), + current_clause: SatisfiedClause::default(), + }; + searcher.visit_node(db, self); + searcher.clauses } - /// Returns the union of two negative constraints. (This this is _intersection_ of the - /// constraints' holes.) - fn union_negative( - &self, - db: &'db dyn Db, - positive: &NegatedRangeConstraint<'db>, - ) -> Option { - self.hole - .intersect(db, &positive.hole) - .map(|hole| NegatedRangeConstraint { hole }) - } - - /// Returns the intersection of two negative constraints. (This this is _union_ of the - /// constraints' holes.) - fn intersect_negative( - &self, - db: &'db dyn Db, - other: &NegatedRangeConstraint<'db>, - ) -> Option { - self.hole - .union(db, &other.hole) - .map(|hole| NegatedRangeConstraint { hole }) - } - - fn display(&self, db: &'db dyn Db, typevar: impl Display) -> impl Display { - struct DisplayNegatedRangeConstraint<'a, 'db, D> { - constraint: &'a NegatedRangeConstraint<'db>, - typevar: D, + fn display(self, db: &'db dyn Db) -> impl Display { + // To render a BDD in DNF form, you perform a depth-first search of the BDD tree, looking + // for any path that leads to the AlwaysTrue terminal. Each such path represents one of the + // intersection clauses in the DNF form. The path traverses zero or more interior nodes, + // and takes either the true or false edge from each one. That gives you the positive or + // negative individual constraints in the path's clause. + struct DisplayNode<'db> { + node: Node<'db>, db: &'db dyn Db, } - impl Display for DisplayNegatedRangeConstraint<'_, '_, D> { + impl Display for DisplayNode<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if (self.constraint.hole.lower) - .is_equivalent_to(self.db, self.constraint.hole.upper) - { - return write!( - f, - "({} ≠ {})", - &self.typevar, - self.constraint.hole.lower.display(self.db) - ); + match self.node { + Node::AlwaysTrue => f.write_str("always"), + Node::AlwaysFalse => f.write_str("never"), + Node::Interior(_) => { + let mut clauses = self.node.satisfied_clauses(self.db); + clauses.simplify(); + clauses.display(self.db).fmt(f) + } } - - f.write_str("¬")?; - self.constraint.hole.display(self.db, &self.typevar).fmt(f) } } - DisplayNegatedRangeConstraint { + DisplayNode { node: self, db } + } +} + +/// An interior node of a BDD +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +struct InteriorNode<'db> { + constraint: ConstrainedTypeVar<'db>, + if_true: Node<'db>, + if_false: Node<'db>, +} + +// The Salsa heap is tracked separately. +impl get_size2::GetSize for InteriorNode<'_> {} + +#[salsa::tracked] +impl<'db> InteriorNode<'db> { + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn negate(self, db: &'db dyn Db) -> Node<'db> { + Node::new( + db, + self.constraint(db), + self.if_true(db).negate(db), + self.if_false(db).negate(db), + ) + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn or(self, db: &'db dyn Db, other: Self) -> Node<'db> { + let self_constraint = self.constraint(db); + let other_constraint = other.constraint(db); + match self_constraint.cmp(&other_constraint) { + Ordering::Equal => Node::new( + db, + self_constraint, + self.if_true(db).or(db, other.if_true(db)), + self.if_false(db).or(db, other.if_false(db)), + ), + Ordering::Less => Node::new( + db, + self_constraint, + self.if_true(db).or(db, Node::Interior(other)), + self.if_false(db).or(db, Node::Interior(other)), + ), + Ordering::Greater => Node::new( + db, + other_constraint, + Node::Interior(self).or(db, other.if_true(db)), + Node::Interior(self).or(db, other.if_false(db)), + ), + } + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn and(self, db: &'db dyn Db, other: Self) -> Node<'db> { + let self_constraint = self.constraint(db); + let other_constraint = other.constraint(db); + match self_constraint.cmp(&other_constraint) { + Ordering::Equal => Node::new( + db, + self_constraint, + self.if_true(db).and(db, other.if_true(db)), + self.if_false(db).and(db, other.if_false(db)), + ), + Ordering::Less => Node::new( + db, + self_constraint, + self.if_true(db).and(db, Node::Interior(other)), + self.if_false(db).and(db, Node::Interior(other)), + ), + Ordering::Greater => Node::new( + db, + other_constraint, + Node::Interior(self).and(db, other.if_true(db)), + Node::Interior(self).and(db, other.if_false(db)), + ), + } + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn iff(self, db: &'db dyn Db, other: Self) -> Node<'db> { + let self_constraint = self.constraint(db); + let other_constraint = other.constraint(db); + match self_constraint.cmp(&other_constraint) { + Ordering::Equal => Node::new( + db, + self_constraint, + self.if_true(db).iff(db, other.if_true(db)), + self.if_false(db).iff(db, other.if_false(db)), + ), + Ordering::Less => Node::new( + db, + self_constraint, + self.if_true(db).iff(db, Node::Interior(other)), + self.if_false(db).iff(db, Node::Interior(other)), + ), + Ordering::Greater => Node::new( + db, + other_constraint, + Node::Interior(self).iff(db, other.if_true(db)), + Node::Interior(self).iff(db, other.if_false(db)), + ), + } + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn restrict_one( + self, + db: &'db dyn Db, + assignment: ConstraintAssignment<'db>, + ) -> (Node<'db>, bool) { + // If this node's variable is larger than the assignment's variable, then we have reached a + // point in the BDD where the assignment can no longer affect the result, + // and we can return early. + let self_constraint = self.constraint(db); + if assignment.constraint() < self_constraint { + return (Node::Interior(self), false); + } + + // Otherwise, check if this node's variable is in the assignment. If so, substitute the + // variable by replacing this node with its if_false/if_true edge, accordingly. + if assignment == self_constraint.when_true() { + (self.if_true(db), true) + } else if assignment == self_constraint.when_false() { + (self.if_false(db), true) + } else { + let (if_true, found_in_true) = self.if_true(db).restrict_one(db, assignment); + let (if_false, found_in_false) = self.if_false(db).restrict_one(db, assignment); + ( + Node::new(db, self_constraint, if_true, if_false), + found_in_true || found_in_false, + ) + } + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn simplify(self, db: &'db dyn Db) -> Node<'db> { + // To simplify a non-terminal BDD, we find all pairs of constraints that are mentioned in + // the BDD. If any of those pairs can be simplified to some other BDD, we perform a + // substitution to replace the pair with the simplification. + // + // Some of the simplifications create _new_ constraints that weren't originally present in + // the BDD. If we encounter one of those cases, we need to check if we can simplify things + // further relative to that new constraint. + // + // To handle this, we keep track of the individual constraints that we have already + // discovered (`seen_constraints`), and a queue of constraint pairs that we still need to + // check (`to_visit`). + + // Seed the seen set with all of the constraints that are present in the input BDD, and the + // visit queue with all pairs of those constraints. (We use "combinations" because we don't + // need to compare a constraint against itself, and because ordering doesn't matter.) + let mut seen_constraints = FxHashSet::default(); + Node::Interior(self).for_each_constraint(db, &mut |constraint| { + seen_constraints.insert(constraint); + }); + let mut to_visit: Vec<(_, _)> = (seen_constraints.iter().copied()) + .tuple_combinations() + .collect(); + + // Repeatedly pop constraint pairs off of the visit queue, checking whether each pair can + // be simplified. + let mut simplified = Node::Interior(self); + while let Some((left_constraint, right_constraint)) = to_visit.pop() { + // If the constraints refer to different typevars, they trivially cannot be compared. + // TODO: We might need to consider when one constraint's upper or lower bound refers to + // the other constraint's typevar. + let typevar = left_constraint.typevar(db); + if typevar != right_constraint.typevar(db) { + continue; + } + + // Containment: The range of one constraint might completely contain the range of the + // other. If so, there are several potential simplifications. + let larger_smaller = if left_constraint.contains(db, right_constraint) { + Some((left_constraint, right_constraint)) + } else if right_constraint.contains(db, left_constraint) { + Some((right_constraint, left_constraint)) + } else { + None + }; + if let Some((larger_constraint, smaller_constraint)) = larger_smaller { + // larger ∨ smaller = larger + simplified = simplified.substitute_union( + db, + larger_constraint.when_true(), + smaller_constraint.when_true(), + Node::new_satisfied_constraint(db, larger_constraint.when_true()), + ); + + // ¬larger ∧ ¬smaller = ¬larger + simplified = simplified.substitute_intersection( + db, + larger_constraint.when_false(), + smaller_constraint.when_false(), + Node::new_satisfied_constraint(db, larger_constraint.when_false()), + ); + + // smaller ∧ ¬larger = false + // (¬larger removes everything that's present in smaller) + simplified = simplified.substitute_intersection( + db, + larger_constraint.when_false(), + smaller_constraint.when_true(), + Node::AlwaysFalse, + ); + + // larger ∨ ¬smaller = true + // (larger fills in everything that's missing in ¬smaller) + simplified = simplified.substitute_union( + db, + larger_constraint.when_true(), + smaller_constraint.when_false(), + Node::AlwaysTrue, + ); + } + + // There are some simplifications we can make when the intersection of the two + // constraints is empty, and others that we can make when the intersection is + // non-empty. + match left_constraint.intersect(db, right_constraint) { + Some(intersection_constraint) => { + // If the intersection is non-empty, we need to create a new constraint to + // represent that intersection. We also need to add the new constraint to our + // seen set and (if we haven't already seen it) to the to-visit queue. + if seen_constraints.insert(intersection_constraint) { + to_visit.extend( + (seen_constraints.iter().copied()) + .filter(|seen| *seen != intersection_constraint) + .map(|seen| (seen, intersection_constraint)), + ); + } + let positive_intersection_node = + Node::new_satisfied_constraint(db, intersection_constraint.when_true()); + let negative_intersection_node = + Node::new_satisfied_constraint(db, intersection_constraint.when_false()); + + // left ∧ right = intersection + simplified = simplified.substitute_intersection( + db, + left_constraint.when_true(), + right_constraint.when_true(), + positive_intersection_node, + ); + + // ¬left ∨ ¬right = ¬intersection + simplified = simplified.substitute_union( + db, + left_constraint.when_false(), + right_constraint.when_false(), + negative_intersection_node, + ); + + // left ∧ ¬right = left ∧ ¬intersection + // (clip the negative constraint to the smallest range that actually removes + // something from positive constraint) + simplified = simplified.substitute_intersection( + db, + left_constraint.when_true(), + right_constraint.when_false(), + Node::new_satisfied_constraint(db, left_constraint.when_true()) + .and(db, negative_intersection_node), + ); + + // ¬left ∧ right = ¬intersection ∧ right + // (save as above but reversed) + simplified = simplified.substitute_intersection( + db, + left_constraint.when_false(), + right_constraint.when_true(), + Node::new_satisfied_constraint(db, right_constraint.when_true()) + .and(db, negative_intersection_node), + ); + + // left ∨ ¬right = intersection ∨ ¬right + // (clip the positive constraint to the smallest range that actually adds + // something to the negative constraint) + simplified = simplified.substitute_union( + db, + left_constraint.when_true(), + right_constraint.when_false(), + Node::new_satisfied_constraint(db, right_constraint.when_false()) + .or(db, positive_intersection_node), + ); + + // ¬left ∨ right = ¬left ∨ intersection + // (save as above but reversed) + simplified = simplified.substitute_union( + db, + left_constraint.when_false(), + right_constraint.when_true(), + Node::new_satisfied_constraint(db, left_constraint.when_false()) + .or(db, positive_intersection_node), + ); + } + + None => { + // All of the below hold because we just proved that the intersection of left + // and right is empty. + + // left ∧ right = false + simplified = simplified.substitute_intersection( + db, + left_constraint.when_true(), + right_constraint.when_true(), + Node::AlwaysFalse, + ); + + // ¬left ∨ ¬right = true + simplified = simplified.substitute_union( + db, + left_constraint.when_false(), + right_constraint.when_false(), + Node::AlwaysTrue, + ); + + // left ∧ ¬right = left + // (there is nothing in the hole of ¬right that overlaps with left) + simplified = simplified.substitute_intersection( + db, + left_constraint.when_true(), + right_constraint.when_false(), + Node::new_constraint(db, left_constraint), + ); + + // ¬left ∧ right = right + // (save as above but reversed) + simplified = simplified.substitute_intersection( + db, + left_constraint.when_false(), + right_constraint.when_true(), + Node::new_constraint(db, right_constraint), + ); + } + } + } + + simplified + } +} + +/// An assignment of one BDD variable to either `true` or `false`. (When evaluating a BDD, we +/// must provide an assignment for each variable present in the BDD.) +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +enum ConstraintAssignment<'db> { + Positive(ConstrainedTypeVar<'db>), + Negative(ConstrainedTypeVar<'db>), +} + +impl<'db> ConstraintAssignment<'db> { + fn constraint(self) -> ConstrainedTypeVar<'db> { + match self { + ConstraintAssignment::Positive(constraint) => constraint, + ConstraintAssignment::Negative(constraint) => constraint, + } + } + + fn negated(self) -> Self { + match self { + ConstraintAssignment::Positive(constraint) => { + ConstraintAssignment::Negative(constraint) + } + ConstraintAssignment::Negative(constraint) => { + ConstraintAssignment::Positive(constraint) + } + } + } + + fn negate(&mut self) { + *self = self.negated(); + } + + // Keep this for future debugging needs, even though it's not currently used when rendering + // constraint sets. + #[expect(dead_code)] + fn display(self, db: &'db dyn Db) -> impl Display { + struct DisplayConstraintAssignment<'db> { + constraint: ConstraintAssignment<'db>, + db: &'db dyn Db, + } + + impl Display for DisplayConstraintAssignment<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.constraint { + ConstraintAssignment::Positive(constraint) => { + constraint.display(self.db).fmt(f) + } + ConstraintAssignment::Negative(constraint) => { + constraint.display_negated(self.db).fmt(f) + } + } + } + } + + DisplayConstraintAssignment { constraint: self, - typevar, db, } } } -/// Wraps a constraint (or clause, or set), while using distinct variants to represent when the -/// constraint is never satisfiable or always satisfiable. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum Satisfiable { - Never, - Always, - Constrained(T), +/// A single clause in the DNF representation of a BDD +#[derive(Clone, Debug, Default, Eq, PartialEq)] +struct SatisfiedClause<'db> { + constraints: Vec>, } -impl Satisfiable { - fn map(self, f: impl FnOnce(T) -> U) -> Satisfiable { - match self { - Satisfiable::Never => Satisfiable::Never, - Satisfiable::Always => Satisfiable::Always, - Satisfiable::Constrained(t) => Satisfiable::Constrained(f(t)), - } +impl<'db> SatisfiedClause<'db> { + fn push(&mut self, constraint: ConstraintAssignment<'db>) { + self.constraints.push(constraint); } -} -/// The result of trying to simplify two constraints (or clauses, or sets). Like [`Satisfiable`], -/// we use distinct variants to represent when the simplification is never satisfiable or always -/// satisfiable. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum Simplifiable { - NeverSatisfiable, - AlwaysSatisfiable, - Simplified(T), - NotSimplified(T, T), -} - -impl Simplifiable { - fn from_one(constraint: Satisfiable) -> Self { - match constraint { - Satisfiable::Never => Simplifiable::NeverSatisfiable, - Satisfiable::Always => Simplifiable::AlwaysSatisfiable, - Satisfiable::Constrained(constraint) => Simplifiable::Simplified(constraint), - } + fn pop(&mut self) { + self.constraints + .pop() + .expect("clause vector should not be empty"); } -} -impl Simplifiable { - fn from_union(first: Satisfiable, second: Satisfiable) -> Self { - match (first, second) { - (Satisfiable::Always, _) | (_, Satisfiable::Always) => Simplifiable::AlwaysSatisfiable, - (Satisfiable::Never, Satisfiable::Never) => Simplifiable::NeverSatisfiable, - (Satisfiable::Never, Satisfiable::Constrained(constraint)) - | (Satisfiable::Constrained(constraint), Satisfiable::Never) => { - Simplifiable::Simplified(constraint) - } - (Satisfiable::Constrained(first), Satisfiable::Constrained(second)) => { - Simplifiable::NotSimplified(first, second) + /// Invokes a closure with the last constraint in this clause negated. Returns the clause back + /// to its original state after invoking the closure. + fn with_negated_last_constraint(&mut self, f: impl for<'a> FnOnce(&'a Self)) { + if self.constraints.is_empty() { + return; + } + let last_index = self.constraints.len() - 1; + self.constraints[last_index].negate(); + f(self); + self.constraints[last_index].negate(); + } + + /// Removes another clause from this clause, if it appears as a prefix of this clause. Returns + /// whether the prefix was removed. + fn remove_prefix(&mut self, prefix: &SatisfiedClause<'db>) -> bool { + if self.constraints.starts_with(&prefix.constraints) { + self.constraints.drain(0..prefix.constraints.len()); + return true; + } + false + } + + fn display(&self, db: &'db dyn Db) -> impl Display { + struct DisplaySatisfiedClause<'a, 'db> { + clause: &'a SatisfiedClause<'db>, + db: &'db dyn Db, + } + + impl Display for DisplaySatisfiedClause<'_, '_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.clause.constraints.len() > 1 { + f.write_str("(")?; + } + for (i, constraint) in self.clause.constraints.iter().enumerate() { + if i > 0 { + f.write_str(" ∧ ")?; + } + match constraint { + ConstraintAssignment::Positive(constraint) => { + write!(f, "{}", constraint.display(self.db))?; + } + ConstraintAssignment::Negative(constraint) => { + write!(f, "{}", constraint.display_negated(self.db))?; + } + } + } + if self.clause.constraints.len() > 1 { + f.write_str(")")?; + } + Ok(()) } } - } - fn map(self, mut f: impl FnMut(T) -> U) -> Simplifiable { - match self { - Simplifiable::NeverSatisfiable => Simplifiable::NeverSatisfiable, - Simplifiable::AlwaysSatisfiable => Simplifiable::AlwaysSatisfiable, - Simplifiable::Simplified(t) => Simplifiable::Simplified(f(t)), - Simplifiable::NotSimplified(t1, t2) => Simplifiable::NotSimplified(f(t1), f(t2)), - } - } - - fn reverse(self) -> Self { - match self { - Simplifiable::NeverSatisfiable - | Simplifiable::AlwaysSatisfiable - | Simplifiable::Simplified(_) => self, - Simplifiable::NotSimplified(t1, t2) => Simplifiable::NotSimplified(t2, t1), - } + DisplaySatisfiedClause { clause: self, db } + } +} + +/// A list of the clauses that satisfy a BDD. This is a DNF representation of the boolean function +/// that the BDD represents. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +struct SatisfiedClauses<'db> { + clauses: Vec>, +} + +impl<'db> SatisfiedClauses<'db> { + fn push(&mut self, clause: SatisfiedClause<'db>) { + self.clauses.push(clause); + } + + /// Simplifies the DNF representation, removing redundancies that do not change the underlying + /// function. (This is used when displaying a BDD, to make sure that the representation that we + /// show is as simple as possible while still producing the same results.) + fn simplify(&mut self) { + while self.simplify_one_round() { + // Keep going + } + } + + fn simplify_one_round(&mut self) -> bool { + let mut changes_made = false; + + // First remove any duplicate clauses. (The clause list will start out with no duplicates + // in the first round of simplification, because of the guarantees provided by the BDD + // structure. But earlier rounds of simplification might have made some clauses redundant.) + // Note that we have to loop through the vector element indexes manually, since we might + // remove elements in each iteration. + let mut i = 0; + while i < self.clauses.len() { + let mut j = i + 1; + while j < self.clauses.len() { + if self.clauses[i] == self.clauses[j] { + self.clauses.swap_remove(j); + changes_made = true; + } else { + j += 1; + } + } + i += 1; + } + if changes_made { + return true; + } + + // Then look for "prefix simplifications". That is, looks for patterns + // + // (A ∧ B) ∨ (A ∧ ¬B ∧ ...) + // + // and replaces them with + // + // (A ∧ B) ∨ (...) + for i in 0..self.clauses.len() { + let (clause, rest) = self.clauses[..=i] + .split_last_mut() + .expect("index should be in range"); + clause.with_negated_last_constraint(|clause| { + for existing in rest { + changes_made |= existing.remove_prefix(clause); + } + }); + + let (clause, rest) = self.clauses[i..] + .split_first_mut() + .expect("index should be in range"); + clause.with_negated_last_constraint(|clause| { + for existing in rest { + changes_made |= existing.remove_prefix(clause); + } + }); + + if changes_made { + return true; + } + } + + false + } + + fn display(&self, db: &'db dyn Db) -> impl Display { + struct DisplaySatisfiedClauses<'a, 'db> { + clauses: &'a SatisfiedClauses<'db>, + db: &'db dyn Db, + } + + impl Display for DisplaySatisfiedClauses<'_, '_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.clauses.clauses.is_empty() { + return f.write_str("always"); + } + for (i, clause) in self.clauses.clauses.iter().enumerate() { + if i > 0 { + f.write_str(" ∨ ")?; + } + clause.display(self.db).fmt(f)?; + } + Ok(()) + } + } + + DisplaySatisfiedClauses { clauses: self, db } } } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 7650a4eac3..0edb463366 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -877,7 +877,7 @@ impl<'db> Specialization<'db> { } TypeVarVariance::Bivariant => ConstraintSet::from(true), }; - if result.intersect(db, &compatible).is_never_satisfied() { + if result.intersect(db, compatible).is_never_satisfied() { return result; } } @@ -918,7 +918,7 @@ impl<'db> Specialization<'db> { } TypeVarVariance::Bivariant => ConstraintSet::from(true), }; - if result.intersect(db, &compatible).is_never_satisfied() { + if result.intersect(db, compatible).is_never_satisfied() { return result; } } @@ -928,7 +928,7 @@ impl<'db> Specialization<'db> { (None, None) => {} (Some(self_tuple), Some(other_tuple)) => { let compatible = self_tuple.is_equivalent_to_impl(db, other_tuple, visitor); - if result.intersect(db, &compatible).is_never_satisfied() { + if result.intersect(db, compatible).is_never_satisfied() { return result; } } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 298c3b4644..79cd816b0a 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -6947,7 +6947,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::UnaryOp::Invert, Type::KnownInstance(KnownInstanceType::ConstraintSet(constraints)), ) => { - let constraints = constraints.constraints(self.db()).clone(); + let constraints = constraints.constraints(self.db()); let result = constraints.negate(self.db()); Type::KnownInstance(KnownInstanceType::ConstraintSet(TrackedConstraintSet::new( self.db(), @@ -7311,9 +7311,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::KnownInstance(KnownInstanceType::ConstraintSet(right)), ast::Operator::BitAnd, ) => { - let left = left.constraints(self.db()).clone(); - let right = right.constraints(self.db()).clone(); - let result = left.and(self.db(), || right); + let left = left.constraints(self.db()); + let right = right.constraints(self.db()); + let result = left.and(self.db(), || *right); Some(Type::KnownInstance(KnownInstanceType::ConstraintSet( TrackedConstraintSet::new(self.db(), result), ))) @@ -7324,9 +7324,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::KnownInstance(KnownInstanceType::ConstraintSet(right)), ast::Operator::BitOr, ) => { - let left = left.constraints(self.db()).clone(); - let right = right.constraints(self.db()).clone(); - let result = left.or(self.db(), || right); + let left = left.constraints(self.db()); + let right = right.constraints(self.db()); + let result = left.or(self.db(), || *right); Some(Type::KnownInstance(KnownInstanceType::ConstraintSet( TrackedConstraintSet::new(self.db(), result), ))) diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index ddc06951fc..d5ae64fdaf 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -551,10 +551,7 @@ impl<'db> Signature<'db> { let self_type = self_type.unwrap_or(Type::unknown()); let other_type = other_type.unwrap_or(Type::unknown()); !result - .intersect( - db, - &self_type.is_equivalent_to_impl(db, other_type, visitor), - ) + .intersect(db, self_type.is_equivalent_to_impl(db, other_type, visitor)) .is_never_satisfied() }; @@ -699,10 +696,7 @@ impl<'db> Signature<'db> { let type1 = type1.unwrap_or(Type::unknown()); let type2 = type2.unwrap_or(Type::unknown()); !result - .intersect( - db, - &type1.has_relation_to_impl(db, type2, relation, visitor), - ) + .intersect(db, type1.has_relation_to_impl(db, type2, relation, visitor)) .is_never_satisfied() }; diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index db7c4d26b7..c6d669612a 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -439,7 +439,7 @@ impl<'db> FixedLengthTuple> { let element_constraints = self_ty.has_relation_to_impl(db, *other_ty, relation, visitor); if result - .intersect(db, &element_constraints) + .intersect(db, element_constraints) .is_never_satisfied() { return result; @@ -452,7 +452,7 @@ impl<'db> FixedLengthTuple> { let element_constraints = self_ty.has_relation_to_impl(db, *other_ty, relation, visitor); if result - .intersect(db, &element_constraints) + .intersect(db, element_constraints) .is_never_satisfied() { return result; @@ -774,7 +774,7 @@ impl<'db> VariableLengthTuple> { let element_constraints = self_ty.has_relation_to_impl(db, other_ty, relation, visitor); if result - .intersect(db, &element_constraints) + .intersect(db, element_constraints) .is_never_satisfied() { return result; @@ -788,7 +788,7 @@ impl<'db> VariableLengthTuple> { let element_constraints = self_ty.has_relation_to_impl(db, other_ty, relation, visitor); if result - .intersect(db, &element_constraints) + .intersect(db, element_constraints) .is_never_satisfied() { return result; @@ -832,7 +832,7 @@ impl<'db> VariableLengthTuple> { return ConstraintSet::from(false); } }; - if result.intersect(db, &pair_constraints).is_never_satisfied() { + if result.intersect(db, pair_constraints).is_never_satisfied() { return result; } } @@ -858,7 +858,7 @@ impl<'db> VariableLengthTuple> { return ConstraintSet::from(false); } }; - if result.intersect(db, &pair_constraints).is_never_satisfied() { + if result.intersect(db, pair_constraints).is_never_satisfied() { return result; } }