[ty] Protocols: Fixpoint iteration for fully-static check (#17880)

## Summary

A recursive protocol like the following would previously lead to stack
overflows when attempting to create the union type for the `P | None`
member, because `UnionBuilder` checks if element types are fully static,
and the fully-static check on `P` would in turn list all members and
check whether all of them were fully static, leading to a cycle.

```py
from __future__ import annotations

from typing import Protocol

class P(Protocol):
    parent: P | None
```

Here, we make the fully-static check on protocols a salsa query and add
fixpoint iteration, starting with `true` as the initial value (assume
that the recursive protocol is fully-static). If the recursive protocol
has any non-fully-static members, we still return `false` when
re-executing the query (see newly added tests).

closes #17861

## Test Plan

Added regression test
This commit is contained in:
David Peter 2025-05-07 08:55:21 +02:00 committed by GitHub
parent a33d0d4bf4
commit 04457f99b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 113 additions and 48 deletions

View file

@ -1558,7 +1558,43 @@ def two(some_list: list, some_tuple: tuple[int, str], some_sized: Sized):
c: Sized = some_sized c: Sized = some_sized
``` ```
## Regression test: narrowing with self-referential protocols ## Recursive protocols
### Properties
```py
from __future__ import annotations
from typing import Protocol, Any
from ty_extensions import is_fully_static, static_assert, is_assignable_to, is_subtype_of, is_equivalent_to
class RecursiveFullyStatic(Protocol):
parent: RecursiveFullyStatic | None
x: int
class RecursiveNonFullyStatic(Protocol):
parent: RecursiveNonFullyStatic | None
x: Any
static_assert(is_fully_static(RecursiveFullyStatic))
static_assert(not is_fully_static(RecursiveNonFullyStatic))
static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic))
static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic))
# TODO: currently leads to a stack overflow
# static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic))
# static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic))
class AlsoRecursiveFullyStatic(Protocol):
parent: AlsoRecursiveFullyStatic | None
x: int
# TODO: currently leads to a stack overflow
# static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic))
```
### Regression test: narrowing with self-referential protocols
This snippet caused us to panic on an early version of the implementation for protocols. This snippet caused us to panic on an early version of the implementation for protocols.

View file

@ -622,7 +622,7 @@ impl<'db> Bindings<'db> {
db, db,
protocol_class protocol_class
.interface(db) .interface(db)
.members() .members(db)
.map(|member| Type::string_literal(db, member.name())) .map(|member| Type::string_literal(db, member.name()))
.collect::<Box<[Type<'db>]>>(), .collect::<Box<[Type<'db>]>>(),
))); )));

View file

