From 766ed5b5f394b3dae19dba4859aac797b48fbd00 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 22 Oct 2025 13:38:44 -0400 Subject: [PATCH] [ty] Some more simplifications when rendering constraint sets (#21009) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds another useful simplification when rendering constraint sets: `T = int` instead of `T = int ∧ T ≠ str`. (The "smaller" constraint `T = int` implies the "larger" constraint `T ≠ str`. Constraint set clauses are intersections, and if one constraint in a clause implies another, we can throw away the "larger" constraint.) While we're here, we also normalize the bounds of a constraint, so that we equate e.g. `T ≤ int | str` with `T ≤ str | int`, and change the ordering of BDD variables so that all constraints with the same typevar are ordered adjacent to each other. Lastly, we also add a new `display_graph` helper method that prints out the full graph structure of a BDD. --------- Co-authored-by: Alex Waygood --- Cargo.lock | 2 + crates/ty_python_semantic/Cargo.toml | 2 + .../mdtest/type_properties/constraints.md | 37 +++ .../src/semantic_index/definition.rs | 5 + crates/ty_python_semantic/src/types.rs | 8 +- .../src/types/constraints.rs | 283 ++++++++++++++++-- 6 files changed, 314 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3d917171a4..1bc41b0c9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4434,10 +4434,12 @@ dependencies = [ "glob", "hashbrown 0.16.0", "indexmap", + "indoc", "insta", "itertools 0.14.0", "memchr", "ordermap", + "pretty_assertions", "quickcheck", "quickcheck_macros", "ruff_annotate_snippets", diff --git a/crates/ty_python_semantic/Cargo.toml b/crates/ty_python_semantic/Cargo.toml index edeafa821f..faf5c37881 100644 --- a/crates/ty_python_semantic/Cargo.toml +++ b/crates/ty_python_semantic/Cargo.toml @@ -63,7 +63,9 @@ ty_vendored = { workspace = true } anyhow = { workspace = true } dir-test = { workspace = true } glob = { workspace = true } +indoc = { workspace = true } insta = { workspace = true } +pretty_assertions = { workspace = true } tempfile = { workspace = true } quickcheck = { version = "1.0.3", default-features = false } quickcheck_macros = { version = "1.0.0" } diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md b/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md index 5791ba00de..607501a349 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/constraints.md @@ -601,3 +601,40 @@ def _[T, U]() -> None: # revealed: ty_extensions.ConstraintSet[always] reveal_type(~union | union) ``` + +## Other simplifications + +When displaying a constraint set, we transform the internal BDD representation into a DNF formula +(i.e., the logical OR of several clauses, each of which is the logical AND of several constraints). +This section contains several examples that show that we simplify the DNF formula as much as we can +before displaying it. + +```py +from ty_extensions import range_constraint + +def f[T, U](): + t1 = range_constraint(str, T, str) + t2 = range_constraint(bool, T, bool) + u1 = range_constraint(str, U, str) + u2 = range_constraint(bool, U, bool) + + # revealed: ty_extensions.ConstraintSet[(T@f = bool) ∨ (T@f = str)] + reveal_type(t1 | t2) + # revealed: ty_extensions.ConstraintSet[(U@f = bool) ∨ (U@f = str)] + reveal_type(u1 | u2) + # revealed: ty_extensions.ConstraintSet[((T@f = bool) ∧ (U@f = bool)) ∨ ((T@f = bool) ∧ (U@f = str)) ∨ ((T@f = str) ∧ (U@f = bool)) ∨ ((T@f = str) ∧ (U@f = str))] + reveal_type((t1 | t2) & (u1 | u2)) +``` + +The lower and upper bounds of a constraint are normalized, so that we equate unions and +intersections whose elements appear in different orders. + +```py +from typing import Never + +def f[T](): + # revealed: ty_extensions.ConstraintSet[(T@f ≤ int | str)] + reveal_type(range_constraint(Never, T, str | int)) + # revealed: ty_extensions.ConstraintSet[(T@f ≤ int | str)] + reveal_type(range_constraint(Never, T, int | str)) +``` diff --git a/crates/ty_python_semantic/src/semantic_index/definition.rs b/crates/ty_python_semantic/src/semantic_index/definition.rs index 368994fd34..81af22d314 100644 --- a/crates/ty_python_semantic/src/semantic_index/definition.rs +++ b/crates/ty_python_semantic/src/semantic_index/definition.rs @@ -22,7 +22,12 @@ use crate::unpack::{Unpack, UnpackPosition}; /// because a new scope gets inserted before the `Definition` or a new place is inserted /// before this `Definition`. However, the ID can be considered stable and it is okay to use /// `Definition` in cross-module` salsa queries or as a field on other salsa tracked structs. +/// +/// # Ordering +/// Ordering is based on the definition's salsa-assigned id and not on its values. +/// The id may change between runs, or when the definition was garbage collected and recreated. #[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)] +#[derive(Ord, PartialOrd)] pub struct Definition<'db> { /// The file in which the definition occurs. pub file: File, diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 465df24583..93c306ba8d 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -8461,7 +8461,9 @@ fn lazy_bound_or_constraints_cycle_initial<'db>( } /// Where a type variable is bound and usable. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] +#[derive( + Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Update, get_size2::GetSize, +)] pub enum BindingContext<'db> { /// The definition of the generic class, function, or type alias that binds this typevar. Definition(Definition<'db>), @@ -8495,7 +8497,9 @@ impl<'db> BindingContext<'db> { /// independent of the typevar's bounds or constraints. Two bound typevars have the same identity /// if they represent the same logical typevar bound in the same context, even if their bounds /// have been materialized differently. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, get_size2::GetSize, salsa::Update)] +#[derive( + Debug, Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd, get_size2::GetSize, salsa::Update, +)] pub struct BoundTypeVarIdentity<'db> { pub(crate) identity: TypeVarIdentity<'db>, pub(crate) binding_context: BindingContext<'db>, diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 3d2b23c09f..832aa71ab2 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -23,6 +23,9 @@ //! Note that all lower and upper bounds in a constraint must be fully static. We take the bottom //! and top materializations of the types to remove any gradual forms if needed. //! +//! Lower and upper bounds must also be normalized. This lets us identify, for instance, +//! two constraints with equivalent but differently ordered unions as their bounds. +//! //! NOTE: This module is currently in a transitional state. We've added the BDD [`ConstraintSet`] //! representation, and updated all of our property checks to build up a constraint set and then //! check whether it is ever or always satisfiable, as appropriate. We are not yet inferring @@ -58,6 +61,7 @@ use std::fmt::Display; use itertools::Itertools; use rustc_hash::FxHashSet; +use salsa::plumbing::AsId; use crate::Db; use crate::types::{BoundTypeVarIdentity, IntersectionType, Type, UnionType}; @@ -183,20 +187,20 @@ impl<'db> ConstraintSet<'db> { /// Updates this constraint set to hold the union of itself and another constraint set. pub(crate) fn union(&mut self, db: &'db dyn Db, other: Self) -> Self { - self.node = self.node.or(db, other.node).simplify(db); + self.node = self.node.or(db, other.node); *self } /// Updates this constraint set to hold the intersection of itself and another constraint set. pub(crate) fn intersect(&mut self, db: &'db dyn Db, other: Self) -> Self { - self.node = self.node.and(db, other.node).simplify(db); + self.node = self.node.and(db, other.node); *self } /// Returns the negation of this constraint set. pub(crate) fn negate(self, db: &'db dyn Db) -> Self { Self { - node: self.node.negate(db).simplify(db), + node: self.node.negate(db), } } @@ -256,7 +260,6 @@ impl From for ConstraintSet<'_> { /// An individual constraint in a constraint set. This restricts a single typevar to be within a /// lower and upper bound. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -#[derive(PartialOrd, Ord)] pub(crate) struct ConstrainedTypeVar<'db> { typevar: BoundTypeVarIdentity<'db>, lower: Type<'db>, @@ -292,8 +295,11 @@ impl<'db> ConstrainedTypeVar<'db> { return Node::AlwaysTrue; } + let lower = lower.normalized(db); + let upper = upper.normalized(db); Node::new_constraint(db, ConstrainedTypeVar::new(db, typevar, lower, upper)) } + fn when_true(self) -> ConstraintAssignment<'db> { ConstraintAssignment::Positive(self) } @@ -310,11 +316,37 @@ impl<'db> ConstrainedTypeVar<'db> { && other.upper(db).is_subtype_of(db, self.upper(db)) } + /// Defines the ordering of the variables in a constraint set BDD. + /// + /// If we only care about _correctness_, we can choose any ordering that we want, as long as + /// it's consistent. However, different orderings can have very different _performance_ + /// characteristics. Many BDD libraries attempt to reorder variables on the fly while building + /// and working with BDDs. We don't do that, but we have tried to make some simple choices that + /// have clear wins. + /// + /// In particular, we compare the _typevars_ of each constraint first, so that all constraints + /// for a single typevar are guaranteed to be adjacent in the BDD structure. There are several + /// simplifications that we perform that operate on constraints with the same typevar, and this + /// ensures that we can find all candidate simplifications more easily. + fn ordering(self, db: &'db dyn Db) -> impl Ord { + (self.typevar(db), self.as_id()) + } + + /// Returns whether this constraint implies another — i.e., whether every type that + /// satisfies this constraint also satisfies `other`. + /// + /// This is used (among other places) to simplify how we display constraint sets, by removing + /// redundant constraints from a clause. + fn implies(self, db: &'db dyn Db, other: Self) -> bool { + other.contains(db, self) + } + /// Returns the intersection of two range constraints, or `None` if the intersection is empty. fn intersect(self, db: &'db dyn Db, other: Self) -> Option { // (s₁ ≤ α ≤ t₁) ∧ (s₂ ≤ α ≤ t₂) = (s₁ ∪ s₂) ≤ α ≤ (t₁ ∩ t₂)) - let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]); - let upper = IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]); + let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]).normalized(db); + let upper = + IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]).normalized(db); // If `lower ≰ upper`, then the intersection is empty, since there is no type that is both // greater than `lower`, and less than `upper`. @@ -390,8 +422,8 @@ impl<'db> ConstrainedTypeVar<'db> { /// that point at the same node. /// /// BDD nodes are also _ordered_, meaning that every path from the root of a BDD to a terminal node -/// visits variables in the same order. [`ConstrainedTypeVar`]s are interned, so we can use the IDs -/// that salsa assigns to define this order. +/// visits variables in the same order. [`ConstrainedTypeVar::ordering`] defines the variable +/// ordering that we use for constraint set BDDs. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, get_size2::GetSize, salsa::Update)] enum Node<'db> { AlwaysFalse, @@ -407,13 +439,13 @@ impl<'db> Node<'db> { if_true: Node<'db>, if_false: Node<'db>, ) -> Self { + debug_assert!((if_true.root_constraint(db)).is_none_or(|root_constraint| { + root_constraint.ordering(db) > constraint.ordering(db) + })); debug_assert!( - (if_true.root_constraint(db)) - .is_none_or(|root_constraint| root_constraint > constraint) - ); - debug_assert!( - (if_false.root_constraint(db)) - .is_none_or(|root_constraint| root_constraint > constraint) + (if_false.root_constraint(db)).is_none_or(|root_constraint| { + root_constraint.ordering(db) > constraint.ordering(db) + }) ); if if_true == if_false { return if_true; @@ -762,14 +794,87 @@ impl<'db> Node<'db> { Node::AlwaysFalse => f.write_str("never"), Node::Interior(_) => { let mut clauses = self.node.satisfied_clauses(self.db); - clauses.simplify(); + clauses.simplify(self.db); clauses.display(self.db).fmt(f) } } } } - DisplayNode { node: self, db } + DisplayNode { + node: self.simplify(db), + db, + } + } + + /// Displays the full graph structure of this BDD. `prefix` will be output before each line + /// other than the first. Produces output like the following: + /// + /// ```text + /// (T@_ = str) + /// ┡━₁ (U@_ = str) + /// │ ┡━₁ always + /// │ └─₀ (U@_ = bool) + /// │ ┡━₁ always + /// │ └─₀ never + /// └─₀ (T@_ = bool) + /// ┡━₁ (U@_ = str) + /// │ ┡━₁ always + /// │ └─₀ (U@_ = bool) + /// │ ┡━₁ always + /// │ └─₀ never + /// └─₀ never + /// ``` + #[cfg_attr(not(test), expect(dead_code))] // Keep this around for debugging purposes + fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display { + struct DisplayNode<'a, 'db> { + db: &'db dyn Db, + node: Node<'db>, + prefix: &'a dyn Display, + } + + impl<'a, 'db> DisplayNode<'a, 'db> { + fn new(db: &'db dyn Db, node: Node<'db>, prefix: &'a dyn Display) -> Self { + Self { db, node, prefix } + } + } + + impl Display for DisplayNode<'_, '_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.node { + Node::AlwaysTrue => write!(f, "always"), + Node::AlwaysFalse => write!(f, "never"), + Node::Interior(interior) => { + interior.constraint(self.db).display(self.db).fmt(f)?; + // Calling display_graph recursively here causes rustc to claim that the + // expect(unused) up above is unfulfilled! + write!( + f, + "\n{}┡━₁ {}", + self.prefix, + DisplayNode::new( + self.db, + interior.if_true(self.db), + &format_args!("{}│ ", self.prefix) + ), + )?; + write!( + f, + "\n{}└─₀ {}", + self.prefix, + DisplayNode::new( + self.db, + interior.if_false(self.db), + &format_args!("{} ", self.prefix) + ), + )?; + Ok(()) + } + } + } + } + + DisplayNode::new(db, self, prefix) } } @@ -800,7 +905,7 @@ impl<'db> InteriorNode<'db> { fn or(self, db: &'db dyn Db, other: Self) -> Node<'db> { let self_constraint = self.constraint(db); let other_constraint = other.constraint(db); - match self_constraint.cmp(&other_constraint) { + match (self_constraint.ordering(db)).cmp(&other_constraint.ordering(db)) { Ordering::Equal => Node::new( db, self_constraint, @@ -826,7 +931,7 @@ impl<'db> InteriorNode<'db> { fn and(self, db: &'db dyn Db, other: Self) -> Node<'db> { let self_constraint = self.constraint(db); let other_constraint = other.constraint(db); - match self_constraint.cmp(&other_constraint) { + match (self_constraint.ordering(db)).cmp(&other_constraint.ordering(db)) { Ordering::Equal => Node::new( db, self_constraint, @@ -852,7 +957,7 @@ impl<'db> InteriorNode<'db> { fn iff(self, db: &'db dyn Db, other: Self) -> Node<'db> { let self_constraint = self.constraint(db); let other_constraint = other.constraint(db); - match self_constraint.cmp(&other_constraint) { + match (self_constraint.ordering(db)).cmp(&other_constraint.ordering(db)) { Ordering::Equal => Node::new( db, self_constraint, @@ -884,7 +989,7 @@ impl<'db> InteriorNode<'db> { // point in the BDD where the assignment can no longer affect the result, // and we can return early. let self_constraint = self.constraint(db); - if assignment.constraint() < self_constraint { + if assignment.constraint().ordering(db) < self_constraint.ordering(db) { return (Node::Interior(self), false); } @@ -1141,6 +1246,55 @@ impl<'db> ConstraintAssignment<'db> { *self = self.negated(); } + /// Returns whether this constraint implies another — i.e., whether every type that + /// satisfies this constraint also satisfies `other`. + /// + /// This is used (among other places) to simplify how we display constraint sets, by removing + /// redundant constraints from a clause. + fn implies(self, db: &'db dyn Db, other: Self) -> bool { + match (self, other) { + // For two positive constraints, one range has to fully contain the other; the smaller + // constraint implies the larger. + // + // ....|----other-----|.... + // ......|---self---|...... + ( + ConstraintAssignment::Positive(self_constraint), + ConstraintAssignment::Positive(other_constraint), + ) => self_constraint.implies(db, other_constraint), + + // For two negative constraints, one range has to fully contain the other; the ranges + // represent "holes", though, so the constraint with the larger range implies the one + // with the smaller. + // + // |-----|...other...|-----| + // |---|.....self......|---| + ( + ConstraintAssignment::Negative(self_constraint), + ConstraintAssignment::Negative(other_constraint), + ) => other_constraint.implies(db, self_constraint), + + // For a positive and negative constraint, the ranges have to be disjoint, and the + // positive range implies the negative range. + // + // |---------------|...self...|---| + // ..|---other---|................| + ( + ConstraintAssignment::Positive(self_constraint), + ConstraintAssignment::Negative(other_constraint), + ) => self_constraint.intersect(db, other_constraint).is_none(), + + // It's theoretically possible for a negative constraint to imply a positive constraint + // if the positive constraint is always satisfied (`Never ≤ T ≤ object`). But we never + // create constraints of that form, so with our representation, a negative constraint + // can never imply a positive constraint. + // + // |------other-------| + // |---|...self...|---| + (ConstraintAssignment::Negative(_), ConstraintAssignment::Positive(_)) => false, + } + } + // Keep this for future debugging needs, even though it's not currently used when rendering // constraint sets. #[expect(dead_code)] @@ -1209,6 +1363,43 @@ impl<'db> SatisfiedClause<'db> { false } + /// Simplifies this clause by removing constraints that are implied by other constraints in the + /// clause. (Clauses are the intersection of constraints, so if two clauses are redundant, we + /// want to remove the larger one and keep the smaller one.) + /// + /// Returns a boolean that indicates whether any simplifications were made. + fn simplify(&mut self, db: &'db dyn Db) -> bool { + let mut changes_made = false; + let mut i = 0; + // Loop through each constraint, comparing it with any constraints that appear later in the + // list. + 'outer: while i < self.constraints.len() { + let mut j = i + 1; + while j < self.constraints.len() { + if self.constraints[j].implies(db, self.constraints[i]) { + // If constraint `i` is removed, then we don't need to compare it with any + // later constraints in the list. Note that we continue the outer loop, instead + // of breaking from the inner loop, so that we don't bump index `i` below. + // (We'll have swapped another element into place at that index, and want to + // make sure that we process it.) + self.constraints.swap_remove(i); + changes_made = true; + continue 'outer; + } else if self.constraints[i].implies(db, self.constraints[j]) { + // If constraint `j` is removed, then we can continue the inner loop. We will + // swap a new element into place at index `j`, and will continue comparing the + // constraint at index `i` with later constraints. + self.constraints.swap_remove(j); + changes_made = true; + } else { + j += 1; + } + } + i += 1; + } + changes_made + } + fn display(&self, db: &'db dyn Db) -> String { // This is a bit heavy-handed, but we need to output the constraints in a consistent order // even though Salsa IDs are assigned non-deterministically. This Display output is only @@ -1258,7 +1449,13 @@ impl<'db> SatisfiedClauses<'db> { /// Simplifies the DNF representation, removing redundancies that do not change the underlying /// function. (This is used when displaying a BDD, to make sure that the representation that we /// show is as simple as possible while still producing the same results.) - fn simplify(&mut self) { + fn simplify(&mut self, db: &'db dyn Db) { + // First simplify each clause individually, by removing constraints that are implied by + // other constraints in the clause. + for clause in &mut self.clauses { + clause.simplify(db); + } + while self.simplify_one_round() { // Keep going } @@ -1340,3 +1537,47 @@ impl<'db> SatisfiedClauses<'db> { clauses.join(" ∨ ") } } + +#[cfg(test)] +mod tests { + use super::*; + + use indoc::indoc; + use pretty_assertions::assert_eq; + + use crate::db::tests::setup_db; + use crate::types::{BoundTypeVarInstance, KnownClass, TypeVarVariance}; + + #[test] + fn test_display_graph_output() { + let expected = indoc! {r#" + (T = str) + ┡━₁ (U = str) + │ ┡━₁ always + │ └─₀ (U = bool) + │ ┡━₁ always + │ └─₀ never + └─₀ (T = bool) + ┡━₁ (U = str) + │ ┡━₁ always + │ └─₀ (U = bool) + │ ┡━₁ always + │ └─₀ never + └─₀ never + "#} + .trim_end(); + + let db = setup_db(); + let t = BoundTypeVarInstance::synthetic(&db, "T", TypeVarVariance::Invariant); + let u = BoundTypeVarInstance::synthetic(&db, "U", TypeVarVariance::Invariant); + let bool_type = KnownClass::Bool.to_instance(&db); + let str_type = KnownClass::Str.to_instance(&db); + let t_str = ConstraintSet::range(&db, str_type, t.identity(&db), str_type); + let t_bool = ConstraintSet::range(&db, bool_type, t.identity(&db), bool_type); + let u_str = ConstraintSet::range(&db, str_type, u.identity(&db), str_type); + let u_bool = ConstraintSet::range(&db, bool_type, u.identity(&db), bool_type); + let constraints = (t_str.or(&db, || t_bool)).and(&db, || u_str.or(&db, || u_bool)); + let actual = constraints.node.display_graph(&db, &"").to_string(); + assert_eq!(actual, expected); + } +}