feat: working TypeGuard

This commit is contained in:
Eric Mark Martin 2025-10-19 16:56:55 -04:00
parent 2562b47d84
commit ca5e6070be
4 changed files with 143 additions and 129 deletions

View file

@ -16,7 +16,9 @@ def f(*args: Unpack[Ts]) -> tuple[Unpack[Ts]]:
reveal_type(args) # revealed: tuple[@Todo(`Unpack[]` special form), ...]
return args
def g() -> TypeGuard[int]: ...
def g() -> TypeGuard[int]:
return True
def i(callback: Callable[Concatenate[int, P], R_co], *args: P.args, **kwargs: P.kwargs) -> R_co:
reveal_type(args) # revealed: P@i.args
reveal_type(kwargs) # revealed: P@i.kwargs

View file

@ -1383,8 +1383,7 @@ from typing_extensions import Any, TypeGuard, TypeIs
static_assert(is_assignable_to(TypeGuard[Unknown], bool))
static_assert(is_assignable_to(TypeIs[Any], bool))
# TODO no error
static_assert(not is_assignable_to(TypeGuard[Unknown], str)) # error: [static-assert-error]
static_assert(not is_assignable_to(TypeGuard[Unknown], str))
static_assert(not is_assignable_to(TypeIs[Any], str))
```

View file

@ -578,8 +578,7 @@ from typing_extensions import TypeGuard, TypeIs
static_assert(not is_disjoint_from(bool, TypeGuard[str]))
static_assert(not is_disjoint_from(bool, TypeIs[str]))
# TODO no error
static_assert(is_disjoint_from(str, TypeGuard[str])) # error: [static-assert-error]
static_assert(is_disjoint_from(str, TypeGuard[str]))
static_assert(is_disjoint_from(str, TypeIs[str]))
```

View file

