From 65f24b122413d105934006c9c31d40243d4b1efe Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sat, 18 Oct 2025 01:30:05 -0400 Subject: [PATCH 01/22] wip: initial typeguard --- crates/ty_ide/src/completion.rs | 1 + crates/ty_python_semantic/src/types.rs | 169 ++++++++++++++++-- .../src/types/bound_super.rs | 2 +- .../src/types/class_base.rs | 1 + .../ty_python_semantic/src/types/display.rs | 16 ++ .../ty_python_semantic/src/types/function.rs | 1 + .../src/types/infer/builder.rs | 6 +- .../src/types/list_members.rs | 3 +- crates/ty_python_semantic/src/types/narrow.rs | 1 + .../src/types/type_ordering.rs | 24 ++- .../ty_python_semantic/src/types/visitor.rs | 13 +- 11 files changed, 214 insertions(+), 23 deletions(-) diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index d8dabbf413..3bf666635a 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -353,6 +353,7 @@ impl<'db> Completion<'db> { Type::IntLiteral(_) | Type::BooleanLiteral(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::StringLiteral(_) | Type::LiteralString | Type::BytesLiteral(_) => CompletionKind::Value, diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index cc05c9c000..1d6a28b3d1 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -864,6 +864,8 @@ pub enum Type<'db> { BoundSuper(BoundSuperType<'db>), /// A subtype of `bool` that allows narrowing in both positive and negative cases. TypeIs(TypeIsType<'db>), + /// A subtype of `bool` that allows narrowing in only the positive case. + TypeGuard(TypeGuardType<'db>), /// A type that represents an inhabitant of a `TypedDict`. TypedDict(TypedDictType<'db>), /// An aliased type (lazily not-yet-unpacked to its value type). @@ -1617,6 +1619,9 @@ impl<'db> Type<'db> { Type::TypeIs(type_is) => visitor.visit(self, || { type_is.with_type(db, type_is.return_type(db).normalized_impl(db, visitor)) }), + Type::TypeGuard(type_guard) => visitor.visit(self, || { + type_guard.with_type(db, type_guard.return_type(db).normalized_impl(db, visitor)) + }), Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized()), Type::EnumLiteral(enum_literal) if is_single_member_enum(db, enum_literal.enum_class(db)) => @@ -1752,6 +1757,20 @@ impl<'db> Type<'db> { }; Some(type_is.with_type(db, ty)) } + // TODO: deduplicate + Type::TypeGuard(type_guard) => { + let ty = if nested { + type_guard + .return_type(db) + .recursive_type_normalized_impl(db, div, true)? + } else { + type_guard + .return_type(db) + .recursive_type_normalized_impl(db, div, true) + .unwrap_or(div) + }; + Some(type_guard.with_type(db, ty)) + } Type::Dynamic(dynamic) => Some(Type::Dynamic(dynamic.recursive_type_normalized())), Type::TypedDict(_) => { // TODO: Normalize TypedDicts @@ -1824,6 +1843,7 @@ impl<'db> Type<'db> { | Type::TypeVar(_) | Type::BoundSuper(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) | Type::TypeAlias(_) | Type::NewTypeInstance(_) => false, @@ -1936,6 +1956,7 @@ impl<'db> Type<'db> { | Type::LiteralString | Type::BytesLiteral(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) => None, // TODO @@ -2806,15 +2827,29 @@ impl<'db> Type<'db> { ) }), - // `TypeIs[T]` is a subtype of `bool`. - (Type::TypeIs(_), _) => KnownClass::Bool.to_instance(db).has_relation_to_impl( - db, - target, - inferable, - relation, - relation_visitor, - disjointness_visitor, - ), + // `TypeGuard` is covariant. + (Type::TypeGuard(left), Type::TypeGuard(right)) => { + left.return_type(db).has_relation_to_impl( + db, + right.return_type(db), + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + } + + // `TypeIs[T]` and `TypeGuard[T]` is a subtype of `bool`. + (Type::TypeIs(_) | Type::TypeGuard(_), _) => { + KnownClass::Bool.to_instance(db).has_relation_to_impl( + db, + target, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + } // Function-like callables are subtypes of `FunctionType` (Type::Callable(callable), _) if callable.is_function_like(db) => { @@ -3783,8 +3818,14 @@ impl<'db> Type<'db> { ConstraintSet::from(!known_instance.is_instance_of(db, instance.class(db))) } - (Type::BooleanLiteral(..) | Type::TypeIs(_), Type::NominalInstance(instance)) - | (Type::NominalInstance(instance), Type::BooleanLiteral(..) | Type::TypeIs(_)) => { + ( + Type::BooleanLiteral(..) | Type::TypeIs(_) | Type::TypeGuard(_), + Type::NominalInstance(instance), + ) + | ( + Type::NominalInstance(instance), + Type::BooleanLiteral(..) | Type::TypeIs(_) | Type::TypeGuard(_), + ) => { // A `Type::BooleanLiteral()` must be an instance of exactly `bool` // (it cannot be an instance of a `bool` subclass) KnownClass::Bool @@ -3792,8 +3833,10 @@ impl<'db> Type<'db> { .negate(db) } - (Type::BooleanLiteral(..) | Type::TypeIs(_), _) - | (_, Type::BooleanLiteral(..) | Type::TypeIs(_)) => ConstraintSet::from(true), + (Type::BooleanLiteral(..) | Type::TypeIs(_) | Type::TypeGuard(_), _) + | (_, Type::BooleanLiteral(..) | Type::TypeIs(_) | Type::TypeGuard(_)) => { + ConstraintSet::from(true) + } (Type::IntLiteral(..), Type::NominalInstance(instance)) | (Type::NominalInstance(instance), Type::IntLiteral(..)) => { @@ -4259,6 +4302,7 @@ impl<'db> Type<'db> { } Type::AlwaysTruthy | Type::AlwaysFalsy => false, Type::TypeIs(type_is) => type_is.is_bound(db), + Type::TypeGuard(type_guard) => type_guard.is_bound(db), Type::TypedDict(_) => false, Type::TypeAlias(alias) => alias.value_type(db).is_singleton(db), Type::NewTypeInstance(newtype) => newtype.concrete_base_type(db).is_singleton(db), @@ -4320,6 +4364,7 @@ impl<'db> Type<'db> { } Type::TypeIs(type_is) => type_is.is_bound(db), + Type::TypeGuard(type_is) => type_is.is_bound(db), Type::TypeAlias(alias) => alias.value_type(db).is_single_valued(db), @@ -4476,6 +4521,7 @@ impl<'db> Type<'db> { | Type::ProtocolInstance(_) | Type::PropertyInstance(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) | Type::NewTypeInstance(_) => None, } @@ -4604,7 +4650,7 @@ impl<'db> Type<'db> { } Type::IntLiteral(_) => KnownClass::Int.to_instance(db).instance_member(db, name), - Type::BooleanLiteral(_) | Type::TypeIs(_) => { + Type::BooleanLiteral(_) | Type::TypeIs(_) | Type::TypeGuard(_) => { KnownClass::Bool.to_instance(db).instance_member(db, name) } Type::StringLiteral(_) | Type::LiteralString => { @@ -5268,6 +5314,7 @@ impl<'db> Type<'db> { | Type::AlwaysTruthy | Type::AlwaysFalsy | Type::TypeIs(..) + | Type::TypeGuard(..) | Type::TypedDict(_) => { let fallback = self.instance_member(db, name_str); @@ -5595,7 +5642,8 @@ impl<'db> Type<'db> { | Type::Never | Type::Callable(_) | Type::LiteralString - | Type::TypeIs(_) => Truthiness::Ambiguous, + | Type::TypeIs(_) + | Type::TypeGuard(_) => Truthiness::Ambiguous, Type::TypedDict(_) => { // TODO: We could do better here, but it's unclear if this is important. @@ -6539,6 +6587,7 @@ impl<'db> Type<'db> { | Type::BoundSuper(_) | Type::ModuleLiteral(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) => CallableBinding::not_callable(self).into(), } } @@ -6768,6 +6817,7 @@ impl<'db> Type<'db> { | Type::EnumLiteral(_) | Type::BoundSuper(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) => None } } @@ -7382,6 +7432,7 @@ impl<'db> Type<'db> { | Type::AlwaysTruthy | Type::AlwaysFalsy | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) | Type::NewTypeInstance(_) => None, } @@ -7438,6 +7489,7 @@ impl<'db> Type<'db> { | Type::ProtocolInstance(_) | Type::PropertyInstance(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) => Err(InvalidTypeExpressionError { invalid_expressions: smallvec::smallvec_inline![ InvalidTypeExpression::InvalidType(*self, scope_id) @@ -7727,7 +7779,9 @@ impl<'db> Type<'db> { Type::SpecialForm(special_form) => special_form.to_meta_type(db), Type::PropertyInstance(_) => KnownClass::Property.to_class_literal(db), Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)), - Type::BooleanLiteral(_) | Type::TypeIs(_) => KnownClass::Bool.to_class_literal(db), + Type::BooleanLiteral(_) | Type::TypeIs(_) | Type::TypeGuard(_) => { + KnownClass::Bool.to_class_literal(db) + } Type::BytesLiteral(_) => KnownClass::Bytes.to_class_literal(db), Type::IntLiteral(_) => KnownClass::Int.to_class_literal(db), Type::EnumLiteral(enum_literal) => Type::ClassLiteral(enum_literal.enum_class(db)), @@ -8020,6 +8074,9 @@ impl<'db> Type<'db> { // TODO(jelle): Materialize should be handled differently, since TypeIs is invariant Type::TypeIs(type_is) => type_is.with_type(db, type_is.return_type(db).apply_type_mapping(db, type_mapping, tcx)), + // TODO: check variance + Type::TypeGuard(type_guard) => type_guard.with_type(db, type_guard.return_type(db).apply_type_mapping(db, type_mapping, tcx)), + Type::TypeAlias(alias) => { if TypeMapping::EagerExpansion == *type_mapping { return alias.raw_value_type(db).expand_eagerly(db); @@ -8226,6 +8283,15 @@ impl<'db> Type<'db> { ); } + Type::TypeGuard(type_guard) => { + type_guard.return_type(db).find_legacy_typevars_impl( + db, + binding_context, + typevars, + visitor, + ); + } + Type::TypeAlias(alias) => { visitor.visit(self, || { alias.value_type(db).find_legacy_typevars_impl( @@ -8490,7 +8556,8 @@ impl<'db> Type<'db> { // These types have no definition Self::Dynamic(DynamicType::Divergent(_) | DynamicType::Todo(_) | DynamicType::TodoUnpack | DynamicType::TodoStarredExpression) | Self::Callable(_) - | Self::TypeIs(_) => None, + | Self::TypeIs(_) + | Self::TypeGuard(_) => None, } } @@ -8654,6 +8721,7 @@ impl<'db> VarianceInferable<'db> for Type<'db> { .collect(), Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar), Type::TypeIs(type_is_type) => type_is_type.variance_of(db, typevar), + Type::TypeGuard(type_guard_type) => type_guard_type.variance_of(db, typevar), Type::KnownInstance(known_instance) => known_instance.variance_of(db, typevar), Type::Dynamic(_) | Type::Never @@ -14621,6 +14689,73 @@ impl<'db> VarianceInferable<'db> for TypeIsType<'db> { } } +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +pub struct TypeGuardType<'db> { + return_type: Type<'db>, + /// The ID of the scope to which the place belongs + /// and the ID of the place itself within that scope. + place_info: Option<(ScopeId<'db>, ScopedPlaceId)>, +} + +fn walk_typeguard_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( + db: &'db dyn Db, + typeguard_type: TypeGuardType<'db>, + visitor: &V, +) { + visitor.visit_type(db, typeguard_type.return_type(db)); +} + +// The Salsa heap is tracked separately. +impl get_size2::GetSize for TypeGuardType<'_> {} + +impl<'db> TypeGuardType<'db> { + pub(crate) fn place_name(self, db: &'db dyn Db) -> Option { + let (scope, place) = self.place_info(db)?; + let table = place_table(db, scope); + + Some(format!("{}", table.place(place))) + } + + pub(crate) fn unbound(db: &'db dyn Db, ty: Type<'db>) -> Type<'db> { + Type::TypeGuard(Self::new(db, ty, None)) + } + + pub(crate) fn bound( + db: &'db dyn Db, + return_type: Type<'db>, + scope: ScopeId<'db>, + place: ScopedPlaceId, + ) -> Type<'db> { + Type::TypeGuard(Self::new(db, return_type, Some((scope, place)))) + } + + #[must_use] + pub(crate) fn bind( + self, + db: &'db dyn Db, + scope: ScopeId<'db>, + place: ScopedPlaceId, + ) -> Type<'db> { + Self::bound(db, self.return_type(db), scope, place) + } + + #[must_use] + pub(crate) fn with_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> { + Type::TypeGuard(Self::new(db, ty, self.place_info(db))) + } + + pub(crate) fn is_bound(self, db: &'db dyn Db) -> bool { + self.place_info(db).is_some() + } +} + +impl<'db> VarianceInferable<'db> for TypeGuardType<'db> { + // TODO: comment + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + self.return_type(db).variance_of(db, typevar) + } +} + /// Walk the MRO of this class and return the last class just before the specified known base. /// This can be used to determine upper bounds for `Self` type variables on methods that are /// being added to the given class. diff --git a/crates/ty_python_semantic/src/types/bound_super.rs b/crates/ty_python_semantic/src/types/bound_super.rs index 47a9bdeac3..ad71662959 100644 --- a/crates/ty_python_semantic/src/types/bound_super.rs +++ b/crates/ty_python_semantic/src/types/bound_super.rs @@ -389,7 +389,7 @@ impl<'db> BoundSuperType<'db> { None => delegate_with_error_mapped(Type::object(), Some(type_var)), }; } - Type::BooleanLiteral(_) | Type::TypeIs(_) => { + Type::BooleanLiteral(_) | Type::TypeIs(_) | Type::TypeGuard(_) => { return delegate_to(KnownClass::Bool.to_instance(db)); } Type::IntLiteral(_) => return delegate_to(KnownClass::Int.to_instance(db)), diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 26b490fa3b..9dd74c7297 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -177,6 +177,7 @@ impl<'db> ClassBase<'db> { | Type::AlwaysFalsy | Type::AlwaysTruthy | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) => None, Type::KnownInstance(known_instance) => match known_instance { diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index f1c37debeb..16fe01babf 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -979,6 +979,22 @@ impl<'db> FmtDetailed<'db> for DisplayRepresentation<'db> { } f.write_str("]") } + // TODO: deduplicate + Type::TypeGuard(type_guard) => { + f.with_type(Type::SpecialForm(SpecialFormType::TypeGuard)) + .write_str("TypeGuard")?; + f.write_char('[')?; + type_guard + .return_type(self.db) + .display_with(self.db, self.settings.singleline()) + .fmt_detailed(f)?; + if let Some(name) = type_guard.place_name(self.db) { + f.set_invalid_syntax(); + f.write_str(" @ ")?; + f.write_str(&name)?; + } + f.write_str("]") + } Type::TypedDict(TypedDictType::Class(defining_class)) => defining_class .class_literal(self.db) .0 diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index ee54b42f06..cb3d27c4a4 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -1301,6 +1301,7 @@ fn is_instance_truthiness<'db>( | Type::AlwaysFalsy | Type::BoundSuper(..) | Type::TypeIs(..) + | Type::TypeGuard(..) | Type::Callable(..) | Type::Dynamic(..) | Type::Never diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 55afa55a0b..16810fbb59 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -2171,7 +2171,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let declared_ty = self.file_expression_type(returns); let expected_ty = match declared_ty { - Type::TypeIs(_) => KnownClass::Bool.to_instance(self.db()), + Type::TypeIs(_) | Type::TypeGuard(_) => KnownClass::Bool.to_instance(self.db()), ty => ty, }; @@ -4572,6 +4572,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { | Type::AlwaysTruthy | Type::AlwaysFalsy | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) | Type::NewTypeInstance(_) => { // TODO: We could use the annotated parameter type of `__setattr__` as type context here. @@ -9893,6 +9894,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { | Type::BoundSuper(_) | Type::TypeVar(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) | Type::NewTypeInstance(_), ) => { @@ -10393,6 +10395,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { | Type::BoundSuper(_) | Type::TypeVar(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) | Type::NewTypeInstance(_), Type::FunctionLiteral(_) @@ -10423,6 +10426,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { | Type::BoundSuper(_) | Type::TypeVar(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::TypedDict(_) | Type::NewTypeInstance(_), op, diff --git a/crates/ty_python_semantic/src/types/list_members.rs b/crates/ty_python_semantic/src/types/list_members.rs index 4e4a32c294..0bebd2ae88 100644 --- a/crates/ty_python_semantic/src/types/list_members.rs +++ b/crates/ty_python_semantic/src/types/list_members.rs @@ -282,7 +282,8 @@ impl<'db> AllMembers<'db> { | Type::SpecialForm(_) | Type::KnownInstance(_) | Type::BoundSuper(_) - | Type::TypeIs(_) => match ty.to_meta_type(db) { + | Type::TypeIs(_) + | Type::TypeGuard(_) => match ty.to_meta_type(db) { Type::ClassLiteral(class_literal) => { self.extend_with_class_members(db, ty, class_literal); } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index ae36ea47ed..ab57aed5d6 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -265,6 +265,7 @@ impl ClassInfoConstraintFunction { | Type::IntLiteral(_) | Type::KnownInstance(_) | Type::TypeIs(_) + | Type::TypeGuard(_) | Type::WrapperDescriptor(_) | Type::DataclassTransformer(_) | Type::TypedDict(_) diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index c47720011c..51a0e7d4fd 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -5,7 +5,8 @@ use salsa::plumbing::AsId; use crate::{db::Db, types::bound_super::SuperOwnerKind}; use super::{ - DynamicType, TodoType, Type, TypeIsType, class_base::ClassBase, subclass_of::SubclassOfInner, + DynamicType, TodoType, Type, TypeGuardType, TypeIsType, class_base::ClassBase, + subclass_of::SubclassOfInner, }; /// Return an [`Ordering`] that describes the canonical order in which two types should appear @@ -132,6 +133,10 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (Type::TypeIs(_), _) => Ordering::Less, (_, Type::TypeIs(_)) => Ordering::Greater, + (Type::TypeGuard(left), Type::TypeGuard(right)) => typeguard_ordering(db, *left, *right), + (Type::TypeGuard(_), _) => Ordering::Less, + (_, Type::TypeGuard(_)) => Ordering::Greater, + (Type::NominalInstance(left), Type::NominalInstance(right)) => { left.class(db).cmp(&right.class(db)) } @@ -307,3 +312,20 @@ fn typeis_ordering(db: &dyn Db, left: TypeIsType, right: TypeIsType) -> Ordering }, } } + +// TODO: de-duplicate +fn typeguard_ordering(db: &dyn Db, left: TypeGuardType, right: TypeGuardType) -> Ordering { + let (left_ty, right_ty) = (left.return_type(db), right.return_type(db)); + + match (left.place_info(db), right.place_info(db)) { + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + + (None, None) => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), + + (Some(_), Some(_)) => match left.place_name(db).cmp(&right.place_name(db)) { + Ordering::Equal => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), + ordering => ordering, + }, + } +} diff --git a/crates/ty_python_semantic/src/types/visitor.rs b/crates/ty_python_semantic/src/types/visitor.rs index 54ce30cc53..73a31c5f33 100644 --- a/crates/ty_python_semantic/src/types/visitor.rs +++ b/crates/ty_python_semantic/src/types/visitor.rs @@ -4,7 +4,7 @@ use crate::{ BoundMethodType, BoundSuperType, BoundTypeVarInstance, CallableType, GenericAlias, IntersectionType, KnownBoundMethodType, KnownInstanceType, NominalInstanceType, PropertyInstanceType, ProtocolInstanceType, SubclassOfType, Type, TypeAliasType, - TypeIsType, TypeVarInstance, TypedDictType, UnionType, + TypeGuardType, TypeIsType, TypeVarInstance, TypedDictType, UnionType, bound_super::walk_bound_super_type, class::walk_generic_alias, function::{FunctionType, walk_function_type}, @@ -14,7 +14,7 @@ use crate::{ walk_bound_method_type, walk_bound_type_var_type, walk_callable_type, walk_intersection_type, walk_known_instance_type, walk_method_wrapper_type, walk_property_instance_type, walk_type_alias_type, walk_type_var_type, - walk_typed_dict_type, walk_typeis_type, walk_union, + walk_typed_dict_type, walk_typeguard_type, walk_typeis_type, walk_union, }, }; use std::cell::{Cell, RefCell}; @@ -50,6 +50,10 @@ pub(crate) trait TypeVisitor<'db> { walk_typeis_type(db, type_is, self); } + fn visit_typeguard_type(&self, db: &'db dyn Db, type_is: TypeGuardType<'db>) { + walk_typeguard_type(db, type_is, self); + } + fn visit_subclass_of_type(&self, db: &'db dyn Db, subclass_of: SubclassOfType<'db>) { walk_subclass_of_type(db, subclass_of, self); } @@ -127,6 +131,7 @@ pub(super) enum NonAtomicType<'db> { NominalInstance(NominalInstanceType<'db>), PropertyInstance(PropertyInstanceType<'db>), TypeIs(TypeIsType<'db>), + TypeGuard(TypeGuardType<'db>), TypeVar(BoundTypeVarInstance<'db>), ProtocolInstance(ProtocolInstanceType<'db>), TypedDict(TypedDictType<'db>), @@ -195,6 +200,9 @@ impl<'db> From> for TypeKind<'db> { TypeKind::NonAtomic(NonAtomicType::TypeVar(bound_typevar)) } Type::TypeIs(type_is) => TypeKind::NonAtomic(NonAtomicType::TypeIs(type_is)), + Type::TypeGuard(type_guard) => { + TypeKind::NonAtomic(NonAtomicType::TypeGuard(type_guard)) + } Type::TypedDict(typed_dict) => { TypeKind::NonAtomic(NonAtomicType::TypedDict(typed_dict)) } @@ -233,6 +241,7 @@ pub(super) fn walk_non_atomic_type<'db, V: TypeVisitor<'db> + ?Sized>( visitor.visit_property_instance_type(db, property); } NonAtomicType::TypeIs(type_is) => visitor.visit_typeis_type(db, type_is), + NonAtomicType::TypeGuard(type_guard) => visitor.visit_typeguard_type(db, type_guard), NonAtomicType::TypeVar(bound_typevar) => { visitor.visit_bound_type_var_type(db, bound_typevar); } From 2ed31d6c7c401c0f37a1cb4db6ed2360d2d41a1e Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sun, 19 Oct 2025 15:10:29 -0400 Subject: [PATCH 02/22] wip(typeguard): refactor constraint to CNF --- .../src/types/infer/builder.rs | 2 +- crates/ty_python_semantic/src/types/narrow.rs | 311 +++++++++++++++--- 2 files changed, 261 insertions(+), 52 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 16810fbb59..7a7f25eb31 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -8824,7 +8824,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // TODO: TypeGuard Type::TypeIs(type_is) => match find_narrowed_place() { Some(place) => type_is.bind(db, scope, place), - None => return_ty, + None => return_ty, // TODO(ericmarkmartin): ? }, _ => return_ty, } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index ab57aed5d6..328ed0e07d 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -274,46 +274,232 @@ impl ClassInfoConstraintFunction { } } +/// Represents a single conjunction (AND) of constraints in Disjunctive Normal Form (DNF). +/// +/// A conjunction may contain: +/// - A regular constraint (intersection of types) +/// - An optional TypeGuard constraint that "replaces" the type rather than intersecting +/// +/// For example, `(Regular(A) & TypeGuard(B))` evaluates to just `B` because TypeGuard clobbers. +#[derive(Hash, PartialEq, Debug, Eq, Clone)] +struct Conjunction<'db> { + /// The intersected constraints (represented as an intersection type) + constraint: Type<'db>, + /// If any constraint in this conjunction is a TypeGuard, this is Some + /// and contains the union of all TypeGuard types in this conjunction + typeguard: Option>, +} + +impl get_size2::GetSize for Conjunction<'_> {} + +impl<'db> Conjunction<'db> { + /// Create a new conjunction with just a regular constraint + fn regular(constraint: Type<'db>) -> Self { + Self { + constraint, + typeguard: None, + } + } + + /// Create a new conjunction with a TypeGuard constraint + fn typeguard(constraint: Type<'db>) -> Self { + Self { + constraint: Type::object(), + typeguard: Some(constraint), + } + } + + /// Evaluate this conjunction to a single type. + /// If there's a TypeGuard constraint, it replaces the regular constraint. + /// Otherwise, returns the regular constraint. + fn to_type(self) -> Type<'db> { + self.typeguard.unwrap_or(self.constraint) + } +} + +/// Represents narrowing constraints in Disjunctive Normal Form (DNF). +/// +/// This is a disjunction (OR) of conjunctions (AND) of constraints. +/// The DNF representation allows us to properly track TypeGuard constraints +/// through boolean operations. +/// +/// For example: +/// - `f(x) and g(x)` where f returns TypeIs[A] and g returns TypeGuard[B] +/// => `[Conjunction { constraint: A, typeguard: Some(B) }]` +/// => evaluates to `B` (TypeGuard clobbers) +/// +/// - `f(x) or g(x)` where f returns TypeIs[A] and g returns TypeGuard[B] +/// => `[Conjunction { constraint: A, typeguard: None }, Conjunction { constraint: object, typeguard: Some(B) }]` +/// => evaluates to `A | B` +#[derive(Hash, PartialEq, Debug, Eq, Clone)] +struct NarrowingConstraint<'db> { + /// Disjunctions of conjunctions (DNF) + disjuncts: Vec>, +} + +impl get_size2::GetSize for NarrowingConstraint<'_> {} + +impl<'db> NarrowingConstraint<'db> { + /// Create a constraint from a regular (non-TypeGuard) type + fn regular(constraint: Type<'db>) -> Self { + Self { + disjuncts: vec![Conjunction::regular(constraint)], + } + } + + /// Create a constraint from a TypeGuard type + fn typeguard(constraint: Type<'db>) -> Self { + Self { + disjuncts: vec![Conjunction::typeguard(constraint)], + } + } + + /// Evaluate this constraint to a single type by evaluating each disjunct + /// and taking their union + fn to_type(self, db: &'db dyn Db) -> Type<'db> { + if self.disjuncts.is_empty() { + return Type::Never; + } + + if self.disjuncts.len() == 1 { + return self.disjuncts.into_iter().next().unwrap().to_type(); + } + + UnionType::from_elements(db, self.disjuncts.into_iter().map(|c| c.to_type())) + } +} + +impl<'db> From> for NarrowingConstraint<'db> { + fn from(constraint: Type<'db>) -> Self { + Self::regular(constraint) + } +} + +/// Internal representation of constraints with DNF structure for tracking TypeGuard semantics +type InternalConstraints<'db> = FxHashMap>; + +/// Public representation of constraints as returned by tracked functions type NarrowingConstraints<'db> = FxHashMap>; +/// Helper trait to make inserting constraints more ergonomic +trait InternalConstraintsExt<'db> { + fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>); + fn insert_typeguard(&mut self, place: ScopedPlaceId, ty: Type<'db>); + fn to_public(self, db: &'db dyn Db) -> NarrowingConstraints<'db>; +} + +impl<'db> InternalConstraintsExt<'db> for InternalConstraints<'db> { + fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>) { + self.insert(place, NarrowingConstraint::regular(ty)); + } + + fn insert_typeguard(&mut self, place: ScopedPlaceId, ty: Type<'db>) { + self.insert(place, NarrowingConstraint::typeguard(ty)); + } + + fn to_public(self, db: &'db dyn Db) -> NarrowingConstraints<'db> { + self.into_iter() + .map(|(place, constraint)| (place, constraint.to_type(db))) + .collect() + } +} + +/// Merge constraints with AND semantics (intersection/conjunction). +/// +/// When we have `constraint1 AND constraint2`, we need to distribute AND over the OR +/// in the DNF representations: +/// `(A | B) AND (C | D)` becomes `(A & C) | (A & D) | (B & C) | (B & D)` +/// +/// For each conjunction pair, we: +/// - Intersect the regular constraints +/// - If either has a TypeGuard, the result gets a TypeGuard (TypeGuard "poisons" the AND) fn merge_constraints_and<'db>( - into: &mut NarrowingConstraints<'db>, - from: &NarrowingConstraints<'db>, + into: &mut InternalConstraints<'db>, + from: &InternalConstraints<'db>, db: &'db dyn Db, ) { - for (key, value) in from { + for (key, from_constraint) in from { match into.entry(*key) { Entry::Occupied(mut entry) => { - *entry.get_mut() = IntersectionBuilder::new(db) - .add_positive(*entry.get()) - .add_positive(*value) - .build(); + let into_constraint = entry.get().clone(); + + // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) + // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... + let mut new_disjuncts = Vec::new(); + + for left_conj in &into_constraint.disjuncts { + for right_conj in &from_constraint.disjuncts { + // Intersect the regular constraints + let new_regular = IntersectionBuilder::new(db) + .add_positive(left_conj.constraint) + .add_positive(right_conj.constraint) + .build(); + + // Union the TypeGuard constraints if both have them, + // or take the one that exists + let new_typeguard = match (left_conj.typeguard, right_conj.typeguard) { + (Some(left_tg), Some(right_tg)) => { + Some(UnionBuilder::new(db).add(left_tg).add(right_tg).build()) + } + (Some(tg), None) | (None, Some(tg)) => Some(tg), + (None, None) => None, + }; + + new_disjuncts.push(Conjunction { + constraint: new_regular, + typeguard: new_typeguard, + }); + } + } + + *entry.get_mut() = NarrowingConstraint { + disjuncts: new_disjuncts, + }; } Entry::Vacant(entry) => { - entry.insert(*value); + entry.insert(from_constraint.clone()); } } } } +/// Merge constraints with OR semantics (union/disjunction). +/// +/// When we have `constraint1 OR constraint2`, we simply concatenate the disjuncts +/// from both constraints: `(A | B) OR (C | D)` becomes `A | B | C | D` +/// +/// However, if a place appears in only one branch of the OR, we need to widen it +/// to `object` in the overall result (because the other branch doesn't constrain it). fn merge_constraints_or<'db>( - into: &mut NarrowingConstraints<'db>, - from: &NarrowingConstraints<'db>, - db: &'db dyn Db, + into: &mut InternalConstraints<'db>, + from: &InternalConstraints<'db>, + _db: &'db dyn Db, ) { - for (key, value) in from { + for (key, from_constraint) in from { match into.entry(*key) { Entry::Occupied(mut entry) => { - *entry.get_mut() = UnionBuilder::new(db).add(*entry.get()).add(*value).build(); + let into_constraint = entry.get().clone(); + + // Simply concatenate the disjuncts + let mut new_disjuncts = into_constraint.disjuncts; + new_disjuncts.extend(from_constraint.disjuncts.clone()); + + *entry.get_mut() = NarrowingConstraint { + disjuncts: new_disjuncts, + }; } Entry::Vacant(entry) => { - entry.insert(Type::object()); + // Place only appears in `from`, not in `into`. + // Widen to object since the other branch doesn't constrain it. + entry.insert(NarrowingConstraint::regular(Type::object())); } } } - for (key, value) in into.iter_mut() { - if !from.contains_key(key) { - *value = Type::object(); + + // For places that appear in `into` but not in `from`, widen to object + for (_key, value) in into.iter_mut() { + if !from.contains_key(_key) { + *value = NarrowingConstraint::regular(Type::object()); } } } @@ -381,7 +567,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } fn finish(mut self) -> Option> { - let constraints: Option> = match self.predicate { + let constraints: Option> = match self.predicate { PredicateNode::Expression(expression) => { self.evaluate_expression_predicate(expression, self.is_positive) } @@ -393,7 +579,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { }; if let Some(mut constraints) = constraints { constraints.shrink_to_fit(); - Some(constraints) + Some(constraints.to_public(self.db)) } else { None } @@ -403,7 +589,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { let expression_node = expression.node_ref(self.db, self.module); self.evaluate_expression_node_predicate(expression_node, expression, is_positive) } @@ -413,7 +599,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expression_node: &ruff_python_ast::Expr, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { match expression_node { ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => { self.evaluate_simple_expr(expression_node, is_positive) @@ -438,7 +624,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { pattern_predicate_kind: &PatternPredicateKind<'db>, subject: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { match pattern_predicate_kind { PatternPredicateKind::Singleton(singleton) => { self.evaluate_match_pattern_singleton(subject, *singleton, is_positive) @@ -463,7 +649,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, pattern: PatternPredicate<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { self.evaluate_pattern_predicate_kind( pattern.kind(self.db), pattern.subject(self.db), @@ -573,7 +759,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, expr: &ast::Expr, is_positive: bool, - ) -> Option> { + ) -> Option> { let target = place_expr(expr)?; let place = self.expect_place(&target); @@ -583,14 +769,17 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { Type::AlwaysTruthy.negate(self.db) }; - Some(NarrowingConstraints::from_iter([(place, ty)])) + Some(InternalConstraints::from_iter([( + place, + NarrowingConstraint::regular(ty), + )])) } fn evaluate_expr_named( &mut self, expr_named: &ast::ExprNamed, is_positive: bool, - ) -> Option> { + ) -> Option> { self.evaluate_simple_expr(&expr_named.target, is_positive) } @@ -834,7 +1023,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expr_compare: &ast::ExprCompare, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool { matches!( expr, @@ -876,7 +1065,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let comparator_tuples = std::iter::once(&**left) .chain(comparators) .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); - let mut constraints = NarrowingConstraints::default(); + let mut constraints = InternalConstraints::default(); let mut last_rhs_ty: Option = None; @@ -895,7 +1084,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive) { let place = self.expect_place(&left); - constraints.insert(place, ty); + constraints.insert_regular(place, ty); } } ast::Expr::Call(ast::ExprCall { @@ -941,7 +1130,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { .is_some_and(|c| c.is_known(self.db, KnownClass::Type)) { let place = self.expect_place(&target); - constraints.insert( + constraints.insert_regular( place, Type::instance(self.db, rhs_class.unknown_specialization(self.db)) .negate_if(self.db, !is_positive), @@ -959,7 +1148,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expr_call: &ast::ExprCall, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); let callable_ty = inference.expression_type(&*expr_call.func); @@ -973,19 +1162,34 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { { let return_ty = inference.expression_type(expr_call); - let (guarded_ty, place) = match return_ty { - // TODO: TypeGuard + let (guarded_ty, place, is_typeguard) = match return_ty { Type::TypeIs(type_is) => { let (_, place) = type_is.place_info(self.db)?; - (type_is.return_type(self.db), place) + (type_is.return_type(self.db), place, false) + } + Type::TypeGuard(type_guard) => { + let (_, place) = type_guard.place_info(self.db)?; + (type_guard.return_type(self.db), place, true) } _ => return None, }; - Some(NarrowingConstraints::from_iter([( - place, - guarded_ty.negate_if(self.db, !is_positive), - )])) + // Apply negation if needed + let narrowed_ty = guarded_ty.negate_if(self.db, !is_positive); + + // For TypeGuard in the positive case, use typeguard constraint + // For TypeGuard in the negative case OR TypeIs in any case, use regular constraint + // Note: TypeGuard only narrows in the positive case + let constraint = if is_typeguard && is_positive { + NarrowingConstraint::typeguard(narrowed_ty) + } else if is_typeguard && !is_positive { + // TypeGuard doesn't narrow in the negative case + return None; + } else { + NarrowingConstraint::regular(narrowed_ty) + }; + + Some(InternalConstraints::from_iter([(place, constraint)])) } // For the expression `len(E)`, we narrow the type based on whether len(E) is truthy // (i.e., whether E is non-empty). We only narrow the parts of the type where we know @@ -1031,9 +1235,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let constraint = Type::protocol_with_readonly_members(self.db, [(attr, Type::object())]); - return Some(NarrowingConstraints::from_iter([( + return Some(InternalConstraints::from_iter([( place, - constraint.negate_if(self.db, !is_positive), + NarrowingConstraint::regular(constraint.negate_if(self.db, !is_positive)), )])); } @@ -1044,9 +1248,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { function .generate_constraint(self.db, class_info_ty) .map(|constraint| { - NarrowingConstraints::from_iter([( + InternalConstraints::from_iter([( place, - constraint.negate_if(self.db, !is_positive), + NarrowingConstraint::regular( + constraint.negate_if(self.db, !is_positive), + ), )]) }) } @@ -1071,7 +1277,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, singleton: ast::Singleton, is_positive: bool, - ) -> Option> { + ) -> Option> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); @@ -1081,7 +1287,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ast::Singleton::False => Type::BooleanLiteral(false), }; let ty = ty.negate_if(self.db, !is_positive); - Some(NarrowingConstraints::from_iter([(place, ty)])) + Some(InternalConstraints::from_iter([( + place, + NarrowingConstraint::regular(ty), + )])) } fn evaluate_match_pattern_class( @@ -1090,7 +1299,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { cls: Expression<'db>, kind: ClassPatternKind, is_positive: bool, - ) -> Option> { + ) -> Option> { if !kind.is_irrefutable() && !is_positive { // A class pattern like `case Point(x=0, y=0)` is not irrefutable. In the positive case, // we can still narrow the type of the match subject to `Point`. But in the negative case, @@ -1114,7 +1323,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { _ => return None, }; - Some(NarrowingConstraints::from_iter([(place, narrowed_type)])) + Some(InternalConstraints::from_iter([(place, NarrowingConstraint::regular(narrowed_type))])) } fn evaluate_match_pattern_value( @@ -1122,7 +1331,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, value: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { let place = { let subject = place_expr(subject.node_ref(self.db, self.module))?; self.expect_place(&subject) @@ -1134,7 +1343,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive) - .map(|ty| NarrowingConstraints::from_iter([(place, ty)])) + .map(|ty| InternalConstraints::from_iter([(place, NarrowingConstraint::regular(ty))])) } fn evaluate_match_pattern_or( @@ -1142,7 +1351,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, predicates: &Vec>, is_positive: bool, - ) -> Option> { + ) -> Option> { let db = self.db; // DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints. @@ -1168,7 +1377,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expr_bool_op: &ExprBoolOp, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); let mut sub_constraints = expr_bool_op .values @@ -1187,7 +1396,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { .collect::>(); match (expr_bool_op.op, is_positive) { (BoolOp::And, true) | (BoolOp::Or, false) => { - let mut aggregation: Option = None; + let mut aggregation: Option = None; for sub_constraint in sub_constraints.into_iter().flatten() { if let Some(ref mut some_aggregation) = aggregation { merge_constraints_and(some_aggregation, &sub_constraint, self.db); From 2562b47d848a55a81fea447840804f8ec7a5017c Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sun, 19 Oct 2025 15:22:45 -0400 Subject: [PATCH 03/22] wip(typeguard): implement narrowing --- .../resources/mdtest/narrow/type_guards.md | 78 ++++++++++++------- .../src/types/infer/builder.rs | 7 +- .../types/infer/builder/type_expression.rs | 29 +++++-- crates/ty_python_semantic/src/types/narrow.rs | 10 ++- 4 files changed, 84 insertions(+), 40 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index 0d94326c89..76e3de60a8 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -17,16 +17,15 @@ def _( e: TypeGuard, # error: [invalid-type-form] f: TypeIs, # error: [invalid-type-form] ): - # TODO: Should be `TypeGuard[str]` - reveal_type(a) # revealed: @Todo(`TypeGuard[]` special form) + reveal_type(a) # revealed: TypeGuard[str] reveal_type(b) # revealed: TypeIs[str | int] - # TODO: Should be `TypeGuard[complex & ~int & ~float]` - reveal_type(c) # revealed: @Todo(`TypeGuard[]` special form) + # TODO: Should be `TypeGuard[complex & ~int & ~float]` - intersection not preserved in type expression + reveal_type(c) # revealed: TypeGuard[complex] reveal_type(d) # revealed: TypeIs[tuple[]] reveal_type(e) # revealed: Unknown reveal_type(f) # revealed: Unknown -# TODO: error: [invalid-return-type] "Function always implicitly returns `None`, which is not assignable to return type `TypeGuard[str]`" +# error: [invalid-return-type] "Function always implicitly returns `None`, which is not assignable to return type `TypeGuard[str]`" def _(a) -> TypeGuard[str]: ... # error: [invalid-return-type] "Function always implicitly returns `None`, which is not assignable to return type `TypeIs[str]`" @@ -38,8 +37,7 @@ def g(a) -> TypeIs[str]: return True def _(a: object): - # TODO: Should be `TypeGuard[str @ a]` - reveal_type(f(a)) # revealed: @Todo(`TypeGuard[]` special form) + reveal_type(f(a)) # revealed: TypeGuard[str @ a] reveal_type(g(a)) # revealed: TypeIs[str @ a] ``` @@ -105,15 +103,14 @@ from typing_extensions import TypeGuard, TypeIs a = 123 -# TODO: error: [invalid-type-form] +# error: [invalid-type-form] "Special form `typing.TypeGuard` expected exactly one type parameter" def f(_) -> TypeGuard[int, str]: ... # error: [invalid-type-form] "Special form `typing.TypeIs` expected exactly one type parameter" # error: [invalid-type-form] "Variable of type `Literal[123]` is not allowed in a type expression" def g(_) -> TypeIs[a, str]: ... -# TODO: Should be `Unknown` -reveal_type(f(0)) # revealed: @Todo(`TypeGuard[]` special form) +reveal_type(f(0)) # revealed: Unknown reveal_type(g(0)) # revealed: Unknown ``` @@ -126,9 +123,10 @@ from typing_extensions import Literal, TypeGuard, TypeIs, assert_never def _(a: object, flag: bool) -> TypeGuard[str]: if flag: + # error: [invalid-return-type] "Return type does not match returned value: expected `TypeGuard[str]`, found `Literal[0]`" return 0 - # TODO: error: [invalid-return-type] "Return type does not match returned value: expected `TypeIs[str]`, found `Literal["foo"]`" + # error: [invalid-return-type] "Return type does not match returned value: expected `TypeGuard[str]`, found `Literal["foo"]`" return "foo" # error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `TypeIs[str]`" @@ -193,8 +191,7 @@ def is_bar(a: object) -> TypeIs[Bar]: def _(a: Foo | Bar): if guard_foo(a): - # TODO: Should be `Foo` - reveal_type(a) # revealed: Foo | Bar + reveal_type(a) # revealed: Foo else: reveal_type(a) # revealed: Foo | Bar @@ -215,23 +212,19 @@ class C(Generic[T]): v: T def _(a: tuple[Foo, Bar] | tuple[Bar, Foo], c: C[Any]): - # TODO: Should be `TypeGuard[Foo @ a[1]]` - if reveal_type(guard_foo(a[1])): # revealed: @Todo(`TypeGuard[]` special form) - # TODO: Should be `tuple[Bar, Foo]` + if reveal_type(guard_foo(a[1])): # revealed: TypeGuard[Foo @ a[1]] + # TODO: Should be `tuple[Bar, Foo]` - requires narrowing tuple by subscript reveal_type(a) # revealed: tuple[Foo, Bar] | tuple[Bar, Foo] - # TODO: Should be `Foo` - reveal_type(a[1]) # revealed: Bar | Foo + reveal_type(a[1]) # revealed: Foo if reveal_type(is_bar(a[0])): # revealed: TypeIs[Bar @ a[0]] # TODO: Should be `tuple[Bar, Bar & Foo]` reveal_type(a) # revealed: tuple[Foo, Bar] | tuple[Bar, Foo] reveal_type(a[0]) # revealed: Bar - # TODO: Should be `TypeGuard[Foo @ c.v]` - if reveal_type(guard_foo(c.v)): # revealed: @Todo(`TypeGuard[]` special form) + if reveal_type(guard_foo(c.v)): # revealed: TypeGuard[Foo @ c.v] reveal_type(c) # revealed: C[Any] - # TODO: Should be `Foo` - reveal_type(c.v) # revealed: Any + reveal_type(c.v) # revealed: Any & Foo if reveal_type(is_bar(c.v)): # revealed: TypeIs[Bar @ c.v] reveal_type(c) # revealed: C[Any] @@ -246,8 +239,7 @@ def _(a: Foo | Bar): c = is_bar(a) reveal_type(a) # revealed: Foo | Bar - # TODO: Should be `TypeGuard[Foo @ a]` - reveal_type(b) # revealed: @Todo(`TypeGuard[]` special form) + reveal_type(b) # revealed: TypeGuard[Foo @ a] reveal_type(c) # revealed: TypeIs[Bar @ a] if b: @@ -350,20 +342,46 @@ def is_bar(a: object) -> TypeIs[Bar]: def does_not_narrow_in_negative_case(a: Foo | Bar): if not guard_foo(a): - # TODO: Should be `Bar` reveal_type(a) # revealed: Foo | Bar else: - reveal_type(a) # revealed: Foo | Bar + reveal_type(a) # revealed: Foo def narrowed_type_must_be_exact(a: object, b: Baz): if guard_foo(b): - # TODO: Should be `Foo` - reveal_type(b) # revealed: Baz + reveal_type(b) # revealed: Baz & Foo if isinstance(a, Baz) and is_bar(a): reveal_type(a) # revealed: Baz if isinstance(a, Bar) and guard_foo(a): - # TODO: Should be `Foo` - reveal_type(a) # revealed: Bar + reveal_type(a) # revealed: Foo +``` + +## Complex boolean logic with TypeGuard and TypeIs + +TypeGuard constraints need to properly distribute through boolean operations. + +```py +from typing_extensions import TypeGuard, TypeIs + +class A: ... +class B: ... +class C: ... + +def f(x: object) -> TypeIs[A]: + return True + +def g(x: object) -> TypeGuard[B]: + return True + +def h(x: object) -> TypeIs[C]: + return True + +def _(x: object): + # g(x) or h(x) should give B | C + # Then f(x) and (...) should distribute: (f(x) and g(x)) or (f(x) and h(x)) + # Which is (Regular(A) & TypeGuard(B)) | (Regular(A) & Regular(C)) + # TypeGuard clobbers in the first branch, giving: B | (A & C) + if f(x) and (g(x) or h(x)): + reveal_type(x) # revealed: B | (A & C) ``` diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 7a7f25eb31..ccd0ad8192 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -8821,10 +8821,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; match return_ty { - // TODO: TypeGuard Type::TypeIs(type_is) => match find_narrowed_place() { Some(place) => type_is.bind(db, scope, place), - None => return_ty, // TODO(ericmarkmartin): ? + None => return_ty, + }, + Type::TypeGuard(type_guard) => match find_narrowed_place() { + Some(place) => type_guard.bind(db, scope, place), + None => return_ty, }, _ => return_ty, } diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index d4c5701f1d..8550df420d 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -16,8 +16,8 @@ use crate::types::tuple::{TupleSpecBuilder, TupleType}; use crate::types::{ BindingContext, CallableType, DynamicType, GenericContext, IntersectionBuilder, KnownClass, KnownInstanceType, LintDiagnosticGuard, Parameter, Parameters, SpecialFormType, SubclassOfType, - Type, TypeAliasType, TypeContext, TypeIsType, TypeMapping, TypeVarKind, UnionBuilder, - UnionType, any_over_type, todo_type, + Type, TypeAliasType, TypeContext, TypeGuardType, TypeIsType, TypeMapping, TypeVarKind, + UnionBuilder, UnionType, any_over_type, todo_type, }; /// Type expressions @@ -1521,10 +1521,27 @@ impl<'db> TypeInferenceBuilder<'db, '_> { .top_materialization(self.db()), ), }, - SpecialFormType::TypeGuard => { - self.infer_type_expression(arguments_slice); - todo_type!("`TypeGuard[]` special form") - } + // TODO: deduplicate + SpecialFormType::TypeGuard => match arguments_slice { + ast::Expr::Tuple(_) => { + self.infer_type_expression(arguments_slice); + + if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) { + let diag = builder.into_diagnostic(format_args!( + "Special form `typing.TypeGuard` expected exactly one type parameter", + )); + diagnostic::add_type_expression_reference_link(diag); + } + + Type::unknown() + } + _ => TypeGuardType::unbound( + self.db(), + // Similar to TypeIs, use top materialization + self.infer_type_expression(arguments_slice) + .top_materialization(self.db()), + ), + }, SpecialFormType::Concatenate => { let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice { &*tuple.elts diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 328ed0e07d..4beab1e583 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -1207,7 +1207,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if let Some(narrowed_ty) = Self::narrow_type_by_len(self.db, arg_ty, is_positive) { let target = place_expr(arg)?; let place = self.expect_place(&target); - Some(NarrowingConstraints::from_iter([(place, narrowed_ty)])) + Some(InternalConstraints::from_iter([( + place, + NarrowingConstraint::regular(narrowed_ty), + )])) } else { None } @@ -1323,7 +1326,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { _ => return None, }; - Some(InternalConstraints::from_iter([(place, NarrowingConstraint::regular(narrowed_type))])) + Some(InternalConstraints::from_iter([( + place, + NarrowingConstraint::regular(narrowed_type), + )])) } fn evaluate_match_pattern_value( From ca5e6070be33887174ce6b46350999d1d03385c8 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sun, 19 Oct 2025 16:56:55 -0400 Subject: [PATCH 04/22] feat: working `TypeGuard` --- .../annotations/unsupported_special_forms.md | 4 +- .../type_properties/is_assignable_to.md | 3 +- .../type_properties/is_disjoint_from.md | 3 +- crates/ty_python_semantic/src/types/narrow.rs | 262 +++++++++--------- 4 files changed, 143 insertions(+), 129 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md b/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md index 18ebd03682..6c0947c5e0 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md @@ -16,7 +16,9 @@ def f(*args: Unpack[Ts]) -> tuple[Unpack[Ts]]: reveal_type(args) # revealed: tuple[@Todo(`Unpack[]` special form), ...] return args -def g() -> TypeGuard[int]: ... +def g() -> TypeGuard[int]: + return True + def i(callback: Callable[Concatenate[int, P], R_co], *args: P.args, **kwargs: P.kwargs) -> R_co: reveal_type(args) # revealed: P@i.args reveal_type(kwargs) # revealed: P@i.kwargs diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index f2e38485c5..cfb7b5d6e0 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -1383,8 +1383,7 @@ from typing_extensions import Any, TypeGuard, TypeIs static_assert(is_assignable_to(TypeGuard[Unknown], bool)) static_assert(is_assignable_to(TypeIs[Any], bool)) -# TODO no error -static_assert(not is_assignable_to(TypeGuard[Unknown], str)) # error: [static-assert-error] +static_assert(not is_assignable_to(TypeGuard[Unknown], str)) static_assert(not is_assignable_to(TypeIs[Any], str)) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md index d4aa7db231..0b2d95842f 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md @@ -578,8 +578,7 @@ from typing_extensions import TypeGuard, TypeIs static_assert(not is_disjoint_from(bool, TypeGuard[str])) static_assert(not is_disjoint_from(bool, TypeIs[str])) -# TODO no error -static_assert(is_disjoint_from(str, TypeGuard[str])) # error: [static-assert-error] +static_assert(is_disjoint_from(str, TypeGuard[str])) static_assert(is_disjoint_from(str, TypeIs[str])) ``` diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 4beab1e583..b4eb3789ff 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -23,6 +23,7 @@ use itertools::Itertools; use ruff_python_ast as ast; use ruff_python_ast::{BoolOp, ExprBoolOp}; use rustc_hash::FxHashMap; +use smallvec::{SmallVec, smallvec}; use std::collections::hash_map::Entry; use super::UnionType; @@ -274,19 +275,23 @@ impl ClassInfoConstraintFunction { } } -/// Represents a single conjunction (AND) of constraints in Disjunctive Normal Form (DNF). +/// Represents a single conjunction (AND) of constraints in Disjunctive Normal +/// Form (DNF). /// -/// A conjunction may contain: -/// - A regular constraint (intersection of types) -/// - An optional TypeGuard constraint that "replaces" the type rather than intersecting +/// A conjunction may contain: - A regular constraint (intersection of types) - +/// An optional `TypeGuard` constraint that "replaces" the type rather than +/// intersecting /// -/// For example, `(Regular(A) & TypeGuard(B))` evaluates to just `B` because TypeGuard clobbers. -#[derive(Hash, PartialEq, Debug, Eq, Clone)] +/// For example, `(Conjunction { constraint: A, typeguard: Some(B) } & +/// Conjunction { constraint: C, typeguard: Some(D)})` evlaluates to +/// `Conjunction { constraint: C, typeguard: Some(D) }` because the type guard +/// in the second clobbers the first. +#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)] struct Conjunction<'db> { /// The intersected constraints (represented as an intersection type) constraint: Type<'db>, - /// If any constraint in this conjunction is a TypeGuard, this is Some - /// and contains the union of all TypeGuard types in this conjunction + /// If any constraint in this conjunction is a `TypeGuard`, this is Some and + /// contains the union of all `TypeGuard` types in this conjunction typeguard: Option>, } @@ -301,7 +306,7 @@ impl<'db> Conjunction<'db> { } } - /// Create a new conjunction with a TypeGuard constraint + /// Create a new conjunction with a `TypeGuard` constraint fn typeguard(constraint: Type<'db>) -> Self { Self { constraint: Type::object(), @@ -310,9 +315,9 @@ impl<'db> Conjunction<'db> { } /// Evaluate this conjunction to a single type. - /// If there's a TypeGuard constraint, it replaces the regular constraint. + /// If there's a `TypeGuard` constraint, it replaces the regular constraint. /// Otherwise, returns the regular constraint. - fn to_type(self) -> Type<'db> { + fn evaluate_type_constraint(self) -> Type<'db> { self.typeguard.unwrap_or(self.constraint) } } @@ -320,52 +325,50 @@ impl<'db> Conjunction<'db> { /// Represents narrowing constraints in Disjunctive Normal Form (DNF). /// /// This is a disjunction (OR) of conjunctions (AND) of constraints. -/// The DNF representation allows us to properly track TypeGuard constraints +/// The DNF representation allows us to properly track `TypeGuard` constraints /// through boolean operations. /// /// For example: -/// - `f(x) and g(x)` where f returns TypeIs[A] and g returns TypeGuard[B] +/// - `f(x) and g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]` /// => `[Conjunction { constraint: A, typeguard: Some(B) }]` -/// => evaluates to `B` (TypeGuard clobbers) +/// => evaluates to `B` (`TypeGuard` clobbers) /// -/// - `f(x) or g(x)` where f returns TypeIs[A] and g returns TypeGuard[B] +/// - `f(x) or g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]` /// => `[Conjunction { constraint: A, typeguard: None }, Conjunction { constraint: object, typeguard: Some(B) }]` /// => evaluates to `A | B` #[derive(Hash, PartialEq, Debug, Eq, Clone)] struct NarrowingConstraint<'db> { /// Disjunctions of conjunctions (DNF) - disjuncts: Vec>, + disjuncts: SmallVec<[Conjunction<'db>; 4]>, } impl get_size2::GetSize for NarrowingConstraint<'_> {} impl<'db> NarrowingConstraint<'db> { - /// Create a constraint from a regular (non-TypeGuard) type + /// Create a constraint from a regular (non-`TypeGuard`) type fn regular(constraint: Type<'db>) -> Self { Self { - disjuncts: vec![Conjunction::regular(constraint)], + disjuncts: smallvec![Conjunction::regular(constraint)], } } - /// Create a constraint from a TypeGuard type + /// Create a constraint from a `TypeGuard` type fn typeguard(constraint: Type<'db>) -> Self { Self { - disjuncts: vec![Conjunction::typeguard(constraint)], + disjuncts: smallvec![Conjunction::typeguard(constraint)], } } - /// Evaluate this constraint to a single type by evaluating each disjunct - /// and taking their union - fn to_type(self, db: &'db dyn Db) -> Type<'db> { - if self.disjuncts.is_empty() { - return Type::Never; - } - - if self.disjuncts.len() == 1 { - return self.disjuncts.into_iter().next().unwrap().to_type(); - } - - UnionType::from_elements(db, self.disjuncts.into_iter().map(|c| c.to_type())) + /// Evaluate the type this effectively constrains to + /// + /// Forgets whether each constraint originated from a `TypeGuard` or not + fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> { + UnionType::from_elements( + db, + self.disjuncts + .into_iter() + .map(|c| c.evaluate_type_constraint()), + ) } } @@ -375,80 +378,84 @@ impl<'db> From> for NarrowingConstraint<'db> { } } -/// Internal representation of constraints with DNF structure for tracking TypeGuard semantics -type InternalConstraints<'db> = FxHashMap>; - -/// Public representation of constraints as returned by tracked functions -type NarrowingConstraints<'db> = FxHashMap>; - -/// Helper trait to make inserting constraints more ergonomic -trait InternalConstraintsExt<'db> { - fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>); - fn insert_typeguard(&mut self, place: ScopedPlaceId, ty: Type<'db>); - fn to_public(self, db: &'db dyn Db) -> NarrowingConstraints<'db>; +/// Internal representation of constraints with DNF structure for tracking `TypeGuard` semantics. +/// +/// This is a newtype wrapper around `FxHashMap>` that +/// provides methods for working with constraints during boolean operation evaluation. +#[derive(Clone, Debug, Default)] +struct InternalConstraints<'db> { + constraints: FxHashMap>, } -impl<'db> InternalConstraintsExt<'db> for InternalConstraints<'db> { +impl<'db> InternalConstraints<'db> { + /// Insert a regular (non-`TypeGuard`) constraint for a place fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>) { - self.insert(place, NarrowingConstraint::regular(ty)); + self.constraints + .insert(place, NarrowingConstraint::regular(ty)); } - fn insert_typeguard(&mut self, place: ScopedPlaceId, ty: Type<'db>) { - self.insert(place, NarrowingConstraint::typeguard(ty)); - } - - fn to_public(self, db: &'db dyn Db) -> NarrowingConstraints<'db> { - self.into_iter() - .map(|(place, constraint)| (place, constraint.to_type(db))) + /// Convert internal constraints to public constraints by evaluating each DNF constraint to a Type + fn evaluate_type_constraints(self, db: &'db dyn Db) -> NarrowingConstraints<'db> { + self.constraints + .into_iter() + .map(|(place, constraint)| (place, constraint.evaluate_type_constraint(db))) .collect() } } +impl<'db> FromIterator<(ScopedPlaceId, NarrowingConstraint<'db>)> for InternalConstraints<'db> { + fn from_iter)>>( + iter: T, + ) -> Self { + Self { + constraints: FxHashMap::from_iter(iter), + } + } +} + +/// Public representation of constraints as returned by tracked functions +type NarrowingConstraints<'db> = FxHashMap>; + /// Merge constraints with AND semantics (intersection/conjunction). /// -/// When we have `constraint1 AND constraint2`, we need to distribute AND over the OR +/// When we have `constraint1 & constraint2`, we need to distribute AND over the OR /// in the DNF representations: -/// `(A | B) AND (C | D)` becomes `(A & C) | (A & D) | (B & C) | (B & D)` +/// `(A | B) & (C | D)` becomes `(A & C) | (A & D) | (B & C) | (B & D)` /// /// For each conjunction pair, we: -/// - Intersect the regular constraints -/// - If either has a TypeGuard, the result gets a TypeGuard (TypeGuard "poisons" the AND) +/// - Take the right conjunct if it has a `TypeGuard` +/// - Intersect the constraints normally otherwise fn merge_constraints_and<'db>( into: &mut InternalConstraints<'db>, from: &InternalConstraints<'db>, db: &'db dyn Db, ) { - for (key, from_constraint) in from { - match into.entry(*key) { + for (key, from_constraint) in &from.constraints { + match into.constraints.entry(*key) { Entry::Occupied(mut entry) => { - let into_constraint = entry.get().clone(); + let into_constraint = entry.get(); // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... - let mut new_disjuncts = Vec::new(); + let mut new_disjuncts = SmallVec::new(); for left_conj in &into_constraint.disjuncts { for right_conj in &from_constraint.disjuncts { - // Intersect the regular constraints - let new_regular = IntersectionBuilder::new(db) - .add_positive(left_conj.constraint) - .add_positive(right_conj.constraint) - .build(); + if right_conj.typeguard.is_some() { + // If the right conjunct has a TypeGuard, it "wins" the conjunction + new_disjuncts.push(*right_conj); + } else { + // Intersect the regular constraints + let new_regular = IntersectionBuilder::new(db) + .add_positive(left_conj.constraint) + .add_positive(right_conj.constraint) + .build(); - // Union the TypeGuard constraints if both have them, - // or take the one that exists - let new_typeguard = match (left_conj.typeguard, right_conj.typeguard) { - (Some(left_tg), Some(right_tg)) => { - Some(UnionBuilder::new(db).add(left_tg).add(right_tg).build()) - } - (Some(tg), None) | (None, Some(tg)) => Some(tg), - (None, None) => None, - }; - - new_disjuncts.push(Conjunction { - constraint: new_regular, - typeguard: new_typeguard, - }); + new_disjuncts.push(Conjunction { + constraint: new_regular, + typeguard: left_conj.typeguard, + }); + } } } @@ -475,18 +482,21 @@ fn merge_constraints_or<'db>( from: &InternalConstraints<'db>, _db: &'db dyn Db, ) { - for (key, from_constraint) in from { - match into.entry(*key) { + // For places that appear in `into` but not in `from`, widen to object + for (_key, value) in into.constraints.iter_mut() { + if !from.constraints.contains_key(_key) { + *value = NarrowingConstraint::regular(Type::object()); + } + } + + for (key, from_constraint) in &from.constraints { + match into.constraints.entry(*key) { Entry::Occupied(mut entry) => { - let into_constraint = entry.get().clone(); - // Simply concatenate the disjuncts - let mut new_disjuncts = into_constraint.disjuncts; - new_disjuncts.extend(from_constraint.disjuncts.clone()); - - *entry.get_mut() = NarrowingConstraint { - disjuncts: new_disjuncts, - }; + entry + .get_mut() + .disjuncts + .extend(from_constraint.disjuncts.clone()); } Entry::Vacant(entry) => { // Place only appears in `from`, not in `into`. @@ -495,13 +505,6 @@ fn merge_constraints_or<'db>( } } } - - // For places that appear in `into` but not in `from`, widen to object - for (_key, value) in into.iter_mut() { - if !from.contains_key(_key) { - *value = NarrowingConstraint::regular(Type::object()); - } - } } fn place_expr(expr: &ast::Expr) -> Option { @@ -578,8 +581,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(mut constraints) = constraints { - constraints.shrink_to_fit(); - Some(constraints.to_public(self.db)) + constraints.constraints.shrink_to_fit(); + Some(constraints.evaluate_type_constraints(self.db)) } else { None } @@ -1162,34 +1165,30 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { { let return_ty = inference.expression_type(expr_call); - let (guarded_ty, place, is_typeguard) = match return_ty { + let place_and_constraint = match return_ty { Type::TypeIs(type_is) => { let (_, place) = type_is.place_info(self.db)?; - (type_is.return_type(self.db), place, false) + Some(( + place, + NarrowingConstraint::regular( + type_is + .return_type(self.db) + .negate_if(self.db, !is_positive), + ), + )) } - Type::TypeGuard(type_guard) => { + // TypeGuard only narrows in the positive case + Type::TypeGuard(type_guard) if is_positive => { let (_, place) = type_guard.place_info(self.db)?; - (type_guard.return_type(self.db), place, true) + Some(( + place, + NarrowingConstraint::typeguard(type_guard.return_type(self.db)), + )) } - _ => return None, - }; + _ => None, + }?; - // Apply negation if needed - let narrowed_ty = guarded_ty.negate_if(self.db, !is_positive); - - // For TypeGuard in the positive case, use typeguard constraint - // For TypeGuard in the negative case OR TypeIs in any case, use regular constraint - // Note: TypeGuard only narrows in the positive case - let constraint = if is_typeguard && is_positive { - NarrowingConstraint::typeguard(narrowed_ty) - } else if is_typeguard && !is_positive { - // TypeGuard doesn't narrow in the negative case - return None; - } else { - NarrowingConstraint::regular(narrowed_ty) - }; - - Some(InternalConstraints::from_iter([(place, constraint)])) + Some(InternalConstraints::from_iter([place_and_constraint])) } // For the expression `len(E)`, we narrow the type based on whether len(E) is truthy // (i.e., whether E is non-empty). We only narrow the parts of the type where we know @@ -1385,7 +1384,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { is_positive: bool, ) -> Option> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); - let mut sub_constraints = expr_bool_op + let sub_constraints = expr_bool_op .values .iter() // filter our arms with statically known truthiness @@ -1413,17 +1412,32 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { aggregation } (BoolOp::Or, true) | (BoolOp::And, false) => { - let (first, rest) = sub_constraints.split_first_mut()?; - if let Some(first) = first { + let (mut first, rest) = { + let mut it = sub_constraints.into_iter(); + (it.next()?, it) + }; + + if let Some(ref mut first) = first { for rest_constraint in rest { if let Some(rest_constraint) = rest_constraint { - merge_constraints_or(first, rest_constraint, self.db); + merge_constraints_or(first, &rest_constraint, self.db); } else { return None; } } } first.clone() + // let (first, rest) = sub_constraints.split_first_mut()?; + // if let Some(first) = first { + // for rest_constraint in rest { + // if let Some(rest_constraint) = rest_constraint { + // merge_constraints_or(first, rest_constraint, self.db); + // } else { + // return None; + // } + // } + // } + // first.clone() } } } From b610969a9fbac3a4788b108dca23283207bfdfb5 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sun, 19 Oct 2025 18:31:29 -0400 Subject: [PATCH 05/22] cleanup(type_guards.md): add explanation for weird `complex` intersection case --- .../resources/mdtest/narrow/type_guards.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index 76e3de60a8..9c88845f12 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -19,7 +19,9 @@ def _( ): reveal_type(a) # revealed: TypeGuard[str] reveal_type(b) # revealed: TypeIs[str | int] - # TODO: Should be `TypeGuard[complex & ~int & ~float]` - intersection not preserved in type expression + # not `TypeGuard[complex & ~int & ~float]`: `complex` in argument position + # means `complex & int & float` semantically so `Intersection[complex, + # Not[int], Not[float]]` means `complex` semantically reveal_type(c) # revealed: TypeGuard[complex] reveal_type(d) # revealed: TypeIs[tuple[]] reveal_type(e) # revealed: Unknown @@ -213,7 +215,7 @@ class C(Generic[T]): def _(a: tuple[Foo, Bar] | tuple[Bar, Foo], c: C[Any]): if reveal_type(guard_foo(a[1])): # revealed: TypeGuard[Foo @ a[1]] - # TODO: Should be `tuple[Bar, Foo]` - requires narrowing tuple by subscript + # TODO: Should be `tuple[Bar, Foo]` reveal_type(a) # revealed: tuple[Foo, Bar] | tuple[Bar, Foo] reveal_type(a[1]) # revealed: Foo From 6b5f24530bb83f1c8a7d3e5c450d7d0452bd3cf4 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sun, 19 Oct 2025 18:32:16 -0400 Subject: [PATCH 06/22] test(variance): update/add variance tests for `TypeGuard` --- .../mdtest/generics/pep695/variance.md | 38 +++++++++++++++++++ .../mdtest/type_properties/is_subtype_of.md | 11 +++--- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md index 7dc9392b21..77a41caa00 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md @@ -790,6 +790,44 @@ static_assert(not is_assignable_to(C[B], C[A])) static_assert(not is_assignable_to(C[A], C[B])) ``` +## TypeGuard + +`TypeGuard[T]` is covariant in `T`. The typing spec doesn't explicitly call this out, but it follows +from similar logic to invariance of `TypeIs` except without the negative case. + +Formally, suppose we have types `A` and `B` with `B < A`. Take `x: object` to be the value that all +subsequent `TypeGuard`s are narrowing. + +We can assign `p: TypeGuard[A] = q` where `q: TypeGuard[B]` because + +- if `q` is `False`, then no constraints were learned on `x` before and none are now learned, so + nothing changes +- if `q` is `True`, then we know `x: B`. From `B < A`, we conclude `x: A`. + +We _cannot_ assign `p: TypeGuard[B] = q` where `q: TypeGuard[A]` because if `q` is `True`, we would +be concluding `x: B` from `x: A`, which is an unsafe downcast. + +```py +from typing import TypeGuard +from ty_extensions import is_assignable_to, is_subtype_of, static_assert + +class A: + pass + +class B(A): + pass + +class C[T]: + def check(x: object) -> TypeGuard[T]: + # this is a bad check, but we only care about it type-checking + return False + +static_assert(is_subtype_of(C[B], C[A])) +static_assert(not is_subtype_of(C[A], C[B])) +static_assert(is_assignable_to(C[B], C[A])) +static_assert(not is_assignable_to(C[A], C[B])) +``` + ## Type aliases The variance of the type alias matches the variance of the value type (RHS type). diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md index a2b9ca89d0..9877c14a86 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md @@ -670,9 +670,8 @@ Fully-static `TypeGuard[...]` and `TypeIs[...]` are subtypes of `bool`. from ty_extensions import is_subtype_of, static_assert from typing_extensions import TypeGuard, TypeIs -# TODO: TypeGuard -# static_assert(is_subtype_of(TypeGuard[int], bool)) -# static_assert(is_subtype_of(TypeGuard[int], int)) +static_assert(is_subtype_of(TypeGuard[str], bool)) +static_assert(is_subtype_of(TypeGuard[str], int)) static_assert(is_subtype_of(TypeIs[str], bool)) static_assert(is_subtype_of(TypeIs[str], int)) ``` @@ -683,12 +682,12 @@ static_assert(is_subtype_of(TypeIs[str], int)) from ty_extensions import is_equivalent_to, is_subtype_of, static_assert from typing_extensions import TypeGuard, TypeIs -# TODO: TypeGuard -# static_assert(is_subtype_of(TypeGuard[int], TypeGuard[int])) -# static_assert(is_subtype_of(TypeGuard[bool], TypeGuard[int])) +static_assert(is_subtype_of(TypeGuard[int], TypeGuard[int])) +static_assert(is_subtype_of(TypeGuard[bool], TypeGuard[int])) static_assert(is_subtype_of(TypeIs[int], TypeIs[int])) static_assert(is_subtype_of(TypeIs[int], TypeIs[int])) +static_assert(is_subtype_of(TypeGuard[bool], TypeGuard[int])) static_assert(not is_subtype_of(TypeGuard[int], TypeGuard[bool])) static_assert(not is_subtype_of(TypeIs[bool], TypeIs[int])) static_assert(not is_subtype_of(TypeIs[int], TypeIs[bool])) From 57d61c487e1c13448302b9503f8e8b0d9345af45 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sun, 19 Oct 2025 18:51:56 -0400 Subject: [PATCH 07/22] refactor: remove unneeded clone --- crates/ty_python_semantic/src/types/narrow.rs | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index b4eb3789ff..783a5a5e00 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -1426,18 +1426,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } } - first.clone() - // let (first, rest) = sub_constraints.split_first_mut()?; - // if let Some(first) = first { - // for rest_constraint in rest { - // if let Some(rest_constraint) = rest_constraint { - // merge_constraints_or(first, rest_constraint, self.db); - // } else { - // return None; - // } - // } - // } - // first.clone() + first } } } From e0f37ba713fb6e55394e4dc2d9a98468a870f040 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sun, 19 Oct 2025 19:08:09 -0400 Subject: [PATCH 08/22] fix: clippy --- crates/ty_python_semantic/src/types/narrow.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 783a5a5e00..72e34c83db 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -367,7 +367,7 @@ impl<'db> NarrowingConstraint<'db> { db, self.disjuncts .into_iter() - .map(|c| c.evaluate_type_constraint()), + .map(Conjunction::evaluate_type_constraint), ) } } @@ -483,8 +483,8 @@ fn merge_constraints_or<'db>( _db: &'db dyn Db, ) { // For places that appear in `into` but not in `from`, widen to object - for (_key, value) in into.constraints.iter_mut() { - if !from.constraints.contains_key(_key) { + for (key, value) in &mut into.constraints { + if !from.constraints.contains_key(key) { *value = NarrowingConstraint::regular(Type::object()); } } From f7804ea29da839e747cd40afcf4016bde76da19c Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Wed, 22 Oct 2025 17:09:23 -0400 Subject: [PATCH 09/22] fix typeguard overriding logic --- .../resources/mdtest/narrow/type_guards.md | 28 ++++++++++++++++++- crates/ty_python_semantic/src/types/narrow.rs | 14 ++++++++-- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index 9c88845f12..c2a5f8feb6 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -359,7 +359,33 @@ def narrowed_type_must_be_exact(a: object, b: Baz): reveal_type(a) # revealed: Foo ``` -## Complex boolean logic with TypeGuard and TypeIs +## TypeGuard overrides normal constraints + +TypeGuard constraints override any previous narrowing, but additional "regular" constraints can be +added on to TypeGuard constraints. + +```py +from typing_extensions import TypeGuard, TypeIs + +class A: ... +class B: ... +class C: ... + +def f(x: object) -> TypeGuard[A]: + return True + +def g(x: object) -> TypeGuard[B]: + return True + +def h(x: object) -> TypeIs[C]: + return True + +def _(x: object): + if f(x) and g(x) and h(x): + reveal_type(x) # revealed: B & C +``` + +## Boolean logic with TypeGuard and TypeIs TypeGuard constraints need to properly distribute through boolean operations. diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 72e34c83db..b7cdf10e02 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -317,8 +317,16 @@ impl<'db> Conjunction<'db> { /// Evaluate this conjunction to a single type. /// If there's a `TypeGuard` constraint, it replaces the regular constraint. /// Otherwise, returns the regular constraint. - fn evaluate_type_constraint(self) -> Type<'db> { - self.typeguard.unwrap_or(self.constraint) + fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> { + self.typeguard.map_or_else( + || self.constraint, + |typeguard_constraint| { + IntersectionBuilder::new(db) + .add_positive(typeguard_constraint) + .add_positive(self.constraint) + .build() + }, + ) } } @@ -367,7 +375,7 @@ impl<'db> NarrowingConstraint<'db> { db, self.disjuncts .into_iter() - .map(Conjunction::evaluate_type_constraint), + .map(|disjunct| Conjunction::evaluate_type_constraint(disjunct, db)), ) } } From bf9857e056a1a284011c1b763d5ab5dd2303165c Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sun, 16 Nov 2025 16:12:11 -0500 Subject: [PATCH 10/22] de-duplicate TypeGuard and TypeIs ordering --- .../src/types/type_ordering.rs | 94 +++++++++++++------ 1 file changed, 65 insertions(+), 29 deletions(-) diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index 51a0e7d4fd..67875461b4 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -2,7 +2,11 @@ use std::cmp::Ordering; use salsa::plumbing::AsId; -use crate::{db::Db, types::bound_super::SuperOwnerKind}; +use crate::{ + db::Db, + semantic_index::{place::ScopedPlaceId, scope::ScopeId}, + types::bound_super::SuperOwnerKind, +}; use super::{ DynamicType, TodoType, Type, TypeGuardType, TypeIsType, class_base::ClassBase, @@ -291,41 +295,73 @@ fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering } } -/// Determine a canonical order for two instances of [`TypeIsType`]. +/// Trait for type guard-like types that can be ordered canonically. +trait GuardLikeOrdering<'db>: Copy { + fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)>; + fn place_name(self, db: &'db dyn Db) -> Option; + fn return_type(self, db: &'db dyn Db) -> Type<'db>; +} + +impl<'db> GuardLikeOrdering<'db> for TypeIsType<'db> { + fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)> { + TypeIsType::place_info(self, db) + } + + fn place_name(self, db: &'db dyn Db) -> Option { + TypeIsType::place_name(self, db) + } + + fn return_type(self, db: &'db dyn Db) -> Type<'db> { + TypeIsType::return_type(self, db) + } +} + +impl<'db> GuardLikeOrdering<'db> for TypeGuardType<'db> { + fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)> { + TypeGuardType::place_info(self, db) + } + + fn place_name(self, db: &'db dyn Db) -> Option { + TypeGuardType::place_name(self, db) + } + + fn return_type(self, db: &'db dyn Db) -> Type<'db> { + TypeGuardType::return_type(self, db) + } +} + +/// Generic helper for ordering type guard-like types. /// /// The following criteria are considered, in order: /// * Boundness: Unbound precedes bound /// * Symbol name: String comparison /// * Guarded type: [`union_or_intersection_elements_ordering`] +fn guard_like_ordering<'db, T: GuardLikeOrdering<'db>>( + db: &'db dyn Db, + left: T, + right: T, +) -> Ordering { + let (left_ty, right_ty) = (left.return_type(db), right.return_type(db)); + + match (left.place_info(db), right.place_info(db)) { + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + + (None, None) => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), + + (Some(_), Some(_)) => match left.place_name(db).cmp(&right.place_name(db)) { + Ordering::Equal => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), + ordering => ordering, + }, + } +} + +/// Determine a canonical order for two instances of [`TypeIsType`]. fn typeis_ordering(db: &dyn Db, left: TypeIsType, right: TypeIsType) -> Ordering { - let (left_ty, right_ty) = (left.return_type(db), right.return_type(db)); - - match (left.place_info(db), right.place_info(db)) { - (None, Some(_)) => Ordering::Less, - (Some(_), None) => Ordering::Greater, - - (None, None) => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), - - (Some(_), Some(_)) => match left.place_name(db).cmp(&right.place_name(db)) { - Ordering::Equal => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), - ordering => ordering, - }, - } + guard_like_ordering(db, left, right) } -// TODO: de-duplicate +/// Determine a canonical order for two instances of [`TypeGuardType`]. fn typeguard_ordering(db: &dyn Db, left: TypeGuardType, right: TypeGuardType) -> Ordering { - let (left_ty, right_ty) = (left.return_type(db), right.return_type(db)); - - match (left.place_info(db), right.place_info(db)) { - (None, Some(_)) => Ordering::Less, - (Some(_), None) => Ordering::Greater, - - (None, None) => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), - - (Some(_), Some(_)) => match left.place_name(db).cmp(&right.place_name(db)) { - Ordering::Equal => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), - ordering => ordering, - }, - } + guard_like_ordering(db, left, right) } From 27b715215972f0526890df3e2d82549c239d75f3 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Mon, 17 Nov 2025 20:45:01 -0500 Subject: [PATCH 11/22] some cleanup --- .../mdtest/annotations/unsupported_special_forms.md | 3 --- .../resources/mdtest/narrow/type_guards.md | 7 ++----- crates/ty_python_semantic/src/types.rs | 3 ++- crates/ty_python_semantic/src/types/narrow.rs | 9 ++++----- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md b/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md index 6c0947c5e0..50df54b556 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md @@ -16,9 +16,6 @@ def f(*args: Unpack[Ts]) -> tuple[Unpack[Ts]]: reveal_type(args) # revealed: tuple[@Todo(`Unpack[]` special form), ...] return args -def g() -> TypeGuard[int]: - return True - def i(callback: Callable[Concatenate[int, P], R_co], *args: P.args, **kwargs: P.kwargs) -> R_co: reveal_type(args) # revealed: P@i.args reveal_type(kwargs) # revealed: P@i.kwargs diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index c2a5f8feb6..fdd5a5aee6 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -12,17 +12,14 @@ from typing_extensions import TypeGuard, TypeIs def _( a: TypeGuard[str], b: TypeIs[str | int], - c: TypeGuard[Intersection[complex, Not[int], Not[float]]], + c: TypeGuard[bool], d: TypeIs[tuple[TypeOf[bytes]]], e: TypeGuard, # error: [invalid-type-form] f: TypeIs, # error: [invalid-type-form] ): reveal_type(a) # revealed: TypeGuard[str] reveal_type(b) # revealed: TypeIs[str | int] - # not `TypeGuard[complex & ~int & ~float]`: `complex` in argument position - # means `complex & int & float` semantically so `Intersection[complex, - # Not[int], Not[float]]` means `complex` semantically - reveal_type(c) # revealed: TypeGuard[complex] + reveal_type(c) # revealed: TypeGuard[bool] reveal_type(d) # revealed: TypeIs[tuple[]] reveal_type(e) # revealed: Unknown reveal_type(f) # revealed: Unknown diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 1d6a28b3d1..94f33c07f7 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -14750,7 +14750,8 @@ impl<'db> TypeGuardType<'db> { } impl<'db> VarianceInferable<'db> for TypeGuardType<'db> { - // TODO: comment + // `TypeGuard` is covariant in its type parameter. See the `TypeGuard` + // section of mdtest/generics/pep695/variance.md for details. fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { self.return_type(db).variance_of(db, typevar) } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index b7cdf10e02..f5cf1b4d38 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -288,10 +288,9 @@ impl ClassInfoConstraintFunction { /// in the second clobbers the first. #[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)] struct Conjunction<'db> { - /// The intersected constraints (represented as an intersection type) + /// The intersected constraints (represented as a type to intersect the guard with) constraint: Type<'db>, - /// If any constraint in this conjunction is a `TypeGuard`, this is Some and - /// contains the union of all `TypeGuard` types in this conjunction + /// If any constraint in this conjunction is a `TypeGuard[T]`, this is `Some(T)` typeguard: Option>, } @@ -343,10 +342,10 @@ impl<'db> Conjunction<'db> { /// /// - `f(x) or g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]` /// => `[Conjunction { constraint: A, typeguard: None }, Conjunction { constraint: object, typeguard: Some(B) }]` -/// => evaluates to `A | B` +/// => evaluates to `(P & A) | B`, where `P` is our previously-known type #[derive(Hash, PartialEq, Debug, Eq, Clone)] struct NarrowingConstraint<'db> { - /// Disjunctions of conjunctions (DNF) + /// Disjunction of conjunctions (DNF) disjuncts: SmallVec<[Conjunction<'db>; 4]>, } From 9fdb64dfcacbbd697bb61b05884411b325f1e0c5 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Fri, 19 Dec 2025 20:12:31 -0500 Subject: [PATCH 12/22] --amend --- crates/ty_python_semantic/src/types/narrow.rs | 47 +++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index f5cf1b4d38..d5a1373f6b 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -68,7 +68,10 @@ pub(crate) fn infer_narrowing_constraint<'db>( PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(constraints) = constraints { - constraints.get(&place).copied() + constraints + .constraints + .get(&place) + .map(|constraint| constraint.clone().evaluate_type_constraint(db)) } else { None } @@ -78,7 +81,7 @@ pub(crate) fn infer_narrowing_constraint<'db>( fn all_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, pattern.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), true).finish() } @@ -91,7 +94,7 @@ fn all_narrowing_constraints_for_pattern<'db>( fn all_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, expression.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), true) .finish() @@ -105,7 +108,7 @@ fn all_narrowing_constraints_for_expression<'db>( fn all_negative_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, expression.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), false) .finish() @@ -115,7 +118,7 @@ fn all_negative_narrowing_constraints_for_expression<'db>( fn all_negative_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, pattern.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), false).finish() } @@ -124,7 +127,7 @@ fn constraints_for_expression_cycle_initial<'db>( _db: &'db dyn Db, _id: salsa::Id, _expression: Expression<'db>, -) -> Option> { +) -> Option> { None } @@ -132,7 +135,7 @@ fn negative_constraints_for_expression_cycle_initial<'db>( _db: &'db dyn Db, _id: salsa::Id, _expression: Expression<'db>, -) -> Option> { +) -> Option> { None } @@ -389,11 +392,27 @@ impl<'db> From> for NarrowingConstraint<'db> { /// /// This is a newtype wrapper around `FxHashMap>` that /// provides methods for working with constraints during boolean operation evaluation. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] struct InternalConstraints<'db> { constraints: FxHashMap>, } +impl get_size2::GetSize for InternalConstraints<'_> {} + +// SAFETY: InternalConstraints contains only `'db` lifetimes which are covariant, +// and the inner types (FxHashMap, ScopedPlaceId, NarrowingConstraint) are all safe to transmute +unsafe impl salsa::Update for InternalConstraints<'_> { + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + let old_ref = unsafe { &mut (*old_pointer) }; + if *old_ref != new_value { + *old_ref = new_value; + true + } else { + false + } + } +} + impl<'db> InternalConstraints<'db> { /// Insert a regular (non-`TypeGuard`) constraint for a place fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>) { @@ -576,8 +595,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } - fn finish(mut self) -> Option> { - let constraints: Option> = match self.predicate { + fn finish(mut self) -> Option> { + let mut constraints: Option> = match self.predicate { PredicateNode::Expression(expression) => { self.evaluate_expression_predicate(expression, self.is_positive) } @@ -587,12 +606,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { PredicateNode::ReturnsNever(_) => return None, PredicateNode::StarImportPlaceholder(_) => return None, }; - if let Some(mut constraints) = constraints { + + if let Some(ref mut constraints) = constraints { constraints.constraints.shrink_to_fit(); - Some(constraints.evaluate_type_constraints(self.db)) - } else { - None } + + constraints } fn evaluate_expression_predicate( From 5fef55ff493b2ff7d4ea9789a8ade64629ea8f27 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Fri, 19 Dec 2025 23:13:08 -0500 Subject: [PATCH 13/22] more cleanup --- crates/ty_python_semantic/src/types/narrow.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index d5a1373f6b..811a50cbec 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -349,7 +349,7 @@ impl<'db> Conjunction<'db> { #[derive(Hash, PartialEq, Debug, Eq, Clone)] struct NarrowingConstraint<'db> { /// Disjunction of conjunctions (DNF) - disjuncts: SmallVec<[Conjunction<'db>; 4]>, + disjuncts: SmallVec<[Conjunction<'db>; 1]>, } impl get_size2::GetSize for NarrowingConstraint<'_> {} @@ -524,10 +524,8 @@ fn merge_constraints_or<'db>( .disjuncts .extend(from_constraint.disjuncts.clone()); } - Entry::Vacant(entry) => { - // Place only appears in `from`, not in `into`. - // Widen to object since the other branch doesn't constrain it. - entry.insert(NarrowingConstraint::regular(Type::object())); + Entry::Vacant(_) => { + // Place only appears in `from`, not in `into`. No constraint needed. } } } From f8185ce8be96b46c2c23e7333f17d8548abab859 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sat, 20 Dec 2025 00:19:43 -0500 Subject: [PATCH 14/22] correct typeguard evaluation --- .../resources/mdtest/narrow/type_guards.md | 27 +++++-- .../src/semantic_index/use_def.rs | 23 ++---- crates/ty_python_semantic/src/types.rs | 2 +- crates/ty_python_semantic/src/types/narrow.rs | 72 ++++++++++--------- 4 files changed, 67 insertions(+), 57 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index fdd5a5aee6..09f1cc69ad 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -200,6 +200,27 @@ def _(a: Foo | Bar): reveal_type(a) # revealed: Foo & ~Bar ``` +```py +from typing import TypeGuard, reveal_type + +class P: + pass + +class A: + pass + +class B: + pass + +def is_b(val: object) -> TypeGuard[B]: + return isinstance(val, B) + +def _(x: P): + if isinstance(x, A) or is_b(x): + # currently reveals `(P & A) | (P & B)`, should reveal `(P & A) | B` + reveal_type(x) # revealed: (P & A) | B +``` + Attribute and subscript narrowing is supported: ```py @@ -212,18 +233,16 @@ class C(Generic[T]): def _(a: tuple[Foo, Bar] | tuple[Bar, Foo], c: C[Any]): if reveal_type(guard_foo(a[1])): # revealed: TypeGuard[Foo @ a[1]] - # TODO: Should be `tuple[Bar, Foo]` reveal_type(a) # revealed: tuple[Foo, Bar] | tuple[Bar, Foo] reveal_type(a[1]) # revealed: Foo if reveal_type(is_bar(a[0])): # revealed: TypeIs[Bar @ a[0]] - # TODO: Should be `tuple[Bar, Bar & Foo]` reveal_type(a) # revealed: tuple[Foo, Bar] | tuple[Bar, Foo] reveal_type(a[0]) # revealed: Bar if reveal_type(guard_foo(c.v)): # revealed: TypeGuard[Foo @ c.v] reveal_type(c) # revealed: C[Any] - reveal_type(c.v) # revealed: Any & Foo + reveal_type(c.v) # revealed: Foo if reveal_type(is_bar(c.v)): # revealed: TypeIs[Bar @ c.v] reveal_type(c) # revealed: C[Any] @@ -347,7 +366,7 @@ def does_not_narrow_in_negative_case(a: Foo | Bar): def narrowed_type_must_be_exact(a: object, b: Baz): if guard_foo(b): - reveal_type(b) # revealed: Baz & Foo + reveal_type(b) # revealed: Foo if isinstance(a, Baz) and is_bar(a): reveal_type(a) # revealed: Baz diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index dbd26595fd..09e35b3e27 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -266,7 +266,7 @@ use crate::semantic_index::use_def::place_state::{ LiveDeclarationsIterator, PlaceState, PreviousDefinitions, ScopedDefinitionId, }; use crate::semantic_index::{EnclosingSnapshotResult, SemanticIndex}; -use crate::types::{IntersectionBuilder, Truthiness, Type, infer_narrowing_constraint}; +use crate::types::{NarrowingConstraint, Truthiness, Type, infer_narrowing_constraint}; mod place_state; @@ -757,22 +757,11 @@ impl<'db> ConstraintsIterator<'_, 'db> { base_ty: Type<'db>, place: ScopedPlaceId, ) -> Type<'db> { - let constraint_tys: Vec<_> = self - .filter_map(|constraint| infer_narrowing_constraint(db, constraint, place)) - .collect(); - - if constraint_tys.is_empty() { - base_ty - } else { - constraint_tys - .into_iter() - .rev() - .fold( - IntersectionBuilder::new(db).add_positive(base_ty), - IntersectionBuilder::add_positive, - ) - .build() - } + self.filter_map(|constraint| infer_narrowing_constraint(db, constraint, place)) + .fold(NarrowingConstraint::regular(base_ty), |acc, constraint| { + acc.merge_constraint_and(&constraint, db) + }) + .evaluate_type_constraint(db) } } diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 94f33c07f7..dfd7cecf91 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -64,7 +64,7 @@ use crate::types::generics::{ walk_generic_context, }; use crate::types::mro::{Mro, MroError, MroIterator}; -pub(crate) use crate::types::narrow::infer_narrowing_constraint; +pub(crate) use crate::types::narrow::{NarrowingConstraint, infer_narrowing_constraint}; use crate::types::newtype::NewType; pub(crate) use crate::types::signatures::{Parameter, Parameters}; use crate::types::signatures::{ParameterForm, walk_signature}; diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 811a50cbec..ac00a4ee65 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -48,7 +48,7 @@ pub(crate) fn infer_narrowing_constraint<'db>( db: &'db dyn Db, predicate: Predicate<'db>, place: ScopedPlaceId, -) -> Option> { +) -> Option> { let constraints = match predicate.node { PredicateNode::Expression(expression) => { if predicate.is_positive { @@ -68,10 +68,7 @@ pub(crate) fn infer_narrowing_constraint<'db>( PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(constraints) = constraints { - constraints - .constraints - .get(&place) - .map(|constraint| constraint.clone().evaluate_type_constraint(db)) + constraints.constraints.get(&place).cloned() } else { None } @@ -347,7 +344,7 @@ impl<'db> Conjunction<'db> { /// => `[Conjunction { constraint: A, typeguard: None }, Conjunction { constraint: object, typeguard: Some(B) }]` /// => evaluates to `(P & A) | B`, where `P` is our previously-known type #[derive(Hash, PartialEq, Debug, Eq, Clone)] -struct NarrowingConstraint<'db> { +pub(crate) struct NarrowingConstraint<'db> { /// Disjunction of conjunctions (DNF) disjuncts: SmallVec<[Conjunction<'db>; 1]>, } @@ -356,7 +353,7 @@ impl get_size2::GetSize for NarrowingConstraint<'_> {} impl<'db> NarrowingConstraint<'db> { /// Create a constraint from a regular (non-`TypeGuard`) type - fn regular(constraint: Type<'db>) -> Self { + pub(crate) fn regular(constraint: Type<'db>) -> Self { Self { disjuncts: smallvec![Conjunction::regular(constraint)], } @@ -369,10 +366,41 @@ impl<'db> NarrowingConstraint<'db> { } } + /// Merge two constraints, taking their intersection but respecting `TypeGuard` semantics + pub(crate) fn merge_constraint_and(&self, other: &Self, db: &'db dyn Db) -> Self { + let mut new_disjuncts = SmallVec::new(); + + // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) + // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... + for left_conj in &self.disjuncts { + for right_conj in &other.disjuncts { + if right_conj.typeguard.is_some() { + // If the right conjunct has a TypeGuard, it "wins" the conjunction + new_disjuncts.push(*right_conj); + } else { + // Intersect the regular constraints + let new_regular = IntersectionBuilder::new(db) + .add_positive(left_conj.constraint) + .add_positive(right_conj.constraint) + .build(); + + new_disjuncts.push(Conjunction { + constraint: new_regular, + typeguard: left_conj.typeguard, + }); + } + } + } + + NarrowingConstraint { + disjuncts: new_disjuncts, + } + } + /// Evaluate the type this effectively constrains to /// /// Forgets whether each constraint originated from a `TypeGuard` or not - fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> { + pub(crate) fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> { UnionType::from_elements( db, self.disjuncts @@ -461,33 +489,7 @@ fn merge_constraints_and<'db>( Entry::Occupied(mut entry) => { let into_constraint = entry.get(); - // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) - // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... - let mut new_disjuncts = SmallVec::new(); - - for left_conj in &into_constraint.disjuncts { - for right_conj in &from_constraint.disjuncts { - if right_conj.typeguard.is_some() { - // If the right conjunct has a TypeGuard, it "wins" the conjunction - new_disjuncts.push(*right_conj); - } else { - // Intersect the regular constraints - let new_regular = IntersectionBuilder::new(db) - .add_positive(left_conj.constraint) - .add_positive(right_conj.constraint) - .build(); - - new_disjuncts.push(Conjunction { - constraint: new_regular, - typeguard: left_conj.typeguard, - }); - } - } - } - - *entry.get_mut() = NarrowingConstraint { - disjuncts: new_disjuncts, - }; + entry.insert(into_constraint.merge_constraint_and(&from_constraint, db)); } Entry::Vacant(entry) => { entry.insert(from_constraint.clone()); From 051894165d6a39642989d52f6ae2bf1e05aa100e Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sat, 20 Dec 2025 03:05:52 -0500 Subject: [PATCH 15/22] more cleanup --- .../resources/mdtest/narrow/type_guards.md | 1 - crates/ty_python_semantic/src/types/narrow.rs | 187 ++++++------------ 2 files changed, 65 insertions(+), 123 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index 09f1cc69ad..6306c0639a 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -217,7 +217,6 @@ def is_b(val: object) -> TypeGuard[B]: def _(x: P): if isinstance(x, A) or is_b(x): - # currently reveals `(P & A) | (P & B)`, should reveal `(P & A) | B` reveal_type(x) # revealed: (P & A) | B ``` diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index ac00a4ee65..ae1f02972e 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -68,7 +68,7 @@ pub(crate) fn infer_narrowing_constraint<'db>( PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(constraints) = constraints { - constraints.constraints.get(&place).cloned() + constraints.get(&place).cloned() } else { None } @@ -78,7 +78,7 @@ pub(crate) fn infer_narrowing_constraint<'db>( fn all_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, pattern.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), true).finish() } @@ -91,7 +91,7 @@ fn all_narrowing_constraints_for_pattern<'db>( fn all_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, expression.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), true) .finish() @@ -105,7 +105,7 @@ fn all_narrowing_constraints_for_expression<'db>( fn all_negative_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, expression.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), false) .finish() @@ -115,7 +115,7 @@ fn all_negative_narrowing_constraints_for_expression<'db>( fn all_negative_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, pattern.file(db)).load(db); NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), false).finish() } @@ -124,7 +124,7 @@ fn constraints_for_expression_cycle_initial<'db>( _db: &'db dyn Db, _id: salsa::Id, _expression: Expression<'db>, -) -> Option> { +) -> Option> { None } @@ -132,7 +132,7 @@ fn negative_constraints_for_expression_cycle_initial<'db>( _db: &'db dyn Db, _id: salsa::Id, _expression: Expression<'db>, -) -> Option> { +) -> Option> { None } @@ -283,10 +283,10 @@ impl ClassInfoConstraintFunction { /// intersecting /// /// For example, `(Conjunction { constraint: A, typeguard: Some(B) } & -/// Conjunction { constraint: C, typeguard: Some(D)})` evlaluates to +/// Conjunction { constraint: C, typeguard: Some(D)})` evaluates to /// `Conjunction { constraint: C, typeguard: Some(D) }` because the type guard -/// in the second clobbers the first. -#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)] +/// in the second conjunct clobbers that in the first. +#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy, salsa::Update, get_size2::GetSize)] struct Conjunction<'db> { /// The intersected constraints (represented as a type to intersect the guard with) constraint: Type<'db>, @@ -294,8 +294,6 @@ struct Conjunction<'db> { typeguard: Option>, } -impl get_size2::GetSize for Conjunction<'_> {} - impl<'db> Conjunction<'db> { /// Create a new conjunction with just a regular constraint fn regular(constraint: Type<'db>) -> Self { @@ -343,14 +341,12 @@ impl<'db> Conjunction<'db> { /// - `f(x) or g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]` /// => `[Conjunction { constraint: A, typeguard: None }, Conjunction { constraint: object, typeguard: Some(B) }]` /// => evaluates to `(P & A) | B`, where `P` is our previously-known type -#[derive(Hash, PartialEq, Debug, Eq, Clone)] +#[derive(Hash, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] pub(crate) struct NarrowingConstraint<'db> { /// Disjunction of conjunctions (DNF) disjuncts: SmallVec<[Conjunction<'db>; 1]>, } -impl get_size2::GetSize for NarrowingConstraint<'_> {} - impl<'db> NarrowingConstraint<'db> { /// Create a constraint from a regular (non-`TypeGuard`) type pub(crate) fn regular(constraint: Type<'db>) -> Self { @@ -366,17 +362,18 @@ impl<'db> NarrowingConstraint<'db> { } } - /// Merge two constraints, taking their intersection but respecting `TypeGuard` semantics + /// Merge two constraints, taking their intersection but respecting `TypeGuard` semantics (with `other` winning) pub(crate) fn merge_constraint_and(&self, other: &Self, db: &'db dyn Db) -> Self { - let mut new_disjuncts = SmallVec::new(); - // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... - for left_conj in &self.disjuncts { - for right_conj in &other.disjuncts { + let new_disjuncts = self + .disjuncts + .iter() + .cartesian_product(other.disjuncts.iter()) + .map(|(left_conj, right_conj)| { if right_conj.typeguard.is_some() { // If the right conjunct has a TypeGuard, it "wins" the conjunction - new_disjuncts.push(*right_conj); + *right_conj } else { // Intersect the regular constraints let new_regular = IntersectionBuilder::new(db) @@ -384,13 +381,13 @@ impl<'db> NarrowingConstraint<'db> { .add_positive(right_conj.constraint) .build(); - new_disjuncts.push(Conjunction { + Conjunction { constraint: new_regular, typeguard: left_conj.typeguard, - }); + } } - } - } + }) + .collect::>(); NarrowingConstraint { disjuncts: new_disjuncts, @@ -416,59 +413,7 @@ impl<'db> From> for NarrowingConstraint<'db> { } } -/// Internal representation of constraints with DNF structure for tracking `TypeGuard` semantics. -/// -/// This is a newtype wrapper around `FxHashMap>` that -/// provides methods for working with constraints during boolean operation evaluation. -#[derive(Clone, Debug, Default, PartialEq, Eq)] -struct InternalConstraints<'db> { - constraints: FxHashMap>, -} - -impl get_size2::GetSize for InternalConstraints<'_> {} - -// SAFETY: InternalConstraints contains only `'db` lifetimes which are covariant, -// and the inner types (FxHashMap, ScopedPlaceId, NarrowingConstraint) are all safe to transmute -unsafe impl salsa::Update for InternalConstraints<'_> { - unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { - let old_ref = unsafe { &mut (*old_pointer) }; - if *old_ref != new_value { - *old_ref = new_value; - true - } else { - false - } - } -} - -impl<'db> InternalConstraints<'db> { - /// Insert a regular (non-`TypeGuard`) constraint for a place - fn insert_regular(&mut self, place: ScopedPlaceId, ty: Type<'db>) { - self.constraints - .insert(place, NarrowingConstraint::regular(ty)); - } - - /// Convert internal constraints to public constraints by evaluating each DNF constraint to a Type - fn evaluate_type_constraints(self, db: &'db dyn Db) -> NarrowingConstraints<'db> { - self.constraints - .into_iter() - .map(|(place, constraint)| (place, constraint.evaluate_type_constraint(db))) - .collect() - } -} - -impl<'db> FromIterator<(ScopedPlaceId, NarrowingConstraint<'db>)> for InternalConstraints<'db> { - fn from_iter)>>( - iter: T, - ) -> Self { - Self { - constraints: FxHashMap::from_iter(iter), - } - } -} - -/// Public representation of constraints as returned by tracked functions -type NarrowingConstraints<'db> = FxHashMap>; +type NarrowingConstraints<'db> = FxHashMap>; /// Merge constraints with AND semantics (intersection/conjunction). /// @@ -480,16 +425,16 @@ type NarrowingConstraints<'db> = FxHashMap>; /// - Take the right conjunct if it has a `TypeGuard` /// - Intersect the constraints normally otherwise fn merge_constraints_and<'db>( - into: &mut InternalConstraints<'db>, - from: &InternalConstraints<'db>, + into: &mut NarrowingConstraints<'db>, + from: &NarrowingConstraints<'db>, db: &'db dyn Db, ) { - for (key, from_constraint) in &from.constraints { - match into.constraints.entry(*key) { + for (key, from_constraint) in from { + match into.entry(*key) { Entry::Occupied(mut entry) => { let into_constraint = entry.get(); - entry.insert(into_constraint.merge_constraint_and(&from_constraint, db)); + entry.insert(into_constraint.merge_constraint_and(from_constraint, db)); } Entry::Vacant(entry) => { entry.insert(from_constraint.clone()); @@ -506,19 +451,15 @@ fn merge_constraints_and<'db>( /// However, if a place appears in only one branch of the OR, we need to widen it /// to `object` in the overall result (because the other branch doesn't constrain it). fn merge_constraints_or<'db>( - into: &mut InternalConstraints<'db>, - from: &InternalConstraints<'db>, + into: &mut NarrowingConstraints<'db>, + from: &NarrowingConstraints<'db>, _db: &'db dyn Db, ) { // For places that appear in `into` but not in `from`, widen to object - for (key, value) in &mut into.constraints { - if !from.constraints.contains_key(key) { - *value = NarrowingConstraint::regular(Type::object()); - } - } + into.retain(|key, _| from.contains_key(key)); - for (key, from_constraint) in &from.constraints { - match into.constraints.entry(*key) { + for (key, from_constraint) in from { + match into.entry(*key) { Entry::Occupied(mut entry) => { // Simply concatenate the disjuncts entry @@ -595,8 +536,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } - fn finish(mut self) -> Option> { - let mut constraints: Option> = match self.predicate { + fn finish(mut self) -> Option> { + let mut constraints: Option> = match self.predicate { PredicateNode::Expression(expression) => { self.evaluate_expression_predicate(expression, self.is_positive) } @@ -608,7 +549,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { }; if let Some(ref mut constraints) = constraints { - constraints.constraints.shrink_to_fit(); + constraints.shrink_to_fit(); } constraints @@ -618,7 +559,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { let expression_node = expression.node_ref(self.db, self.module); self.evaluate_expression_node_predicate(expression_node, expression, is_positive) } @@ -628,7 +569,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expression_node: &ruff_python_ast::Expr, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { match expression_node { ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => { self.evaluate_simple_expr(expression_node, is_positive) @@ -653,7 +594,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { pattern_predicate_kind: &PatternPredicateKind<'db>, subject: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { match pattern_predicate_kind { PatternPredicateKind::Singleton(singleton) => { self.evaluate_match_pattern_singleton(subject, *singleton, is_positive) @@ -678,7 +619,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, pattern: PatternPredicate<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { self.evaluate_pattern_predicate_kind( pattern.kind(self.db), pattern.subject(self.db), @@ -788,7 +729,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, expr: &ast::Expr, is_positive: bool, - ) -> Option> { + ) -> Option> { let target = place_expr(expr)?; let place = self.expect_place(&target); @@ -798,7 +739,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { Type::AlwaysTruthy.negate(self.db) }; - Some(InternalConstraints::from_iter([( + Some(NarrowingConstraints::from_iter([( place, NarrowingConstraint::regular(ty), )])) @@ -808,7 +749,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, expr_named: &ast::ExprNamed, is_positive: bool, - ) -> Option> { + ) -> Option> { self.evaluate_simple_expr(&expr_named.target, is_positive) } @@ -1052,7 +993,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expr_compare: &ast::ExprCompare, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool { matches!( expr, @@ -1094,7 +1035,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let comparator_tuples = std::iter::once(&**left) .chain(comparators) .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); - let mut constraints = InternalConstraints::default(); + let mut constraints = NarrowingConstraints::default(); let mut last_rhs_ty: Option = None; @@ -1113,7 +1054,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive) { let place = self.expect_place(&left); - constraints.insert_regular(place, ty); + constraints.insert(place, NarrowingConstraint::regular(ty)); } } ast::Expr::Call(ast::ExprCall { @@ -1159,10 +1100,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { .is_some_and(|c| c.is_known(self.db, KnownClass::Type)) { let place = self.expect_place(&target); - constraints.insert_regular( + constraints.insert( place, - Type::instance(self.db, rhs_class.unknown_specialization(self.db)) - .negate_if(self.db, !is_positive), + NarrowingConstraint::regular( + Type::instance(self.db, rhs_class.unknown_specialization(self.db)) + .negate_if(self.db, !is_positive), + ), ); } } @@ -1177,7 +1120,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expr_call: &ast::ExprCall, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); let callable_ty = inference.expression_type(&*expr_call.func); @@ -1214,7 +1157,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { _ => None, }?; - Some(InternalConstraints::from_iter([place_and_constraint])) + Some(NarrowingConstraints::from_iter([place_and_constraint])) } // For the expression `len(E)`, we narrow the type based on whether len(E) is truthy // (i.e., whether E is non-empty). We only narrow the parts of the type where we know @@ -1232,7 +1175,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if let Some(narrowed_ty) = Self::narrow_type_by_len(self.db, arg_ty, is_positive) { let target = place_expr(arg)?; let place = self.expect_place(&target); - Some(InternalConstraints::from_iter([( + Some(NarrowingConstraints::from_iter([( place, NarrowingConstraint::regular(narrowed_ty), )])) @@ -1263,7 +1206,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let constraint = Type::protocol_with_readonly_members(self.db, [(attr, Type::object())]); - return Some(InternalConstraints::from_iter([( + return Some(NarrowingConstraints::from_iter([( place, NarrowingConstraint::regular(constraint.negate_if(self.db, !is_positive)), )])); @@ -1276,7 +1219,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { function .generate_constraint(self.db, class_info_ty) .map(|constraint| { - InternalConstraints::from_iter([( + NarrowingConstraints::from_iter([( place, NarrowingConstraint::regular( constraint.negate_if(self.db, !is_positive), @@ -1305,7 +1248,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, singleton: ast::Singleton, is_positive: bool, - ) -> Option> { + ) -> Option> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); @@ -1315,7 +1258,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ast::Singleton::False => Type::BooleanLiteral(false), }; let ty = ty.negate_if(self.db, !is_positive); - Some(InternalConstraints::from_iter([( + Some(NarrowingConstraints::from_iter([( place, NarrowingConstraint::regular(ty), )])) @@ -1327,7 +1270,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { cls: Expression<'db>, kind: ClassPatternKind, is_positive: bool, - ) -> Option> { + ) -> Option> { if !kind.is_irrefutable() && !is_positive { // A class pattern like `case Point(x=0, y=0)` is not irrefutable. In the positive case, // we can still narrow the type of the match subject to `Point`. But in the negative case, @@ -1351,7 +1294,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { _ => return None, }; - Some(InternalConstraints::from_iter([( + Some(NarrowingConstraints::from_iter([( place, NarrowingConstraint::regular(narrowed_type), )])) @@ -1362,7 +1305,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, value: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { let place = { let subject = place_expr(subject.node_ref(self.db, self.module))?; self.expect_place(&subject) @@ -1374,7 +1317,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive) - .map(|ty| InternalConstraints::from_iter([(place, NarrowingConstraint::regular(ty))])) + .map(|ty| NarrowingConstraints::from_iter([(place, NarrowingConstraint::regular(ty))])) } fn evaluate_match_pattern_or( @@ -1382,7 +1325,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, predicates: &Vec>, is_positive: bool, - ) -> Option> { + ) -> Option> { let db = self.db; // DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints. @@ -1408,7 +1351,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expr_bool_op: &ExprBoolOp, expression: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); let sub_constraints = expr_bool_op .values @@ -1427,7 +1370,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { .collect::>(); match (expr_bool_op.op, is_positive) { (BoolOp::And, true) | (BoolOp::Or, false) => { - let mut aggregation: Option = None; + let mut aggregation: Option = None; for sub_constraint in sub_constraints.into_iter().flatten() { if let Some(ref mut some_aggregation) = aggregation { merge_constraints_and(some_aggregation, &sub_constraint, self.db); From 7b3eea38ef6e3478a47d50d1b8df174376c911f4 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sat, 20 Dec 2025 12:58:31 -0500 Subject: [PATCH 16/22] fix constraint merge order --- .../resources/mdtest/narrow/type_guards.md | 5 +++++ .../src/semantic_index/use_def.rs | 17 ++++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index 6306c0639a..7d5952c5f3 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -354,6 +354,9 @@ class Baz(Bar): ... def guard_foo(a: object) -> TypeGuard[Foo]: return True +def guard_bar(a: object) -> TypeGuard[Bar]: + return True + def is_bar(a: object) -> TypeIs[Bar]: return True @@ -372,6 +375,8 @@ def narrowed_type_must_be_exact(a: object, b: Baz): if isinstance(a, Bar) and guard_foo(a): reveal_type(a) # revealed: Foo + if guard_bar(a): + reveal_type(a) # revealed: Bar ``` ## TypeGuard overrides normal constraints diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index 09e35b3e27..9288a6df50 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -757,11 +757,22 @@ impl<'db> ConstraintsIterator<'_, 'db> { base_ty: Type<'db>, place: ScopedPlaceId, ) -> Type<'db> { + // Constraints are in reverse-source order. Due to TypeGuard semantics + // constraint AND is non-commutative and so we _must_ apply in + // source order. + // + // Fortunately, constraint AND is still associative, so we can still iterate left-to-right + // and accumulate rightward. self.filter_map(|constraint| infer_narrowing_constraint(db, constraint, place)) - .fold(NarrowingConstraint::regular(base_ty), |acc, constraint| { - acc.merge_constraint_and(&constraint, db) + .reduce(|acc, constraint| { + // See above---note the reverse application + constraint.merge_constraint_and(&acc, db) + }) + .map_or(base_ty, |constraint| { + NarrowingConstraint::regular(base_ty) + .merge_constraint_and(&constraint, db) + .evaluate_type_constraint(db) }) - .evaluate_type_constraint(db) } } From 34390c8b9c0ef2bdc7616656a29b992cbd1ef13c Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sat, 20 Dec 2025 14:02:20 -0500 Subject: [PATCH 17/22] some optimization and simplification --- crates/ty_python_semantic/src/types/narrow.rs | 67 +++++++++++++------ 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index ae1f02972e..14e0d6851c 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -19,7 +19,7 @@ use crate::types::{ use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_stdlib::identifiers::is_identifier; -use itertools::Itertools; +use itertools::{Either, Itertools}; use ruff_python_ast as ast; use ruff_python_ast::{BoolOp, ExprBoolOp}; use rustc_hash::FxHashMap; @@ -366,28 +366,30 @@ impl<'db> NarrowingConstraint<'db> { pub(crate) fn merge_constraint_and(&self, other: &Self, db: &'db dyn Db) -> Self { // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... - let new_disjuncts = self + let new_disjuncts = other .disjuncts .iter() - .cartesian_product(other.disjuncts.iter()) - .map(|(left_conj, right_conj)| { + .flat_map(|right_conj| { + // We iterate the RHS first because if it has a typeguard then we don't need to consider the LHS if right_conj.typeguard.is_some() { // If the right conjunct has a TypeGuard, it "wins" the conjunction - *right_conj + Either::Left(std::iter::once(*right_conj)) } else { - // Intersect the regular constraints - let new_regular = IntersectionBuilder::new(db) - .add_positive(left_conj.constraint) - .add_positive(right_conj.constraint) - .build(); + // Otherwise, we need to consider all LHS disjuncts + Either::Right(self.disjuncts.iter().map(|left_conj| { + let new_regular = IntersectionBuilder::new(db) + .add_positive(left_conj.constraint) + .add_positive(right_conj.constraint) + .build(); - Conjunction { - constraint: new_regular, - typeguard: left_conj.typeguard, - } + Conjunction { + constraint: new_regular, + typeguard: left_conj.typeguard, + } + })) } }) - .collect::>(); + .collect(); NarrowingConstraint { disjuncts: new_disjuncts, @@ -450,10 +452,13 @@ fn merge_constraints_and<'db>( /// /// However, if a place appears in only one branch of the OR, we need to widen it /// to `object` in the overall result (because the other branch doesn't constrain it). +/// +/// When none of the disjuncts have `TypeGuard`, we simplify the constraint types +/// via `UnionBuilder` to enable simplifications like `~AlwaysFalsy | ~AlwaysTruthy -> object`. fn merge_constraints_or<'db>( into: &mut NarrowingConstraints<'db>, from: &NarrowingConstraints<'db>, - _db: &'db dyn Db, + db: &'db dyn Db, ) { // For places that appear in `into` but not in `from`, widen to object into.retain(|key, _| from.contains_key(key)); @@ -461,11 +466,35 @@ fn merge_constraints_or<'db>( for (key, from_constraint) in from { match into.entry(*key) { Entry::Occupied(mut entry) => { - // Simply concatenate the disjuncts - entry - .get_mut() + let into_constraint = entry.get_mut(); + // Concatenate disjuncts + into_constraint .disjuncts .extend(from_constraint.disjuncts.clone()); + + // If none of the disjuncts have TypeGuard, we can simplify the constraint types + // via UnionBuilder. This enables simplifications like: + // `~AlwaysFalsy | ~AlwaysTruthy -> object` + let all_regular = into_constraint + .disjuncts + .iter() + .all(|conj| conj.typeguard.is_none()); + + if all_regular { + // Simplify via UnionBuilder + let simplified = UnionType::from_elements( + db, + into_constraint.disjuncts.iter().map(|conj| conj.constraint), + ); + // If simplified to object, we can drop the constraint entirely + if simplified.is_object() { + // Remove this entry since it provides no constraint + entry.remove(); + } else { + // Replace with simplified constraint + into_constraint.disjuncts = smallvec![Conjunction::regular(simplified)]; + } + } } Entry::Vacant(_) => { // Place only appears in `from`, not in `into`. No constraint needed. From 25428c22afdd394014120215225d99666924fff3 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sat, 20 Dec 2025 14:14:11 -0500 Subject: [PATCH 18/22] remove some clones --- crates/ty_python_semantic/src/types/narrow.rs | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 14e0d6851c..4782db357b 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -67,11 +67,8 @@ pub(crate) fn infer_narrowing_constraint<'db>( PredicateNode::ReturnsNever(_) => return None, PredicateNode::StarImportPlaceholder(_) => return None, }; - if let Some(constraints) = constraints { - constraints.get(&place).cloned() - } else { - None - } + + constraints.and_then(|constraints| constraints.get(&place).cloned()) } #[salsa::tracked(returns(as_ref), heap_size=ruff_memory_usage::heap_size)] @@ -428,18 +425,18 @@ type NarrowingConstraints<'db> = FxHashMap( into: &mut NarrowingConstraints<'db>, - from: &NarrowingConstraints<'db>, + from: NarrowingConstraints<'db>, db: &'db dyn Db, ) { for (key, from_constraint) in from { - match into.entry(*key) { + match into.entry(key) { Entry::Occupied(mut entry) => { let into_constraint = entry.get(); - entry.insert(into_constraint.merge_constraint_and(from_constraint, db)); + entry.insert(into_constraint.merge_constraint_and(&from_constraint, db)); } Entry::Vacant(entry) => { - entry.insert(from_constraint.clone()); + entry.insert(from_constraint); } } } @@ -457,20 +454,18 @@ fn merge_constraints_and<'db>( /// via `UnionBuilder` to enable simplifications like `~AlwaysFalsy | ~AlwaysTruthy -> object`. fn merge_constraints_or<'db>( into: &mut NarrowingConstraints<'db>, - from: &NarrowingConstraints<'db>, + from: NarrowingConstraints<'db>, db: &'db dyn Db, ) { // For places that appear in `into` but not in `from`, widen to object into.retain(|key, _| from.contains_key(key)); for (key, from_constraint) in from { - match into.entry(*key) { + match into.entry(key) { Entry::Occupied(mut entry) => { let into_constraint = entry.get_mut(); // Concatenate disjuncts - into_constraint - .disjuncts - .extend(from_constraint.disjuncts.clone()); + into_constraint.disjuncts.extend(from_constraint.disjuncts); // If none of the disjuncts have TypeGuard, we can simplify the constraint types // via UnionBuilder. This enables simplifications like: @@ -486,9 +481,8 @@ fn merge_constraints_or<'db>( db, into_constraint.disjuncts.iter().map(|conj| conj.constraint), ); - // If simplified to object, we can drop the constraint entirely if simplified.is_object() { - // Remove this entry since it provides no constraint + // If simplified to object, we can drop the constraint entirely entry.remove(); } else { // Replace with simplified constraint @@ -1370,7 +1364,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { self.evaluate_pattern_predicate_kind(predicate, subject, is_positive) }) .reduce(|mut constraints, constraints_| { - merge_constraints(&mut constraints, &constraints_, db); + merge_constraints(&mut constraints, constraints_, db); constraints }) } @@ -1402,7 +1396,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let mut aggregation: Option = None; for sub_constraint in sub_constraints.into_iter().flatten() { if let Some(ref mut some_aggregation) = aggregation { - merge_constraints_and(some_aggregation, &sub_constraint, self.db); + merge_constraints_and(some_aggregation, sub_constraint, self.db); } else { aggregation = Some(sub_constraint); } @@ -1418,7 +1412,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if let Some(ref mut first) = first { for rest_constraint in rest { if let Some(rest_constraint) = rest_constraint { - merge_constraints_or(first, &rest_constraint, self.db); + merge_constraints_or(first, rest_constraint, self.db); } else { return None; } From 9d8ae5f074c8e4df0d19a84a9d826ae7be0933a4 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Sat, 20 Dec 2025 14:24:16 -0500 Subject: [PATCH 19/22] stop top-materialization --- .../src/types/infer/builder/type_expression.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index 8550df420d..f130a9d7b6 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -1537,9 +1537,9 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } _ => TypeGuardType::unbound( self.db(), - // Similar to TypeIs, use top materialization - self.infer_type_expression(arguments_slice) - .top_materialization(self.db()), + // Unlike `TypeIs`, don't use top materialization, because + // `TypeGuard` clobbering behavior makes it counterintuitive + self.infer_type_expression(arguments_slice), ), }, SpecialFormType::Concatenate => { From 22afd6a45c5366b7aef752006b927be27d8258f3 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Tue, 23 Dec 2025 02:43:46 -0500 Subject: [PATCH 20/22] DRY some TypeGuard/TypeIs deduplication with trait --- crates/ty_python_semantic/src/types.rs | 115 ++++++++++++++---- .../ty_python_semantic/src/types/display.rs | 57 ++++----- .../types/infer/builder/type_expression.rs | 1 - .../src/types/type_ordering.rs | 49 +------- 4 files changed, 121 insertions(+), 101 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index dfd7cecf91..ab496a0d5c 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -879,6 +879,26 @@ pub enum Type<'db> { NewTypeInstance(NewType<'db>), } +/// Helper for `recursive_type_normalized_impl` for TypeGuardLike types. +fn recursive_type_normalize_type_guard_like<'db, T: TypeGuardLike<'db>>( + db: &'db dyn Db, + guard: T, + div: Type<'db>, + nested: bool, +) -> Option> { + let ty = if nested { + guard + .return_type(db) + .recursive_type_normalized_impl(db, div, true)? + } else { + guard + .return_type(db) + .recursive_type_normalized_impl(db, div, true) + .unwrap_or(div) + }; + Some(guard.with_type(db, ty)) +} + #[salsa::tracked] impl<'db> Type<'db> { pub(crate) const fn any() -> Self { @@ -1745,31 +1765,10 @@ impl<'db> Type<'db> { .recursive_type_normalized_impl(db, div, nested) .map(Type::KnownInstance), Type::TypeIs(type_is) => { - let ty = if nested { - type_is - .return_type(db) - .recursive_type_normalized_impl(db, div, true)? - } else { - type_is - .return_type(db) - .recursive_type_normalized_impl(db, div, true) - .unwrap_or(div) - }; - Some(type_is.with_type(db, ty)) + recursive_type_normalize_type_guard_like(db, type_is, div, nested) } - // TODO: deduplicate Type::TypeGuard(type_guard) => { - let ty = if nested { - type_guard - .return_type(db) - .recursive_type_normalized_impl(db, div, true)? - } else { - type_guard - .return_type(db) - .recursive_type_normalized_impl(db, div, true) - .unwrap_or(div) - }; - Some(type_guard.with_type(db, ty)) + recursive_type_normalize_type_guard_like(db, type_guard, div, nested) } Type::Dynamic(dynamic) => Some(Type::Dynamic(dynamic.recursive_type_normalized())), Type::TypedDict(_) => { @@ -14757,6 +14756,76 @@ impl<'db> VarianceInferable<'db> for TypeGuardType<'db> { } } +/// Common trait for TypeIs and TypeGuard types that share similar structure +/// but have different semantic behaviors. +pub(crate) trait TypeGuardLike<'db>: Copy { + /// The name of this type guard form (for error messages and display) + const FORM_NAME: &'static str; + + /// Get the return type that the type guard narrows to + fn return_type(self, db: &'db dyn Db) -> Type<'db>; + + /// Get the place info (scope and place ID) if bound + fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)>; + + /// Get the human-readable place name if bound + fn place_name(self, db: &'db dyn Db) -> Option; + + /// Create a new instance with a different return type, wrapped in Type + fn with_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db>; + + /// The SpecialFormType for display purposes + fn special_form() -> SpecialFormType; +} + +impl<'db> TypeGuardLike<'db> for TypeIsType<'db> { + const FORM_NAME: &'static str = "TypeIs"; + + fn return_type(self, db: &'db dyn Db) -> Type<'db> { + TypeIsType::return_type(self, db) + } + + fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)> { + TypeIsType::place_info(self, db) + } + + fn place_name(self, db: &'db dyn Db) -> Option { + TypeIsType::place_name(self, db) + } + + fn with_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> { + TypeIsType::with_type(self, db, ty) + } + + fn special_form() -> SpecialFormType { + SpecialFormType::TypeIs + } +} + +impl<'db> TypeGuardLike<'db> for TypeGuardType<'db> { + const FORM_NAME: &'static str = "TypeGuard"; + + fn return_type(self, db: &'db dyn Db) -> Type<'db> { + TypeGuardType::return_type(self, db) + } + + fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)> { + TypeGuardType::place_info(self, db) + } + + fn place_name(self, db: &'db dyn Db) -> Option { + TypeGuardType::place_name(self, db) + } + + fn with_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> { + TypeGuardType::with_type(self, db, ty) + } + + fn special_form() -> SpecialFormType { + SpecialFormType::TypeGuard + } +} + /// Walk the MRO of this class and return the last class just before the specified known base. /// This can be used to determine upper bounds for `Self` type variables on methods that are /// being added to the given class. diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 16fe01babf..36a801c57d 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -27,8 +27,8 @@ use crate::types::visitor::TypeVisitor; use crate::types::{ BoundTypeVarIdentity, CallableType, CallableTypeKind, IntersectionType, KnownBoundMethodType, KnownClass, KnownInstanceType, MaterializationKind, Protocol, ProtocolInstanceType, - SpecialFormType, StringLiteralType, SubclassOfInner, Type, TypedDictType, UnionType, - WrapperDescriptorKind, visitor, + SpecialFormType, StringLiteralType, SubclassOfInner, Type, TypeGuardLike, TypedDictType, + UnionType, WrapperDescriptorKind, visitor, }; /// Settings for displaying types and signatures @@ -584,6 +584,28 @@ impl Display for ClassDisplay<'_> { } } +/// Helper for displaying TypeGuardLike types (TypeIs and TypeGuard). +fn fmt_type_guard_like<'db, T: TypeGuardLike<'db>>( + db: &'db dyn Db, + guard: T, + settings: &DisplaySettings<'db>, + f: &mut TypeWriter<'_, '_, 'db>, +) -> fmt::Result { + f.with_type(Type::SpecialForm(T::special_form())) + .write_str(T::FORM_NAME)?; + f.write_char('[')?; + guard + .return_type(db) + .display_with(db, settings.singleline()) + .fmt_detailed(f)?; + if let Some(name) = guard.place_name(db) { + f.set_invalid_type_annotation(); + f.write_str(" @ ")?; + f.write_str(&name)?; + } + f.write_str("]") +} + /// Writes the string representation of a type, which is the value displayed either as /// `Literal[]` or `Literal[, ]` for literal types or as `` for /// non literals @@ -964,36 +986,9 @@ impl<'db> FmtDetailed<'db> for DisplayRepresentation<'db> { .fmt_detailed(f)?; f.write_str(">") } - Type::TypeIs(type_is) => { - f.with_type(Type::SpecialForm(SpecialFormType::TypeIs)) - .write_str("TypeIs")?; - f.write_char('[')?; - type_is - .return_type(self.db) - .display_with(self.db, self.settings.singleline()) - .fmt_detailed(f)?; - if let Some(name) = type_is.place_name(self.db) { - f.set_invalid_type_annotation(); - f.write_str(" @ ")?; - f.write_str(&name)?; - } - f.write_str("]") - } - // TODO: deduplicate + Type::TypeIs(type_is) => fmt_type_guard_like(self.db, type_is, &self.settings, f), Type::TypeGuard(type_guard) => { - f.with_type(Type::SpecialForm(SpecialFormType::TypeGuard)) - .write_str("TypeGuard")?; - f.write_char('[')?; - type_guard - .return_type(self.db) - .display_with(self.db, self.settings.singleline()) - .fmt_detailed(f)?; - if let Some(name) = type_guard.place_name(self.db) { - f.set_invalid_syntax(); - f.write_str(" @ ")?; - f.write_str(&name)?; - } - f.write_str("]") + fmt_type_guard_like(self.db, type_guard, &self.settings, f) } Type::TypedDict(TypedDictType::Class(defining_class)) => defining_class .class_literal(self.db) diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index f130a9d7b6..4ca06e9ca9 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -1521,7 +1521,6 @@ impl<'db> TypeInferenceBuilder<'db, '_> { .top_materialization(self.db()), ), }, - // TODO: deduplicate SpecialFormType::TypeGuard => match arguments_slice { ast::Expr::Tuple(_) => { self.infer_type_expression(arguments_slice); diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index 67875461b4..fdf834b8b9 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -2,14 +2,10 @@ use std::cmp::Ordering; use salsa::plumbing::AsId; -use crate::{ - db::Db, - semantic_index::{place::ScopedPlaceId, scope::ScopeId}, - types::bound_super::SuperOwnerKind, -}; +use crate::{db::Db, types::bound_super::SuperOwnerKind}; use super::{ - DynamicType, TodoType, Type, TypeGuardType, TypeIsType, class_base::ClassBase, + DynamicType, TodoType, Type, TypeGuardLike, TypeGuardType, TypeIsType, class_base::ClassBase, subclass_of::SubclassOfInner, }; @@ -295,52 +291,13 @@ fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering } } -/// Trait for type guard-like types that can be ordered canonically. -trait GuardLikeOrdering<'db>: Copy { - fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)>; - fn place_name(self, db: &'db dyn Db) -> Option; - fn return_type(self, db: &'db dyn Db) -> Type<'db>; -} - -impl<'db> GuardLikeOrdering<'db> for TypeIsType<'db> { - fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)> { - TypeIsType::place_info(self, db) - } - - fn place_name(self, db: &'db dyn Db) -> Option { - TypeIsType::place_name(self, db) - } - - fn return_type(self, db: &'db dyn Db) -> Type<'db> { - TypeIsType::return_type(self, db) - } -} - -impl<'db> GuardLikeOrdering<'db> for TypeGuardType<'db> { - fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)> { - TypeGuardType::place_info(self, db) - } - - fn place_name(self, db: &'db dyn Db) -> Option { - TypeGuardType::place_name(self, db) - } - - fn return_type(self, db: &'db dyn Db) -> Type<'db> { - TypeGuardType::return_type(self, db) - } -} - /// Generic helper for ordering type guard-like types. /// /// The following criteria are considered, in order: /// * Boundness: Unbound precedes bound /// * Symbol name: String comparison /// * Guarded type: [`union_or_intersection_elements_ordering`] -fn guard_like_ordering<'db, T: GuardLikeOrdering<'db>>( - db: &'db dyn Db, - left: T, - right: T, -) -> Ordering { +fn guard_like_ordering<'db, T: TypeGuardLike<'db>>(db: &'db dyn Db, left: T, right: T) -> Ordering { let (left_ty, right_ty) = (left.return_type(db), right.return_type(db)); match (left.place_info(db), right.place_info(db)) { From 567f01841645eb709f867c461d4345b43ebb8b3e Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Tue, 23 Dec 2025 02:44:26 -0500 Subject: [PATCH 21/22] optimize the constraint DNF representation --- .../resources/mdtest/narrow/type_guards.md | 2 +- .../src/semantic_index/use_def.rs | 6 +- crates/ty_python_semantic/src/types/narrow.rs | 198 +++++++++--------- 3 files changed, 99 insertions(+), 107 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index 7d5952c5f3..223b6a13ab 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -217,7 +217,7 @@ def is_b(val: object) -> TypeGuard[B]: def _(x: P): if isinstance(x, A) or is_b(x): - reveal_type(x) # revealed: (P & A) | B + reveal_type(x) # revealed: B | (P & A) ``` Attribute and subscript narrowing is supported: diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index 9288a6df50..c34928184d 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -766,12 +766,12 @@ impl<'db> ConstraintsIterator<'_, 'db> { self.filter_map(|constraint| infer_narrowing_constraint(db, constraint, place)) .reduce(|acc, constraint| { // See above---note the reverse application - constraint.merge_constraint_and(&acc, db) + constraint.merge_constraint_and(acc, db) }) .map_or(base_ty, |constraint| { NarrowingConstraint::regular(base_ty) - .merge_constraint_and(&constraint, db) - .evaluate_type_constraint(db) + .merge_constraint_and(constraint, db) + .evaluate_constraint_type(db) }) } } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 4782db357b..010fbfbef4 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -19,7 +19,7 @@ use crate::types::{ use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_stdlib::identifiers::is_identifier; -use itertools::{Either, Itertools}; +use itertools::Itertools; use ruff_python_ast as ast; use ruff_python_ast::{BoolOp, ExprBoolOp}; use rustc_hash::FxHashMap; @@ -272,56 +272,37 @@ impl ClassInfoConstraintFunction { } } -/// Represents a single conjunction (AND) of constraints in Disjunctive Normal -/// Form (DNF). +/// Represents a TypeGuard-containing disjunct in a Disjunctive Normal Form +/// (DNF) narrowing constraint. /// -/// A conjunction may contain: - A regular constraint (intersection of types) - -/// An optional `TypeGuard` constraint that "replaces" the type rather than -/// intersecting +/// Such a constraint may optionally have a refinement applied after the type +/// guard, which is interpreted as being intersected with the type guard. /// -/// For example, `(Conjunction { constraint: A, typeguard: Some(B) } & -/// Conjunction { constraint: C, typeguard: Some(D)})` evaluates to -/// `Conjunction { constraint: C, typeguard: Some(D) }` because the type guard +/// For example, `(TypeGuardConstraint { typeguard: A, refinement: Some(B) } & +/// TypeGuardConstraint { typeguard: C, refinement: Some(D) })` evaluates to +/// `TypeGuardConstraint { typeguard: C, refinement: Some(D) }` because the type guard /// in the second conjunct clobbers that in the first. -#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy, salsa::Update, get_size2::GetSize)] -struct Conjunction<'db> { - /// The intersected constraints (represented as a type to intersect the guard with) - constraint: Type<'db>, - /// If any constraint in this conjunction is a `TypeGuard[T]`, this is `Some(T)` - typeguard: Option>, +#[derive(Hash, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] +struct TypeGuardConstraint<'db> { + /// If `TypeGuard[T]`, this is `Some(T)` + typeguard: Type<'db>, + /// If additional constraints are applied _after_ the TypeGuard, then they + /// go here + refinement: Option>, } -impl<'db> Conjunction<'db> { - /// Create a new conjunction with just a regular constraint - fn regular(constraint: Type<'db>) -> Self { - Self { - constraint, - typeguard: None, +impl<'db> TypeGuardConstraint<'db> { + /// Evaluate this typeguard constraint to a single type. + /// If there's a refinement, it's intersected with the typeguard constraint. + fn evaluate_constraint_type(self, db: &'db dyn Db) -> Type<'db> { + match self.refinement { + Some(refinement) => IntersectionBuilder::new(db) + .add_positive(self.typeguard) + .add_positive(refinement) + .build(), + None => self.typeguard, } } - - /// Create a new conjunction with a `TypeGuard` constraint - fn typeguard(constraint: Type<'db>) -> Self { - Self { - constraint: Type::object(), - typeguard: Some(constraint), - } - } - - /// Evaluate this conjunction to a single type. - /// If there's a `TypeGuard` constraint, it replaces the regular constraint. - /// Otherwise, returns the regular constraint. - fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> { - self.typeguard.map_or_else( - || self.constraint, - |typeguard_constraint| { - IntersectionBuilder::new(db) - .add_positive(typeguard_constraint) - .add_positive(self.constraint) - .build() - }, - ) - } } /// Represents narrowing constraints in Disjunctive Normal Form (DNF). @@ -340,68 +321,92 @@ impl<'db> Conjunction<'db> { /// => evaluates to `(P & A) | B`, where `P` is our previously-known type #[derive(Hash, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] pub(crate) struct NarrowingConstraint<'db> { + /// Regular constraint---we don't need a list here because we can represent + /// with a union type + regular_disjunct: Option>, /// Disjunction of conjunctions (DNF) - disjuncts: SmallVec<[Conjunction<'db>; 1]>, + typeguard_disjuncts: SmallVec<[TypeGuardConstraint<'db>; 1]>, } impl<'db> NarrowingConstraint<'db> { /// Create a constraint from a regular (non-`TypeGuard`) type pub(crate) fn regular(constraint: Type<'db>) -> Self { Self { - disjuncts: smallvec![Conjunction::regular(constraint)], + regular_disjunct: Some(constraint), + typeguard_disjuncts: smallvec![], } } /// Create a constraint from a `TypeGuard` type fn typeguard(constraint: Type<'db>) -> Self { Self { - disjuncts: smallvec![Conjunction::typeguard(constraint)], + regular_disjunct: None, + typeguard_disjuncts: smallvec![TypeGuardConstraint { + typeguard: constraint, + refinement: None, + }], } } /// Merge two constraints, taking their intersection but respecting `TypeGuard` semantics (with `other` winning) - pub(crate) fn merge_constraint_and(&self, other: &Self, db: &'db dyn Db) -> Self { + pub(crate) fn merge_constraint_and(&self, other: Self, db: &'db dyn Db) -> Self { // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... - let new_disjuncts = other - .disjuncts - .iter() - .flat_map(|right_conj| { - // We iterate the RHS first because if it has a typeguard then we don't need to consider the LHS - if right_conj.typeguard.is_some() { - // If the right conjunct has a TypeGuard, it "wins" the conjunction - Either::Left(std::iter::once(*right_conj)) - } else { - // Otherwise, we need to consider all LHS disjuncts - Either::Right(self.disjuncts.iter().map(|left_conj| { - let new_regular = IntersectionBuilder::new(db) - .add_positive(left_conj.constraint) - .add_positive(right_conj.constraint) - .build(); + // + // In our representation, the RHS `typeguard_disjuncts` will all clobber + // the LHS disjuncts when they are anded, so they'll just stay as is. + // + // The thing we actually need to deal with is the RHS `regular_disjunct`. + // It gets anded onto the LHS `regular_disjunct` to form the new + // `regular_disjunct`, and anded onto each LHS `typeguard_disjunct` (via + // the refinement) to form new additional `typeguard_disjuncts`. + let Some(other_regular_disjunct) = other.regular_disjunct else { + return other; + }; - Conjunction { - constraint: new_regular, - typeguard: left_conj.typeguard, - } - })) - } - }) - .collect(); + let new_regular_disjunct = self.regular_disjunct.map(|regular_disjunct| { + IntersectionBuilder::new(db) + .add_positive(regular_disjunct) + .add_positive(other_regular_disjunct) + .build() + }); + + let additional_typeguard_disjuncts = + self.typeguard_disjuncts + .iter() + .map(|typeguard_disjunct| TypeGuardConstraint { + typeguard: typeguard_disjunct.typeguard, + refinement: match typeguard_disjunct.refinement { + Some(refinement) => Some( + IntersectionBuilder::new(db) + .add_positive(refinement) + .add_positive(other_regular_disjunct) + .build(), + ), + None => other.regular_disjunct, + }, + }); + + let mut new_typeguard_disjuncts = other.typeguard_disjuncts; + + new_typeguard_disjuncts.extend(additional_typeguard_disjuncts); NarrowingConstraint { - disjuncts: new_disjuncts, + typeguard_disjuncts: new_typeguard_disjuncts, + regular_disjunct: new_regular_disjunct, } } /// Evaluate the type this effectively constrains to /// /// Forgets whether each constraint originated from a `TypeGuard` or not - pub(crate) fn evaluate_type_constraint(self, db: &'db dyn Db) -> Type<'db> { + pub(crate) fn evaluate_constraint_type(self, db: &'db dyn Db) -> Type<'db> { UnionType::from_elements( db, - self.disjuncts + self.typeguard_disjuncts .into_iter() - .map(|disjunct| Conjunction::evaluate_type_constraint(disjunct, db)), + .map(|disjunct| disjunct.evaluate_constraint_type(db)) + .chain(self.regular_disjunct), ) } } @@ -433,7 +438,7 @@ fn merge_constraints_and<'db>( Entry::Occupied(mut entry) => { let into_constraint = entry.get(); - entry.insert(into_constraint.merge_constraint_and(&from_constraint, db)); + entry.insert(into_constraint.merge_constraint_and(from_constraint, db)); } Entry::Vacant(entry) => { entry.insert(from_constraint); @@ -449,9 +454,6 @@ fn merge_constraints_and<'db>( /// /// However, if a place appears in only one branch of the OR, we need to widen it /// to `object` in the overall result (because the other branch doesn't constrain it). -/// -/// When none of the disjuncts have `TypeGuard`, we simplify the constraint types -/// via `UnionBuilder` to enable simplifications like `~AlwaysFalsy | ~AlwaysTruthy -> object`. fn merge_constraints_or<'db>( into: &mut NarrowingConstraints<'db>, from: NarrowingConstraints<'db>, @@ -464,31 +466,21 @@ fn merge_constraints_or<'db>( match into.entry(key) { Entry::Occupied(mut entry) => { let into_constraint = entry.get_mut(); - // Concatenate disjuncts - into_constraint.disjuncts.extend(from_constraint.disjuncts); + // Union the regular constraints + into_constraint.regular_disjunct = match ( + into_constraint.regular_disjunct, + from_constraint.regular_disjunct, + ) { + (Some(a), Some(b)) => Some(UnionType::from_elements(db, [a, b])), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + }; - // If none of the disjuncts have TypeGuard, we can simplify the constraint types - // via UnionBuilder. This enables simplifications like: - // `~AlwaysFalsy | ~AlwaysTruthy -> object` - let all_regular = into_constraint - .disjuncts - .iter() - .all(|conj| conj.typeguard.is_none()); - - if all_regular { - // Simplify via UnionBuilder - let simplified = UnionType::from_elements( - db, - into_constraint.disjuncts.iter().map(|conj| conj.constraint), - ); - if simplified.is_object() { - // If simplified to object, we can drop the constraint entirely - entry.remove(); - } else { - // Replace with simplified constraint - into_constraint.disjuncts = smallvec![Conjunction::regular(simplified)]; - } - } + // Concatenate typeguard disjuncts + into_constraint + .typeguard_disjuncts + .extend(from_constraint.typeguard_disjuncts); } Entry::Vacant(_) => { // Place only appears in `from`, not in `into`. No constraint needed. From 5a1305ccbc2cc57c1a7867407acc4fdac5cef8ce Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Tue, 23 Dec 2025 03:01:50 -0500 Subject: [PATCH 22/22] clippy --- crates/ty_python_semantic/src/types.rs | 6 +++--- crates/ty_python_semantic/src/types/display.rs | 2 +- crates/ty_python_semantic/src/types/narrow.rs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index ab496a0d5c..4adf77f30a 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -879,7 +879,7 @@ pub enum Type<'db> { NewTypeInstance(NewType<'db>), } -/// Helper for `recursive_type_normalized_impl` for TypeGuardLike types. +/// Helper for `recursive_type_normalized_impl` for `TypeGuardLike` types. fn recursive_type_normalize_type_guard_like<'db, T: TypeGuardLike<'db>>( db: &'db dyn Db, guard: T, @@ -14756,7 +14756,7 @@ impl<'db> VarianceInferable<'db> for TypeGuardType<'db> { } } -/// Common trait for TypeIs and TypeGuard types that share similar structure +/// Common trait for `TypeIs` and `TypeGuard` types that share similar structure /// but have different semantic behaviors. pub(crate) trait TypeGuardLike<'db>: Copy { /// The name of this type guard form (for error messages and display) @@ -14774,7 +14774,7 @@ pub(crate) trait TypeGuardLike<'db>: Copy { /// Create a new instance with a different return type, wrapped in Type fn with_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db>; - /// The SpecialFormType for display purposes + /// The `SpecialFormType` for display purposes fn special_form() -> SpecialFormType; } diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 36a801c57d..1cb1836d7e 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -584,7 +584,7 @@ impl Display for ClassDisplay<'_> { } } -/// Helper for displaying TypeGuardLike types (TypeIs and TypeGuard). +/// Helper for displaying `TypeGuardLike` types `TypeIs` and `TypeGuard`. fn fmt_type_guard_like<'db, T: TypeGuardLike<'db>>( db: &'db dyn Db, guard: T, diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 010fbfbef4..6b5a7a417e 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -286,7 +286,7 @@ impl ClassInfoConstraintFunction { struct TypeGuardConstraint<'db> { /// If `TypeGuard[T]`, this is `Some(T)` typeguard: Type<'db>, - /// If additional constraints are applied _after_ the TypeGuard, then they + /// If additional constraints are applied _after_ the `TypeGuard`, then they /// go here refinement: Option>, }