diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index 797c25e7e8..47cd3632ef 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -1569,11 +1569,11 @@ from typing import Protocol, Any from ty_extensions import is_fully_static, static_assert, is_assignable_to, is_subtype_of, is_equivalent_to class RecursiveFullyStatic(Protocol): - parent: RecursiveFullyStatic | None + parent: RecursiveFullyStatic x: int class RecursiveNonFullyStatic(Protocol): - parent: RecursiveNonFullyStatic | None + parent: RecursiveNonFullyStatic x: Any static_assert(is_fully_static(RecursiveFullyStatic)) @@ -1582,16 +1582,111 @@ static_assert(not is_fully_static(RecursiveNonFullyStatic)) static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic)) static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic)) -# TODO: currently leads to a stack overflow -# static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic)) -# static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic)) +static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveNonFullyStatic)) +static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic)) +static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic)) class AlsoRecursiveFullyStatic(Protocol): - parent: AlsoRecursiveFullyStatic | None + parent: AlsoRecursiveFullyStatic x: int -# TODO: currently leads to a stack overflow -# static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic)) +static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic)) + +class RecursiveOptionalParent(Protocol): + parent: RecursiveOptionalParent | None + +static_assert(is_fully_static(RecursiveOptionalParent)) + +static_assert(is_assignable_to(RecursiveOptionalParent, RecursiveOptionalParent)) + +static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveOptionalParent)) +static_assert(not is_assignable_to(RecursiveOptionalParent, RecursiveNonFullyStatic)) + +class Other(Protocol): + z: str + +def _(rec: RecursiveFullyStatic, other: Other): + reveal_type(rec.parent.parent.parent) # revealed: RecursiveFullyStatic + + rec.parent.parent.parent = rec + rec = rec.parent.parent.parent + + rec.parent.parent.parent = other # error: [invalid-assignment] + other = rec.parent.parent.parent # error: [invalid-assignment] + +class Foo(Protocol): + @property + def x(self) -> "Foo": ... + +class Bar(Protocol): + @property + def x(self) -> "Bar": ... + +# TODO: this should pass +# error: [static-assert-error] +static_assert(is_equivalent_to(Foo, Bar)) +``` + +### Nested occurrences of self-reference + +Make sure that we handle self-reference correctly, even if the self-reference appears deeply nested +within the type of a protocol member: + +```toml +[environment] +python-version = "3.12" +``` + +```py +from __future__ import annotations + +from typing import Protocol, Callable +from ty_extensions import Intersection, Not, is_fully_static, is_assignable_to, is_equivalent_to, static_assert + +class C: ... + +class GenericC[T](Protocol): + pass + +class Recursive(Protocol): + direct: Recursive + + union: None | Recursive + + intersection1: Intersection[C, Recursive] + intersection2: Intersection[C, Not[Recursive]] + + t: tuple[int, tuple[str, Recursive]] + + callable1: Callable[[int], Recursive] + callable2: Callable[[Recursive], int] + + subtype_of: type[Recursive] + + generic: GenericC[Recursive] + + def method(self, x: Recursive) -> Recursive: ... + + nested: Recursive | Callable[[Recursive | Recursive, tuple[Recursive, Recursive]], Recursive | Recursive] + +static_assert(is_fully_static(Recursive)) +static_assert(is_equivalent_to(Recursive, Recursive)) +static_assert(is_assignable_to(Recursive, Recursive)) + +def _(r: Recursive): + reveal_type(r.direct) # revealed: Recursive + reveal_type(r.union) # revealed: None | Recursive + reveal_type(r.intersection1) # revealed: C & Recursive + reveal_type(r.intersection2) # revealed: C & ~Recursive + reveal_type(r.t) # revealed: tuple[int, tuple[str, Recursive]] + reveal_type(r.callable1) # revealed: (int, /) -> Recursive + reveal_type(r.callable2) # revealed: (Recursive, /) -> int + reveal_type(r.subtype_of) # revealed: type[Recursive] + reveal_type(r.generic) # revealed: GenericC[Recursive] + reveal_type(r.method(r)) # revealed: Recursive + reveal_type(r.nested) # revealed: Recursive | ((Recursive, tuple[Recursive, Recursive], /) -> Recursive) + + reveal_type(r.method(r).callable1(1).direct.t[1][1]) # revealed: Recursive ``` ### Regression test: narrowing with self-referential protocols diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 490e267c01..ed2f79231e 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -587,6 +587,79 @@ impl<'db> Type<'db> { matches!(self, Type::Dynamic(DynamicType::Todo(_))) } + /// Replace references to the class `class` with a self-reference marker. This is currently + /// used for recursive protocols, but could probably be extended to self-referential type- + /// aliases and similar. + #[must_use] + pub fn replace_self_reference(&self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Type<'db> { + match self { + Self::ProtocolInstance(protocol) => { + Self::ProtocolInstance(protocol.replace_self_reference(db, class)) + } + + Self::Union(union) => UnionType::from_elements( + db, + union + .elements(db) + .iter() + .map(|ty| ty.replace_self_reference(db, class)), + ), + + Self::Intersection(intersection) => IntersectionBuilder::new(db) + .positive_elements( + intersection + .positive(db) + .iter() + .map(|ty| ty.replace_self_reference(db, class)), + ) + .negative_elements( + intersection + .negative(db) + .iter() + .map(|ty| ty.replace_self_reference(db, class)), + ) + .build(), + + Self::Tuple(tuple) => TupleType::from_elements( + db, + tuple + .elements(db) + .iter() + .map(|ty| ty.replace_self_reference(db, class)), + ), + + Self::Callable(callable) => Self::Callable(callable.replace_self_reference(db, class)), + + Self::GenericAlias(_) | Self::TypeVar(_) => { + // TODO: replace self-references in generic aliases and typevars + *self + } + + Self::Dynamic(_) + | Self::AlwaysFalsy + | Self::AlwaysTruthy + | Self::Never + | Self::BooleanLiteral(_) + | Self::BytesLiteral(_) + | Self::StringLiteral(_) + | Self::IntLiteral(_) + | Self::LiteralString + | Self::FunctionLiteral(_) + | Self::ModuleLiteral(_) + | Self::ClassLiteral(_) + | Self::NominalInstance(_) + | Self::KnownInstance(_) + | Self::PropertyInstance(_) + | Self::BoundMethod(_) + | Self::WrapperDescriptor(_) + | Self::MethodWrapper(_) + | Self::DataclassDecorator(_) + | Self::DataclassTransformer(_) + | Self::SubclassOf(_) + | Self::BoundSuper(_) => *self, + } + } + pub fn contains_todo(&self, db: &'db dyn Db) -> bool { match self { Self::Dynamic(DynamicType::Todo(_) | DynamicType::SubscriptedProtocol) => true, @@ -7272,6 +7345,17 @@ impl<'db> CallableType<'db> { } } } + + /// See [`Type::replace_self_reference`]. + fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self { + CallableType::from_overloads( + db, + self.signatures(db) + .iter() + .cloned() + .map(|signature| signature.replace_self_reference(db, class)), + ) + } } /// Represents a specific instance of `types.MethodWrapperType` diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 3efbd2db0a..01b5fa8c46 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -529,6 +529,28 @@ impl<'db> IntersectionBuilder<'db> { } } + pub(crate) fn positive_elements(mut self, elements: I) -> Self + where + I: IntoIterator, + T: Into>, + { + for element in elements { + self = self.add_positive(element.into()); + } + self + } + + pub(crate) fn negative_elements(mut self, elements: I) -> Self + where + I: IntoIterator, + T: Into>, + { + for element in elements { + self = self.add_negative(element.into()); + } + self + } + pub(crate) fn build(mut self) -> Type<'db> { // Avoid allocating the UnionBuilder unnecessarily if we have just one intersection: if self.intersections.len() == 1 { diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 25cd4d6c51..c834927d6b 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -4,6 +4,7 @@ use super::protocol_class::ProtocolInterface; use super::{ClassType, KnownClass, SubclassOfType, Type}; use crate::symbol::{Symbol, SymbolAndQualifiers}; use crate::types::generics::TypeMapping; +use crate::types::ClassLiteral; use crate::Db; pub(super) use synthesized_protocol::SynthesizedProtocolType; @@ -183,6 +184,19 @@ impl<'db> ProtocolInstanceType<'db> { } } + /// Replace references to `class` with a self-reference marker + pub(super) fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self { + match self.0 { + Protocol::FromClass(class_type) if class_type.class_literal(db).0 == class => { + ProtocolInstanceType(Protocol::Synthesized(SynthesizedProtocolType::new( + db, + ProtocolInterface::SelfReference, + ))) + } + _ => self, + } + } + /// Return `true` if any of the members of this protocol type contain any `Todo` types. pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool { self.0.interface(db).contains_todo(db) diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 75f9e966de..da43422d14 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -1,6 +1,6 @@ use std::{collections::BTreeMap, ops::Deref}; -use itertools::Itertools; +use itertools::{Either, Itertools}; use ruff_python_ast::name::Name; @@ -56,24 +56,41 @@ impl<'db> Deref for ProtocolClassLiteral<'db> { } } -/// The interface of a protocol: the members of that protocol, and the types of those members. #[salsa::interned(debug)] -pub(super) struct ProtocolInterface<'db> { +pub(super) struct ProtocolInterfaceMembers<'db> { #[returns(ref)] - _members: BTreeMap>, + inner: BTreeMap>, +} + +/// The interface of a protocol: the members of that protocol, and the types of those members. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)] +pub(super) enum ProtocolInterface<'db> { + Members(ProtocolInterfaceMembers<'db>), + SelfReference, } impl<'db> ProtocolInterface<'db> { - /// Iterate over the members of this protocol. + fn empty(db: &'db dyn Db) -> Self { + Self::Members(ProtocolInterfaceMembers::new(db, BTreeMap::default())) + } + pub(super) fn members<'a>( - &'a self, + self, db: &'db dyn Db, - ) -> impl ExactSizeIterator> { - self._members(db).iter().map(|(name, data)| ProtocolMember { - name, - ty: data.ty, - qualifiers: data.qualifiers, - }) + ) -> impl ExactSizeIterator> + where + 'db: 'a, + { + match self { + Self::Members(members) => { + Either::Left(members.inner(db).iter().map(|(name, data)| ProtocolMember { + name, + ty: data.ty, + qualifiers: data.qualifiers, + })) + } + Self::SelfReference => Either::Right(std::iter::empty()), + } } pub(super) fn member_by_name<'a>( @@ -81,25 +98,34 @@ impl<'db> ProtocolInterface<'db> { db: &'db dyn Db, name: &'a str, ) -> Option> { - self._members(db).get(name).map(|data| ProtocolMember { - name, - ty: data.ty, - qualifiers: data.qualifiers, - }) + match self { + Self::Members(members) => members.inner(db).get(name).map(|data| ProtocolMember { + name, + ty: data.ty, + qualifiers: data.qualifiers, + }), + Self::SelfReference => None, + } } /// Return `true` if all members of this protocol are fully static. pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool { - cached_is_fully_static(db, self) + self.members(db).all(|member| member.ty.is_fully_static(db)) } /// Return `true` if if all members on `self` are also members of `other`. /// /// TODO: this method should consider the types of the members as well as their names. pub(super) fn is_sub_interface_of(self, db: &'db dyn Db, other: Self) -> bool { - self._members(db) - .keys() - .all(|member_name| other._members(db).contains_key(member_name)) + match (self, other) { + (Self::Members(self_members), Self::Members(other_members)) => self_members + .inner(db) + .keys() + .all(|member_name| other_members.inner(db).contains_key(member_name)), + _ => { + unreachable!("Enclosing protocols should never be a self-reference marker") + } + } } /// Return `true` if any of the members of this protocol type contain any `Todo` types. @@ -108,13 +134,17 @@ impl<'db> ProtocolInterface<'db> { } pub(super) fn normalized(self, db: &'db dyn Db) -> Self { - Self::new( - db, - self._members(db) - .iter() - .map(|(name, data)| (name.clone(), data.normalized(db))) - .collect::>(), - ) + match self { + Self::Members(members) => Self::Members(ProtocolInterfaceMembers::new( + db, + members + .inner(db) + .iter() + .map(|(name, data)| (name.clone(), data.normalized(db))) + .collect::>(), + )), + Self::SelfReference => Self::SelfReference, + } } } @@ -245,13 +275,14 @@ fn cached_protocol_interface<'db>( }) .filter(|(name, _, _)| !excluded_from_proto_members(name)) .map(|(name, ty, qualifiers)| { + let ty = ty.replace_self_reference(db, class); let member = ProtocolMemberData { ty, qualifiers }; (name.clone(), member) }), ); } - ProtocolInterface::new(db, members) + ProtocolInterface::Members(ProtocolInterfaceMembers::new(db, members)) } #[allow(clippy::trivially_copy_pass_by_ref)] @@ -268,30 +299,5 @@ fn proto_interface_cycle_initial<'db>( db: &'db dyn Db, _class: ClassLiteral<'db>, ) -> ProtocolInterface<'db> { - ProtocolInterface::new(db, BTreeMap::default()) -} - -#[salsa::tracked(cycle_fn=is_fully_static_cycle_recover, cycle_initial=is_fully_static_cycle_initial)] -fn cached_is_fully_static<'db>(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> bool { - interface - .members(db) - .all(|member| member.ty.is_fully_static(db)) -} - -#[allow(clippy::trivially_copy_pass_by_ref)] -fn is_fully_static_cycle_recover( - _db: &dyn Db, - _value: &bool, - _count: u32, - _interface: ProtocolInterface<'_>, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - -fn is_fully_static_cycle_initial<'db>( - _db: &'db dyn Db, - _interface: ProtocolInterface<'db>, -) -> bool { - // Assume that the protocol is fully static until we find members that indicate otherwise. - true + ProtocolInterface::empty(db) } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 9994f37167..7cbead0b75 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -18,7 +18,7 @@ use smallvec::{smallvec, SmallVec}; use super::{definition_expression_type, DynamicType, Type}; use crate::semantic_index::definition::Definition; use crate::types::generics::{GenericContext, Specialization, TypeMapping}; -use crate::types::{todo_type, TypeVarInstance}; +use crate::types::{todo_type, ClassLiteral, TypeVarInstance}; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -876,6 +876,28 @@ impl<'db> Signature<'db> { true } + + /// See [`Type::replace_self_reference`]. + pub(crate) fn replace_self_reference( + mut self, + db: &'db dyn Db, + class: ClassLiteral<'db>, + ) -> Self { + // TODO: also replace self references in generic context + + self.parameters = self + .parameters + .iter() + .cloned() + .map(|param| param.replace_self_reference(db, class)) + .collect(); + + if let Some(ty) = self.return_ty.as_mut() { + *ty = ty.replace_self_reference(db, class); + } + + self + } } #[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] @@ -1388,6 +1410,14 @@ impl<'db> Parameter<'db> { ParameterKind::Variadic { .. } | ParameterKind::KeywordVariadic { .. } => None, } } + + /// See [`Type::replace_self_reference`]. + fn replace_self_reference(mut self, db: &'db (dyn Db), class: ClassLiteral<'db>) -> Self { + if let Some(ty) = self.annotated_type.as_mut() { + *ty = ty.replace_self_reference(db, class); + } + self + } } #[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]