@ -83,7 +83,8 @@ impl Display for DisplayRepresentation<'_> {
Protocol::FromClass(ClassType::Generic(alias)) => alias.display(self.db).fmt(f), Protocol::FromClass(ClassType::Generic(alias)) => alias.display(self.db).fmt(f),
Protocol::Synthesized(synthetic) => { Protocol::Synthesized(synthetic) => {
f.write_str("<Protocol with members ")?; f.write_str("<Protocol with members ")?;
let member_list = synthetic.interface(self.db).members(); let interface = synthetic.interface();
let member_list = interface.members(self.db);
let num_members = member_list.len(); let num_members = member_list.len();
for (i, member) in member_list.enumerate() { for (i, member) in member_list.enumerate() {
let is_last = i == num_members - 1; let is_last = i == num_members - 1;

View file

@ -38,7 +38,7 @@ impl<'db> Type<'db> {
protocol protocol
.0 .0
.interface(db) .interface(db)
.members() .members(db)
.all(|member| !self.member(db, member.name()).symbol.is_unbound()) .all(|member| !self.member(db, member.name()).symbol.is_unbound())
} }
} }
@ -177,7 +177,7 @@ impl<'db> ProtocolInstanceType<'db> {
} }
match self.0 { match self.0 {
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized( Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
SynthesizedProtocolType::new(db, self.0.interface(db).clone()), SynthesizedProtocolType::new(db, self.0.interface(db)),
))), ))),
Protocol::Synthesized(_) => Type::ProtocolInstance(self), Protocol::Synthesized(_) => Type::ProtocolInstance(self),
} }
@ -205,7 +205,7 @@ impl<'db> ProtocolInstanceType<'db> {
other other
.0 .0
.interface(db) .interface(db)
.is_sub_interface_of(self.0.interface(db)) .is_sub_interface_of(db, self.0.interface(db))
} }
/// Return `true` if this protocol type is equivalent to the protocol `other`. /// Return `true` if this protocol type is equivalent to the protocol `other`.
@ -237,8 +237,8 @@ impl<'db> ProtocolInstanceType<'db> {
match self.inner() { match self.inner() {
Protocol::FromClass(class) => class.instance_member(db, name), Protocol::FromClass(class) => class.instance_member(db, name),
Protocol::Synthesized(synthesized) => synthesized Protocol::Synthesized(synthesized) => synthesized
.interface(db) .interface()
.member_by_name(name) .member_by_name(db, name)
.map(|member| SymbolAndQualifiers { .map(|member| SymbolAndQualifiers {
symbol: Symbol::bound(member.ty()), symbol: Symbol::bound(member.ty()),
qualifiers: member.qualifiers(), qualifiers: member.qualifiers(),
@ -258,7 +258,7 @@ pub(super) enum Protocol<'db> {
impl<'db> Protocol<'db> { impl<'db> Protocol<'db> {
/// Return the members of this protocol type /// Return the members of this protocol type
fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> { fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
match self { match self {
Self::FromClass(class) => class Self::FromClass(class) => class
.class_literal(db) .class_literal(db)
@ -266,7 +266,7 @@ impl<'db> Protocol<'db> {
.into_protocol_class(db) .into_protocol_class(db)
.expect("Protocol class literal should be a protocol class") .expect("Protocol class literal should be a protocol class")
.interface(db), .interface(db),
Self::Synthesized(synthesized) => synthesized.interface(db), Self::Synthesized(synthesized) => synthesized.interface(),
} }
} }
} }
@ -285,24 +285,15 @@ mod synthesized_protocol {
/// The constructor method of this type maintains the invariant that a synthesized protocol type /// The constructor method of this type maintains the invariant that a synthesized protocol type
/// is always constructed from a *normalized* protocol interface. /// is always constructed from a *normalized* protocol interface.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)] #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
pub(in crate::types) struct SynthesizedProtocolType<'db>(SynthesizedProtocolTypeInner<'db>); pub(in crate::types) struct SynthesizedProtocolType<'db>(ProtocolInterface<'db>);
impl<'db> SynthesizedProtocolType<'db> { impl<'db> SynthesizedProtocolType<'db> {
pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self { pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self {
Self(SynthesizedProtocolTypeInner::new( Self(interface.normalized(db))
db,
interface.normalized(db),
))
} }
pub(in crate::types) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> { pub(in crate::types) fn interface(self) -> ProtocolInterface<'db> {
self.0.interface(db) self.0
} }
} }
#[salsa::interned(debug)]
struct SynthesizedProtocolTypeInner<'db> {
#[return_ref]
interface: ProtocolInterface<'db>,
}
} }

View file

@ -37,7 +37,7 @@ impl<'db> ProtocolClassLiteral<'db> {
/// It is illegal for a protocol class to have any instance attributes that are not declared /// It is illegal for a protocol class to have any instance attributes that are not declared
/// in the protocol's class body. If any are assigned to, they are not taken into account in /// in the protocol's class body. If any are assigned to, they are not taken into account in
/// the protocol's list of members. /// the protocol's list of members.
pub(super) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> { pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
let _span = tracing::trace_span!("protocol_members", "class='{}'", self.name(db)).entered(); let _span = tracing::trace_span!("protocol_members", "class='{}'", self.name(db)).entered();
cached_protocol_interface(db, *self) cached_protocol_interface(db, *self)
} }
@ -57,21 +57,31 @@ impl<'db> Deref for ProtocolClassLiteral<'db> {
} }
/// The interface of a protocol: the members of that protocol, and the types of those members. /// The interface of a protocol: the members of that protocol, and the types of those members.
#[derive(Debug, PartialEq, Eq, salsa::Update, Default, Clone, Hash)] #[salsa::interned(debug)]
pub(super) struct ProtocolInterface<'db>(BTreeMap<Name, ProtocolMemberData<'db>>); pub(super) struct ProtocolInterface<'db> {
#[return_ref]
_members: BTreeMap<Name, ProtocolMemberData<'db>>,
}
impl<'db> ProtocolInterface<'db> { impl<'db> ProtocolInterface<'db> {
/// Iterate over the members of this protocol. /// Iterate over the members of this protocol.
pub(super) fn members<'a>(&'a self) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> { pub(super) fn members<'a>(
self.0.iter().map(|(name, data)| ProtocolMember { &'a self,
db: &'db dyn Db,
) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> {
self._members(db).iter().map(|(name, data)| ProtocolMember {
name, name,
ty: data.ty, ty: data.ty,
qualifiers: data.qualifiers, qualifiers: data.qualifiers,
}) })
} }
pub(super) fn member_by_name<'a>(&self, name: &'a str) -> Option<ProtocolMember<'a, 'db>> { pub(super) fn member_by_name<'a>(
self.0.get(name).map(|data| ProtocolMember { self,
db: &'db dyn Db,
name: &'a str,
) -> Option<ProtocolMember<'a, 'db>> {
self._members(db).get(name).map(|data| ProtocolMember {
name, name,
ty: data.ty, ty: data.ty,
qualifiers: data.qualifiers, qualifiers: data.qualifiers,
@ -79,42 +89,43 @@ impl<'db> ProtocolInterface<'db> {
} }
/// Return `true` if all members of this protocol are fully static. /// Return `true` if all members of this protocol are fully static.
pub(super) fn is_fully_static(&self, db: &'db dyn Db) -> bool { pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool {
self.members().all(|member| member.ty.is_fully_static(db)) cached_is_fully_static(db, self)
} }
/// Return `true` if if all members on `self` are also members of `other`. /// Return `true` if if all members on `self` are also members of `other`.
/// ///
/// TODO: this method should consider the types of the members as well as their names. /// TODO: this method should consider the types of the members as well as their names.
pub(super) fn is_sub_interface_of(&self, other: &Self) -> bool { pub(super) fn is_sub_interface_of(self, db: &'db dyn Db, other: Self) -> bool {
self.0 self._members(db)
.keys() .keys()
.all(|member_name| other.0.contains_key(member_name)) .all(|member_name| other._members(db).contains_key(member_name))
} }
/// Return `true` if any of the members of this protocol type contain any `Todo` types. /// Return `true` if any of the members of this protocol type contain any `Todo` types.
pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool { pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool {
self.members().any(|member| member.ty.contains_todo(db)) self.members(db).any(|member| member.ty.contains_todo(db))
} }
pub(super) fn normalized(self, db: &'db dyn Db) -> Self { pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
Self( Self::new(
self.0 db,
.into_iter() self._members(db)
.map(|(name, data)| (name, data.normalized(db))) .iter()
.collect(), .map(|(name, data)| (name.clone(), data.normalized(db)))
.collect::<BTreeMap<_, _>>(),
) )
} }
} }
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)] #[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
struct ProtocolMemberData<'db> { pub(super) struct ProtocolMemberData<'db> {
ty: Type<'db>, ty: Type<'db>,
qualifiers: TypeQualifiers, qualifiers: TypeQualifiers,
} }
impl<'db> ProtocolMemberData<'db> { impl<'db> ProtocolMemberData<'db> {
fn normalized(self, db: &'db dyn Db) -> Self { fn normalized(&self, db: &'db dyn Db) -> Self {
Self { Self {
ty: self.ty.normalized(db), ty: self.ty.normalized(db),
qualifiers: self.qualifiers, qualifiers: self.qualifiers,
@ -184,7 +195,7 @@ fn excluded_from_proto_members(member: &str) -> bool {
} }
/// Inner Salsa query for [`ProtocolClassLiteral::interface`]. /// Inner Salsa query for [`ProtocolClassLiteral::interface`].
#[salsa::tracked(return_ref, cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)] #[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)]
fn cached_protocol_interface<'db>( fn cached_protocol_interface<'db>(
db: &'db dyn Db, db: &'db dyn Db,
class: ClassLiteral<'db>, class: ClassLiteral<'db>,
@ -240,9 +251,10 @@ fn cached_protocol_interface<'db>(
); );
} }
ProtocolInterface(members) ProtocolInterface::new(db, members)
} }
#[allow(clippy::trivially_copy_pass_by_ref)]
fn proto_interface_cycle_recover<'db>( fn proto_interface_cycle_recover<'db>(
_db: &dyn Db, _db: &dyn Db,
_value: &ProtocolInterface<'db>, _value: &ProtocolInterface<'db>,
@ -253,8 +265,33 @@ fn proto_interface_cycle_recover<'db>(
} }
fn proto_interface_cycle_initial<'db>( fn proto_interface_cycle_initial<'db>(
_db: &dyn Db, db: &'db dyn Db,
_class: ClassLiteral<'db>, _class: ClassLiteral<'db>,
) -> ProtocolInterface<'db> { ) -> ProtocolInterface<'db> {
ProtocolInterface::default() ProtocolInterface::new(db, BTreeMap::default())
}
#[salsa::tracked(cycle_fn=is_fully_static_cycle_recover, cycle_initial=is_fully_static_cycle_initial)]
fn cached_is_fully_static<'db>(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> bool {
interface
.members(db)
.all(|member| member.ty.is_fully_static(db))
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_fully_static_cycle_recover(
_db: &dyn Db,
_value: &bool,
_count: u32,
_interface: ProtocolInterface<'_>,
) -> salsa::CycleRecoveryAction<bool> {
salsa::CycleRecoveryAction::Iterate
}
fn is_fully_static_cycle_initial<'db>(
_db: &'db dyn Db,
_interface: ProtocolInterface<'db>,
) -> bool {
// Assume that the protocol is fully static until we find members that indicate otherwise.
true
} }