[ty] Improve assignability/subtyping between two protocol types

This commit is contained in:
Alex Waygood 2025-09-12 20:27:35 +01:00
parent 1cf19732b9
commit 5ccc8b2a86
13 changed files with 900 additions and 333 deletions

View file

@ -232,7 +232,7 @@ static STATIC_FRAME: std::sync::LazyLock<Benchmark<'static>> = std::sync::LazyLo
max_dep_date: "2025-08-09", max_dep_date: "2025-08-09",
python_version: PythonVersion::PY311, python_version: PythonVersion::PY311,
}, },
600, 620,
) )
}); });

View file

@ -1221,11 +1221,7 @@ from typing_extensions import LiteralString
def f(a: Foo, b: list[str], c: list[LiteralString], e): def f(a: Foo, b: list[str], c: list[LiteralString], e):
reveal_type(e) # revealed: Unknown reveal_type(e) # revealed: Unknown
reveal_type(a.join(b)) # revealed: str
# TODO: we should select the second overload here and reveal `str`
# (the incorrect result is due to missing logic in protocol subtyping/assignability)
reveal_type(a.join(b)) # revealed: LiteralString
reveal_type(a.join(c)) # revealed: LiteralString reveal_type(a.join(c)) # revealed: LiteralString
# since both overloads match and they have return types that are not equivalent, # since both overloads match and they have return types that are not equivalent,

View file

@ -205,8 +205,8 @@ from typing import Protocol
import proto_a import proto_a
import proto_b import proto_b
# TODO should be error: [invalid-assignment] "Object of type `proto_b.Drawable` is not assignable to `proto_a.Drawable`"
def _(drawable_b: proto_b.Drawable): def _(drawable_b: proto_b.Drawable):
# error: [invalid-assignment] "Object of type `proto_b.Drawable` is not assignable to `proto_a.Drawable`"
drawable: proto_a.Drawable = drawable_b drawable: proto_a.Drawable = drawable_b
``` ```

View file

@ -247,8 +247,7 @@ class StrIterator:
def f(x: IntIterator | StrIterator): def f(x: IntIterator | StrIterator):
for a in x: for a in x:
# TODO: this should be `int | str` (https://github.com/astral-sh/ty/issues/1089) reveal_type(a) # revealed: int | str
reveal_type(a) # revealed: int
``` ```
Most real-world iterable types use `Iterator` as the return annotation of their `__iter__` methods: Most real-world iterable types use `Iterator` as the return annotation of their `__iter__` methods:
@ -260,14 +259,11 @@ def g(
c: Literal["foo", b"bar"], c: Literal["foo", b"bar"],
): ):
for x in a: for x in a:
# TODO: should be `int | str` (https://github.com/astral-sh/ty/issues/1089) reveal_type(x) # revealed: int | str
reveal_type(x) # revealed: int
for y in b: for y in b:
# TODO: should be `str | int` (https://github.com/astral-sh/ty/issues/1089) reveal_type(y) # revealed: str | int
reveal_type(y) # revealed: str
for z in c: for z in c:
# TODO: should be `LiteralString | int` (https://github.com/astral-sh/ty/issues/1089) reveal_type(z) # revealed: LiteralString | int
reveal_type(z) # revealed: LiteralString
``` ```
## Union type as iterable where one union element has no `__iter__` method ## Union type as iterable where one union element has no `__iter__` method

View file

