mirror of
https://github.com/astral-sh/ruff.git
synced 2025-07-24 05:25:17 +00:00
[ty] Protocols: Fixpoint iteration for fully-static check (#17880)
## Summary A recursive protocol like the following would previously lead to stack overflows when attempting to create the union type for the `P | None` member, because `UnionBuilder` checks if element types are fully static, and the fully-static check on `P` would in turn list all members and check whether all of them were fully static, leading to a cycle. ```py from __future__ import annotations from typing import Protocol class P(Protocol): parent: P | None ``` Here, we make the fully-static check on protocols a salsa query and add fixpoint iteration, starting with `true` as the initial value (assume that the recursive protocol is fully-static). If the recursive protocol has any non-fully-static members, we still return `false` when re-executing the query (see newly added tests). closes #17861 ## Test Plan Added regression test
This commit is contained in:
parent
a33d0d4bf4
commit
04457f99b6
5 changed files with 113 additions and 48 deletions
|
@ -1558,7 +1558,43 @@ def two(some_list: list, some_tuple: tuple[int, str], some_sized: Sized):
|
||||||
c: Sized = some_sized
|
c: Sized = some_sized
|
||||||
```
|
```
|
||||||
|
|
||||||
## Regression test: narrowing with self-referential protocols
|
## Recursive protocols
|
||||||
|
|
||||||
|
### Properties
|
||||||
|
|
||||||
|
```py
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Protocol, Any
|
||||||
|
from ty_extensions import is_fully_static, static_assert, is_assignable_to, is_subtype_of, is_equivalent_to
|
||||||
|
|
||||||
|
class RecursiveFullyStatic(Protocol):
|
||||||
|
parent: RecursiveFullyStatic | None
|
||||||
|
x: int
|
||||||
|
|
||||||
|
class RecursiveNonFullyStatic(Protocol):
|
||||||
|
parent: RecursiveNonFullyStatic | None
|
||||||
|
x: Any
|
||||||
|
|
||||||
|
static_assert(is_fully_static(RecursiveFullyStatic))
|
||||||
|
static_assert(not is_fully_static(RecursiveNonFullyStatic))
|
||||||
|
|
||||||
|
static_assert(not is_subtype_of(RecursiveFullyStatic, RecursiveNonFullyStatic))
|
||||||
|
static_assert(not is_subtype_of(RecursiveNonFullyStatic, RecursiveFullyStatic))
|
||||||
|
|
||||||
|
# TODO: currently leads to a stack overflow
|
||||||
|
# static_assert(is_assignable_to(RecursiveFullyStatic, RecursiveNonFullyStatic))
|
||||||
|
# static_assert(is_assignable_to(RecursiveNonFullyStatic, RecursiveFullyStatic))
|
||||||
|
|
||||||
|
class AlsoRecursiveFullyStatic(Protocol):
|
||||||
|
parent: AlsoRecursiveFullyStatic | None
|
||||||
|
x: int
|
||||||
|
|
||||||
|
# TODO: currently leads to a stack overflow
|
||||||
|
# static_assert(is_equivalent_to(AlsoRecursiveFullyStatic, RecursiveFullyStatic))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Regression test: narrowing with self-referential protocols
|
||||||
|
|
||||||
This snippet caused us to panic on an early version of the implementation for protocols.
|
This snippet caused us to panic on an early version of the implementation for protocols.
|
||||||
|
|
||||||
|
|
|
@ -622,7 +622,7 @@ impl<'db> Bindings<'db> {
|
||||||
db,
|
db,
|
||||||
protocol_class
|
protocol_class
|
||||||
.interface(db)
|
.interface(db)
|
||||||
.members()
|
.members(db)
|
||||||
.map(|member| Type::string_literal(db, member.name()))
|
.map(|member| Type::string_literal(db, member.name()))
|
||||||
.collect::<Box<[Type<'db>]>>(),
|
.collect::<Box<[Type<'db>]>>(),
|
||||||
)));
|
)));
|
||||||
|
|
|
@ -83,7 +83,8 @@ impl Display for DisplayRepresentation<'_> {
|
||||||
Protocol::FromClass(ClassType::Generic(alias)) => alias.display(self.db).fmt(f),
|
Protocol::FromClass(ClassType::Generic(alias)) => alias.display(self.db).fmt(f),
|
||||||
Protocol::Synthesized(synthetic) => {
|
Protocol::Synthesized(synthetic) => {
|
||||||
f.write_str("<Protocol with members ")?;
|
f.write_str("<Protocol with members ")?;
|
||||||
let member_list = synthetic.interface(self.db).members();
|
let interface = synthetic.interface();
|
||||||
|
let member_list = interface.members(self.db);
|
||||||
let num_members = member_list.len();
|
let num_members = member_list.len();
|
||||||
for (i, member) in member_list.enumerate() {
|
for (i, member) in member_list.enumerate() {
|
||||||
let is_last = i == num_members - 1;
|
let is_last = i == num_members - 1;
|
||||||
|
|
|
@ -38,7 +38,7 @@ impl<'db> Type<'db> {
|
||||||
protocol
|
protocol
|
||||||
.0
|
.0
|
||||||
.interface(db)
|
.interface(db)
|
||||||
.members()
|
.members(db)
|
||||||
.all(|member| !self.member(db, member.name()).symbol.is_unbound())
|
.all(|member| !self.member(db, member.name()).symbol.is_unbound())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -177,7 +177,7 @@ impl<'db> ProtocolInstanceType<'db> {
|
||||||
}
|
}
|
||||||
match self.0 {
|
match self.0 {
|
||||||
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
|
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
|
||||||
SynthesizedProtocolType::new(db, self.0.interface(db).clone()),
|
SynthesizedProtocolType::new(db, self.0.interface(db)),
|
||||||
))),
|
))),
|
||||||
Protocol::Synthesized(_) => Type::ProtocolInstance(self),
|
Protocol::Synthesized(_) => Type::ProtocolInstance(self),
|
||||||
}
|
}
|
||||||
|
@ -205,7 +205,7 @@ impl<'db> ProtocolInstanceType<'db> {
|
||||||
other
|
other
|
||||||
.0
|
.0
|
||||||
.interface(db)
|
.interface(db)
|
||||||
.is_sub_interface_of(self.0.interface(db))
|
.is_sub_interface_of(db, self.0.interface(db))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return `true` if this protocol type is equivalent to the protocol `other`.
|
/// Return `true` if this protocol type is equivalent to the protocol `other`.
|
||||||
|
@ -237,8 +237,8 @@ impl<'db> ProtocolInstanceType<'db> {
|
||||||
match self.inner() {
|
match self.inner() {
|
||||||
Protocol::FromClass(class) => class.instance_member(db, name),
|
Protocol::FromClass(class) => class.instance_member(db, name),
|
||||||
Protocol::Synthesized(synthesized) => synthesized
|
Protocol::Synthesized(synthesized) => synthesized
|
||||||
.interface(db)
|
.interface()
|
||||||
.member_by_name(name)
|
.member_by_name(db, name)
|
||||||
.map(|member| SymbolAndQualifiers {
|
.map(|member| SymbolAndQualifiers {
|
||||||
symbol: Symbol::bound(member.ty()),
|
symbol: Symbol::bound(member.ty()),
|
||||||
qualifiers: member.qualifiers(),
|
qualifiers: member.qualifiers(),
|
||||||
|
@ -258,7 +258,7 @@ pub(super) enum Protocol<'db> {
|
||||||
|
|
||||||
impl<'db> Protocol<'db> {
|
impl<'db> Protocol<'db> {
|
||||||
/// Return the members of this protocol type
|
/// Return the members of this protocol type
|
||||||
fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
|
fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
|
||||||
match self {
|
match self {
|
||||||
Self::FromClass(class) => class
|
Self::FromClass(class) => class
|
||||||
.class_literal(db)
|
.class_literal(db)
|
||||||
|
@ -266,7 +266,7 @@ impl<'db> Protocol<'db> {
|
||||||
.into_protocol_class(db)
|
.into_protocol_class(db)
|
||||||
.expect("Protocol class literal should be a protocol class")
|
.expect("Protocol class literal should be a protocol class")
|
||||||
.interface(db),
|
.interface(db),
|
||||||
Self::Synthesized(synthesized) => synthesized.interface(db),
|
Self::Synthesized(synthesized) => synthesized.interface(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -285,24 +285,15 @@ mod synthesized_protocol {
|
||||||
/// The constructor method of this type maintains the invariant that a synthesized protocol type
|
/// The constructor method of this type maintains the invariant that a synthesized protocol type
|
||||||
/// is always constructed from a *normalized* protocol interface.
|
/// is always constructed from a *normalized* protocol interface.
|
||||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
|
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
|
||||||
pub(in crate::types) struct SynthesizedProtocolType<'db>(SynthesizedProtocolTypeInner<'db>);
|
pub(in crate::types) struct SynthesizedProtocolType<'db>(ProtocolInterface<'db>);
|
||||||
|
|
||||||
impl<'db> SynthesizedProtocolType<'db> {
|
impl<'db> SynthesizedProtocolType<'db> {
|
||||||
pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self {
|
pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self {
|
||||||
Self(SynthesizedProtocolTypeInner::new(
|
Self(interface.normalized(db))
|
||||||
db,
|
|
||||||
interface.normalized(db),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate::types) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
|
pub(in crate::types) fn interface(self) -> ProtocolInterface<'db> {
|
||||||
self.0.interface(db)
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[salsa::interned(debug)]
|
|
||||||
struct SynthesizedProtocolTypeInner<'db> {
|
|
||||||
#[return_ref]
|
|
||||||
interface: ProtocolInterface<'db>,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ impl<'db> ProtocolClassLiteral<'db> {
|
||||||
/// It is illegal for a protocol class to have any instance attributes that are not declared
|
/// It is illegal for a protocol class to have any instance attributes that are not declared
|
||||||
/// in the protocol's class body. If any are assigned to, they are not taken into account in
|
/// in the protocol's class body. If any are assigned to, they are not taken into account in
|
||||||
/// the protocol's list of members.
|
/// the protocol's list of members.
|
||||||
pub(super) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
|
pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
|
||||||
let _span = tracing::trace_span!("protocol_members", "class='{}'", self.name(db)).entered();
|
let _span = tracing::trace_span!("protocol_members", "class='{}'", self.name(db)).entered();
|
||||||
cached_protocol_interface(db, *self)
|
cached_protocol_interface(db, *self)
|
||||||
}
|
}
|
||||||
|
@ -57,21 +57,31 @@ impl<'db> Deref for ProtocolClassLiteral<'db> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The interface of a protocol: the members of that protocol, and the types of those members.
|
/// The interface of a protocol: the members of that protocol, and the types of those members.
|
||||||
#[derive(Debug, PartialEq, Eq, salsa::Update, Default, Clone, Hash)]
|
#[salsa::interned(debug)]
|
||||||
pub(super) struct ProtocolInterface<'db>(BTreeMap<Name, ProtocolMemberData<'db>>);
|
pub(super) struct ProtocolInterface<'db> {
|
||||||
|
#[return_ref]
|
||||||
|
_members: BTreeMap<Name, ProtocolMemberData<'db>>,
|
||||||
|
}
|
||||||
|
|
||||||
impl<'db> ProtocolInterface<'db> {
|
impl<'db> ProtocolInterface<'db> {
|
||||||
/// Iterate over the members of this protocol.
|
/// Iterate over the members of this protocol.
|
||||||
pub(super) fn members<'a>(&'a self) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> {
|
pub(super) fn members<'a>(
|
||||||
self.0.iter().map(|(name, data)| ProtocolMember {
|
&'a self,
|
||||||
|
db: &'db dyn Db,
|
||||||
|
) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> {
|
||||||
|
self._members(db).iter().map(|(name, data)| ProtocolMember {
|
||||||
name,
|
name,
|
||||||
ty: data.ty,
|
ty: data.ty,
|
||||||
qualifiers: data.qualifiers,
|
qualifiers: data.qualifiers,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn member_by_name<'a>(&self, name: &'a str) -> Option<ProtocolMember<'a, 'db>> {
|
pub(super) fn member_by_name<'a>(
|
||||||
self.0.get(name).map(|data| ProtocolMember {
|
self,
|
||||||
|
db: &'db dyn Db,
|
||||||
|
name: &'a str,
|
||||||
|
) -> Option<ProtocolMember<'a, 'db>> {
|
||||||
|
self._members(db).get(name).map(|data| ProtocolMember {
|
||||||
name,
|
name,
|
||||||
ty: data.ty,
|
ty: data.ty,
|
||||||
qualifiers: data.qualifiers,
|
qualifiers: data.qualifiers,
|
||||||
|
@ -79,42 +89,43 @@ impl<'db> ProtocolInterface<'db> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return `true` if all members of this protocol are fully static.
|
/// Return `true` if all members of this protocol are fully static.
|
||||||
pub(super) fn is_fully_static(&self, db: &'db dyn Db) -> bool {
|
pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool {
|
||||||
self.members().all(|member| member.ty.is_fully_static(db))
|
cached_is_fully_static(db, self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return `true` if if all members on `self` are also members of `other`.
|
/// Return `true` if if all members on `self` are also members of `other`.
|
||||||
///
|
///
|
||||||
/// TODO: this method should consider the types of the members as well as their names.
|
/// TODO: this method should consider the types of the members as well as their names.
|
||||||
pub(super) fn is_sub_interface_of(&self, other: &Self) -> bool {
|
pub(super) fn is_sub_interface_of(self, db: &'db dyn Db, other: Self) -> bool {
|
||||||
self.0
|
self._members(db)
|
||||||
.keys()
|
.keys()
|
||||||
.all(|member_name| other.0.contains_key(member_name))
|
.all(|member_name| other._members(db).contains_key(member_name))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
|
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
|
||||||
pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool {
|
pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool {
|
||||||
self.members().any(|member| member.ty.contains_todo(db))
|
self.members(db).any(|member| member.ty.contains_todo(db))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
|
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
|
||||||
Self(
|
Self::new(
|
||||||
self.0
|
db,
|
||||||
.into_iter()
|
self._members(db)
|
||||||
.map(|(name, data)| (name, data.normalized(db)))
|
.iter()
|
||||||
.collect(),
|
.map(|(name, data)| (name.clone(), data.normalized(db)))
|
||||||
|
.collect::<BTreeMap<_, _>>(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
|
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
|
||||||
struct ProtocolMemberData<'db> {
|
pub(super) struct ProtocolMemberData<'db> {
|
||||||
ty: Type<'db>,
|
ty: Type<'db>,
|
||||||
qualifiers: TypeQualifiers,
|
qualifiers: TypeQualifiers,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'db> ProtocolMemberData<'db> {
|
impl<'db> ProtocolMemberData<'db> {
|
||||||
fn normalized(self, db: &'db dyn Db) -> Self {
|
fn normalized(&self, db: &'db dyn Db) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ty: self.ty.normalized(db),
|
ty: self.ty.normalized(db),
|
||||||
qualifiers: self.qualifiers,
|
qualifiers: self.qualifiers,
|
||||||
|
@ -184,7 +195,7 @@ fn excluded_from_proto_members(member: &str) -> bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inner Salsa query for [`ProtocolClassLiteral::interface`].
|
/// Inner Salsa query for [`ProtocolClassLiteral::interface`].
|
||||||
#[salsa::tracked(return_ref, cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)]
|
#[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial)]
|
||||||
fn cached_protocol_interface<'db>(
|
fn cached_protocol_interface<'db>(
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
class: ClassLiteral<'db>,
|
class: ClassLiteral<'db>,
|
||||||
|
@ -240,9 +251,10 @@ fn cached_protocol_interface<'db>(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
ProtocolInterface(members)
|
ProtocolInterface::new(db, members)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::trivially_copy_pass_by_ref)]
|
||||||
fn proto_interface_cycle_recover<'db>(
|
fn proto_interface_cycle_recover<'db>(
|
||||||
_db: &dyn Db,
|
_db: &dyn Db,
|
||||||
_value: &ProtocolInterface<'db>,
|
_value: &ProtocolInterface<'db>,
|
||||||
|
@ -253,8 +265,33 @@ fn proto_interface_cycle_recover<'db>(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn proto_interface_cycle_initial<'db>(
|
fn proto_interface_cycle_initial<'db>(
|
||||||
_db: &dyn Db,
|
db: &'db dyn Db,
|
||||||
_class: ClassLiteral<'db>,
|
_class: ClassLiteral<'db>,
|
||||||
) -> ProtocolInterface<'db> {
|
) -> ProtocolInterface<'db> {
|
||||||
ProtocolInterface::default()
|
ProtocolInterface::new(db, BTreeMap::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked(cycle_fn=is_fully_static_cycle_recover, cycle_initial=is_fully_static_cycle_initial)]
|
||||||
|
fn cached_is_fully_static<'db>(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> bool {
|
||||||
|
interface
|
||||||
|
.members(db)
|
||||||
|
.all(|member| member.ty.is_fully_static(db))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::trivially_copy_pass_by_ref)]
|
||||||
|
fn is_fully_static_cycle_recover(
|
||||||
|
_db: &dyn Db,
|
||||||
|
_value: &bool,
|
||||||
|
_count: u32,
|
||||||
|
_interface: ProtocolInterface<'_>,
|
||||||
|
) -> salsa::CycleRecoveryAction<bool> {
|
||||||
|
salsa::CycleRecoveryAction::Iterate
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_fully_static_cycle_initial<'db>(
|
||||||
|
_db: &'db dyn Db,
|
||||||
|
_interface: ProtocolInterface<'db>,
|
||||||
|
) -> bool {
|
||||||
|
// Assume that the protocol is fully static until we find members that indicate otherwise.
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue