[ty] Some more simplifications when rendering constraint sets (#21009)

This PR adds another useful simplification when rendering constraint
sets: `T = int` instead of `T = int ∧ T ≠ str`. (The "smaller"
constraint `T = int` implies the "larger" constraint `T ≠ str`.
Constraint set clauses are intersections, and if one constraint in a
clause implies another, we can throw away the "larger" constraint.)

While we're here, we also normalize the bounds of a constraint, so that
we equate e.g. `T ≤ int | str` with `T ≤ str | int`, and change the
ordering of BDD variables so that all constraints with the same typevar
are ordered adjacent to each other.

Lastly, we also add a new `display_graph` helper method that prints out
the full graph structure of a BDD.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Douglas Creager 2025-10-22 13:38:44 -04:00 committed by GitHub
parent 81c1d36088
commit 766ed5b5f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 314 additions and 23 deletions

View file

@ -22,7 +22,12 @@ use crate::unpack::{Unpack, UnpackPosition};
/// because a new scope gets inserted before the `Definition` or a new place is inserted
/// before this `Definition`. However, the ID can be considered stable and it is okay to use
/// `Definition` in cross-module` salsa queries or as a field on other salsa tracked structs.
///
/// # Ordering
/// Ordering is based on the definition's salsa-assigned id and not on its values.
/// The id may change between runs, or when the definition was garbage collected and recreated.
#[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)]
#[derive(Ord, PartialOrd)]
pub struct Definition<'db> {
/// The file in which the definition occurs.
pub file: File,

View file

@ -8461,7 +8461,9 @@ fn lazy_bound_or_constraints_cycle_initial<'db>(
}
/// Where a type variable is bound and usable.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
#[derive(
Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Update, get_size2::GetSize,
)]
pub enum BindingContext<'db> {
/// The definition of the generic class, function, or type alias that binds this typevar.
Definition(Definition<'db>),
@ -8495,7 +8497,9 @@ impl<'db> BindingContext<'db> {
/// independent of the typevar's bounds or constraints. Two bound typevars have the same identity
/// if they represent the same logical typevar bound in the same context, even if their bounds
/// have been materialized differently.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, get_size2::GetSize, salsa::Update)]
#[derive(
Debug, Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd, get_size2::GetSize, salsa::Update,
)]
pub struct BoundTypeVarIdentity<'db> {
pub(crate) identity: TypeVarIdentity<'db>,
pub(crate) binding_context: BindingContext<'db>,

View file

