mirror of
https://github.com/astral-sh/ruff.git
synced 2025-07-24 13:33:50 +00:00
[ty] Protocols: Fixpoint iteration for fully-static check (#17880)
## Summary A recursive protocol like the following would previously lead to stack overflows when attempting to create the union type for the `P | None` member, because `UnionBuilder` checks if element types are fully static, and the fully-static check on `P` would in turn list all members and check whether all of them were fully static, leading to a cycle. ```py from __future__ import annotations from typing import Protocol class P(Protocol): parent: P | None ``` Here, we make the fully-static check on protocols a salsa query and add fixpoint iteration, starting with `true` as the initial value (assume that the recursive protocol is fully-static). If the recursive protocol has any non-fully-static members, we still return `false` when re-executing the query (see newly added tests). closes #17861 ## Test Plan Added regression test
This commit is contained in:
parent
a33d0d4bf4
commit
04457f99b6
5 changed files with 113 additions and 48 deletions
|
@ -1558,7 +1558,43 @@ def two(some_list: list, some_tuple: tuple[int, str], some_sized: Sized):
|
|||
c: Sized = some_sized
|
||||
```
|
||||
|
||||
## Regression test: narrowing with self-referential protocols
|
||||
## Recursive protocols
|
||||
|
||||
### Properties
|
||||
|
||||
```py
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, Any
|
||||
from ty_extensions import is_fully_static, static_assert, is_assignable_to, is_subtype_of, is_equivalent_to
|
||||
|
||||
class RecursiveFullyStatic(Protocol):
|
||||
parent: RecursiveFullyStatic | None
|
||||
x: int
|
||||
|
||||
class RecursiveNonFullyStatic(Protocol):
|
||||
parent: RecursiveNonFullyStatic | None
|
||||
x: Any
|
||||
|
||||
static_assert(is_fully_static(RecursiveFullyStatic))
|
||||
static_assert(not is_fully_static(RecursiveNonFullyStatic))
|
||||
|
||||
static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic))
|
||||
static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic))
|
||||
|
||||
# TODO: currently leads to a stack overflow
|
||||
# static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic))
|
||||
# static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic))
|
||||
|
||||
class AlsoRecursiveFullyStatic(Protocol):
|
||||
parent: AlsoRecursiveFullyStatic | None
|
||||
x: int
|
||||
|
||||
# TODO: currently leads to a stack overflow
|
||||
# static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic))
|
||||
```
|
||||
|
||||
### Regression test: narrowing with self-referential protocols
|
||||
|
||||
This snippet caused us to panic on an early version of the implementation for protocols.
|
||||
|
||||
|
|
|
@ -622,7 +622,7 @@ impl<'db> Bindings<'db> {
|
|||
db,
|
||||
protocol_class
|
||||
.interface(db)
|
||||
.members()
|
||||
.members(db)
|
||||
.map(|member| Type::string_literal(db, member.name()))
|
||||
.collect::<Box<[Type<'db>]>>(),
|
||||
)));
|
||||
|
|
|
@ -83,7 +83,8 @@ impl Display for DisplayRepresentation<'_> {
|
|||
Protocol::FromClass(ClassType::Generic(alias)) => alias.display(self.db).fmt(f),
|
||||
Protocol::Synthesized(synthetic) => {
|
||||
f.write_str("<Protocol with members ")?;
|
||||
let member_list = synthetic.interface(self.db).members();
|
||||
let interface = synthetic.interface();
|
||||
let member_list = interface.members(self.db);
|
||||
let num_members = member_list.len();
|
||||
for (i, member) in member_list.enumerate() {
|
||||
let is_last = i == num_members - 1;
|
||||
|
|
|
@ -38,7 +38,7 @@ impl<'db> Type<'db> {
|
|||
protocol
|
||||
.0
|
||||
.interface(db)
|
||||
.members()
|
||||
.members(db)
|
||||
.all(|member| !self.member(db, member.name()).symbol.is_unbound())
|
||||
}
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ impl<'db> ProtocolInstanceType<'db> {
|
|||
}
|
||||
match self.0 {
|
||||
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
|
||||
SynthesizedProtocolType::new(db, self.0.interface(db).clone()),
|
||||
SynthesizedProtocolType::new(db, self.0.interface(db)),
|
||||
))),
|
||||
Protocol::Synthesized(_) => Type::ProtocolInstance(self),
|
||||
}
|
||||
|
@ -205,7 +205,7 @@ impl<'db> ProtocolInstanceType<'db> {
|
|||
other
|
||||
.0
|
||||
.interface(db)
|
||||
.is_sub_interface_of(self.0.interface(db))
|
||||
.is_sub_interface_of(db, self.0.interface(db))
|
||||
}
|
||||
|
||||
/// Return `true` if this protocol type is equivalent to the protocol `other`.
|
||||
|
@ -237,8 +237,8 @@ impl<'db> ProtocolInstanceType<'db> {
|
|||
match self.inner() {
|
||||
Protocol::FromClass(class) => class.instance_member(db, name),
|
||||
Protocol::Synthesized(synthesized) => synthesized
|
||||
.interface(db)
|
||||
.member_by_name(name)
|
||||
.interface()
|
||||
.member_by_name(db, name)
|
||||
.map(|member| SymbolAndQualifiers {
|
||||
symbol: Symbol::bound(member.ty()),
|
||||
qualifiers: member.qualifiers(),
|
||||
|
@ -258,7 +258,7 @@ pub(super) enum Protocol<'db> {
|
|||
|
||||
impl<'db> Protocol<'db> {
|
||||
/// Return the members of this protocol type
|
||||
fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
|
||||
fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
|
||||
match self {
|
||||
Self::FromClass(class) => class
|
||||
.class_literal(db)
|
||||
|
@ -266,7 +266,7 @@ impl<'db> Protocol<'db> {
|
|||
.into_protocol_class(db)
|
||||
.expect("Protocol class literal should be a protocol class")
|
||||
.interface(db),
|
||||
Self::Synthesized(synthesized) => synthesized.interface(db),
|
||||
Self::Synthesized(synthesized) => synthesized.interface(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -285,24 +285,15 @@ mod synthesized_protocol {
|
|||
/// The constructor method of this type maintains the invariant that a synthesized protocol type
|
||||
/// is always constructed from a *normalized* protocol interface.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
|
||||
pub(in crate::types) struct SynthesizedProtocolType<'db>(SynthesizedProtocolTypeInner<'db>);
|
||||
pub(in crate::types) struct SynthesizedProtocolType<'db>(ProtocolInterface<'db>);
|
||||
|
||||
impl<'db> SynthesizedProtocolType<'db> {
|
||||
pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self {
|
||||
Self(SynthesizedProtocolTypeInner::new(
|
||||
db,
|
||||
interface.normalized(db),
|
||||
))
|
||||
Self(interface.normalized(db))
|
||||
}
|
||||
|
||||
pub(in crate::types) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
|
||||
self.0.interface(db)
|
||||
pub(in crate::types) fn interface(self) -> ProtocolInterface<'db> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::interned(debug)]
|
||||
struct SynthesizedProtocolTypeInner<'db> {
|
||||
#[return_ref]
|
||||
interface: ProtocolInterface<'db>,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ impl<'db> ProtocolClassLiteral<'db> {
|
|||
/// It is illegal for a protocol class to have any instance attributes that are not declared
|
||||
/// in the protocol's class body. If any are assigned to, they are not taken into account in
|
||||
/// the protocol's list of members.
|
||||
pub(super) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
|
||||
pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
|
||||
let _span = tracing::trace_span!("protocol_members", "class='{}'", self.name(db)).entered();
|
||||
cached_protocol_interface(db, *self)
|
||||
}
|
||||
|
@ -57,21 +57,31 @@ impl<'db> Deref for ProtocolClassLiteral<'db> {
|
|||
}
|
||||
|
||||
/// The interface of a protocol: the members of that protocol, and the types of those members.
|
||||
#[derive(Debug, PartialEq, Eq, salsa::Update, Default, Clone, Hash)]
|
||||
pub(super) struct ProtocolInterface<'db>(BTreeMap<Name, ProtocolMemberData<'db>>);
|
||||
#[salsa::interned(debug)]
|
||||
pub(super) struct ProtocolInterface<'db> {
|
||||
#[return_ref]
|
||||
_members: BTreeMap<Name, ProtocolMemberData<'db>>,
|
||||
}
|
||||
|
||||
impl<'db> ProtocolInterface<'db> {
|
||||
/// Iterate over the members of this protocol.
|
||||
pub(super) fn members<'a>(&'a self) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> {
|
||||
self.0.iter().map(|(name, data)| ProtocolMember {
|
||||
pub(super) fn members<'a>(
|
||||
&'a self,
|
||||
db: &'db dyn Db,
|
||||
) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> {
|
||||
self._members(db).iter().map(|(name, data)| ProtocolMember {
|
||||
name,
|
||||
ty: data.ty,
|
||||
qualifiers: data.qualifiers,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn member_by_name<'a>(&self, name: &'a str) -> Option<ProtocolMember<'a, 'db>> {
|
||||
self.0.get(name).map(|data| ProtocolMember {
|
||||
pub(super) fn member_by_name<'a>(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
name: &'a str,
|
||||
) -> Option<ProtocolMember<'a, 'db>> {
|
||||
self._members(db).get(name).map(|data| ProtocolMember {
|
||||
name,
|
||||
ty: data.ty,
|
||||
qualifiers: data.qualifiers,
|
||||
|
@ -79,42 +89,43 @@ impl<'db> ProtocolInterface<'db> {
|
|||
}
|
||||
|
||||
/// Return `true` if all members of this protocol are fully static.
|
||||
pub(super) fn is_fully_static(&self, db: &'db dyn Db) -> bool {
|
||||
self.members().all(|member| member.ty.is_fully_static(db))
|
||||
pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool {
|
||||
cached_is_fully_static(db, self)
|
||||
}
|
||||
|
||||
/// Return `true` if if all members on `self` are also members of `other`.
|
||||
///
|
||||
/// TODO: this method should consider the types of the members as well as their names.
|
||||
pub(super) fn is_sub_interface_of(&self, other: &Self) -> bool {
|
||||
self.0
|
||||
pub(super) fn is_sub_interface_of(self, db: &'db dyn Db, other: Self) -> bool {
|
||||
self._members(db)
|
||||
.keys()
|
||||
.all(|member_name| other.0.contains_key(member_name))
|
||||
.all(|member_name| other._members(db).contains_key(member_name))
|
||||
}
|
||||
|
||||
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
|
||||
pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool {
|
||||
self.members().any(|member| member.ty.contains_todo(db))
|
||||
pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool {
|
||||
self.members(db).any(|member| member.ty.contains_todo(db))
|
||||
}
|
||||
|
||||
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
|
||||
Self(
|
||||
self.0
|
||||
.into_iter()
|
||||
.map(|(name, data)| (name, data.normalized(db)))
|
||||
.collect(),
|
||||
Self::new(
|
||||
db,
|
||||
self._members(db)
|
||||
.iter()
|
||||
.map(|(name, data)| (name.clone(), data.normalized(db)))
|
||||
.collect::<BTreeMap<_, _>>(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
|
||||
struct ProtocolMemberData<'db> {
|
||||
pub(super) struct ProtocolMemberData<'db> {
|
||||
ty: Type<'db>,
|
||||
qualifiers: TypeQualifiers,
|
||||
}
|
||||
|
||||
impl<'db> ProtocolMemberData<'db> {
|
||||
fn normalized(self, db: &'db dyn Db) -> Self {
|
||||
fn normalized(&self, db: &'db dyn Db) -> Self {
|
||||
Self {
|
||||
ty: self.ty.normalized(db),
|
||||
qualifiers: self.qualifiers,
|
||||
|
@ -184,7 +195,7 @@ fn excluded_from_proto_members(member: &str) -> bool {
|
|||
}
|
||||
|
||||
/// Inner Salsa query for [`ProtocolClassLiteral::interface`].
|
||||
#[salsa::tracked(return_ref, cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)]
|
||||
#[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)]
|
||||
fn cached_protocol_interface<'db>(
|
||||
db: &'db dyn Db,
|
||||
class: ClassLiteral<'db>,
|
||||
|
@ -240,9 +251,10 @@ fn cached_protocol_interface<'db>(
|
|||
);
|
||||
}
|
||||
|
||||
ProtocolInterface(members)
|
||||
ProtocolInterface::new(db, members)
|
||||
}
|
||||
|
||||
#[allow(clippy::trivially_copy_pass_by_ref)]
|
||||
fn proto_interface_cycle_recover<'db>(
|
||||
_db: &dyn Db,
|
||||
_value: &ProtocolInterface<'db>,
|
||||
|
@ -253,8 +265,33 @@ fn proto_interface_cycle_recover<'db>(
|
|||
}
|
||||
|
||||
fn proto_interface_cycle_initial<'db>(
|
||||
_db: &dyn Db,
|
||||
db: &'db dyn Db,
|
||||
_class: ClassLiteral<'db>,
|
||||
) -> ProtocolInterface<'db> {
|
||||
ProtocolInterface::default()
|
||||
ProtocolInterface::new(db, BTreeMap::default())
|
||||
}
|
||||
|
||||
#[salsa::tracked(cycle_fn=is_fully_static_cycle_recover, cycle_initial=is_fully_static_cycle_initial)]
|
||||
fn cached_is_fully_static<'db>(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> bool {
|
||||
interface
|
||||
.members(db)
|
||||
.all(|member| member.ty.is_fully_static(db))
|
||||
}
|
||||
|
||||
#[allow(clippy::trivially_copy_pass_by_ref)]
|
||||
fn is_fully_static_cycle_recover(
|
||||
_db: &dyn Db,
|
||||
_value: &bool,
|
||||
_count: u32,
|
||||
_interface: ProtocolInterface<'_>,
|
||||
) -> salsa::CycleRecoveryAction<bool> {
|
||||
salsa::CycleRecoveryAction::Iterate
|
||||
}
|
||||
|
||||
fn is_fully_static_cycle_initial<'db>(
|
||||
_db: &'db dyn Db,
|
||||
_interface: ProtocolInterface<'db>,
|
||||
) -> bool {
|
||||
// Assume that the protocol is fully static until we find members that indicate otherwise.
|
||||
true
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue