[ty] detect cycles in Type::is_disjoint_from (#19139)

This commit is contained in:
Carl Meyer 2025-07-04 06:31:44 -07:00 committed by GitHub
parent 7712c2fd15
commit 411cccb35e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 129 additions and 52 deletions

View file

@ -1862,6 +1862,21 @@ class Bar(Protocol):
static_assert(is_equivalent_to(Foo, Bar))
```
### Disjointness of recursive protocol and recursive final type
```py
from typing import Protocol
from ty_extensions import is_disjoint_from, static_assert
class Proto(Protocol):
x: "Proto"
class Nominal:
x: "Nominal"
static_assert(not is_disjoint_from(Proto, Nominal))
```
### Regression test: narrowing with self-referential protocols
This snippet caused us to panic on an early version of the implementation for protocols.

View file

@ -19,7 +19,7 @@ use ruff_text_size::{Ranged, TextRange};
use type_ordering::union_or_intersection_elements_ordering;
pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder};
pub(crate) use self::cyclic::TypeTransformer;
pub(crate) use self::cyclic::{PairVisitor, TypeTransformer};
pub use self::diagnostic::TypeCheckDiagnostics;
pub(crate) use self::diagnostic::register_lints;
pub(crate) use self::infer::{
@ -1637,17 +1637,30 @@ impl<'db> Type<'db> {
/// Note: This function aims to have no false positives, but might return
/// wrong `false` answers in some cases.
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
let mut visitor = PairVisitor::new(false);
self.is_disjoint_from_impl(db, other, &mut visitor)
}
pub(crate) fn is_disjoint_from_impl(
self,
db: &'db dyn Db,
other: Type<'db>,
visitor: &mut PairVisitor<'db>,
) -> bool {
fn any_protocol_members_absent_or_disjoint<'db>(
db: &'db dyn Db,
protocol: ProtocolInstanceType<'db>,
other: Type<'db>,
visitor: &mut PairVisitor<'db>,
) -> bool {
protocol.interface(db).members(db).any(|member| {
other
.member(db, member.name())
.place
.ignore_possibly_unbound()
.is_none_or(|attribute_type| member.has_disjoint_type_from(db, attribute_type))
.is_none_or(|attribute_type| {
member.has_disjoint_type_from(db, attribute_type, visitor)
})
})
}
@ -1681,19 +1694,19 @@ impl<'db> Type<'db> {
match typevar.bound_or_constraints(db) {
None => false,
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
bound.is_disjoint_from(db, other)
bound.is_disjoint_from_impl(db, other, visitor)
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => constraints
.elements(db)
.iter()
.all(|constraint| constraint.is_disjoint_from(db, other)),
.all(|constraint| constraint.is_disjoint_from_impl(db, other, visitor)),
}
}
(Type::Union(union), other) | (other, Type::Union(union)) => union
.elements(db)
.iter()
.all(|e| e.is_disjoint_from(db, other)),
.all(|e| e.is_disjoint_from_impl(db, other, visitor)),
// If we have two intersections, we test the positive elements of each one against the other intersection
// Negative elements need a positive element on the other side in order to be disjoint.
@ -1702,11 +1715,11 @@ impl<'db> Type<'db> {
self_intersection
.positive(db)
.iter()
.any(|p| p.is_disjoint_from(db, other))
.any(|p| p.is_disjoint_from_impl(db, other, visitor))
|| other_intersection
.positive(db)
.iter()
.any(|p: &Type<'_>| p.is_disjoint_from(db, self))
.any(|p: &Type<'_>| p.is_disjoint_from_impl(db, self, visitor))
}
(Type::Intersection(intersection), other)
@ -1714,7 +1727,7 @@ impl<'db> Type<'db> {
intersection
.positive(db)
.iter()
.any(|p| p.is_disjoint_from(db, other))
.any(|p| p.is_disjoint_from_impl(db, other, visitor))
// A & B & Not[C] is disjoint from C
|| intersection
.negative(db)
@ -1828,17 +1841,17 @@ impl<'db> Type<'db> {
}
(Type::ProtocolInstance(left), Type::ProtocolInstance(right)) => {
left.is_disjoint_from(db, right)
left.is_disjoint_from_impl(db, right, visitor)
}
(Type::ProtocolInstance(protocol), Type::SpecialForm(special_form))
| (Type::SpecialForm(special_form), Type::ProtocolInstance(protocol)) => {
any_protocol_members_absent_or_disjoint(db, protocol, special_form.instance_fallback(db))
any_protocol_members_absent_or_disjoint(db, protocol, special_form.instance_fallback(db), visitor)
}
(Type::ProtocolInstance(protocol), Type::KnownInstance(known_instance))
| (Type::KnownInstance(known_instance), Type::ProtocolInstance(protocol)) => {
any_protocol_members_absent_or_disjoint(db, protocol, known_instance.instance_fallback(db))
any_protocol_members_absent_or_disjoint(db, protocol, known_instance.instance_fallback(db), visitor)
}
// The absence of a protocol member on one of these types guarantees
@ -1891,7 +1904,7 @@ impl<'db> Type<'db> {
| Type::ModuleLiteral(..)
| Type::GenericAlias(..)
| Type::IntLiteral(..)),
) => any_protocol_members_absent_or_disjoint(db, protocol, ty),
) => any_protocol_members_absent_or_disjoint(db, protocol, ty, visitor),
// This is the same as the branch above --
// once guard patterns are stabilised, it could be unified with that branch
@ -1900,7 +1913,7 @@ impl<'db> Type<'db> {
| (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol))
if n.class.is_final(db) =>
{
any_protocol_members_absent_or_disjoint(db, protocol, nominal)
any_protocol_members_absent_or_disjoint(db, protocol, nominal, visitor)
}
(Type::ProtocolInstance(protocol), other)
@ -1908,7 +1921,7 @@ impl<'db> Type<'db> {
protocol.interface(db).members(db).any(|member| {
matches!(
other.member(db, member.name()).place,
Place::Type(attribute_type, _) if member.has_disjoint_type_from(db, attribute_type)
Place::Type(attribute_type, _) if member.has_disjoint_type_from(db, attribute_type, visitor)
)
})
}
@ -1931,18 +1944,18 @@ impl<'db> Type<'db> {
}
}
(Type::SubclassOf(left), Type::SubclassOf(right)) => left.is_disjoint_from(db, right),
(Type::SubclassOf(left), Type::SubclassOf(right)) => left.is_disjoint_from_impl(db, right),
// for `type[Any]`/`type[Unknown]`/`type[Todo]`, we know the type cannot be any larger than `type`,
// so although the type is dynamic we can still determine disjointedness in some situations
(Type::SubclassOf(subclass_of_ty), other)
| (other, Type::SubclassOf(subclass_of_ty)) => match subclass_of_ty.subclass_of() {
SubclassOfInner::Dynamic(_) => {
KnownClass::Type.to_instance(db).is_disjoint_from(db, other)
KnownClass::Type.to_instance(db).is_disjoint_from_impl(db, other, visitor)
}
SubclassOfInner::Class(class) => class
.metaclass_instance_type(db)
.is_disjoint_from(db, other),
.is_disjoint_from_impl(db, other, visitor),
},
(Type::SpecialForm(special_form), Type::NominalInstance(instance))
@ -2027,18 +2040,18 @@ impl<'db> Type<'db> {
(Type::BoundMethod(_), other) | (other, Type::BoundMethod(_)) => KnownClass::MethodType
.to_instance(db)
.is_disjoint_from(db, other),
.is_disjoint_from_impl(db, other, visitor),
(Type::MethodWrapper(_), other) | (other, Type::MethodWrapper(_)) => {
KnownClass::MethodWrapperType
.to_instance(db)
.is_disjoint_from(db, other)
.is_disjoint_from_impl(db, other, visitor)
}
(Type::WrapperDescriptor(_), other) | (other, Type::WrapperDescriptor(_)) => {
KnownClass::WrapperDescriptorType
.to_instance(db)
.is_disjoint_from(db, other)
.is_disjoint_from_impl(db, other, visitor)
}
(Type::Callable(_) | Type::FunctionLiteral(_), Type::Callable(_))
@ -2100,15 +2113,15 @@ impl<'db> Type<'db> {
(Type::ModuleLiteral(..), other @ Type::NominalInstance(..))
| (other @ Type::NominalInstance(..), Type::ModuleLiteral(..)) => {
// Modules *can* actually be instances of `ModuleType` subclasses
other.is_disjoint_from(db, KnownClass::ModuleType.to_instance(db))
other.is_disjoint_from_impl(db, KnownClass::ModuleType.to_instance(db), visitor)
}
(Type::NominalInstance(left), Type::NominalInstance(right)) => {
left.is_disjoint_from(db, right)
left.is_disjoint_from_impl(db, right)
}
(Type::Tuple(tuple), Type::Tuple(other_tuple)) => {
tuple.is_disjoint_from(db, other_tuple)
tuple.is_disjoint_from_impl(db, other_tuple, visitor)
}
(Type::Tuple(tuple), Type::NominalInstance(instance))
@ -2121,13 +2134,13 @@ impl<'db> Type<'db> {
(Type::PropertyInstance(_), other) | (other, Type::PropertyInstance(_)) => {
KnownClass::Property
.to_instance(db)
.is_disjoint_from(db, other)
.is_disjoint_from_impl(db, other, visitor)
}
(Type::BoundSuper(_), Type::BoundSuper(_)) => !self.is_equivalent_to(db, other),
(Type::BoundSuper(_), other) | (other, Type::BoundSuper(_)) => KnownClass::Super
.to_instance(db)
.is_disjoint_from(db, other),
.is_disjoint_from_impl(db, other, visitor),
}
}

View file

@ -1,23 +1,39 @@
use crate::FxIndexSet;
use crate::types::Type;
use std::cmp::Eq;
use std::hash::Hash;
#[derive(Debug, Default)]
pub(crate) struct TypeTransformer<'db> {
seen: FxIndexSet<Type<'db>>,
pub(crate) type TypeTransformer<'db> = CycleDetector<Type<'db>, Type<'db>>;
impl Default for TypeTransformer<'_> {
fn default() -> Self {
// TODO: proper recursive type handling
// This must be Any, not e.g. a todo type, because Any is the normalized form of the
// dynamic type (that is, todo types are normalized to Any).
CycleDetector::new(Type::any())
}
}
impl<'db> TypeTransformer<'db> {
pub(crate) fn visit(
&mut self,
ty: Type<'db>,
func: impl FnOnce(&mut Self) -> Type<'db>,
) -> Type<'db> {
if !self.seen.insert(ty) {
// TODO: proper recursive type handling
pub(crate) type PairVisitor<'db> = CycleDetector<(Type<'db>, Type<'db>), bool>;
// This must be Any, not e.g. a todo type, because Any is the normalized form of the
// dynamic type (that is, todo types are normalized to Any).
return Type::any();
#[derive(Debug)]
pub(crate) struct CycleDetector<T, R> {
seen: FxIndexSet<T>,
fallback: R,
}
impl<T: Hash + Eq, R: Copy> CycleDetector<T, R> {
pub(crate) fn new(fallback: R) -> Self {
CycleDetector {
seen: FxIndexSet::default(),
fallback,
}
}
pub(crate) fn visit(&mut self, item: T, func: impl FnOnce(&mut Self) -> R) -> R {
if !self.seen.insert(item) {
return self.fallback;
}
let ret = func(self);
self.seen.pop();

View file

@ -5,6 +5,7 @@ use std::marker::PhantomData;
use super::protocol_class::ProtocolInterface;
use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance};
use crate::place::PlaceAndQualifiers;
use crate::types::cyclic::PairVisitor;
use crate::types::protocol_class::walk_protocol_interface;
use crate::types::tuple::TupleType;
use crate::types::{DynamicType, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance};
@ -118,7 +119,7 @@ impl<'db> NominalInstanceType<'db> {
self.class.is_equivalent_to(db, other.class)
}
pub(super) fn is_disjoint_from(self, db: &'db dyn Db, other: Self) -> bool {
pub(super) fn is_disjoint_from_impl(self, db: &'db dyn Db, other: Self) -> bool {
!self.class.could_coexist_in_mro_with(db, other.class)
}
@ -277,7 +278,12 @@ impl<'db> ProtocolInstanceType<'db> {
/// TODO: a protocol `X` is disjoint from a protocol `Y` if `X` and `Y`
/// have a member with the same name but disjoint types
#[expect(clippy::unused_self)]
pub(super) fn is_disjoint_from(self, _db: &'db dyn Db, _other: Self) -> bool {
pub(super) fn is_disjoint_from_impl(
self,
_db: &'db dyn Db,
_other: Self,
_visitor: &mut PairVisitor<'db>,
) -> bool {
false
}

View file

@ -11,6 +11,7 @@ use crate::{
types::{
CallableType, ClassBase, ClassLiteral, KnownFunction, PropertyInstanceType, Signature,
Type, TypeMapping, TypeQualifiers, TypeRelation, TypeTransformer, TypeVarInstance,
cyclic::PairVisitor,
signatures::{Parameter, Parameters},
},
};
@ -359,11 +360,18 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
}
}
pub(super) fn has_disjoint_type_from(&self, db: &'db dyn Db, other: Type<'db>) -> bool {
pub(super) fn has_disjoint_type_from(
&self,
db: &'db dyn Db,
other: Type<'db>,
visitor: &mut PairVisitor<'db>,
) -> bool {
match &self.kind {
// TODO: implement disjointness for property/method members as well as attribute members
ProtocolMemberKind::Property(_) | ProtocolMemberKind::Method(_) => false,
ProtocolMemberKind::Other(ty) => ty.is_disjoint_from(db, other),
ProtocolMemberKind::Other(ty) => {
visitor.visit((*ty, other), |v| ty.is_disjoint_from_impl(db, other, v))
}
}
}

View file

@ -170,7 +170,7 @@ impl<'db> SubclassOfType<'db> {
/// Return` true` if `self` is a disjoint type from `other`.
///
/// See [`Type::is_disjoint_from`] for more details.
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Self) -> bool {
pub(crate) fn is_disjoint_from_impl(self, db: &'db dyn Db, other: Self) -> bool {
match (self.subclass_of, other.subclass_of) {
(SubclassOfInner::Dynamic(_), _) | (_, SubclassOfInner::Dynamic(_)) => false,
(SubclassOfInner::Class(self_class), SubclassOfInner::Class(other_class)) => {

View file

@ -25,7 +25,7 @@ use itertools::{Either, EitherOrBoth, Itertools};
use crate::types::class::{ClassType, KnownClass};
use crate::types::{
Type, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance, TypeVarVariance,
UnionBuilder, UnionType,
UnionBuilder, UnionType, cyclic::PairVisitor,
};
use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError};
use crate::{Db, FxOrderSet};
@ -227,8 +227,14 @@ impl<'db> TupleType<'db> {
self.tuple(db).is_equivalent_to(db, other.tuple(db))
}
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Self) -> bool {
self.tuple(db).is_disjoint_from(db, other.tuple(db))
pub(crate) fn is_disjoint_from_impl(
self,
db: &'db dyn Db,
other: Self,
visitor: &mut PairVisitor<'db>,
) -> bool {
self.tuple(db)
.is_disjoint_from_impl(db, other.tuple(db), visitor)
}
pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool {
@ -1058,7 +1064,12 @@ impl<'db> Tuple<Type<'db>> {
}
}
fn is_disjoint_from(&self, db: &'db dyn Db, other: &Self) -> bool {
fn is_disjoint_from_impl(
&'db self,
db: &'db dyn Db,
other: &'db Self,
visitor: &mut PairVisitor<'db>,
) -> bool {
// Two tuples with an incompatible number of required elements must always be disjoint.
let (self_min, self_max) = self.len().size_hint();
let (other_min, other_max) = other.len().size_hint();
@ -1075,15 +1086,16 @@ impl<'db> Tuple<Type<'db>> {
db: &'db dyn Db,
a: impl IntoIterator<Item = &'db Type<'db>>,
b: impl IntoIterator<Item = &'db Type<'db>>,
visitor: &mut PairVisitor<'db>,
) -> bool {
a.into_iter().zip(b).any(|(self_element, other_element)| {
self_element.is_disjoint_from(db, *other_element)
self_element.is_disjoint_from_impl(db, *other_element, visitor)
})
}
match (self, other) {
(Tuple::Fixed(self_tuple), Tuple::Fixed(other_tuple)) => {
if any_disjoint(db, self_tuple.elements(), other_tuple.elements()) {
if any_disjoint(db, self_tuple.elements(), other_tuple.elements(), visitor) {
return true;
}
}
@ -1093,6 +1105,7 @@ impl<'db> Tuple<Type<'db>> {
db,
self_tuple.prefix_elements(),
other_tuple.prefix_elements(),
visitor,
) {
return true;
}
@ -1100,6 +1113,7 @@ impl<'db> Tuple<Type<'db>> {
db,
self_tuple.suffix_elements().rev(),
other_tuple.suffix_elements().rev(),
visitor,
) {
return true;
}
@ -1107,10 +1121,15 @@ impl<'db> Tuple<Type<'db>> {
(Tuple::Fixed(fixed), Tuple::Variable(variable))
| (Tuple::Variable(variable), Tuple::Fixed(fixed)) => {
if any_disjoint(db, fixed.elements(), variable.prefix_elements()) {
if any_disjoint(db, fixed.elements(), variable.prefix_elements(), visitor) {
return true;
}
if any_disjoint(db, fixed.elements().rev(), variable.suffix_elements().rev()) {
if any_disjoint(
db,
fixed.elements().rev(),
variable.suffix_elements().rev(),
visitor,
) {
return true;
}
}