[ty] Some more simplifications when rendering constraint sets (#21009)

This PR adds another useful simplification when rendering constraint
sets: `T = int` instead of `T = int ∧ T ≠ str`. (The "smaller"
constraint `T = int` implies the "larger" constraint `T ≠ str`.
Constraint set clauses are intersections, and if one constraint in a
clause implies another, we can throw away the "larger" constraint.)

While we're here, we also normalize the bounds of a constraint, so that
we equate e.g. `T ≤ int | str` with `T ≤ str | int`, and change the
ordering of BDD variables so that all constraints with the same typevar
are ordered adjacent to each other.

Lastly, we also add a new `display_graph` helper method that prints out
the full graph structure of a BDD.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Douglas Creager 2025-10-22 13:38:44 -04:00 committed by GitHub
parent 81c1d36088
commit 766ed5b5f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 314 additions and 23 deletions

2
Cargo.lock generated
View file

@ -4434,10 +4434,12 @@ dependencies = [
"glob",
"hashbrown 0.16.0",
"indexmap",
"indoc",
"insta",
"itertools 0.14.0",
"memchr",
"ordermap",
"pretty_assertions",
"quickcheck",
"quickcheck_macros",
"ruff_annotate_snippets",

View file

@ -63,7 +63,9 @@ ty_vendored = { workspace = true }
anyhow = { workspace = true }
dir-test = { workspace = true }
glob = { workspace = true }
indoc = { workspace = true }
insta = { workspace = true }
pretty_assertions = { workspace = true }
tempfile = { workspace = true }
quickcheck = { version = "1.0.3", default-features = false }
quickcheck_macros = { version = "1.0.0" }

View file

@ -601,3 +601,40 @@ def _[T, U]() -> None:
# revealed: ty_extensions.ConstraintSet[always]
reveal_type(~union | union)
```
## Other simplifications
When displaying a constraint set, we transform the internal BDD representation into a DNF formula
(i.e., the logical OR of several clauses, each of which is the logical AND of several constraints).
This section contains several examples that show that we simplify the DNF formula as much as we can
before displaying it.
```py
from ty_extensions import range_constraint
def f[T, U]():
t1 = range_constraint(str, T, str)
t2 = range_constraint(bool, T, bool)
u1 = range_constraint(str, U, str)
u2 = range_constraint(bool, U, bool)
# revealed: ty_extensions.ConstraintSet[(T@f = bool) (T@f = str)]
reveal_type(t1 | t2)
# revealed: ty_extensions.ConstraintSet[(U@f = bool) (U@f = str)]
reveal_type(u1 | u2)
# revealed: ty_extensions.ConstraintSet[((T@f = bool) ∧ (U@f = bool)) ((T@f = bool) ∧ (U@f = str)) ((T@f = str) ∧ (U@f = bool)) ((T@f = str) ∧ (U@f = str))]
reveal_type((t1 | t2) & (u1 | u2))
```
The lower and upper bounds of a constraint are normalized, so that we equate unions and
intersections whose elements appear in different orders.
```py
from typing import Never
def f[T]():
# revealed: ty_extensions.ConstraintSet[(T@f ≤ int | str)]
reveal_type(range_constraint(Never, T, str | int))
# revealed: ty_extensions.ConstraintSet[(T@f ≤ int | str)]
reveal_type(range_constraint(Never, T, int | str))
```

View file

@ -22,7 +22,12 @@ use crate::unpack::{Unpack, UnpackPosition};
/// because a new scope gets inserted before the `Definition` or a new place is inserted
/// before this `Definition`. However, the ID can be considered stable and it is okay to use
/// `Definition` in cross-module` salsa queries or as a field on other salsa tracked structs.
///
/// # Ordering
/// Ordering is based on the definition's salsa-assigned id and not on its values.
/// The id may change between runs, or when the definition was garbage collected and recreated.
#[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)]
#[derive(Ord, PartialOrd)]
pub struct Definition<'db> {
/// The file in which the definition occurs.
pub file: File,

View file

@ -8461,7 +8461,9 @@ fn lazy_bound_or_constraints_cycle_initial<'db>(
}
/// Where a type variable is bound and usable.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
#[derive(
Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Update, get_size2::GetSize,
)]
pub enum BindingContext<'db> {
/// The definition of the generic class, function, or type alias that binds this typevar.
Definition(Definition<'db>),
@ -8495,7 +8497,9 @@ impl<'db> BindingContext<'db> {
/// independent of the typevar's bounds or constraints. Two bound typevars have the same identity
/// if they represent the same logical typevar bound in the same context, even if their bounds
/// have been materialized differently.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, get_size2::GetSize, salsa::Update)]
#[derive(
Debug, Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd, get_size2::GetSize, salsa::Update,
)]
pub struct BoundTypeVarIdentity<'db> {
pub(crate) identity: TypeVarIdentity<'db>,
pub(crate) binding_context: BindingContext<'db>,