@ -617,11 +617,10 @@ static_assert(is_assignable_to(Foo, HasX))
static_assert(not is_subtype_of(Foo, HasXY)) static_assert(not is_subtype_of(Foo, HasXY))
static_assert(not is_assignable_to(Foo, HasXY)) static_assert(not is_assignable_to(Foo, HasXY))
# TODO: these should pass static_assert(not is_subtype_of(HasXIntSub, HasX))
static_assert(not is_subtype_of(HasXIntSub, HasX)) # error: [static-assert-error] static_assert(not is_assignable_to(HasXIntSub, HasX))
static_assert(not is_assignable_to(HasXIntSub, HasX)) # error: [static-assert-error] static_assert(not is_subtype_of(HasX, HasXIntSub))
static_assert(not is_subtype_of(HasX, HasXIntSub)) # error: [static-assert-error] static_assert(not is_assignable_to(HasX, HasXIntSub))
static_assert(not is_assignable_to(HasX, HasXIntSub)) # error: [static-assert-error]
class FooSub(Foo): ... class FooSub(Foo): ...
@ -2286,10 +2285,9 @@ class MethodPUnrelated(Protocol):
static_assert(is_subtype_of(MethodPSub, MethodPSuper)) static_assert(is_subtype_of(MethodPSub, MethodPSuper))
# TODO: these should pass static_assert(not is_assignable_to(MethodPUnrelated, MethodPSuper))
static_assert(not is_assignable_to(MethodPUnrelated, MethodPSuper)) # error: [static-assert-error] static_assert(not is_assignable_to(MethodPSuper, MethodPUnrelated))
static_assert(not is_assignable_to(MethodPSuper, MethodPUnrelated)) # error: [static-assert-error] static_assert(not is_assignable_to(MethodPSuper, MethodPSub))
static_assert(not is_assignable_to(MethodPSuper, MethodPSub)) # error: [static-assert-error]
``` ```
## Subtyping between protocols with method members and protocols with non-method members ## Subtyping between protocols with method members and protocols with non-method members
@ -2348,8 +2346,7 @@ And for the same reason, they are never assignable to attribute members (which a
class Attribute(Protocol): class Attribute(Protocol):
f: Callable[[], bool] f: Callable[[], bool]
# TODO: should pass static_assert(not is_assignable_to(Method, Attribute))
static_assert(not is_assignable_to(Method, Attribute)) # error: [static-assert-error]
``` ```
Protocols with attribute members, meanwhile, cannot be assigned to protocols with method members, Protocols with attribute members, meanwhile, cannot be assigned to protocols with method members,
@ -2358,9 +2355,8 @@ this is not true for attribute members. The same principle also applies for prot
members members
```py ```py
# TODO: this should pass static_assert(not is_assignable_to(PropertyBool, Method))
static_assert(not is_assignable_to(PropertyBool, Method)) # error: [static-assert-error] static_assert(not is_assignable_to(Attribute, Method))
static_assert(not is_assignable_to(Attribute, Method)) # error: [static-assert-error]
``` ```
But an exception to this rule is if an attribute member is marked as `ClassVar`, as this guarantees But an exception to this rule is if an attribute member is marked as `ClassVar`, as this guarantees
@ -2379,9 +2375,8 @@ static_assert(is_assignable_to(ClassVarAttribute, Method))
class ClassVarAttributeBad(Protocol): class ClassVarAttributeBad(Protocol):
f: ClassVar[Callable[[], str]] f: ClassVar[Callable[[], str]]
# TODO: these should pass: static_assert(not is_subtype_of(ClassVarAttributeBad, Method))
static_assert(not is_subtype_of(ClassVarAttributeBad, Method)) # error: [static-assert-error] static_assert(not is_assignable_to(ClassVarAttributeBad, Method))
static_assert(not is_assignable_to(ClassVarAttributeBad, Method)) # error: [static-assert-error]
``` ```
## Narrowing of protocols ## Narrowing of protocols
@ -2702,9 +2697,8 @@ class RecursiveNonFullyStatic(Protocol):
parent: RecursiveNonFullyStatic parent: RecursiveNonFullyStatic
x: Any x: Any
# TODO: these should pass, once we take into account types of members static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic))
static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic)) # error: [static-assert-error] static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic))
static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic)) # error: [static-assert-error]
static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveNonFullyStatic)) static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveNonFullyStatic))
static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic)) static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic))
@ -2722,9 +2716,7 @@ class RecursiveOptionalParent(Protocol):
static_assert(is_assignable_to(RecursiveOptionalParent, RecursiveOptionalParent)) static_assert(is_assignable_to(RecursiveOptionalParent, RecursiveOptionalParent))
# Due to invariance of mutable attribute members, neither is assignable to the other # Due to invariance of mutable attribute members, neither is assignable to the other
# static_assert(not is_assignable_to(RecursiveNonFullyStatic, RecursiveOptionalParent))
# TODO: should pass
static_assert(not is_assignable_to(RecursiveNonFullyStatic, RecursiveOptionalParent)) # error: [static-assert-error]
static_assert(not is_assignable_to(RecursiveOptionalParent, RecursiveNonFullyStatic)) static_assert(not is_assignable_to(RecursiveOptionalParent, RecursiveNonFullyStatic))
class Other(Protocol): class Other(Protocol):

File diff suppressed because it is too large Load diff

View file

@ -27,11 +27,11 @@ use crate::types::typed_dict::typed_dict_params_from_class_def;
use crate::types::{ use crate::types::{
ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType,
DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType,
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeContext, MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType,
TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance,
TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, determine_upper_bound, TypeVarKind, TypedDictParams, UnionBuilder, VarianceInferable, declaration_type,
infer_definition_types, determine_upper_bound, infer_definition_types,
}; };
use crate::{ use crate::{
Db, FxIndexMap, FxOrderSet, Program, Db, FxIndexMap, FxOrderSet, Program,
@ -538,6 +538,7 @@ impl<'db> ClassType<'db> {
other, other,
TypeRelation::Subtyping, TypeRelation::Subtyping,
&HasRelationToVisitor::default(), &HasRelationToVisitor::default(),
&IsDisjointVisitor::default(),
) )
} }
@ -546,7 +547,8 @@ impl<'db> ClassType<'db> {
db: &'db dyn Db, db: &'db dyn Db,
other: Self, other: Self,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
self.iter_mro(db).when_any(db, |base| { self.iter_mro(db).when_any(db, |base| {
match base { match base {
@ -568,7 +570,8 @@ impl<'db> ClassType<'db> {
db, db,
other.specialization(db), other.specialization(db),
relation, relation,
visitor, relation_visitor,
disjointness_visitor,
) )
}) })
} }

View file

