From ca5e6070be33887174ce6b46350999d1d03385c8 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sun, 19 Oct 2025 16:56:55 -0400 Subject: [PATCH] feat: working `TypeGuard` --- .../annotations/unsupported_special_forms.md | 4 +- .../type_properties/is_assignable_to.md | 3 +- .../type_properties/is_disjoint_from.md | 3 +- crates/ty_python_semantic/src/types/narrow.rs | 262 +++++++++--------- 4 files changed, 143 insertions(+), 129 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md b/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md index 18ebd03682..6c0947c5e0 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md @@ -16,7 +16,9 @@ def f(*args: Unpack[Ts]) -> tuple[Unpack[Ts]]: reveal_type(args) # revealed: tuple[@Todo(`Unpack[]` special form), ...] return args -def g() -> TypeGuard[int]: ... +def g() -> TypeGuard[int]: + return True + def i(callback: Callable[Concatenate[int, P], R_co], *args: P.args, **kwargs: P.kwargs) -> R_co: reveal_type(args) # revealed: P@i.args reveal_type(kwargs) # revealed: P@i.kwargs diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index f2e38485c5..cfb7b5d6e0 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -1383,8 +1383,7 @@ from typing_extensions import Any, TypeGuard, TypeIs static_assert(is_assignable_to(TypeGuard[Unknown], bool)) static_assert(is_assignable_to(TypeIs[Any], bool)) -# TODO no error -static_assert(not is_assignable_to(TypeGuard[Unknown], str)) # error: [static-assert-error] +static_assert(not is_assignable_to(TypeGuard[Unknown], str)) static_assert(not is_assignable_to(TypeIs[Any], str)) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md index d4aa7db231..0b2d95842f 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md @@ -578,8 +578,7 @@ from typing_extensions import TypeGuard, TypeIs static_assert(not is_disjoint_from(bool, TypeGuard[str])) static_assert(not is_disjoint_from(bool, TypeIs[str])) -# TODO no error -static_assert(is_disjoint_from(str, TypeGuard[str])) # error: [static-assert-error] +static_assert(is_disjoint_from(str, TypeGuard[str])) static_assert(is_disjoint_from(str, TypeIs[str])) ``` diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 4beab1e583..b4eb3789ff 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -23,6 +23,7 @@ use itertools::Itertools; use ruff_python_ast as ast; use ruff_python_ast::{BoolOp, ExprBoolOp}; use rustc_hash::FxHashMap; +use smallvec::{SmallVec, smallvec}; use std::collections::hash_map::Entry; use super::UnionType; @@ -274,19 +275,23 @@ impl ClassInfoConstraintFunction { } } -/// Represents a single conjunction (AND) of constraints in Disjunctive Normal Form (DNF). +/// Represents a single conjunction (AND) of constraints in Disjunctive Normal +/// Form (DNF). /// -/// A conjunction may contain: -/// - A regular constraint (intersection of types) -/// - An optional TypeGuard constraint that "replaces" the type rather than intersecting +/// A conjunction may contain: - A regular constraint (intersection of types) - +/// An optional `TypeGuard` constraint that "replaces" the type rather than +/// intersecting /// -/// For example, `(Regular(A) & TypeGuard(B))` evaluates to just `B` because TypeGuard clobbers. -#[derive(Hash, PartialEq, Debug, Eq, Clone)] +/// For example, `(Conjunction { constraint: A, typeguard: Some(B) } & +/// Conjunction { constraint: C, typeguard: Some(D)})` evlaluates to +/// `Conjunction { constraint: C, typeguard: Some(D) }` because the type guard +/// in the second clobbers the first. +#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)] struct Conjunction<'db> { /// The intersected constraints (represented as an intersection type) constraint: Type<'db>, - /// If any constraint in this conjunction is a TypeGuard, this is Some - /// and contains the union of all TypeGuard types in this conjunction + /// If any constraint in this conjunction is a `TypeGuard`, this is Some and + /// contains the union of all `TypeGuard` types in this conjunction typeguard: Option>, } @@ -301,7 +306,7 @@ impl<'db> Conjunction<'db> { } } - /// Create a new conjunction with a TypeGuard constraint + /// Create a new conjunction with a `TypeGuard` constraint fn typeguard(constraint: Type<'db>) -> Self { Self { constraint: Type::object(), @@ -310,9 +315,9 @@ impl<'db> Conjunction<'db> { } /// Evaluate this conjunction to a single type. - /// If there's a TypeGuard constraint, it replaces the regular constraint. + /// If there's a `TypeGuard` constraint, it replaces the regular constraint. /// Otherwise, returns the regular constraint. - fn to_type(self) -> Type<'db> { + fn evaluate_type_constraint(self) -> Type<'db> { self.typeguard.unwrap_or(self.constraint) } } @@ -320,52 +325,50 @@ impl<'db> Conjunction<'db> { /// Represents narrowing constraints in Disjunctive Normal Form (DNF). /// /// This is a disjunction (OR) of conjunctions (AND) of constraints. -/// The DNF representation allows us to properly track TypeGuard constraints +/// The DNF representation allows us to properly track `TypeGuard` constraints /// through boolean operations. /// /// For example: -/// - `f(x) and g(x)` where f returns TypeIs[A] and g returns TypeGuard[B] +/// - `f(x) and g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]` /// => `[Conjunction { constraint: A, typeguard: Some(B) }]` -/// => evaluates to `B` (TypeGuard clobbers) +/// => evaluates to `B` (`TypeGuard` clobbers) /// -/// - `f(x) or g(x)` where f returns TypeIs[A] and g returns TypeGuard[B] +/// - `f(x) or g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]` /// => `[Conjunction { constraint: A, typeguard: None }, Conjunction { constraint: object, typeguard: Some(B) }]` /// => evaluates to `A | B` #[derive(Hash, PartialEq, Debug, Eq, Clone)] struct NarrowingConstraint<'db> { /// Disjunctions of conjunctions (DNF) - disjuncts: Vec>, + disjuncts: SmallVec<[Conjunction<'db>; 4]>, } impl get_size2::GetSize for NarrowingConstraint<'_> {} impl<'db> NarrowingConstraint<'db> { - /// Create a constraint from a regular (non-TypeGuard) type + /// Create a constraint from a regular (non-`TypeGuard`) type fn regular(constraint: Type<'db>) -> Self { Self { - disjuncts: vec![Conjunction::regular(constraint)], + disjuncts: smallvec![Conjunction::regular(constraint)], } } - /// Create a constraint from a TypeGuard type + /// Create a constraint from a `TypeGuard` type fn typeguard(constraint: Type<'db>) -> Self { Self { - disjuncts: vec![Conjunction::typeguard(constraint)], + disjuncts: smallvec![Conjunction::typeguard(constraint)], } } - /// Evaluate this constraint to a single type by evaluating each disjunct - /// and taking their union - fn to_type(self, db: &'db dyn Db) -> Type<'db> { - if self.disjuncts.is_empty() { - return Type::Never; - } - - if self.disjuncts.len() == 1 { - return self.disjuncts.into_iter().next().unwrap().to_type(); - } - - UnionType::from_elements(db, self.disjuncts.into_iter().map(|c| c.to_type())) + /// Evaluate the type this effectively constrains to + /// + /// Forgets whether each constraint originated from a `TypeGuard` or not + fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> { + UnionType::from_elements( + db, + self.disjuncts + .into_iter() + .map(|c| c.evaluate_type_constraint()), + ) } } @@ -375,80 +378,84 @@ impl<'db> From> for NarrowingConstraint<'db> { } } -/// Internal representation of constraints with DNF structure for tracking TypeGuard semantics -type InternalConstraints<'db> = FxHashMap>; - -/// Public representation of constraints as returned by tracked functions -type NarrowingConstraints<'db> = FxHashMap>; - -/// Helper trait to make inserting constraints more ergonomic -trait InternalConstraintsExt<'db> { - fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>); - fn insert_typeguard(&mut self, place: ScopedPlaceId, ty: Type<'db>); - fn to_public(self, db: &'db dyn Db) -> NarrowingConstraints<'db>; +/// Internal representation of constraints with DNF structure for tracking `TypeGuard` semantics. +/// +/// This is a newtype wrapper around `FxHashMap>` that +/// provides methods for working with constraints during boolean operation evaluation. +#[derive(Clone, Debug, Default)] +struct InternalConstraints<'db> { + constraints: FxHashMap>, } -impl<'db> InternalConstraintsExt<'db> for InternalConstraints<'db> { +impl<'db> InternalConstraints<'db> { + /// Insert a regular (non-`TypeGuard`) constraint for a place fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>) { - self.insert(place, NarrowingConstraint::regular(ty)); + self.constraints + .insert(place, NarrowingConstraint::regular(ty)); } - fn insert_typeguard(&mut self, place: ScopedPlaceId, ty: Type<'db>) { - self.insert(place, NarrowingConstraint::typeguard(ty)); - } - - fn to_public(self, db: &'db dyn Db) -> NarrowingConstraints<'db> { - self.into_iter() - .map(|(place, constraint)| (place, constraint.to_type(db))) + /// Convert internal constraints to public constraints by evaluating each DNF constraint to a Type + fn evaluate_type_constraints(self, db: &'db dyn Db) -> NarrowingConstraints<'db> { + self.constraints + .into_iter() + .map(|(place, constraint)| (place, constraint.evaluate_type_constraint(db))) .collect() } } +impl<'db> FromIterator<(ScopedPlaceId, NarrowingConstraint<'db>)> for InternalConstraints<'db> { + fn from_iter)>>( + iter: T, + ) -> Self { + Self { + constraints: FxHashMap::from_iter(iter), + } + } +} + +/// Public representation of constraints as returned by tracked functions +type NarrowingConstraints<'db> = FxHashMap>; + /// Merge constraints with AND semantics (intersection/conjunction). /// -/// When we have `constraint1 AND constraint2`, we need to distribute AND over the OR +/// When we have `constraint1 & constraint2`, we need to distribute AND over the OR /// in the DNF representations: -/// `(A | B) AND (C | D)` becomes `(A & C) | (A & D) | (B & C) | (B & D)` +/// `(A | B) & (C | D)` becomes `(A & C) | (A & D) | (B & C) | (B & D)` /// /// For each conjunction pair, we: -/// - Intersect the regular constraints -/// - If either has a TypeGuard, the result gets a TypeGuard (TypeGuard "poisons" the AND) +/// - Take the right conjunct if it has a `TypeGuard` +/// - Intersect the constraints normally otherwise fn merge_constraints_and<'db>( into: &mut InternalConstraints<'db>, from: &InternalConstraints<'db>, db: &'db dyn Db, ) { - for (key, from_constraint) in from { - match into.entry(*key) { + for (key, from_constraint) in &from.constraints { + match into.constraints.entry(*key) { Entry::Occupied(mut entry) => { - let into_constraint = entry.get().clone(); + let into_constraint = entry.get(); // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... - let mut new_disjuncts = Vec::new(); + let mut new_disjuncts = SmallVec::new(); for left_conj in &into_constraint.disjuncts { for right_conj in &from_constraint.disjuncts { - // Intersect the regular constraints - let new_regular = IntersectionBuilder::new(db) - .add_positive(left_conj.constraint) - .add_positive(right_conj.constraint) - .build(); + if right_conj.typeguard.is_some() { + // If the right conjunct has a TypeGuard, it "wins" the conjunction + new_disjuncts.push(*right_conj); + } else { + // Intersect the regular constraints + let new_regular = IntersectionBuilder::new(db) + .add_positive(left_conj.constraint) + .add_positive(right_conj.constraint) + .build(); - // Union the TypeGuard constraints if both have them, - // or take the one that exists - let new_typeguard = match (left_conj.typeguard, right_conj.typeguard) { - (Some(left_tg), Some(right_tg)) => { - Some(UnionBuilder::new(db).add(left_tg).add(right_tg).build()) - } - (Some(tg), None) | (None, Some(tg)) => Some(tg), - (None, None) => None, - }; - - new_disjuncts.push(Conjunction { - constraint: new_regular, - typeguard: new_typeguard, - }); + new_disjuncts.push(Conjunction { + constraint: new_regular, + typeguard: left_conj.typeguard, + }); + } } } @@ -475,18 +482,21 @@ fn merge_constraints_or<'db>( from: &InternalConstraints<'db>, _db: &'db dyn Db, ) { - for (key, from_constraint) in from { - match into.entry(*key) { + // For places that appear in `into` but not in `from`, widen to object + for (_key, value) in into.constraints.iter_mut() { + if !from.constraints.contains_key(_key) { + *value = NarrowingConstraint::regular(Type::object()); + } + } + + for (key, from_constraint) in &from.constraints { + match into.constraints.entry(*key) { Entry::Occupied(mut entry) => { - let into_constraint = entry.get().clone(); - // Simply concatenate the disjuncts - let mut new_disjuncts = into_constraint.disjuncts; - new_disjuncts.extend(from_constraint.disjuncts.clone()); - - *entry.get_mut() = NarrowingConstraint { - disjuncts: new_disjuncts, - }; + entry + .get_mut() + .disjuncts + .extend(from_constraint.disjuncts.clone()); } Entry::Vacant(entry) => { // Place only appears in `from`, not in `into`. @@ -495,13 +505,6 @@ fn merge_constraints_or<'db>( } } } - - // For places that appear in `into` but not in `from`, widen to object - for (_key, value) in into.iter_mut() { - if !from.contains_key(_key) { - *value = NarrowingConstraint::regular(Type::object()); - } - } } fn place_expr(expr: &ast::Expr) -> Option { @@ -578,8 +581,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(mut constraints) = constraints { - constraints.shrink_to_fit(); - Some(constraints.to_public(self.db)) + constraints.constraints.shrink_to_fit(); + Some(constraints.evaluate_type_constraints(self.db)) } else { None } @@ -1162,34 +1165,30 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { { let return_ty = inference.expression_type(expr_call); - let (guarded_ty, place, is_typeguard) = match return_ty { + let place_and_constraint = match return_ty { Type::TypeIs(type_is) => { let (_, place) = type_is.place_info(self.db)?; - (type_is.return_type(self.db), place, false) + Some(( + place, + NarrowingConstraint::regular( + type_is + .return_type(self.db) + .negate_if(self.db, !is_positive), + ), + )) } - Type::TypeGuard(type_guard) => { + // TypeGuard only narrows in the positive case + Type::TypeGuard(type_guard) if is_positive => { let (_, place) = type_guard.place_info(self.db)?; - (type_guard.return_type(self.db), place, true) + Some(( + place, + NarrowingConstraint::typeguard(type_guard.return_type(self.db)), + )) } - _ => return None, - }; + _ => None, + }?; - // Apply negation if needed - let narrowed_ty = guarded_ty.negate_if(self.db, !is_positive); - - // For TypeGuard in the positive case, use typeguard constraint - // For TypeGuard in the negative case OR TypeIs in any case, use regular constraint - // Note: TypeGuard only narrows in the positive case - let constraint = if is_typeguard && is_positive { - NarrowingConstraint::typeguard(narrowed_ty) - } else if is_typeguard && !is_positive { - // TypeGuard doesn't narrow in the negative case - return None; - } else { - NarrowingConstraint::regular(narrowed_ty) - }; - - Some(InternalConstraints::from_iter([(place, constraint)])) + Some(InternalConstraints::from_iter([place_and_constraint])) } // For the expression `len(E)`, we narrow the type based on whether len(E) is truthy // (i.e., whether E is non-empty). We only narrow the parts of the type where we know @@ -1385,7 +1384,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { is_positive: bool, ) -> Option> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); - let mut sub_constraints = expr_bool_op + let sub_constraints = expr_bool_op .values .iter() // filter our arms with statically known truthiness @@ -1413,17 +1412,32 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { aggregation } (BoolOp::Or, true) | (BoolOp::And, false) => { - let (first, rest) = sub_constraints.split_first_mut()?; - if let Some(first) = first { + let (mut first, rest) = { + let mut it = sub_constraints.into_iter(); + (it.next()?, it) + }; + + if let Some(ref mut first) = first { for rest_constraint in rest { if let Some(rest_constraint) = rest_constraint { - merge_constraints_or(first, rest_constraint, self.db); + merge_constraints_or(first, &rest_constraint, self.db); } else { return None; } } } first.clone() + // let (first, rest) = sub_constraints.split_first_mut()?; + // if let Some(first) = first { + // for rest_constraint in rest { + // if let Some(rest_constraint) = rest_constraint { + // merge_constraints_or(first, rest_constraint, self.db); + // } else { + // return None; + // } + // } + // } + // first.clone() } } }