From 567f01841645eb709f867c461d4345b43ebb8b3e Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Tue, 23 Dec 2025 02:44:26 -0500 Subject: [PATCH] optimize the constraint DNF representation --- .../resources/mdtest/narrow/type_guards.md | 2 +- .../src/semantic_index/use_def.rs | 6 +- crates/ty_python_semantic/src/types/narrow.rs | 198 +++++++++--------- 3 files changed, 99 insertions(+), 107 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index 7d5952c5f3..223b6a13ab 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -217,7 +217,7 @@ def is_b(val: object) -> TypeGuard[B]: def _(x: P): if isinstance(x, A) or is_b(x): - reveal_type(x) # revealed: (P & A) | B + reveal_type(x) # revealed: B | (P & A) ``` Attribute and subscript narrowing is supported: diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index 9288a6df50..c34928184d 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -766,12 +766,12 @@ impl<'db> ConstraintsIterator<'_, 'db> { self.filter_map(|constraint| infer_narrowing_constraint(db, constraint, place)) .reduce(|acc, constraint| { // See above---note the reverse application - constraint.merge_constraint_and(&acc, db) + constraint.merge_constraint_and(acc, db) }) .map_or(base_ty, |constraint| { NarrowingConstraint::regular(base_ty) - .merge_constraint_and(&constraint, db) - .evaluate_type_constraint(db) + .merge_constraint_and(constraint, db) + .evaluate_constraint_type(db) }) } } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 4782db357b..010fbfbef4 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -19,7 +19,7 @@ use crate::types::{ use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_stdlib::identifiers::is_identifier; -use itertools::{Either, Itertools}; +use itertools::Itertools; use ruff_python_ast as ast; use ruff_python_ast::{BoolOp, ExprBoolOp}; use rustc_hash::FxHashMap; @@ -272,56 +272,37 @@ impl ClassInfoConstraintFunction { } } -/// Represents a single conjunction (AND) of constraints in Disjunctive Normal -/// Form (DNF). +/// Represents a TypeGuard-containing disjunct in a Disjunctive Normal Form +/// (DNF) narrowing constraint. /// -/// A conjunction may contain: - A regular constraint (intersection of types) - -/// An optional `TypeGuard` constraint that "replaces" the type rather than -/// intersecting +/// Such a constraint may optionally have a refinement applied after the type +/// guard, which is interpreted as being intersected with the type guard. /// -/// For example, `(Conjunction { constraint: A, typeguard: Some(B) } & -/// Conjunction { constraint: C, typeguard: Some(D)})` evaluates to -/// `Conjunction { constraint: C, typeguard: Some(D) }` because the type guard +/// For example, `(TypeGuardConstraint { typeguard: A, refinement: Some(B) } & +/// TypeGuardConstraint { typeguard: C, refinement: Some(D) })` evaluates to +/// `TypeGuardConstraint { typeguard: C, refinement: Some(D) }` because the type guard /// in the second conjunct clobbers that in the first. -#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy, salsa::Update, get_size2::GetSize)] -struct Conjunction<'db> { - /// The intersected constraints (represented as a type to intersect the guard with) - constraint: Type<'db>, - /// If any constraint in this conjunction is a `TypeGuard[T]`, this is `Some(T)` - typeguard: Option>, +#[derive(Hash, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] +struct TypeGuardConstraint<'db> { + /// If `TypeGuard[T]`, this is `Some(T)` + typeguard: Type<'db>, + /// If additional constraints are applied _after_ the TypeGuard, then they + /// go here + refinement: Option>, } -impl<'db> Conjunction<'db> { - /// Create a new conjunction with just a regular constraint - fn regular(constraint: Type<'db>) -> Self { - Self { - constraint, - typeguard: None, +impl<'db> TypeGuardConstraint<'db> { + /// Evaluate this typeguard constraint to a single type. + /// If there's a refinement, it's intersected with the typeguard constraint. + fn evaluate_constraint_type(self, db: &'db dyn Db) -> Type<'db> { + match self.refinement { + Some(refinement) => IntersectionBuilder::new(db) + .add_positive(self.typeguard) + .add_positive(refinement) + .build(), + None => self.typeguard, } } - - /// Create a new conjunction with a `TypeGuard` constraint - fn typeguard(constraint: Type<'db>) -> Self { - Self { - constraint: Type::object(), - typeguard: Some(constraint), - } - } - - /// Evaluate this conjunction to a single type. - /// If there's a `TypeGuard` constraint, it replaces the regular constraint. - /// Otherwise, returns the regular constraint. - fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> { - self.typeguard.map_or_else( - || self.constraint, - |typeguard_constraint| { - IntersectionBuilder::new(db) - .add_positive(typeguard_constraint) - .add_positive(self.constraint) - .build() - }, - ) - } } /// Represents narrowing constraints in Disjunctive Normal Form (DNF). @@ -340,68 +321,92 @@ impl<'db> Conjunction<'db> { /// => evaluates to `(P & A) | B`, where `P` is our previously-known type #[derive(Hash, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] pub(crate) struct NarrowingConstraint<'db> { + /// Regular constraint---we don't need a list here because we can represent + /// with a union type + regular_disjunct: Option>, /// Disjunction of conjunctions (DNF) - disjuncts: SmallVec<[Conjunction<'db>; 1]>, + typeguard_disjuncts: SmallVec<[TypeGuardConstraint<'db>; 1]>, } impl<'db> NarrowingConstraint<'db> { /// Create a constraint from a regular (non-`TypeGuard`) type pub(crate) fn regular(constraint: Type<'db>) -> Self { Self { - disjuncts: smallvec![Conjunction::regular(constraint)], + regular_disjunct: Some(constraint), + typeguard_disjuncts: smallvec![], } } /// Create a constraint from a `TypeGuard` type fn typeguard(constraint: Type<'db>) -> Self { Self { - disjuncts: smallvec![Conjunction::typeguard(constraint)], + regular_disjunct: None, + typeguard_disjuncts: smallvec![TypeGuardConstraint { + typeguard: constraint, + refinement: None, + }], } } /// Merge two constraints, taking their intersection but respecting `TypeGuard` semantics (with `other` winning) - pub(crate) fn merge_constraint_and(&self, other: &Self, db: &'db dyn Db) -> Self { + pub(crate) fn merge_constraint_and(&self, other: Self, db: &'db dyn Db) -> Self { // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... - let new_disjuncts = other - .disjuncts - .iter() - .flat_map(|right_conj| { - // We iterate the RHS first because if it has a typeguard then we don't need to consider the LHS - if right_conj.typeguard.is_some() { - // If the right conjunct has a TypeGuard, it "wins" the conjunction - Either::Left(std::iter::once(*right_conj)) - } else { - // Otherwise, we need to consider all LHS disjuncts - Either::Right(self.disjuncts.iter().map(|left_conj| { - let new_regular = IntersectionBuilder::new(db) - .add_positive(left_conj.constraint) - .add_positive(right_conj.constraint) - .build(); + // + // In our representation, the RHS `typeguard_disjuncts` will all clobber + // the LHS disjuncts when they are anded, so they'll just stay as is. + // + // The thing we actually need to deal with is the RHS `regular_disjunct`. + // It gets anded onto the LHS `regular_disjunct` to form the new + // `regular_disjunct`, and anded onto each LHS `typeguard_disjunct` (via + // the refinement) to form new additional `typeguard_disjuncts`. + let Some(other_regular_disjunct) = other.regular_disjunct else { + return other; + }; - Conjunction { - constraint: new_regular, - typeguard: left_conj.typeguard, - } - })) - } - }) - .collect(); + let new_regular_disjunct = self.regular_disjunct.map(|regular_disjunct| { + IntersectionBuilder::new(db) + .add_positive(regular_disjunct) + .add_positive(other_regular_disjunct) + .build() + }); + + let additional_typeguard_disjuncts = + self.typeguard_disjuncts + .iter() + .map(|typeguard_disjunct| TypeGuardConstraint { + typeguard: typeguard_disjunct.typeguard, + refinement: match typeguard_disjunct.refinement { + Some(refinement) => Some( + IntersectionBuilder::new(db) + .add_positive(refinement) + .add_positive(other_regular_disjunct) + .build(), + ), + None => other.regular_disjunct, + }, + }); + + let mut new_typeguard_disjuncts = other.typeguard_disjuncts; + + new_typeguard_disjuncts.extend(additional_typeguard_disjuncts); NarrowingConstraint { - disjuncts: new_disjuncts, + typeguard_disjuncts: new_typeguard_disjuncts, + regular_disjunct: new_regular_disjunct, } } /// Evaluate the type this effectively constrains to /// /// Forgets whether each constraint originated from a `TypeGuard` or not - pub(crate) fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> { + pub(crate) fn evaluate_constraint_type(self, db: &'db dyn Db) -> Type<'db> { UnionType::from_elements( db, - self.disjuncts + self.typeguard_disjuncts .into_iter() - .map(|disjunct| Conjunction::evaluate_type_constraint(disjunct, db)), + .map(|disjunct| disjunct.evaluate_constraint_type(db)) + .chain(self.regular_disjunct), ) } } @@ -433,7 +438,7 @@ fn merge_constraints_and<'db>( Entry::Occupied(mut entry) => { let into_constraint = entry.get(); - entry.insert(into_constraint.merge_constraint_and(&from_constraint, db)); + entry.insert(into_constraint.merge_constraint_and(from_constraint, db)); } Entry::Vacant(entry) => { entry.insert(from_constraint); @@ -449,9 +454,6 @@ fn merge_constraints_and<'db>( /// /// However, if a place appears in only one branch of the OR, we need to widen it /// to `object` in the overall result (because the other branch doesn't constrain it). -/// -/// When none of the disjuncts have `TypeGuard`, we simplify the constraint types -/// via `UnionBuilder` to enable simplifications like `~AlwaysFalsy | ~AlwaysTruthy -> object`. fn merge_constraints_or<'db>( into: &mut NarrowingConstraints<'db>, from: NarrowingConstraints<'db>, @@ -464,31 +466,21 @@ fn merge_constraints_or<'db>( match into.entry(key) { Entry::Occupied(mut entry) => { let into_constraint = entry.get_mut(); - // Concatenate disjuncts - into_constraint.disjuncts.extend(from_constraint.disjuncts); + // Union the regular constraints + into_constraint.regular_disjunct = match ( + into_constraint.regular_disjunct, + from_constraint.regular_disjunct, + ) { + (Some(a), Some(b)) => Some(UnionType::from_elements(db, [a, b])), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + }; - // If none of the disjuncts have TypeGuard, we can simplify the constraint types - // via UnionBuilder. This enables simplifications like: - // `~AlwaysFalsy | ~AlwaysTruthy -> object` - let all_regular = into_constraint - .disjuncts - .iter() - .all(|conj| conj.typeguard.is_none()); - - if all_regular { - // Simplify via UnionBuilder - let simplified = UnionType::from_elements( - db, - into_constraint.disjuncts.iter().map(|conj| conj.constraint), - ); - if simplified.is_object() { - // If simplified to object, we can drop the constraint entirely - entry.remove(); - } else { - // Replace with simplified constraint - into_constraint.disjuncts = smallvec![Conjunction::regular(simplified)]; - } - } + // Concatenate typeguard disjuncts + into_constraint + .typeguard_disjuncts + .extend(from_constraint.typeguard_disjuncts); } Entry::Vacant(_) => { // Place only appears in `from`, not in `into`. No constraint needed.