@ -16,9 +16,9 @@ use crate::types::signatures::{Parameter, Parameters, Signature};
use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type};
use crate::types::{ use crate::types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind,
Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, NormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance,
TypeVarVariance, UnionType, binding_type, declaration_type, TypeVarKind, TypeVarVariance, UnionType, binding_type, declaration_type,
}; };
use crate::{Db, FxOrderSet}; use crate::{Db, FxOrderSet};
@ -481,7 +481,8 @@ fn is_subtype_in_invariant_position<'db>(
derived_materialization: MaterializationKind, derived_materialization: MaterializationKind,
base_type: &Type<'db>, base_type: &Type<'db>,
base_materialization: MaterializationKind, base_materialization: MaterializationKind,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
let derived_top = derived_type.top_materialization(db); let derived_top = derived_type.top_materialization(db);
let derived_bottom = derived_type.bottom_materialization(db); let derived_bottom = derived_type.bottom_materialization(db);
@ -501,7 +502,13 @@ fn is_subtype_in_invariant_position<'db>(
return ConstraintSet::from(true); return ConstraintSet::from(true);
} }
derived.has_relation_to_impl(db, base, TypeRelation::Subtyping, visitor) derived.has_relation_to_impl(
db,
base,
TypeRelation::Subtyping,
relation_visitor,
disjointness_visitor,
)
}; };
match (derived_materialization, base_materialization) { match (derived_materialization, base_materialization) {
// `Derived` is a subtype of `Base` if the range of materializations covered by `Derived` // `Derived` is a subtype of `Base` if the range of materializations covered by `Derived`
@ -543,6 +550,7 @@ fn is_subtype_in_invariant_position<'db>(
/// Whether two types encountered in an invariant position /// Whether two types encountered in an invariant position
/// have a relation (subtyping or assignability), taking into account /// have a relation (subtyping or assignability), taking into account
/// that the two types may come from a top or bottom materialization. /// that the two types may come from a top or bottom materialization.
#[expect(clippy::too_many_arguments)]
fn has_relation_in_invariant_position<'db>( fn has_relation_in_invariant_position<'db>(
db: &'db dyn Db, db: &'db dyn Db,
derived_type: &Type<'db>, derived_type: &Type<'db>,
@ -550,7 +558,8 @@ fn has_relation_in_invariant_position<'db>(
base_type: &Type<'db>, base_type: &Type<'db>,
base_materialization: Option<MaterializationKind>, base_materialization: Option<MaterializationKind>,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match (derived_materialization, base_materialization, relation) { match (derived_materialization, base_materialization, relation) {
// Top and bottom materializations are fully static types, so subtyping // Top and bottom materializations are fully static types, so subtyping
@ -561,19 +570,27 @@ fn has_relation_in_invariant_position<'db>(
derived_mat, derived_mat,
base_type, base_type,
base_mat, base_mat,
visitor, relation_visitor,
disjointness_visitor,
), ),
// Subtyping between invariant type parameters without a top/bottom materialization involved // Subtyping between invariant type parameters without a top/bottom materialization involved
// is equivalence // is equivalence
(None, None, TypeRelation::Subtyping) => derived_type.when_equivalent_to(db, *base_type), (None, None, TypeRelation::Subtyping) => derived_type.when_equivalent_to(db, *base_type),
(None, None, TypeRelation::Assignability) => derived_type (None, None, TypeRelation::Assignability) => derived_type
.has_relation_to_impl(db, *base_type, TypeRelation::Assignability, visitor) .has_relation_to_impl(
db,
*base_type,
TypeRelation::Assignability,
relation_visitor,
disjointness_visitor,
)
.and(db, || { .and(db, || {
base_type.has_relation_to_impl( base_type.has_relation_to_impl(
db, db,
*derived_type, *derived_type,
TypeRelation::Assignability, TypeRelation::Assignability,
visitor, relation_visitor,
disjointness_visitor,
) )
}), }),
// For gradual types, A <: B (subtyping) is defined as Top[A] <: Bottom[B] // For gradual types, A <: B (subtyping) is defined as Top[A] <: Bottom[B]
@ -583,7 +600,8 @@ fn has_relation_in_invariant_position<'db>(
MaterializationKind::Top, MaterializationKind::Top,
base_type, base_type,
base_mat, base_mat,
visitor, relation_visitor,
disjointness_visitor,
), ),
(Some(derived_mat), None, TypeRelation::Subtyping) => is_subtype_in_invariant_position( (Some(derived_mat), None, TypeRelation::Subtyping) => is_subtype_in_invariant_position(
db, db,
@ -591,7 +609,8 @@ fn has_relation_in_invariant_position<'db>(
derived_mat, derived_mat,
base_type, base_type,
MaterializationKind::Bottom, MaterializationKind::Bottom,
visitor, relation_visitor,
disjointness_visitor,
), ),
// And A <~ B (assignability) is Bottom[A] <: Top[B] // And A <~ B (assignability) is Bottom[A] <: Top[B]
(None, Some(base_mat), TypeRelation::Assignability) => is_subtype_in_invariant_position( (None, Some(base_mat), TypeRelation::Assignability) => is_subtype_in_invariant_position(
@ -600,7 +619,8 @@ fn has_relation_in_invariant_position<'db>(
MaterializationKind::Bottom, MaterializationKind::Bottom,
base_type, base_type,
base_mat, base_mat,
visitor, relation_visitor,
disjointness_visitor,
), ),
(Some(derived_mat), None, TypeRelation::Assignability) => is_subtype_in_invariant_position( (Some(derived_mat), None, TypeRelation::Assignability) => is_subtype_in_invariant_position(
db, db,
@ -608,7 +628,8 @@ fn has_relation_in_invariant_position<'db>(
derived_mat, derived_mat,
base_type, base_type,
MaterializationKind::Top, MaterializationKind::Top,
visitor, relation_visitor,
disjointness_visitor,
), ),
} }
} }
@ -819,7 +840,8 @@ impl<'db> Specialization<'db> {
db: &'db dyn Db, db: &'db dyn Db,
other: Self, other: Self,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
let generic_context = self.generic_context(db); let generic_context = self.generic_context(db);
if generic_context != other.generic_context(db) { if generic_context != other.generic_context(db) {
@ -828,7 +850,13 @@ impl<'db> Specialization<'db> {
if let (Some(self_tuple), Some(other_tuple)) = (self.tuple_inner(db), other.tuple_inner(db)) if let (Some(self_tuple), Some(other_tuple)) = (self.tuple_inner(db), other.tuple_inner(db))
{ {
return self_tuple.has_relation_to_impl(db, other_tuple, relation, visitor); return self_tuple.has_relation_to_impl(
db,
other_tuple,
relation,
relation_visitor,
disjointness_visitor,
);
} }
let self_materialization_kind = self.materialization_kind(db); let self_materialization_kind = self.materialization_kind(db);
@ -853,14 +881,23 @@ impl<'db> Specialization<'db> {
other_type, other_type,
other_materialization_kind, other_materialization_kind,
relation, relation,
visitor, relation_visitor,
disjointness_visitor,
),
TypeVarVariance::Covariant => self_type.has_relation_to_impl(
db,
*other_type,
relation,
relation_visitor,
disjointness_visitor,
),
TypeVarVariance::Contravariant => other_type.has_relation_to_impl(
db,
*self_type,
relation,
relation_visitor,
disjointness_visitor,
), ),
TypeVarVariance::Covariant => {
self_type.has_relation_to_impl(db, *other_type, relation, visitor)
}
TypeVarVariance::Contravariant => {
other_type.has_relation_to_impl(db, *self_type, relation, visitor)
}
TypeVarVariance::Bivariant => ConstraintSet::from(true), TypeVarVariance::Bivariant => ConstraintSet::from(true),
}; };
if result.intersect(db, compatible).is_never_satisfied() { if result.intersect(db, compatible).is_never_satisfied() {

View file

@ -122,14 +122,16 @@ impl<'db> Type<'db> {
db: &'db dyn Db, db: &'db dyn Db,
protocol: ProtocolInstanceType<'db>, protocol: ProtocolInstanceType<'db>,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
let structurally_satisfied = if let Type::ProtocolInstance(self_protocol) = self { let structurally_satisfied = if let Type::ProtocolInstance(self_protocol) = self {
self_protocol.interface(db).extends_interface_of( self_protocol.interface(db).has_relation_to_impl(
db, db,
protocol.interface(db), protocol.interface(db),
relation, relation,
visitor, relation_visitor,
disjointness_visitor,
) )
} else { } else {
protocol protocol
@ -137,7 +139,13 @@ impl<'db> Type<'db> {
.interface(db) .interface(db)
.members(db) .members(db)
.when_all(db, |member| { .when_all(db, |member| {
member.is_satisfied_by(db, self, relation, visitor) member.is_satisfied_by(
db,
self,
relation,
relation_visitor,
disjointness_visitor,
)
}) })
}; };
@ -149,11 +157,22 @@ impl<'db> Type<'db> {
// recognise `str` as a subtype of `Container[str]`. // recognise `str` as a subtype of `Container[str]`.
structurally_satisfied.or(db, || { structurally_satisfied.or(db, || {
if let Protocol::FromClass(class) = protocol.inner { if let Protocol::FromClass(class) = protocol.inner {
self.has_relation_to_impl( let type_to_test = if let Type::ProtocolInstance(ProtocolInstanceType {
inner: Protocol::FromClass(class),
..
}) = self
{
Type::non_tuple_instance(db, class)
} else {
self
};
type_to_test.has_relation_to_impl(
db, db,
Type::non_tuple_instance(db, class), Type::non_tuple_instance(db, class),
relation, relation,
visitor, relation_visitor,
disjointness_visitor,
) )
} else { } else {
ConstraintSet::from(false) ConstraintSet::from(false)
@ -342,17 +361,28 @@ impl<'db> NominalInstanceType<'db> {
db: &'db dyn Db, db: &'db dyn Db,
other: Self, other: Self,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match (self.0, other.0) { match (self.0, other.0) {
(_, NominalInstanceInner::Object) => ConstraintSet::from(true), (_, NominalInstanceInner::Object) => ConstraintSet::from(true),
( (
NominalInstanceInner::ExactTuple(tuple1), NominalInstanceInner::ExactTuple(tuple1),
NominalInstanceInner::ExactTuple(tuple2), NominalInstanceInner::ExactTuple(tuple2),
) => tuple1.has_relation_to_impl(db, tuple2, relation, visitor), ) => tuple1.has_relation_to_impl(
_ => self db,
.class(db) tuple2,
.has_relation_to_impl(db, other.class(db), relation, visitor), relation,
relation_visitor,
disjointness_visitor,
),
_ => self.class(db).has_relation_to_impl(
db,
other.class(db),
relation,
relation_visitor,
disjointness_visitor,
),
} }
} }
@ -381,7 +411,8 @@ impl<'db> NominalInstanceType<'db> {
self, self,
db: &'db dyn Db, db: &'db dyn Db,
other: Self, other: Self,
visitor: &IsDisjointVisitor<'db>, disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
if self.is_object() || other.is_object() { if self.is_object() || other.is_object() {
return ConstraintSet::from(false); return ConstraintSet::from(false);
@ -389,7 +420,12 @@ impl<'db> NominalInstanceType<'db> {
let mut result = ConstraintSet::from(false); let mut result = ConstraintSet::from(false);
if let Some(self_spec) = self.tuple_spec(db) { if let Some(self_spec) = self.tuple_spec(db) {
if let Some(other_spec) = other.tuple_spec(db) { if let Some(other_spec) = other.tuple_spec(db) {
let compatible = self_spec.is_disjoint_from_impl(db, &other_spec, visitor); let compatible = self_spec.is_disjoint_from_impl(
db,
&other_spec,
disjointness_visitor,
relation_visitor,
);
if result.union(db, compatible).is_always_satisfied() { if result.union(db, compatible).is_always_satisfied() {
return result; return result;
} }
@ -601,6 +637,7 @@ impl<'db> ProtocolInstanceType<'db> {
protocol, protocol,
TypeRelation::Subtyping, TypeRelation::Subtyping,
&HasRelationToVisitor::default(), &HasRelationToVisitor::default(),
&IsDisjointVisitor::default(),
) )
.is_always_satisfied() .is_always_satisfied()
} }

View file

@ -18,7 +18,7 @@ use crate::{
InstanceFallbackShadowsNonDataDescriptor, IsDisjointVisitor, KnownFunction, InstanceFallbackShadowsNonDataDescriptor, IsDisjointVisitor, KnownFunction,
MemberLookupPolicy, NormalizedVisitor, PropertyInstanceType, Signature, Type, TypeMapping, MemberLookupPolicy, NormalizedVisitor, PropertyInstanceType, Signature, Type, TypeMapping,
TypeQualifiers, TypeRelation, TypeVarVariance, VarianceInferable, TypeQualifiers, TypeRelation, TypeVarVariance, VarianceInferable,
constraints::{ConstraintSet, IteratorConstraintsExtension}, constraints::{ConstraintSet, IteratorConstraintsExtension, OptionConstraintsExtension},
context::InferContext, context::InferContext,
diagnostic::report_undeclared_protocol_member, diagnostic::report_undeclared_protocol_member,
signatures::{Parameter, Parameters}, signatures::{Parameter, Parameters},
@ -230,21 +230,98 @@ impl<'db> ProtocolInterface<'db> {
.unwrap_or_else(|| Type::object().member(db, name)) .unwrap_or_else(|| Type::object().member(db, name))
} }
/// Return `true` if `self` extends the interface of `other`, i.e., pub(super) fn has_relation_to_impl(
/// all members on `other` are also members of `self`.
///
/// TODO: this method should consider the types of the members as well as their names.
pub(super) fn extends_interface_of(
self, self,
db: &'db dyn Db, db: &'db dyn Db,
other: Self, other: Self,
_relation: TypeRelation, relation: TypeRelation,
_visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
// TODO: This could just return a bool as written, but this form is what will be needed to other.members(db).when_all(db, |other_member| {
// combine the constraints when we do assignability checks on each member. self.member_by_name(db, other_member.name)
other.inner(db).keys().when_all(db, |member_name| { .when_some_and(|our_member| match (our_member.kind, other_member.kind) {
ConstraintSet::from(self.inner(db).contains_key(member_name)) // Method members are always immutable;
// they can never be subtypes of/assignable to mutable attribute members.
(ProtocolMemberKind::Method(_), ProtocolMemberKind::Other(_)) => {
ConstraintSet::from(false)
}
// A property member can only be a subtype of an attribute member
// if the property is readable *and* writable.
//
// TODO: this should also consider the types of the members on both sides.
(ProtocolMemberKind::Property(property), ProtocolMemberKind::Other(_)) => {
ConstraintSet::from(
property.getter(db).is_some() && property.setter(db).is_some(),
)
}
// A `@property` member can never be a subtype of a method member, as it is not necessarily
// accessible on the meta-type, whereas a method member must be.
(ProtocolMemberKind::Property(_), ProtocolMemberKind::Method(_)) => {
ConstraintSet::from(false)
}
// But an attribute member *can* be a subtype of a method member,
// providing it is marked `ClassVar`
(
ProtocolMemberKind::Other(our_type),
ProtocolMemberKind::Method(other_type),
) => ConstraintSet::from(
our_member.qualifiers.contains(TypeQualifiers::CLASS_VAR),
)
.and(db, || {
our_type.has_relation_to_impl(
db,
Type::Callable(other_type.bind_self(db)),
relation,
relation_visitor,
disjointness_visitor,
)
}),
(
ProtocolMemberKind::Method(our_method),
ProtocolMemberKind::Method(other_method),
) => our_method.bind_self(db).has_relation_to_impl(
db,
other_method.bind_self(db),
relation,
relation_visitor,
disjointness_visitor,
),
(
ProtocolMemberKind::Other(our_type),
ProtocolMemberKind::Other(other_type),
) => our_type
.has_relation_to_impl(
db,
other_type,
relation,
relation_visitor,
disjointness_visitor,
)
.and(db, || {
other_type.has_relation_to_impl(
db,
our_type,
relation,
relation_visitor,
disjointness_visitor,
)
}),
// TODO: finish assignability/subtyping between two `@property` members,
// and between a `@property` member and a member of a different kind.
(
ProtocolMemberKind::Property(_)
| ProtocolMemberKind::Method(_)
| ProtocolMemberKind::Other(_),
ProtocolMemberKind::Property(_),
) => ConstraintSet::from(true),
})
}) })
} }
@ -518,14 +595,17 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
&self, &self,
db: &'db dyn Db, db: &'db dyn Db,
other: Type<'db>, other: Type<'db>,
visitor: &IsDisjointVisitor<'db>, disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match &self.kind { match &self.kind {
// TODO: implement disjointness for property/method members as well as attribute members // TODO: implement disjointness for property/method members as well as attribute members
ProtocolMemberKind::Property(_) | ProtocolMemberKind::Method(_) => { ProtocolMemberKind::Property(_) | ProtocolMemberKind::Method(_) => {
ConstraintSet::from(false) ConstraintSet::from(false)
} }
ProtocolMemberKind::Other(ty) => ty.is_disjoint_from_impl(db, other, visitor), ProtocolMemberKind::Other(ty) => {
ty.is_disjoint_from_impl(db, other, disjointness_visitor, relation_visitor)
}
} }
} }
@ -536,7 +616,8 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
db: &'db dyn Db, db: &'db dyn Db,
other: Type<'db>, other: Type<'db>,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match &self.kind { match &self.kind {
ProtocolMemberKind::Method(method) => { ProtocolMemberKind::Method(method) => {
@ -570,7 +651,13 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
attribute_type attribute_type
}; };
attribute_type.has_relation_to_impl(db, method.bind_self(db), relation, visitor) attribute_type.has_relation_to_impl(
db,
Type::Callable(method.bind_self(db)),
relation,
relation_visitor,
disjointness_visitor,
)
} }
// TODO: consider the types of the attribute on `other` for property members // TODO: consider the types of the attribute on `other` for property members
ProtocolMemberKind::Property(_) => ConstraintSet::from(matches!( ProtocolMemberKind::Property(_) => ConstraintSet::from(matches!(
@ -584,9 +671,21 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
return ConstraintSet::from(false); return ConstraintSet::from(false);
}; };
member_type member_type
.has_relation_to_impl(db, attribute_type, relation, visitor) .has_relation_to_impl(
db,
attribute_type,
relation,
relation_visitor,
disjointness_visitor,
)
.and(db, || { .and(db, || {
attribute_type.has_relation_to_impl(db, *member_type, relation, visitor) attribute_type.has_relation_to_impl(
db,
*member_type,
relation,
relation_visitor,
disjointness_visitor,
)
}) })
} }
} }

View file

@ -21,8 +21,8 @@ use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::generics::{GenericContext, walk_generic_context};
use crate::types::{ use crate::types::{
ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, FindLegacyTypeVarsVisitor,
HasRelationToVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, NormalizedVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind,
TypeMapping, TypeRelation, VarianceInferable, todo_type, NormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, todo_type,
}; };
use crate::{Db, FxOrderSet}; use crate::{Db, FxOrderSet};
use ruff_python_ast::{self as ast, name::Name}; use ruff_python_ast::{self as ast, name::Name};
@ -136,6 +136,7 @@ impl<'db> CallableSignature<'db> {
other, other,
TypeRelation::Subtyping, TypeRelation::Subtyping,
&HasRelationToVisitor::default(), &HasRelationToVisitor::default(),
&IsDisjointVisitor::default(),
) )
} }
@ -148,6 +149,7 @@ impl<'db> CallableSignature<'db> {
other, other,
TypeRelation::Assignability, TypeRelation::Assignability,
&HasRelationToVisitor::default(), &HasRelationToVisitor::default(),
&IsDisjointVisitor::default(),
) )
.is_always_satisfied() .is_always_satisfied()
} }
@ -157,9 +159,17 @@ impl<'db> CallableSignature<'db> {
db: &'db dyn Db, db: &'db dyn Db,
other: &Self, other: &Self,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
Self::has_relation_to_inner(db, &self.overloads, &other.overloads, relation, visitor) Self::has_relation_to_inner(
db,
&self.overloads,
&other.overloads,
relation,
relation_visitor,
disjointness_visitor,
)
} }
/// Implementation of subtyping and assignability between two, possible overloaded, callable /// Implementation of subtyping and assignability between two, possible overloaded, callable
@ -169,12 +179,19 @@ impl<'db> CallableSignature<'db> {
self_signatures: &[Signature<'db>], self_signatures: &[Signature<'db>],
other_signatures: &[Signature<'db>], other_signatures: &[Signature<'db>],
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match (self_signatures, other_signatures) { match (self_signatures, other_signatures) {
([self_signature], [other_signature]) => { ([self_signature], [other_signature]) => {
// Base case: both callable types contain a single signature. // Base case: both callable types contain a single signature.
self_signature.has_relation_to_impl(db, other_signature, relation, visitor) self_signature.has_relation_to_impl(
db,
other_signature,
relation,
relation_visitor,
disjointness_visitor,
)
} }
// `self` is possibly overloaded while `other` is definitely not overloaded. // `self` is possibly overloaded while `other` is definitely not overloaded.
@ -184,7 +201,8 @@ impl<'db> CallableSignature<'db> {
std::slice::from_ref(self_signature), std::slice::from_ref(self_signature),
other_signatures, other_signatures,
relation, relation,
visitor, relation_visitor,
disjointness_visitor,
) )
}), }),
@ -195,7 +213,8 @@ impl<'db> CallableSignature<'db> {
self_signatures, self_signatures,
std::slice::from_ref(other_signature), std::slice::from_ref(other_signature),
relation, relation,
visitor, relation_visitor,
disjointness_visitor,
) )
}), }),
@ -206,7 +225,8 @@ impl<'db> CallableSignature<'db> {
self_signatures, self_signatures,
std::slice::from_ref(other_signature), std::slice::from_ref(other_signature),
relation, relation,
visitor, relation_visitor,
disjointness_visitor,
) )
}), }),
} }
@ -631,7 +651,8 @@ impl<'db> Signature<'db> {
db: &'db dyn Db, db: &'db dyn Db,
other: &Signature<'db>, other: &Signature<'db>,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
/// A helper struct to zip two slices of parameters together that provides control over the /// A helper struct to zip two slices of parameters together that provides control over the
/// two iterators individually. It also keeps track of the current parameter in each /// two iterators individually. It also keeps track of the current parameter in each
@ -699,7 +720,16 @@ impl<'db> Signature<'db> {
let type1 = type1.unwrap_or(Type::unknown()); let type1 = type1.unwrap_or(Type::unknown());
let type2 = type2.unwrap_or(Type::unknown()); let type2 = type2.unwrap_or(Type::unknown());
!result !result
.intersect(db, type1.has_relation_to_impl(db, type2, relation, visitor)) .intersect(
db,
type1.has_relation_to_impl(
db,
type2,
relation,
relation_visitor,
disjointness_visitor,
),
)
.is_never_satisfied() .is_never_satisfied()
}; };