@ -23,6 +23,7 @@ use itertools::Itertools;
use ruff_python_ast as ast;
use ruff_python_ast::{BoolOp, ExprBoolOp};
use rustc_hash::FxHashMap;
use smallvec::{SmallVec, smallvec};
use std::collections::hash_map::Entry;
use super::UnionType;
@ -274,19 +275,23 @@ impl ClassInfoConstraintFunction {
}
}
/// Represents a single conjunction (AND) of constraints in Disjunctive Normal Form (DNF).
/// Represents a single conjunction (AND) of constraints in Disjunctive Normal
/// Form (DNF).
///
/// A conjunction may contain:
/// - A regular constraint (intersection of types)
/// - An optional TypeGuard constraint that "replaces" the type rather than intersecting
/// A conjunction may contain: - A regular constraint (intersection of types) -
/// An optional `TypeGuard` constraint that "replaces" the type rather than
/// intersecting
///
/// For example, `(Regular(A) & TypeGuard(B))` evaluates to just `B` because TypeGuard clobbers.
#[derive(Hash, PartialEq, Debug, Eq, Clone)]
/// For example, `(Conjunction { constraint: A, typeguard: Some(B) } &
/// Conjunction { constraint: C, typeguard: Some(D)})` evlaluates to
/// `Conjunction { constraint: C, typeguard: Some(D) }` because the type guard
/// in the second clobbers the first.
#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)]
struct Conjunction<'db> {
/// The intersected constraints (represented as an intersection type)
constraint: Type<'db>,
/// If any constraint in this conjunction is a TypeGuard, this is Some
/// and contains the union of all TypeGuard types in this conjunction
/// If any constraint in this conjunction is a `TypeGuard`, this is Some and
/// contains the union of all `TypeGuard` types in this conjunction
typeguard: Option<Type<'db>>,
}
@ -301,7 +306,7 @@ impl<'db> Conjunction<'db> {
}
}
/// Create a new conjunction with a TypeGuard constraint
/// Create a new conjunction with a `TypeGuard` constraint
fn typeguard(constraint: Type<'db>) -> Self {
Self {
constraint: Type::object(),
@ -310,9 +315,9 @@ impl<'db> Conjunction<'db> {
}
/// Evaluate this conjunction to a single type.
/// If there's a TypeGuard constraint, it replaces the regular constraint.
/// If there's a `TypeGuard` constraint, it replaces the regular constraint.
/// Otherwise, returns the regular constraint.
fn to_type(self) -> Type<'db> {
fn evaluate_type_constraint(self) -> Type<'db> {
self.typeguard.unwrap_or(self.constraint)
}
}
@ -320,52 +325,50 @@ impl<'db> Conjunction<'db> {
/// Represents narrowing constraints in Disjunctive Normal Form (DNF).
///
/// This is a disjunction (OR) of conjunctions (AND) of constraints.
/// The DNF representation allows us to properly track TypeGuard constraints
/// The DNF representation allows us to properly track `TypeGuard` constraints
/// through boolean operations.
///
/// For example:
/// - `f(x) and g(x)` where f returns TypeIs[A] and g returns TypeGuard[B]
/// - `f(x) and g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]`
/// => `[Conjunction { constraint: A, typeguard: Some(B) }]`
/// => evaluates to `B` (TypeGuard clobbers)
/// => evaluates to `B` (`TypeGuard` clobbers)
///
/// - `f(x) or g(x)` where f returns TypeIs[A] and g returns TypeGuard[B]
/// - `f(x) or g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]`
/// => `[Conjunction { constraint: A, typeguard: None }, Conjunction { constraint: object, typeguard: Some(B) }]`
/// => evaluates to `A | B`
#[derive(Hash, PartialEq, Debug, Eq, Clone)]
struct NarrowingConstraint<'db> {
/// Disjunctions of conjunctions (DNF)
disjuncts: Vec<Conjunction<'db>>,
disjuncts: SmallVec<[Conjunction<'db>; 4]>,
}
impl get_size2::GetSize for NarrowingConstraint<'_> {}
impl<'db> NarrowingConstraint<'db> {
/// Create a constraint from a regular (non-TypeGuard) type
/// Create a constraint from a regular (non-`TypeGuard`) type
fn regular(constraint: Type<'db>) -> Self {
Self {
disjuncts: vec![Conjunction::regular(constraint)],
disjuncts: smallvec![Conjunction::regular(constraint)],
}
}
/// Create a constraint from a TypeGuard type
/// Create a constraint from a `TypeGuard` type
fn typeguard(constraint: Type<'db>) -> Self {
Self {
disjuncts: vec![Conjunction::typeguard(constraint)],
disjuncts: smallvec![Conjunction::typeguard(constraint)],
}
}
/// Evaluate this constraint to a single type by evaluating each disjunct
/// and taking their union
fn to_type(self, db: &'db dyn Db) -> Type<'db> {
if self.disjuncts.is_empty() {
return Type::Never;
}
if self.disjuncts.len() == 1 {
return self.disjuncts.into_iter().next().unwrap().to_type();
}
UnionType::from_elements(db, self.disjuncts.into_iter().map(|c| c.to_type()))
/// Evaluate the type this effectively constrains to
///
/// Forgets whether each constraint originated from a `TypeGuard` or not
fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> {
UnionType::from_elements(
db,
self.disjuncts
.into_iter()
.map(|c| c.evaluate_type_constraint()),
)
}
}
@ -375,80 +378,84 @@ impl<'db> From<Type<'db>> for NarrowingConstraint<'db> {
}
}
/// Internal representation of constraints with DNF structure for tracking TypeGuard semantics
type InternalConstraints<'db> = FxHashMap<ScopedPlaceId, NarrowingConstraint<'db>>;
/// Public representation of constraints as returned by tracked functions
type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
/// Helper trait to make inserting constraints more ergonomic
trait InternalConstraintsExt<'db> {
fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>);
fn insert_typeguard(&mut self, place: ScopedPlaceId, ty: Type<'db>);
fn to_public(self, db: &'db dyn Db) -> NarrowingConstraints<'db>;
/// Internal representation of constraints with DNF structure for tracking `TypeGuard` semantics.
///
/// This is a newtype wrapper around `FxHashMap<ScopedPlaceId, NarrowingConstraint<'db>>` that
/// provides methods for working with constraints during boolean operation evaluation.
#[derive(Clone, Debug, Default)]
struct InternalConstraints<'db> {
constraints: FxHashMap<ScopedPlaceId, NarrowingConstraint<'db>>,
}
impl<'db> InternalConstraintsExt<'db> for InternalConstraints<'db> {
impl<'db> InternalConstraints<'db> {
/// Insert a regular (non-`TypeGuard`) constraint for a place
fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>) {
self.insert(place, NarrowingConstraint::regular(ty));
self.constraints
.insert(place, NarrowingConstraint::regular(ty));
}
fn insert_typeguard(&mut self, place: ScopedPlaceId, ty: Type<'db>) {
self.insert(place, NarrowingConstraint::typeguard(ty));
}
fn to_public(self, db: &'db dyn Db) -> NarrowingConstraints<'db> {
self.into_iter()
.map(|(place, constraint)| (place, constraint.to_type(db)))
/// Convert internal constraints to public constraints by evaluating each DNF constraint to a Type
fn evaluate_type_constraints(self, db: &'db dyn Db) -> NarrowingConstraints<'db> {
self.constraints
.into_iter()
.map(|(place, constraint)| (place, constraint.evaluate_type_constraint(db)))
.collect()
}
}
impl<'db> FromIterator<(ScopedPlaceId, NarrowingConstraint<'db>)> for InternalConstraints<'db> {
fn from_iter<T: IntoIterator<Item = (ScopedPlaceId, NarrowingConstraint<'db>)>>(
iter: T,
) -> Self {
Self {
constraints: FxHashMap::from_iter(iter),
}
}
}
/// Public representation of constraints as returned by tracked functions
type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
/// Merge constraints with AND semantics (intersection/conjunction).
///
/// When we have `constraint1 AND constraint2`, we need to distribute AND over the OR
/// When we have `constraint1 & constraint2`, we need to distribute AND over the OR
/// in the DNF representations:
/// `(A | B) AND (C | D)` becomes `(A & C) | (A & D) | (B & C) | (B & D)`
/// `(A | B) & (C | D)` becomes `(A & C) | (A & D) | (B & C) | (B & D)`
///
/// For each conjunction pair, we:
/// - Intersect the regular constraints
/// - If either has a TypeGuard, the result gets a TypeGuard (TypeGuard "poisons" the AND)
/// - Take the right conjunct if it has a `TypeGuard`
/// - Intersect the constraints normally otherwise
fn merge_constraints_and<'db>(
into: &mut InternalConstraints<'db>,
from: &InternalConstraints<'db>,
db: &'db dyn Db,
) {
for (key, from_constraint) in from {
match into.entry(*key) {
for (key, from_constraint) in &from.constraints {
match into.constraints.entry(*key) {
Entry::Occupied(mut entry) => {
let into_constraint = entry.get().clone();
let into_constraint = entry.get();
// Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...)
// becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ...
let mut new_disjuncts = Vec::new();
let mut new_disjuncts = SmallVec::new();
for left_conj in &into_constraint.disjuncts {
for right_conj in &from_constraint.disjuncts {
// Intersect the regular constraints
let new_regular = IntersectionBuilder::new(db)
.add_positive(left_conj.constraint)
.add_positive(right_conj.constraint)
.build();
if right_conj.typeguard.is_some() {
// If the right conjunct has a TypeGuard, it "wins" the conjunction
new_disjuncts.push(*right_conj);
} else {
// Intersect the regular constraints
let new_regular = IntersectionBuilder::new(db)
.add_positive(left_conj.constraint)
.add_positive(right_conj.constraint)
.build();
// Union the TypeGuard constraints if both have them,
// or take the one that exists
let new_typeguard = match (left_conj.typeguard, right_conj.typeguard) {
(Some(left_tg), Some(right_tg)) => {
Some(UnionBuilder::new(db).add(left_tg).add(right_tg).build())
}
(Some(tg), None) | (None, Some(tg)) => Some(tg),
(None, None) => None,
};
new_disjuncts.push(Conjunction {
constraint: new_regular,
typeguard: new_typeguard,
});
new_disjuncts.push(Conjunction {
constraint: new_regular,
typeguard: left_conj.typeguard,
});
}
}
}
@ -475,18 +482,21 @@ fn merge_constraints_or<'db>(
from: &InternalConstraints<'db>,
_db: &'db dyn Db,
) {
for (key, from_constraint) in from {
match into.entry(*key) {
// For places that appear in `into` but not in `from`, widen to object
for (_key, value) in into.constraints.iter_mut() {
if !from.constraints.contains_key(_key) {
*value = NarrowingConstraint::regular(Type::object());
}
}
for (key, from_constraint) in &from.constraints {
match into.constraints.entry(*key) {
Entry::Occupied(mut entry) => {
let into_constraint = entry.get().clone();
// Simply concatenate the disjuncts
let mut new_disjuncts = into_constraint.disjuncts;
new_disjuncts.extend(from_constraint.disjuncts.clone());
*entry.get_mut() = NarrowingConstraint {
disjuncts: new_disjuncts,
};
entry
.get_mut()
.disjuncts
.extend(from_constraint.disjuncts.clone());
}
Entry::Vacant(entry) => {
// Place only appears in `from`, not in `into`.
@ -495,13 +505,6 @@ fn merge_constraints_or<'db>(
}
}
}
// For places that appear in `into` but not in `from`, widen to object
for (_key, value) in into.iter_mut() {
if !from.contains_key(_key) {
*value = NarrowingConstraint::regular(Type::object());
}
}
}
fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
@ -578,8 +581,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(mut constraints) = constraints {
constraints.shrink_to_fit();
Some(constraints.to_public(self.db))
constraints.constraints.shrink_to_fit();
Some(constraints.evaluate_type_constraints(self.db))
} else {
None
}
@ -1162,34 +1165,30 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
{
let return_ty = inference.expression_type(expr_call);
let (guarded_ty, place, is_typeguard) = match return_ty {
let place_and_constraint = match return_ty {
Type::TypeIs(type_is) => {
let (_, place) = type_is.place_info(self.db)?;
(type_is.return_type(self.db), place, false)
Some((
place,
NarrowingConstraint::regular(
type_is
.return_type(self.db)
.negate_if(self.db, !is_positive),
),
))
}
Type::TypeGuard(type_guard) => {
// TypeGuard only narrows in the positive case
Type::TypeGuard(type_guard) if is_positive => {
let (_, place) = type_guard.place_info(self.db)?;
(type_guard.return_type(self.db), place, true)
Some((
place,
NarrowingConstraint::typeguard(type_guard.return_type(self.db)),
))
}
_ => return None,
};
_ => None,
}?;
// Apply negation if needed
let narrowed_ty = guarded_ty.negate_if(self.db, !is_positive);
// For TypeGuard in the positive case, use typeguard constraint
// For TypeGuard in the negative case OR TypeIs in any case, use regular constraint
// Note: TypeGuard only narrows in the positive case
let constraint = if is_typeguard && is_positive {
NarrowingConstraint::typeguard(narrowed_ty)
} else if is_typeguard && !is_positive {
// TypeGuard doesn't narrow in the negative case
return None;
} else {
NarrowingConstraint::regular(narrowed_ty)
};
Some(InternalConstraints::from_iter([(place, constraint)]))
Some(InternalConstraints::from_iter([place_and_constraint]))
}
// For the expression `len(E)`, we narrow the type based on whether len(E) is truthy
// (i.e., whether E is non-empty). We only narrow the parts of the type where we know
@ -1385,7 +1384,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
is_positive: bool,
) -> Option<InternalConstraints<'db>> {
let inference = infer_expression_types(self.db, expression, TypeContext::default());
let mut sub_constraints = expr_bool_op
let sub_constraints = expr_bool_op
.values
.iter()
// filter our arms with statically known truthiness
@ -1413,17 +1412,32 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
aggregation
}
(BoolOp::Or, true) | (BoolOp::And, false) => {
let (first, rest) = sub_constraints.split_first_mut()?;
if let Some(first) = first {
let (mut first, rest) = {
let mut it = sub_constraints.into_iter();
(it.next()?, it)
};
if let Some(ref mut first) = first {
for rest_constraint in rest {
if let Some(rest_constraint) = rest_constraint {
merge_constraints_or(first, rest_constraint, self.db);
merge_constraints_or(first, &rest_constraint, self.db);
} else {
return None;
}
}
}
first.clone()
// let (first, rest) = sub_constraints.split_first_mut()?;
// if let Some(first) = first {
// for rest_constraint in rest {
// if let Some(rest_constraint) = rest_constraint {
// merge_constraints_or(first, rest_constraint, self.db);
// } else {
// return None;
// }
// }
// }
// first.clone()
}
}
}