@ -23,6 +23,9 @@
//! Note that all lower and upper bounds in a constraint must be fully static. We take the bottom
//! and top materializations of the types to remove any gradual forms if needed.
//!
//! Lower and upper bounds must also be normalized. This lets us identify, for instance,
//! two constraints with equivalent but differently ordered unions as their bounds.
//!
//! NOTE: This module is currently in a transitional state. We've added the BDD [`ConstraintSet`]
//! representation, and updated all of our property checks to build up a constraint set and then
//! check whether it is ever or always satisfiable, as appropriate. We are not yet inferring
@ -58,6 +61,7 @@ use std::fmt::Display;
use itertools::Itertools;
use rustc_hash::FxHashSet;
use salsa::plumbing::AsId;
use crate::Db;
use crate::types::{BoundTypeVarIdentity, IntersectionType, Type, UnionType};
@ -183,20 +187,20 @@ impl<'db> ConstraintSet<'db> {
/// Updates this constraint set to hold the union of itself and another constraint set.
pub(crate) fn union(&mut self, db: &'db dyn Db, other: Self) -> Self {
self.node = self.node.or(db, other.node).simplify(db);
self.node = self.node.or(db, other.node);
*self
}
/// Updates this constraint set to hold the intersection of itself and another constraint set.
pub(crate) fn intersect(&mut self, db: &'db dyn Db, other: Self) -> Self {
self.node = self.node.and(db, other.node).simplify(db);
self.node = self.node.and(db, other.node);
*self
}
/// Returns the negation of this constraint set.
pub(crate) fn negate(self, db: &'db dyn Db) -> Self {
Self {
node: self.node.negate(db).simplify(db),
node: self.node.negate(db),
}
}
@ -256,7 +260,6 @@ impl From<bool> for ConstraintSet<'_> {
/// An individual constraint in a constraint set. This restricts a single typevar to be within a
/// lower and upper bound.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
#[derive(PartialOrd, Ord)]
pub(crate) struct ConstrainedTypeVar<'db> {
typevar: BoundTypeVarIdentity<'db>,
lower: Type<'db>,
@ -292,8 +295,11 @@ impl<'db> ConstrainedTypeVar<'db> {
return Node::AlwaysTrue;
}
let lower = lower.normalized(db);
let upper = upper.normalized(db);
Node::new_constraint(db, ConstrainedTypeVar::new(db, typevar, lower, upper))
}
fn when_true(self) -> ConstraintAssignment<'db> {
ConstraintAssignment::Positive(self)
}
@ -310,11 +316,37 @@ impl<'db> ConstrainedTypeVar<'db> {
&& other.upper(db).is_subtype_of(db, self.upper(db))
}
/// Defines the ordering of the variables in a constraint set BDD.
///
/// If we only care about _correctness_, we can choose any ordering that we want, as long as
/// it's consistent. However, different orderings can have very different _performance_
/// characteristics. Many BDD libraries attempt to reorder variables on the fly while building
/// and working with BDDs. We don't do that, but we have tried to make some simple choices that
/// have clear wins.
///
/// In particular, we compare the _typevars_ of each constraint first, so that all constraints
/// for a single typevar are guaranteed to be adjacent in the BDD structure. There are several
/// simplifications that we perform that operate on constraints with the same typevar, and this
/// ensures that we can find all candidate simplifications more easily.
fn ordering(self, db: &'db dyn Db) -> impl Ord {
(self.typevar(db), self.as_id())
}
/// Returns whether this constraint implies another — i.e., whether every type that
/// satisfies this constraint also satisfies `other`.
///
/// This is used (among other places) to simplify how we display constraint sets, by removing
/// redundant constraints from a clause.
fn implies(self, db: &'db dyn Db, other: Self) -> bool {
other.contains(db, self)
}
/// Returns the intersection of two range constraints, or `None` if the intersection is empty.
fn intersect(self, db: &'db dyn Db, other: Self) -> Option<Self> {
// (s₁ ≤ α ≤ t₁) ∧ (s₂ ≤ α ≤ t₂) = (s₁ s₂) ≤ α ≤ (t₁ ∩ t₂))
let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]);
let upper = IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]);
let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]).normalized(db);
let upper =
IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]).normalized(db);
// If `lower ≰ upper`, then the intersection is empty, since there is no type that is both
// greater than `lower`, and less than `upper`.
@ -390,8 +422,8 @@ impl<'db> ConstrainedTypeVar<'db> {
/// that point at the same node.
///
/// BDD nodes are also _ordered_, meaning that every path from the root of a BDD to a terminal node
/// visits variables in the same order. [`ConstrainedTypeVar`]s are interned, so we can use the IDs
/// that salsa assigns to define this order.
/// visits variables in the same order. [`ConstrainedTypeVar::ordering`] defines the variable
/// ordering that we use for constraint set BDDs.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)]
enum Node<'db> {
AlwaysFalse,
@ -407,13 +439,13 @@ impl<'db> Node<'db> {
if_true: Node<'db>,
if_false: Node<'db>,
) -> Self {
debug_assert!((if_true.root_constraint(db)).is_none_or(|root_constraint| {
root_constraint.ordering(db) > constraint.ordering(db)
}));
debug_assert!(
(if_true.root_constraint(db))
.is_none_or(|root_constraint| root_constraint > constraint)
);
debug_assert!(
(if_false.root_constraint(db))
.is_none_or(|root_constraint| root_constraint > constraint)
(if_false.root_constraint(db)).is_none_or(|root_constraint| {
root_constraint.ordering(db) > constraint.ordering(db)
})
);
if if_true == if_false {
return if_true;
@ -762,14 +794,87 @@ impl<'db> Node<'db> {
Node::AlwaysFalse => f.write_str("never"),
Node::Interior(_) => {
let mut clauses = self.node.satisfied_clauses(self.db);
clauses.simplify();
clauses.simplify(self.db);
clauses.display(self.db).fmt(f)
}
}
}
}
DisplayNode { node: self, db }
DisplayNode {
node: self.simplify(db),
db,
}
}
/// Displays the full graph structure of this BDD. `prefix` will be output before each line
/// other than the first. Produces output like the following:
///
/// ```text
/// (T@_ = str)
/// ┡━₁ (U@_ = str)
/// │ ┡━₁ always
/// │ └─₀ (U@_ = bool)
/// │ ┡━₁ always
/// │ └─₀ never
/// └─₀ (T@_ = bool)
/// ┡━₁ (U@_ = str)
/// │ ┡━₁ always
/// │ └─₀ (U@_ = bool)
/// │ ┡━₁ always
/// │ └─₀ never
/// └─₀ never
/// ```
#[cfg_attr(not(test), expect(dead_code))] // Keep this around for debugging purposes
fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display {
struct DisplayNode<'a, 'db> {
db: &'db dyn Db,
node: Node<'db>,
prefix: &'a dyn Display,
}
impl<'a, 'db> DisplayNode<'a, 'db> {
fn new(db: &'db dyn Db, node: Node<'db>, prefix: &'a dyn Display) -> Self {
Self { db, node, prefix }
}
}
impl Display for DisplayNode<'_, '_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.node {
Node::AlwaysTrue => write!(f, "always"),
Node::AlwaysFalse => write!(f, "never"),
Node::Interior(interior) => {
interior.constraint(self.db).display(self.db).fmt(f)?;
// Calling display_graph recursively here causes rustc to claim that the
// expect(unused) up above is unfulfilled!
write!(
f,
"\n{}┡━₁ {}",
self.prefix,
DisplayNode::new(
self.db,
interior.if_true(self.db),
&format_args!("{}", self.prefix)
),
)?;
write!(
f,
"\n{}└─₀ {}",
self.prefix,
DisplayNode::new(
self.db,
interior.if_false(self.db),
&format_args!("{} ", self.prefix)
),
)?;
Ok(())
}
}
}
}
DisplayNode::new(db, self, prefix)
}
}
@ -800,7 +905,7 @@ impl<'db> InteriorNode<'db> {
fn or(self, db: &'db dyn Db, other: Self) -> Node<'db> {
let self_constraint = self.constraint(db);
let other_constraint = other.constraint(db);
match self_constraint.cmp(&other_constraint) {
match (self_constraint.ordering(db)).cmp(&other_constraint.ordering(db)) {
Ordering::Equal => Node::new(
db,
self_constraint,
@ -826,7 +931,7 @@ impl<'db> InteriorNode<'db> {
fn and(self, db: &'db dyn Db, other: Self) -> Node<'db> {
let self_constraint = self.constraint(db);
let other_constraint = other.constraint(db);
match self_constraint.cmp(&other_constraint) {
match (self_constraint.ordering(db)).cmp(&other_constraint.ordering(db)) {
Ordering::Equal => Node::new(
db,
self_constraint,
@ -852,7 +957,7 @@ impl<'db> InteriorNode<'db> {
fn iff(self, db: &'db dyn Db, other: Self) -> Node<'db> {
let self_constraint = self.constraint(db);
let other_constraint = other.constraint(db);
match self_constraint.cmp(&other_constraint) {
match (self_constraint.ordering(db)).cmp(&other_constraint.ordering(db)) {
Ordering::Equal => Node::new(
db,
self_constraint,
@ -884,7 +989,7 @@ impl<'db> InteriorNode<'db> {
// point in the BDD where the assignment can no longer affect the result,
// and we can return early.
let self_constraint = self.constraint(db);
if assignment.constraint() < self_constraint {
if assignment.constraint().ordering(db) < self_constraint.ordering(db) {
return (Node::Interior(self), false);
}
@ -1141,6 +1246,55 @@ impl<'db> ConstraintAssignment<'db> {
*self = self.negated();
}
/// Returns whether this constraint implies another — i.e., whether every type that
/// satisfies this constraint also satisfies `other`.
///
/// This is used (among other places) to simplify how we display constraint sets, by removing
/// redundant constraints from a clause.
fn implies(self, db: &'db dyn Db, other: Self) -> bool {
match (self, other) {
// For two positive constraints, one range has to fully contain the other; the smaller
// constraint implies the larger.
//
// ....|----other-----|....
// ......|---self---|......
(
ConstraintAssignment::Positive(self_constraint),
ConstraintAssignment::Positive(other_constraint),
) => self_constraint.implies(db, other_constraint),
// For two negative constraints, one range has to fully contain the other; the ranges
// represent "holes", though, so the constraint with the larger range implies the one
// with the smaller.
//
// |-----|...other...|-----|
// |---|.....self......|---|
(
ConstraintAssignment::Negative(self_constraint),
ConstraintAssignment::Negative(other_constraint),
) => other_constraint.implies(db, self_constraint),
// For a positive and negative constraint, the ranges have to be disjoint, and the
// positive range implies the negative range.
//
// |---------------|...self...|---|
// ..|---other---|................|
(
ConstraintAssignment::Positive(self_constraint),
ConstraintAssignment::Negative(other_constraint),
) => self_constraint.intersect(db, other_constraint).is_none(),
// It's theoretically possible for a negative constraint to imply a positive constraint
// if the positive constraint is always satisfied (`Never ≤ T ≤ object`). But we never
// create constraints of that form, so with our representation, a negative constraint
// can never imply a positive constraint.
//
// |------other-------|
// |---|...self...|---|
(ConstraintAssignment::Negative(_), ConstraintAssignment::Positive(_)) => false,
}
}
// Keep this for future debugging needs, even though it's not currently used when rendering
// constraint sets.
#[expect(dead_code)]
@ -1209,6 +1363,43 @@ impl<'db> SatisfiedClause<'db> {
false
}
/// Simplifies this clause by removing constraints that are implied by other constraints in the
/// clause. (Clauses are the intersection of constraints, so if two clauses are redundant, we
/// want to remove the larger one and keep the smaller one.)
///
/// Returns a boolean that indicates whether any simplifications were made.
fn simplify(&mut self, db: &'db dyn Db) -> bool {
let mut changes_made = false;
let mut i = 0;
// Loop through each constraint, comparing it with any constraints that appear later in the
// list.
'outer: while i < self.constraints.len() {
let mut j = i + 1;
while j < self.constraints.len() {
if self.constraints[j].implies(db, self.constraints[i]) {
// If constraint `i` is removed, then we don't need to compare it with any
// later constraints in the list. Note that we continue the outer loop, instead
// of breaking from the inner loop, so that we don't bump index `i` below.
// (We'll have swapped another element into place at that index, and want to
// make sure that we process it.)
self.constraints.swap_remove(i);
changes_made = true;
continue 'outer;
} else if self.constraints[i].implies(db, self.constraints[j]) {
// If constraint `j` is removed, then we can continue the inner loop. We will
// swap a new element into place at index `j`, and will continue comparing the
// constraint at index `i` with later constraints.
self.constraints.swap_remove(j);
changes_made = true;
} else {
j += 1;
}
}
i += 1;
}
changes_made
}
fn display(&self, db: &'db dyn Db) -> String {
// This is a bit heavy-handed, but we need to output the constraints in a consistent order
// even though Salsa IDs are assigned non-deterministically. This Display output is only
@ -1258,7 +1449,13 @@ impl<'db> SatisfiedClauses<'db> {
/// Simplifies the DNF representation, removing redundancies that do not change the underlying
/// function. (This is used when displaying a BDD, to make sure that the representation that we
/// show is as simple as possible while still producing the same results.)
fn simplify(&mut self) {
fn simplify(&mut self, db: &'db dyn Db) {
// First simplify each clause individually, by removing constraints that are implied by
// other constraints in the clause.
for clause in &mut self.clauses {
clause.simplify(db);
}
while self.simplify_one_round() {
// Keep going
}
@ -1340,3 +1537,47 @@ impl<'db> SatisfiedClauses<'db> {
clauses.join(" ")
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use pretty_assertions::assert_eq;
use crate::db::tests::setup_db;
use crate::types::{BoundTypeVarInstance, KnownClass, TypeVarVariance};
#[test]
fn test_display_graph_output() {
let expected = indoc! {r#"
(T = str)
(U = str)
always
(U = bool)
always
never
(T = bool)
(U = str)
always
(U = bool)
always
never
never
"#}
.trim_end();
let db = setup_db();
let t = BoundTypeVarInstance::synthetic(&db, "T", TypeVarVariance::Invariant);
let u = BoundTypeVarInstance::synthetic(&db, "U", TypeVarVariance::Invariant);
let bool_type = KnownClass::Bool.to_instance(&db);
let str_type = KnownClass::Str.to_instance(&db);
let t_str = ConstraintSet::range(&db, str_type, t.identity(&db), str_type);
let t_bool = ConstraintSet::range(&db, bool_type, t.identity(&db), bool_type);
let u_str = ConstraintSet::range(&db, str_type, u.identity(&db), str_type);
let u_bool = ConstraintSet::range(&db, bool_type, u.identity(&db), bool_type);
let constraints = (t_str.or(&db, || t_bool)).and(&db, || u_str.or(&db, || u_bool));
let actual = constraints.node.display_graph(&db, &"").to_string();
assert_eq!(actual, expected);
}
}