View file

@ -134,7 +134,8 @@ impl<'db> SubclassOfType<'db> {
db: &'db dyn Db, db: &'db dyn Db,
other: SubclassOfType<'db>, other: SubclassOfType<'db>,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match (self.subclass_of, other.subclass_of) { match (self.subclass_of, other.subclass_of) {
(SubclassOfInner::Dynamic(_), SubclassOfInner::Dynamic(_)) => { (SubclassOfInner::Dynamic(_), SubclassOfInner::Dynamic(_)) => {
@ -150,9 +151,14 @@ impl<'db> SubclassOfType<'db> {
// For example, `type[bool]` describes all possible runtime subclasses of the class `bool`, // For example, `type[bool]` describes all possible runtime subclasses of the class `bool`,
// and `type[int]` describes all possible runtime subclasses of the class `int`. // and `type[int]` describes all possible runtime subclasses of the class `int`.
// The first set is a subset of the second set, because `bool` is itself a subclass of `int`. // The first set is a subset of the second set, because `bool` is itself a subclass of `int`.
(SubclassOfInner::Class(self_class), SubclassOfInner::Class(other_class)) => { (SubclassOfInner::Class(self_class), SubclassOfInner::Class(other_class)) => self_class
self_class.has_relation_to_impl(db, other_class, relation, visitor) .has_relation_to_impl(
} db,
other_class,
relation,
relation_visitor,
disjointness_visitor,
),
} }
} }

View file

@ -258,10 +258,16 @@ impl<'db> TupleType<'db> {
db: &'db dyn Db, db: &'db dyn Db,
other: Self, other: Self,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
self.tuple(db) self.tuple(db).has_relation_to_impl(
.has_relation_to_impl(db, other.tuple(db), relation, visitor) db,
other.tuple(db),
relation,
relation_visitor,
disjointness_visitor,
)
} }
pub(crate) fn is_equivalent_to_impl( pub(crate) fn is_equivalent_to_impl(
@ -416,13 +422,20 @@ impl<'db> FixedLengthTuple<Type<'db>> {
db: &'db dyn Db, db: &'db dyn Db,
other: &Tuple<Type<'db>>, other: &Tuple<Type<'db>>,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match other { match other {
Tuple::Fixed(other) => { Tuple::Fixed(other) => {
ConstraintSet::from(self.0.len() == other.0.len()).and(db, || { ConstraintSet::from(self.0.len() == other.0.len()).and(db, || {
(self.0.iter().zip(&other.0)).when_all(db, |(self_ty, other_ty)| { (self.0.iter().zip(&other.0)).when_all(db, |(self_ty, other_ty)| {
self_ty.has_relation_to_impl(db, *other_ty, relation, visitor) self_ty.has_relation_to_impl(
db,
*other_ty,
relation,
relation_visitor,
disjointness_visitor,
)
}) })
}) })
} }
@ -436,8 +449,13 @@ impl<'db> FixedLengthTuple<Type<'db>> {
let Some(self_ty) = self_iter.next() else { let Some(self_ty) = self_iter.next() else {
return ConstraintSet::from(false); return ConstraintSet::from(false);
}; };
let element_constraints = let element_constraints = self_ty.has_relation_to_impl(
self_ty.has_relation_to_impl(db, *other_ty, relation, visitor); db,
*other_ty,
relation,
relation_visitor,
disjointness_visitor,
);
if result if result
.intersect(db, element_constraints) .intersect(db, element_constraints)
.is_never_satisfied() .is_never_satisfied()
@ -449,8 +467,13 @@ impl<'db> FixedLengthTuple<Type<'db>> {
let Some(self_ty) = self_iter.next_back() else { let Some(self_ty) = self_iter.next_back() else {
return ConstraintSet::from(false); return ConstraintSet::from(false);
}; };
let element_constraints = let element_constraints = self_ty.has_relation_to_impl(
self_ty.has_relation_to_impl(db, *other_ty, relation, visitor); db,
*other_ty,
relation,
relation_visitor,
disjointness_visitor,
);
if result if result
.intersect(db, element_constraints) .intersect(db, element_constraints)
.is_never_satisfied() .is_never_satisfied()
@ -463,7 +486,13 @@ impl<'db> FixedLengthTuple<Type<'db>> {
// variable-length portion of the other tuple. // variable-length portion of the other tuple.
result.and(db, || { result.and(db, || {
self_iter.when_all(db, |self_ty| { self_iter.when_all(db, |self_ty| {
self_ty.has_relation_to_impl(db, other.variable, relation, visitor) self_ty.has_relation_to_impl(
db,
other.variable,
relation,
relation_visitor,
disjointness_visitor,
)
}) })
}) })
} }
@ -743,7 +772,8 @@ impl<'db> VariableLengthTuple<Type<'db>> {
db: &'db dyn Db, db: &'db dyn Db,
other: &Tuple<Type<'db>>, other: &Tuple<Type<'db>>,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match other { match other {
Tuple::Fixed(other) => { Tuple::Fixed(other) => {
@ -771,8 +801,13 @@ impl<'db> VariableLengthTuple<Type<'db>> {
let Some(other_ty) = other_iter.next() else { let Some(other_ty) = other_iter.next() else {
return ConstraintSet::from(false); return ConstraintSet::from(false);
}; };
let element_constraints = let element_constraints = self_ty.has_relation_to_impl(
self_ty.has_relation_to_impl(db, other_ty, relation, visitor); db,
other_ty,
relation,
relation_visitor,
disjointness_visitor,
);
if result if result
.intersect(db, element_constraints) .intersect(db, element_constraints)
.is_never_satisfied() .is_never_satisfied()
@ -785,8 +820,13 @@ impl<'db> VariableLengthTuple<Type<'db>> {
let Some(other_ty) = other_iter.next_back() else { let Some(other_ty) = other_iter.next_back() else {
return ConstraintSet::from(false); return ConstraintSet::from(false);
}; };
let element_constraints = let element_constraints = self_ty.has_relation_to_impl(
self_ty.has_relation_to_impl(db, other_ty, relation, visitor); db,
other_ty,
relation,
relation_visitor,
disjointness_visitor,
);
if result if result
.intersect(db, element_constraints) .intersect(db, element_constraints)
.is_never_satisfied() .is_never_satisfied()
@ -820,12 +860,20 @@ impl<'db> VariableLengthTuple<Type<'db>> {
); );
for pair in pairwise { for pair in pairwise {
let pair_constraints = match pair { let pair_constraints = match pair {
EitherOrBoth::Both(self_ty, other_ty) => { EitherOrBoth::Both(self_ty, other_ty) => self_ty.has_relation_to_impl(
self_ty.has_relation_to_impl(db, other_ty, relation, visitor) db,
} other_ty,
EitherOrBoth::Left(self_ty) => { relation,
self_ty.has_relation_to_impl(db, other.variable, relation, visitor) relation_visitor,
} disjointness_visitor,
),
EitherOrBoth::Left(self_ty) => self_ty.has_relation_to_impl(
db,
other.variable,
relation,
relation_visitor,
disjointness_visitor,
),
EitherOrBoth::Right(_) => { EitherOrBoth::Right(_) => {
// The rhs has a required element that the lhs is not guaranteed to // The rhs has a required element that the lhs is not guaranteed to
// provide. // provide.
@ -846,12 +894,20 @@ impl<'db> VariableLengthTuple<Type<'db>> {
let pairwise = (self_suffix.iter().rev()).zip_longest(other_suffix.iter().rev()); let pairwise = (self_suffix.iter().rev()).zip_longest(other_suffix.iter().rev());
for pair in pairwise { for pair in pairwise {
let pair_constraints = match pair { let pair_constraints = match pair {
EitherOrBoth::Both(self_ty, other_ty) => { EitherOrBoth::Both(self_ty, other_ty) => self_ty.has_relation_to_impl(
self_ty.has_relation_to_impl(db, *other_ty, relation, visitor) db,
} *other_ty,
EitherOrBoth::Left(self_ty) => { relation,
self_ty.has_relation_to_impl(db, other.variable, relation, visitor) relation_visitor,
} disjointness_visitor,
),
EitherOrBoth::Left(self_ty) => self_ty.has_relation_to_impl(
db,
other.variable,
relation,
relation_visitor,
disjointness_visitor,
),
EitherOrBoth::Right(_) => { EitherOrBoth::Right(_) => {
// The rhs has a required element that the lhs is not guaranteed to // The rhs has a required element that the lhs is not guaranteed to
// provide. // provide.
@ -865,8 +921,13 @@ impl<'db> VariableLengthTuple<Type<'db>> {
// And lastly, the variable-length portions must satisfy the relation. // And lastly, the variable-length portions must satisfy the relation.
result.and(db, || { result.and(db, || {
self.variable self.variable.has_relation_to_impl(
.has_relation_to_impl(db, other.variable, relation, visitor) db,
other.variable,
relation,
relation_visitor,
disjointness_visitor,
)
}) })
} }
} }
@ -1088,15 +1149,24 @@ impl<'db> Tuple<Type<'db>> {
db: &'db dyn Db, db: &'db dyn Db,
other: &Self, other: &Self,
relation: TypeRelation, relation: TypeRelation,
visitor: &HasRelationToVisitor<'db>, relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
match self { match self {
Tuple::Fixed(self_tuple) => { Tuple::Fixed(self_tuple) => self_tuple.has_relation_to_impl(
self_tuple.has_relation_to_impl(db, other, relation, visitor) db,
} other,
Tuple::Variable(self_tuple) => { relation,
self_tuple.has_relation_to_impl(db, other, relation, visitor) relation_visitor,
} disjointness_visitor,
),
Tuple::Variable(self_tuple) => self_tuple.has_relation_to_impl(
db,
other,
relation,
relation_visitor,
disjointness_visitor,
),
} }
} }
@ -1123,7 +1193,8 @@ impl<'db> Tuple<Type<'db>> {
&self, &self,
db: &'db dyn Db, db: &'db dyn Db,
other: &Self, other: &Self,
visitor: &IsDisjointVisitor<'db>, disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
// Two tuples with an incompatible number of required elements must always be disjoint. // Two tuples with an incompatible number of required elements must always be disjoint.
let (self_min, self_max) = self.len().size_hint(); let (self_min, self_max) = self.len().size_hint();
@ -1141,20 +1212,30 @@ impl<'db> Tuple<Type<'db>> {
db: &'db dyn Db, db: &'db dyn Db,
a: impl IntoIterator<Item = &'s Type<'db>>, a: impl IntoIterator<Item = &'s Type<'db>>,
b: impl IntoIterator<Item = &'s Type<'db>>, b: impl IntoIterator<Item = &'s Type<'db>>,
visitor: &IsDisjointVisitor<'db>, disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> ) -> ConstraintSet<'db>
where where
'db: 's, 'db: 's,
{ {
(a.into_iter().zip(b)).when_any(db, |(self_element, other_element)| { (a.into_iter().zip(b)).when_any(db, |(self_element, other_element)| {
self_element.is_disjoint_from_impl(db, *other_element, visitor) self_element.is_disjoint_from_impl(
db,
*other_element,
disjointness_visitor,
relation_visitor,
)
}) })
} }
match (self, other) { match (self, other) {
(Tuple::Fixed(self_tuple), Tuple::Fixed(other_tuple)) => { (Tuple::Fixed(self_tuple), Tuple::Fixed(other_tuple)) => any_disjoint(
any_disjoint(db, self_tuple.elements(), other_tuple.elements(), visitor) db,
} self_tuple.elements(),
other_tuple.elements(),
disjointness_visitor,
relation_visitor,
),
// Note that we don't compare the variable-length portions; two pure homogeneous tuples // Note that we don't compare the variable-length portions; two pure homogeneous tuples
// `tuple[A, ...]` and `tuple[B, ...]` can never be disjoint even if A and B are // `tuple[A, ...]` and `tuple[B, ...]` can never be disjoint even if A and B are
@ -1163,31 +1244,36 @@ impl<'db> Tuple<Type<'db>> {
db, db,
self_tuple.prefix_elements(), self_tuple.prefix_elements(),
other_tuple.prefix_elements(), other_tuple.prefix_elements(),
visitor, disjointness_visitor,
relation_visitor,
) )
.or(db, || { .or(db, || {
any_disjoint( any_disjoint(
db, db,
self_tuple.suffix_elements().rev(), self_tuple.suffix_elements().rev(),
other_tuple.suffix_elements().rev(), other_tuple.suffix_elements().rev(),
visitor, disjointness_visitor,
relation_visitor,
) )
}), }),
(Tuple::Fixed(fixed), Tuple::Variable(variable)) (Tuple::Fixed(fixed), Tuple::Variable(variable))
| (Tuple::Variable(variable), Tuple::Fixed(fixed)) => { | (Tuple::Variable(variable), Tuple::Fixed(fixed)) => any_disjoint(
any_disjoint(db, fixed.elements(), variable.prefix_elements(), visitor).or( db,
fixed.elements(),
variable.prefix_elements(),
disjointness_visitor,
relation_visitor,
)
.or(db, || {
any_disjoint(
db, db,
|| { fixed.elements().rev(),
any_disjoint( variable.suffix_elements().rev(),
db, disjointness_visitor,
fixed.elements().rev(), relation_visitor,
variable.suffix_elements().rev(),
visitor,
)
},
) )
} }),
} }
} }