View file

@ -23,6 +23,9 @@
//! Note that all lower and upper bounds in a constraint must be fully static. We take the bottom
//! and top materializations of the types to remove any gradual forms if needed.
//!
//! Lower and upper bounds must also be normalized. This lets us identify, for instance,
//! two constraints with equivalent but differently ordered unions as their bounds.
//!
//! NOTE: This module is currently in a transitional state. We've added the BDD [`ConstraintSet`]
//! representation, and updated all of our property checks to build up a constraint set and then
//! check whether it is ever or always satisfiable, as appropriate. We are not yet inferring
@ -58,6 +61,7 @@ use std::fmt::Display;
use itertools::Itertools;
use rustc_hash::FxHashSet;
use salsa::plumbing::AsId;
use crate::Db;
use crate::types::{BoundTypeVarIdentity, IntersectionType, Type, UnionType};
@ -183,20 +187,20 @@ impl<'db> ConstraintSet<'db> {
/// Updates this constraint set to hold the union of itself and another constraint set.
pub(crate) fn union(&mut self, db: &'db dyn Db, other: Self) -> Self {
self.node = self.node.or(db, other.node).simplify(db);
self.node = self.node.or(db, other.node);
*self
}
/// Updates this constraint set to hold the intersection of itself and another constraint set.
pub(crate) fn intersect(&mut self, db: &'db dyn Db, other: Self) -> Self {
self.node = self.node.and(db, other.node).simplify(db);
self.node = self.node.and(db, other.node);
*self
}
/// Returns the negation of this constraint set.
pub(crate) fn negate(self, db: &'db dyn Db) -> Self {
Self {
node: self.node.negate(db).simplify(db),
node: self.node.negate(db),
}
}
@ -256,7 +260,6 @@ impl From<bool> for ConstraintSet<'_> {
/// An individual constraint in a constraint set. This restricts a single typevar to be within a
/// lower and upper bound.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
#[derive(PartialOrd, Ord)]
pub(crate) struct ConstrainedTypeVar<'db> {
typevar: BoundTypeVarIdentity<'db>,
lower: Type<'db>,
@ -292,8 +295,11 @@ impl<'db> ConstrainedTypeVar<'db> {
return Node::AlwaysTrue;
}
let lower = lower.normalized(db);
let upper = upper.normalized(db);
Node::new_constraint(db, ConstrainedTypeVar::new(db, typevar, lower, upper))
}
fn when_true(self) -> ConstraintAssignment<'db> {
ConstraintAssignment::Positive(self)
}
@ -310,11 +316,37 @@ impl<'db> ConstrainedTypeVar<'db> {
&& other.upper(db).is_subtype_of(db, self.upper(db))
}
/// Defines the ordering of the variables in a constraint set BDD.
///
/// If we only care about _correctness_, we can choose any ordering that we want, as long as
/// it's consistent. However, different orderings can have very different _performance_
/// characteristics. Many BDD libraries attempt to reorder variables on the fly while building
/// and working with BDDs. We don't do that, but we have tried to make some simple choices that
/// have clear wins.
///
/// In particular, we compare the _typevars_ of each constraint first, so that all constraints
/// for a single typevar are guaranteed to be adjacent in the BDD structure. There are several
/// simplifications that we perform that operate on constraints with the same typevar, and this
/// ensures that we can find all candidate simplifications more easily.
fn ordering(self, db: &'db dyn Db) -> impl Ord {
(self.typevar(db), self.as_id())
}
/// Returns whether this constraint implies another — i.e., whether every type that
/// satisfies this constraint also satisfies `other`.
///
/// This is used (among other places) to simplify how we display constraint sets, by removing
/// redundant constraints from a clause.
fn implies(self, db: &'db dyn Db, other: Self) -> bool {
other.contains(db, self)
}
/// Returns the intersection of two range constraints, or `None` if the intersection is empty.
fn intersect(self, db: &'db dyn Db, other: Self) -> Option<Self> {
// (s₁ ≤ α ≤ t₁) ∧ (s₂ ≤ α ≤ t₂) = (s₁ s₂) ≤ α ≤ (t₁ ∩ t₂))
let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]);
let upper = IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]);
let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]).normalized(db);
let upper =
IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]).normalized(db);
// If `lower ≰ upper`, then the intersection is empty, since there is no type that is both
// greater than `lower`, and less than `upper`.
@ -390,8 +422,8 @@ impl<'db> ConstrainedTypeVar<'db> {
/// that point at the same node.
///
/// BDD nodes are also _ordered_, meaning that every path from the root of a BDD to a terminal node
/// visits variables in the same order. [`ConstrainedTypeVar`]s are interned, so we can use the IDs
/// that salsa assigns to define this order.
/// visits variables in the same order. [`ConstrainedTypeVar::ordering`] defines the variable
/// ordering that we use for constraint set BDDs.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)]
enum Node<'db> {
AlwaysFalse,
@ -407,13 +439,13 @@ impl<'db> Node<'db> {
if_true: Node<'db>,
if_false: Node<'db>,
) -> Self {
debug_assert!((if_true.root_constraint(db)).is_none_or(|root_constraint| {
root_constraint.ordering(db) > constraint.ordering(db)
}));
debug_assert!(
(if_true.root_constraint(db))
.is_none_or(|root_constraint| root_constraint > constraint)
);
debug_assert!(
(if_false.root_constraint(db))
.is_none_or(|root_constraint| root_constraint > constraint)
(if_false.root_constraint(db)).is_none_or(|root_constraint| {
root_constraint.ordering(db) > constraint.ordering(db)
})
);
if if_true == if_false {
return if_true;
@ -762,14 +794,87 @@ impl<'db> Node<'db> {
Node::AlwaysFalse => f.write_str("never"),
Node::Interior(_) => {
let mut clauses = self.node.satisfied_clauses(self.db);
clauses.simplify();
clauses.simplify(self.db);
clauses.display(self.db).fmt(f)
}
}
}
}
DisplayNode { node: self, db }
DisplayNode {
node: self.simplify(db),
db,
}
}
/// Displays the full graph structure of this BDD. `prefix` will be output before each line
/// other than the first. Produces output like the following:
///
/// ```text
/// (T@_ = str)
/// ┡━₁ (U@_ = str)
/// │ ┡━₁ always
/// │ └─₀ (U@_ = bool)
/// │ ┡━₁ always
/// │ └─₀ never
/// └─₀ (T@_ = bool)
/// ┡━₁ (U@_ = str)
/// │ ┡━₁ always
/// │ └─₀ (U@_ = bool)
/// │ ┡━₁ always
/// │ └─₀ never
/// └─₀ never
/// ```
#[cfg_attr(not(test), expect(dead_code))] // Keep this around for debugging purposes
fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display {
struct DisplayNode<'a, 'db> {
db: &'db dyn Db,
node: Node<'db>,
prefix: &'a dyn Display,
}
impl<'a, 'db> DisplayNode<'a, 'db> {
fn new(db: &'db dyn Db, node: Node<'db>, prefix: &'a dyn Display) -> Self {
Self { db, node, prefix }
}
}
impl Display for DisplayNode<'_, '_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.node {
Node::AlwaysTrue => write!(f, "always"),
Node::AlwaysFalse => write!(f, "never"),
Node::Interior(interior) => {
interior.constraint(self.db).display(self.db).fmt(f)?;
// Calling display_graph recursively here causes rustc to claim that the
// expect(unused) up above is unfulfilled!
write!(
f,
"\n{}┡━₁ {}",
self.prefix,
DisplayNode::new(
self.db,
interior.if_true(self.db),
&format_args!("{}", self.prefix)
),
)?;
write!(
f,
"\n{}└─₀ {}",
self.prefix,
DisplayNode::new(
self.db,
interior.if_false(self.db),
&format_args!("{} ", self.prefix)
),
)?;
Ok(())
}
}
}
}
DisplayNode::new(db, self, prefix)
}
}
@ -800,7 +905,7 @@ impl<'db> InteriorNode<'db> {
fn or(self, db: &'db dyn Db, other: Self) -> Node<'db> {
let self_constraint = self.constraint(db);
let other_constraint = other.constraint(db);
match self_constraint.cmp(&other_constraint) {
match (self_constraint.ordering(db)).cmp(&other_constraint.ordering(db)) {
Ordering::Equal => Node::new(
db,
self_constraint,
@ -826,7 +931,7 @@ impl<'db> InteriorNode<'db> {
fn and(self, db: &'db dyn Db, other: Self) -> Node<'db> {
let self_constraint = self.constraint(db);
let other_constraint = other.constraint(db);
match self_constraint.cmp(&other_constraint) {
match (self_constraint.ordering(db)).cmp(&other_constraint.ordering(db)) {
Ordering::Equal => Node::new(
db,
self_constraint,
@ -852,7 +957,7 @@ impl<'db> InteriorNode<'db> {
fn iff(self, db: &'db dyn Db, other: Self) -> Node<'db> {
let self_constraint = self.constraint(db);
let other_constraint = other.constraint(db);
match self_constraint.cmp(&other_constraint) {
match (self_constraint.ordering(db)).cmp(&other_constraint.ordering(db)) {
Ordering::Equal => Node::new(
db,
self_constraint,
@ -884,7 +989,7 @@ impl<'db> InteriorNode<'db> {
// point in the BDD where the assignment can no longer affect the result,
// and we can return early.
let self_constraint = self.constraint(db);
if assignment.constraint() < self_constraint {
if assignment.constraint().ordering(db) < self_constraint.ordering(db) {
return (Node::Interior(self), false);
}
@ -1141,6 +1246,55 @@ impl<'db> ConstraintAssignment<'db> {
*self = self.negated();
}
/// Returns whether this constraint implies another — i.e., whether every type that
/// satisfies this constraint also satisfies `other`.
///
/// This is used (among other places) to simplify how we display constraint sets, by removing
/// redundant constraints from a clause.
fn implies(self, db: &'db dyn Db, other: Self) -> bool {
match (self, other) {
// For two positive constraints, one range has to fully contain the other; the smaller
// constraint implies the larger.
//
// ....|----other-----|....
// ......|---self---|......
(
ConstraintAssignment::Positive(self_constraint),
ConstraintAssignment::Positive(other_constraint),
) => self_constraint.implies(db, other_constraint),
// For two negative constraints, one range has to fully contain the other; the ranges
// represent "holes", though, so the constraint with the larger range implies the one
// with the smaller.
//
// |-----|...other...|-----|
// |---|.....self......|---|
(
ConstraintAssignment::Negative(self_constraint),
ConstraintAssignment::Negative(other_constraint),
) => other_constraint.implies(db, self_constraint),
// For a positive and negative constraint, the ranges have to be disjoint, and the
// positive range implies the negative range.
//
// |---------------|...self...|---|
// ..|---other---|................|
(
ConstraintAssignment::Positive(self_constraint),
ConstraintAssignment::Negative(other_constraint),
) => self_constraint.intersect(db, other_constraint).is_none(),
// It's theoretically possible for a negative constraint to imply a positive constraint
// if the positive constraint is always satisfied (`Never ≤ T ≤ object`). But we never
// create constraints of that form, so with our representation, a negative constraint
// can never imply a positive constraint.
//
// |------other-------|
// |---|...self...|---|
(ConstraintAssignment::Negative(_), ConstraintAssignment::Positive(_)) => false,
}
}
// Keep this for future debugging needs, even though it's not currently used when rendering
// constraint sets.
#[expect(dead_code)]
@ -1209,6 +1363,43 @@ impl<'db> SatisfiedClause<'db> {
false
}
/// Simplifies this clause by removing constraints that are implied by other constraints in the
/// clause. (Clauses are the intersection of constraints, so if two clauses are redundant, we
/// want to remove the larger one and keep the smaller one.)
///
/// Returns a boolean that indicates whether any simplifications were made.
fn simplify(&mut self, db: &'db dyn Db) -> bool {
let mut changes_made = false;
let mut i = 0;
// Loop through each constraint, comparing it with any constraints that appear later in the
// list.
'outer: while i < self.constraints.len() {
let mut j = i + 1;
while j < self.constraints.len() {
if self.constraints[j].implies(db, self.constraints[i]) {
// If constraint `i` is removed, then we don't need to compare it with any
// later constraints in the list. Note that we continue the outer loop, instead
// of breaking from the inner loop, so that we don't bump index `i` below.
// (We'll have swapped another element into place at that index, and want to
// make sure that we process it.)
self.constraints.swap_remove(i);
changes_made = true;
continue 'outer;
} else if self.constraints[i].implies(db, self.constraints[j]) {
// If constraint `j` is removed, then we can continue the inner loop. We will
// swap a new element into place at index `j`, and will continue comparing the
// constraint at index `i` with later constraints.
self.constraints.swap_remove(j);
changes_made = true;
} else {
j += 1;
}
}
i += 1;
}
changes_made
}
fn display(&self, db: &'db dyn Db) -> String {
// This is a bit heavy-handed, but we need to output the constraints in a consistent order
// even though Salsa IDs are assigned non-deterministically. This Display output is only
@ -1258,7 +1449,13 @@ impl<'db> SatisfiedClauses<'db> {
/// Simplifies the DNF representation, removing redundancies that do not change the underlying
/// function. (This is used when displaying a BDD, to make sure that the representation that we
/// show is as simple as possible while still producing the same results.)
fn simplify(&mut self) {
fn simplify(&mut self, db: &'db dyn Db) {
// First simplify each clause individually, by removing constraints that are implied by
// other constraints in the clause.
for clause in &mut self.clauses {
clause.simplify(db);
}
while self.simplify_one_round() {
// Keep going
}
@ -1340,3 +1537,47 @@ impl<'db> SatisfiedClauses<'db> {
clauses.join(" ")
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use pretty_assertions::assert_eq;
use crate::db::tests::setup_db;
use crate::types::{BoundTypeVarInstance, KnownClass, TypeVarVariance};
#[test]
fn test_display_graph_output() {
let expected = indoc! {r#"
(T = str)
(U = str)
always
(U = bool)
always
never
(T = bool)
(U = str)
always
(U = bool)
always
never
never
"#}
.trim_end();
let db = setup_db();
let t = BoundTypeVarInstance::synthetic(&db, "T", TypeVarVariance::Invariant);
let u = BoundTypeVarInstance::synthetic(&db, "U", TypeVarVariance::Invariant);
let bool_type = KnownClass::Bool.to_instance(&db);
let str_type = KnownClass::Str.to_instance(&db);
let t_str = ConstraintSet::range(&db, str_type, t.identity(&db), str_type);
let t_bool = ConstraintSet::range(&db, bool_type, t.identity(&db), bool_type);
let u_str = ConstraintSet::range(&db, str_type, u.identity(&db), str_type);
let u_bool = ConstraintSet::range(&db, bool_type, u.identity(&db), bool_type);
let constraints = (t_str.or(&db, || t_bool)).and(&db, || u_str.or(&db, || u_bool));
let actual = constraints.node.display_graph(&db, &"").to_string();
assert_eq!(actual, expected);
}
}