[ty] Protocols: Fixpoint iteration for fully-static check (#17880)

## Summary

A recursive protocol like the following would previously lead to stack
overflows when attempting to create the union type for the `P | None`
member, because `UnionBuilder` checks if element types are fully static,
and the fully-static check on `P` would in turn list all members and
check whether all of them were fully static, leading to a cycle.

```py
from __future__ import annotations

from typing import Protocol

class P(Protocol):
    parent: P | None
```

Here, we make the fully-static check on protocols a salsa query and add
fixpoint iteration, starting with `true` as the initial value (assume
that the recursive protocol is fully-static). If the recursive protocol
has any non-fully-static members, we still return `false` when
re-executing the query (see newly added tests).

closes #17861

## Test Plan

Added regression test
This commit is contained in:
David Peter 2025-05-07 08:55:21 +02:00 committed by GitHub
parent a33d0d4bf4
commit 04457f99b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 113 additions and 48 deletions

View file

@ -1558,7 +1558,43 @@ def two(some_list: list, some_tuple: tuple[int, str], some_sized: Sized):
c: Sized = some_sized
```
## Regression test: narrowing with self-referential protocols
## Recursive protocols
### Properties
```py
from __future__ import annotations
from typing import Protocol, Any
from ty_extensions import is_fully_static, static_assert, is_assignable_to, is_subtype_of, is_equivalent_to
class RecursiveFullyStatic(Protocol):
parent: RecursiveFullyStatic | None
x: int
class RecursiveNonFullyStatic(Protocol):
parent: RecursiveNonFullyStatic | None
x: Any
static_assert(is_fully_static(RecursiveFullyStatic))
static_assert(not is_fully_static(RecursiveNonFullyStatic))
static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic))
static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic))
# TODO: currently leads to a stack overflow
# static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic))
# static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic))
class AlsoRecursiveFullyStatic(Protocol):
parent: AlsoRecursiveFullyStatic | None
x: int
# TODO: currently leads to a stack overflow
# static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic))
```
### Regression test: narrowing with self-referential protocols
This snippet caused us to panic on an early version of the implementation for protocols.

View file

@ -622,7 +622,7 @@ impl<'db> Bindings<'db> {
db,
protocol_class
.interface(db)
.members()
.members(db)
.map(|member| Type::string_literal(db, member.name()))
.collect::<Box<[Type<'db>]>>(),
)));

View file

@ -83,7 +83,8 @@ impl Display for DisplayRepresentation<'_> {
Protocol::FromClass(ClassType::Generic(alias)) => alias.display(self.db).fmt(f),
Protocol::Synthesized(synthetic) => {
f.write_str("<Protocol with members ")?;
let member_list = synthetic.interface(self.db).members();
let interface = synthetic.interface();
let member_list = interface.members(self.db);
let num_members = member_list.len();
for (i, member) in member_list.enumerate() {
let is_last = i == num_members - 1;

View file

@ -38,7 +38,7 @@ impl<'db> Type<'db> {
protocol
.0
.interface(db)
.members()
.members(db)
.all(|member| !self.member(db, member.name()).symbol.is_unbound())
}
}
@ -177,7 +177,7 @@ impl<'db> ProtocolInstanceType<'db> {
}
match self.0 {
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
SynthesizedProtocolType::new(db, self.0.interface(db).clone()),
SynthesizedProtocolType::new(db, self.0.interface(db)),
))),
Protocol::Synthesized(_) => Type::ProtocolInstance(self),
}
@ -205,7 +205,7 @@ impl<'db> ProtocolInstanceType<'db> {
other
.0
.interface(db)
.is_sub_interface_of(self.0.interface(db))
.is_sub_interface_of(db, self.0.interface(db))
}
/// Return `true` if this protocol type is equivalent to the protocol `other`.
@ -237,8 +237,8 @@ impl<'db> ProtocolInstanceType<'db> {
match self.inner() {
Protocol::FromClass(class) => class.instance_member(db, name),
Protocol::Synthesized(synthesized) => synthesized
.interface(db)
.member_by_name(name)
.interface()
.member_by_name(db, name)
.map(|member| SymbolAndQualifiers {
symbol: Symbol::bound(member.ty()),
qualifiers: member.qualifiers(),
@ -258,7 +258,7 @@ pub(super) enum Protocol<'db> {
impl<'db> Protocol<'db> {
/// Return the members of this protocol type
fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
match self {
Self::FromClass(class) => class
.class_literal(db)
@ -266,7 +266,7 @@ impl<'db> Protocol<'db> {
.into_protocol_class(db)
.expect("Protocol class literal should be a protocol class")
.interface(db),
Self::Synthesized(synthesized) => synthesized.interface(db),
Self::Synthesized(synthesized) => synthesized.interface(),
}
}
}
@ -285,24 +285,15 @@ mod synthesized_protocol {
/// The constructor method of this type maintains the invariant that a synthesized protocol type
/// is always constructed from a *normalized* protocol interface.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
pub(in crate::types) struct SynthesizedProtocolType<'db>(SynthesizedProtocolTypeInner<'db>);
pub(in crate::types) struct SynthesizedProtocolType<'db>(ProtocolInterface<'db>);
impl<'db> SynthesizedProtocolType<'db> {
pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self {
Self(SynthesizedProtocolTypeInner::new(
db,
interface.normalized(db),
))
Self(interface.normalized(db))
}
pub(in crate::types) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
self.0.interface(db)
pub(in crate::types) fn interface(self) -> ProtocolInterface<'db> {
self.0
}
}
#[salsa::interned(debug)]
struct SynthesizedProtocolTypeInner<'db> {
#[return_ref]
interface: ProtocolInterface<'db>,
}
}

View file

@ -37,7 +37,7 @@ impl<'db> ProtocolClassLiteral<'db> {
/// It is illegal for a protocol class to have any instance attributes that are not declared
/// in the protocol's class body. If any are assigned to, they are not taken into account in
/// the protocol's list of members.
pub(super) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
let _span = tracing::trace_span!("protocol_members", "class='{}'", self.name(db)).entered();
cached_protocol_interface(db, *self)
}
@ -57,21 +57,31 @@ impl<'db> Deref for ProtocolClassLiteral<'db> {
}
/// The interface of a protocol: the members of that protocol, and the types of those members.
#[derive(Debug, PartialEq, Eq, salsa::Update, Default, Clone, Hash)]
pub(super) struct ProtocolInterface<'db>(BTreeMap<Name, ProtocolMemberData<'db>>);
#[salsa::interned(debug)]
pub(super) struct ProtocolInterface<'db> {
#[return_ref]
_members: BTreeMap<Name, ProtocolMemberData<'db>>,
}
impl<'db> ProtocolInterface<'db> {
/// Iterate over the members of this protocol.
pub(super) fn members<'a>(&'a self) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> {
self.0.iter().map(|(name, data)| ProtocolMember {
pub(super) fn members<'a>(
&'a self,
db: &'db dyn Db,
) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> {
self._members(db).iter().map(|(name, data)| ProtocolMember {
name,
ty: data.ty,
qualifiers: data.qualifiers,
})
}
pub(super) fn member_by_name<'a>(&self, name: &'a str) -> Option<ProtocolMember<'a, 'db>> {
self.0.get(name).map(|data| ProtocolMember {
pub(super) fn member_by_name<'a>(
self,
db: &'db dyn Db,
name: &'a str,
) -> Option<ProtocolMember<'a, 'db>> {
self._members(db).get(name).map(|data| ProtocolMember {
name,
ty: data.ty,
qualifiers: data.qualifiers,
@ -79,42 +89,43 @@ impl<'db> ProtocolInterface<'db> {
}
/// Return `true` if all members of this protocol are fully static.
pub(super) fn is_fully_static(&self, db: &'db dyn Db) -> bool {
self.members().all(|member| member.ty.is_fully_static(db))
pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool {
cached_is_fully_static(db, self)
}
/// Return `true` if if all members on `self` are also members of `other`.
///
/// TODO: this method should consider the types of the members as well as their names.
pub(super) fn is_sub_interface_of(&self, other: &Self) -> bool {
self.0
pub(super) fn is_sub_interface_of(self, db: &'db dyn Db, other: Self) -> bool {
self._members(db)
.keys()
.all(|member_name| other.0.contains_key(member_name))
.all(|member_name| other._members(db).contains_key(member_name))
}
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool {
self.members().any(|member| member.ty.contains_todo(db))
pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool {
self.members(db).any(|member| member.ty.contains_todo(db))
}
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
Self(
self.0
.into_iter()
.map(|(name, data)| (name, data.normalized(db)))
.collect(),
Self::new(
db,
self._members(db)
.iter()
.map(|(name, data)| (name.clone(), data.normalized(db)))
.collect::<BTreeMap<_, _>>(),
)
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
struct ProtocolMemberData<'db> {
pub(super) struct ProtocolMemberData<'db> {
ty: Type<'db>,
qualifiers: TypeQualifiers,
}
impl<'db> ProtocolMemberData<'db> {
fn normalized(self, db: &'db dyn Db) -> Self {
fn normalized(&self, db: &'db dyn Db) -> Self {
Self {
ty: self.ty.normalized(db),
qualifiers: self.qualifiers,
@ -184,7 +195,7 @@ fn excluded_from_proto_members(member: &str) -> bool {
}
/// Inner Salsa query for [`ProtocolClassLiteral::interface`].
#[salsa::tracked(return_ref, cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)]
#[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)]
fn cached_protocol_interface<'db>(
db: &'db dyn Db,
class: ClassLiteral<'db>,
@ -240,9 +251,10 @@ fn cached_protocol_interface<'db>(
);
}
ProtocolInterface(members)
ProtocolInterface::new(db, members)
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn proto_interface_cycle_recover<'db>(
_db: &dyn Db,
_value: &ProtocolInterface<'db>,
@ -253,8 +265,33 @@ fn proto_interface_cycle_recover<'db>(
}
fn proto_interface_cycle_initial<'db>(
_db: &dyn Db,
db: &'db dyn Db,
_class: ClassLiteral<'db>,
) -> ProtocolInterface<'db> {
ProtocolInterface::default()
ProtocolInterface::new(db, BTreeMap::default())
}
#[salsa::tracked(cycle_fn=is_fully_static_cycle_recover, cycle_initial=is_fully_static_cycle_initial)]
fn cached_is_fully_static<'db>(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> bool {
interface
.members(db)
.all(|member| member.ty.is_fully_static(db))
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_fully_static_cycle_recover(
_db: &dyn Db,
_value: &bool,
_count: u32,
_interface: ProtocolInterface<'_>,
) -> salsa::CycleRecoveryAction<bool> {
salsa::CycleRecoveryAction::Iterate
}
fn is_fully_static_cycle_initial<'db>(
_db: &'db dyn Db,
_interface: ProtocolInterface<'db>,
) -> bool {
// Assume that the protocol is fully static until we find members that indicate otherwise.
true
}