From 04457f99b6ba69dff38c1163a1668b6332ebd8b3 Mon Sep 17 00:00:00 2001 From: David Peter Date: Wed, 7 May 2025 08:55:21 +0200 Subject: [PATCH] [ty] Protocols: Fixpoint iteration for fully-static check (#17880) ## Summary A recursive protocol like the following would previously lead to stack overflows when attempting to create the union type for the `P | None` member, because `UnionBuilder` checks if element types are fully static, and the fully-static check on `P` would in turn list all members and check whether all of them were fully static, leading to a cycle. ```py from __future__ import annotations from typing import Protocol class P(Protocol): parent: P | None ``` Here, we make the fully-static check on protocols a salsa query and add fixpoint iteration, starting with `true` as the initial value (assume that the recursive protocol is fully-static). If the recursive protocol has any non-fully-static members, we still return `false` when re-executing the query (see newly added tests). closes #17861 ## Test Plan Added regression test --- .../resources/mdtest/protocols.md | 38 +++++++- .../ty_python_semantic/src/types/call/bind.rs | 2 +- .../ty_python_semantic/src/types/display.rs | 3 +- .../ty_python_semantic/src/types/instance.rs | 31 +++---- .../src/types/protocol_class.rs | 87 +++++++++++++------ 5 files changed, 113 insertions(+), 48 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index b1d3860334..06f8cc9844 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -1558,7 +1558,43 @@ def two(some_list: list, some_tuple: tuple[int, str], some_sized: Sized): c: Sized = some_sized ``` -## Regression test: narrowing with self-referential protocols +## Recursive protocols + +### Properties + +```py +from __future__ import annotations + +from typing import Protocol, Any +from ty_extensions import is_fully_static, static_assert, is_assignable_to, is_subtype_of, is_equivalent_to + +class RecursiveFullyStatic(Protocol): + parent: RecursiveFullyStatic | None + x: int + +class RecursiveNonFullyStatic(Protocol): + parent: RecursiveNonFullyStatic | None + x: Any + +static_assert(is_fully_static(RecursiveFullyStatic)) +static_assert(not is_fully_static(RecursiveNonFullyStatic)) + +static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic)) +static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic)) + +# TODO: currently leads to a stack overflow +# static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic)) +# static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic)) + +class AlsoRecursiveFullyStatic(Protocol): + parent: AlsoRecursiveFullyStatic | None + x: int + +# TODO: currently leads to a stack overflow +# static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic)) +``` + +### Regression test: narrowing with self-referential protocols This snippet caused us to panic on an early version of the implementation for protocols. diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 3f93985a43..2bb1bc0730 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -622,7 +622,7 @@ impl<'db> Bindings<'db> { db, protocol_class .interface(db) - .members() + .members(db) .map(|member| Type::string_literal(db, member.name())) .collect::]>>(), ))); diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index fbd4cddbd9..1da5a6eb61 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -83,7 +83,8 @@ impl Display for DisplayRepresentation<'_> { Protocol::FromClass(ClassType::Generic(alias)) => alias.display(self.db).fmt(f), Protocol::Synthesized(synthetic) => { f.write_str(" Type<'db> { protocol .0 .interface(db) - .members() + .members(db) .all(|member| !self.member(db, member.name()).symbol.is_unbound()) } } @@ -177,7 +177,7 @@ impl<'db> ProtocolInstanceType<'db> { } match self.0 { Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized( - SynthesizedProtocolType::new(db, self.0.interface(db).clone()), + SynthesizedProtocolType::new(db, self.0.interface(db)), ))), Protocol::Synthesized(_) => Type::ProtocolInstance(self), } @@ -205,7 +205,7 @@ impl<'db> ProtocolInstanceType<'db> { other .0 .interface(db) - .is_sub_interface_of(self.0.interface(db)) + .is_sub_interface_of(db, self.0.interface(db)) } /// Return `true` if this protocol type is equivalent to the protocol `other`. @@ -237,8 +237,8 @@ impl<'db> ProtocolInstanceType<'db> { match self.inner() { Protocol::FromClass(class) => class.instance_member(db, name), Protocol::Synthesized(synthesized) => synthesized - .interface(db) - .member_by_name(name) + .interface() + .member_by_name(db, name) .map(|member| SymbolAndQualifiers { symbol: Symbol::bound(member.ty()), qualifiers: member.qualifiers(), @@ -258,7 +258,7 @@ pub(super) enum Protocol<'db> { impl<'db> Protocol<'db> { /// Return the members of this protocol type - fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> { + fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { match self { Self::FromClass(class) => class .class_literal(db) @@ -266,7 +266,7 @@ impl<'db> Protocol<'db> { .into_protocol_class(db) .expect("Protocol class literal should be a protocol class") .interface(db), - Self::Synthesized(synthesized) => synthesized.interface(db), + Self::Synthesized(synthesized) => synthesized.interface(), } } } @@ -285,24 +285,15 @@ mod synthesized_protocol { /// The constructor method of this type maintains the invariant that a synthesized protocol type /// is always constructed from a *normalized* protocol interface. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)] - pub(in crate::types) struct SynthesizedProtocolType<'db>(SynthesizedProtocolTypeInner<'db>); + pub(in crate::types) struct SynthesizedProtocolType<'db>(ProtocolInterface<'db>); impl<'db> SynthesizedProtocolType<'db> { pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self { - Self(SynthesizedProtocolTypeInner::new( - db, - interface.normalized(db), - )) + Self(interface.normalized(db)) } - pub(in crate::types) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> { - self.0.interface(db) + pub(in crate::types) fn interface(self) -> ProtocolInterface<'db> { + self.0 } } - - #[salsa::interned(debug)] - struct SynthesizedProtocolTypeInner<'db> { - #[return_ref] - interface: ProtocolInterface<'db>, - } } diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 743ca2f7a7..2bd0682a84 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -37,7 +37,7 @@ impl<'db> ProtocolClassLiteral<'db> { /// It is illegal for a protocol class to have any instance attributes that are not declared /// in the protocol's class body. If any are assigned to, they are not taken into account in /// the protocol's list of members. - pub(super) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> { + pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { let _span = tracing::trace_span!("protocol_members", "class='{}'", self.name(db)).entered(); cached_protocol_interface(db, *self) } @@ -57,21 +57,31 @@ impl<'db> Deref for ProtocolClassLiteral<'db> { } /// The interface of a protocol: the members of that protocol, and the types of those members. -#[derive(Debug, PartialEq, Eq, salsa::Update, Default, Clone, Hash)] -pub(super) struct ProtocolInterface<'db>(BTreeMap>); +#[salsa::interned(debug)] +pub(super) struct ProtocolInterface<'db> { + #[return_ref] + _members: BTreeMap>, +} impl<'db> ProtocolInterface<'db> { /// Iterate over the members of this protocol. - pub(super) fn members<'a>(&'a self) -> impl ExactSizeIterator> { - self.0.iter().map(|(name, data)| ProtocolMember { + pub(super) fn members<'a>( + &'a self, + db: &'db dyn Db, + ) -> impl ExactSizeIterator> { + self._members(db).iter().map(|(name, data)| ProtocolMember { name, ty: data.ty, qualifiers: data.qualifiers, }) } - pub(super) fn member_by_name<'a>(&self, name: &'a str) -> Option> { - self.0.get(name).map(|data| ProtocolMember { + pub(super) fn member_by_name<'a>( + self, + db: &'db dyn Db, + name: &'a str, + ) -> Option> { + self._members(db).get(name).map(|data| ProtocolMember { name, ty: data.ty, qualifiers: data.qualifiers, @@ -79,42 +89,43 @@ impl<'db> ProtocolInterface<'db> { } /// Return `true` if all members of this protocol are fully static. - pub(super) fn is_fully_static(&self, db: &'db dyn Db) -> bool { - self.members().all(|member| member.ty.is_fully_static(db)) + pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool { + cached_is_fully_static(db, self) } /// Return `true` if if all members on `self` are also members of `other`. /// /// TODO: this method should consider the types of the members as well as their names. - pub(super) fn is_sub_interface_of(&self, other: &Self) -> bool { - self.0 + pub(super) fn is_sub_interface_of(self, db: &'db dyn Db, other: Self) -> bool { + self._members(db) .keys() - .all(|member_name| other.0.contains_key(member_name)) + .all(|member_name| other._members(db).contains_key(member_name)) } /// Return `true` if any of the members of this protocol type contain any `Todo` types. - pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool { - self.members().any(|member| member.ty.contains_todo(db)) + pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool { + self.members(db).any(|member| member.ty.contains_todo(db)) } pub(super) fn normalized(self, db: &'db dyn Db) -> Self { - Self( - self.0 - .into_iter() - .map(|(name, data)| (name, data.normalized(db))) - .collect(), + Self::new( + db, + self._members(db) + .iter() + .map(|(name, data)| (name.clone(), data.normalized(db))) + .collect::>(), ) } } #[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)] -struct ProtocolMemberData<'db> { +pub(super) struct ProtocolMemberData<'db> { ty: Type<'db>, qualifiers: TypeQualifiers, } impl<'db> ProtocolMemberData<'db> { - fn normalized(self, db: &'db dyn Db) -> Self { + fn normalized(&self, db: &'db dyn Db) -> Self { Self { ty: self.ty.normalized(db), qualifiers: self.qualifiers, @@ -184,7 +195,7 @@ fn excluded_from_proto_members(member: &str) -> bool { } /// Inner Salsa query for [`ProtocolClassLiteral::interface`]. -#[salsa::tracked(return_ref, cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)] +#[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)] fn cached_protocol_interface<'db>( db: &'db dyn Db, class: ClassLiteral<'db>, @@ -240,9 +251,10 @@ fn cached_protocol_interface<'db>( ); } - ProtocolInterface(members) + ProtocolInterface::new(db, members) } +#[allow(clippy::trivially_copy_pass_by_ref)] fn proto_interface_cycle_recover<'db>( _db: &dyn Db, _value: &ProtocolInterface<'db>, @@ -253,8 +265,33 @@ fn proto_interface_cycle_recover<'db>( } fn proto_interface_cycle_initial<'db>( - _db: &dyn Db, + db: &'db dyn Db, _class: ClassLiteral<'db>, ) -> ProtocolInterface<'db> { - ProtocolInterface::default() + ProtocolInterface::new(db, BTreeMap::default()) +} + +#[salsa::tracked(cycle_fn=is_fully_static_cycle_recover, cycle_initial=is_fully_static_cycle_initial)] +fn cached_is_fully_static<'db>(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> bool { + interface + .members(db) + .all(|member| member.ty.is_fully_static(db)) +} + +#[allow(clippy::trivially_copy_pass_by_ref)] +fn is_fully_static_cycle_recover( + _db: &dyn Db, + _value: &bool, + _count: u32, + _interface: ProtocolInterface<'_>, +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Iterate +} + +fn is_fully_static_cycle_initial<'db>( + _db: &'db dyn Db, + _interface: ProtocolInterface<'db>, +) -> bool { + // Assume that the protocol is fully static until we find members that indicate otherwise. + true }