[ty] Recursive protocols (#17929)

## Summary

Use a self-reference "marker" ~~and fixpoint iteration~~ to solve the
stack overflow problems with recursive protocols. This is not pretty and
somewhat tedious, but seems to work fine. Much better than all my
fixpoint-iteration attempts anyway.

closes https://github.com/astral-sh/ty/issues/93

## Test Plan

New Markdown tests.
This commit is contained in:
David Peter 2025-05-09 14:54:02 +02:00 committed by GitHub
parent c1b875799b
commit 642eac452d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 315 additions and 64 deletions

View file

@ -587,6 +587,79 @@ impl<'db> Type<'db> {
matches!(self, Type::Dynamic(DynamicType::Todo(_)))
}
/// Replace references to the class `class` with a self-reference marker. This is currently
/// used for recursive protocols, but could probably be extended to self-referential type-
/// aliases and similar.
#[must_use]
pub fn replace_self_reference(&self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Type<'db> {
match self {
Self::ProtocolInstance(protocol) => {
Self::ProtocolInstance(protocol.replace_self_reference(db, class))
}
Self::Union(union) => UnionType::from_elements(
db,
union
.elements(db)
.iter()
.map(|ty| ty.replace_self_reference(db, class)),
),
Self::Intersection(intersection) => IntersectionBuilder::new(db)
.positive_elements(
intersection
.positive(db)
.iter()
.map(|ty| ty.replace_self_reference(db, class)),
)
.negative_elements(
intersection
.negative(db)
.iter()
.map(|ty| ty.replace_self_reference(db, class)),
)
.build(),
Self::Tuple(tuple) => TupleType::from_elements(
db,
tuple
.elements(db)
.iter()
.map(|ty| ty.replace_self_reference(db, class)),
),
Self::Callable(callable) => Self::Callable(callable.replace_self_reference(db, class)),
Self::GenericAlias(_) | Self::TypeVar(_) => {
// TODO: replace self-references in generic aliases and typevars
*self
}
Self::Dynamic(_)
| Self::AlwaysFalsy
| Self::AlwaysTruthy
| Self::Never
| Self::BooleanLiteral(_)
| Self::BytesLiteral(_)
| Self::StringLiteral(_)
| Self::IntLiteral(_)
| Self::LiteralString
| Self::FunctionLiteral(_)
| Self::ModuleLiteral(_)
| Self::ClassLiteral(_)
| Self::NominalInstance(_)
| Self::KnownInstance(_)
| Self::PropertyInstance(_)
| Self::BoundMethod(_)
| Self::WrapperDescriptor(_)
| Self::MethodWrapper(_)
| Self::DataclassDecorator(_)
| Self::DataclassTransformer(_)
| Self::SubclassOf(_)
| Self::BoundSuper(_) => *self,
}
}
pub fn contains_todo(&self, db: &'db dyn Db) -> bool {
match self {
Self::Dynamic(DynamicType::Todo(_) | DynamicType::SubscriptedProtocol) => true,
@ -7272,6 +7345,17 @@ impl<'db> CallableType<'db> {
}
}
}
/// See [`Type::replace_self_reference`].
fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self {
CallableType::from_overloads(
db,
self.signatures(db)
.iter()
.cloned()
.map(|signature| signature.replace_self_reference(db, class)),
)
}
}
/// Represents a specific instance of `types.MethodWrapperType`

View file

@ -529,6 +529,28 @@ impl<'db> IntersectionBuilder<'db> {
}
}
pub(crate) fn positive_elements<I, T>(mut self, elements: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
{
for element in elements {
self = self.add_positive(element.into());
}
self
}
pub(crate) fn negative_elements<I, T>(mut self, elements: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
{
for element in elements {
self = self.add_negative(element.into());
}
self
}
pub(crate) fn build(mut self) -> Type<'db> {
// Avoid allocating the UnionBuilder unnecessarily if we have just one intersection:
if self.intersections.len() == 1 {

View file

@ -4,6 +4,7 @@ use super::protocol_class::ProtocolInterface;
use super::{ClassType, KnownClass, SubclassOfType, Type};
use crate::symbol::{Symbol, SymbolAndQualifiers};
use crate::types::generics::TypeMapping;
use crate::types::ClassLiteral;
use crate::Db;
pub(super) use synthesized_protocol::SynthesizedProtocolType;
@ -183,6 +184,19 @@ impl<'db> ProtocolInstanceType<'db> {
}
}
/// Replace references to `class` with a self-reference marker
pub(super) fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self {
match self.0 {
Protocol::FromClass(class_type) if class_type.class_literal(db).0 == class => {
ProtocolInstanceType(Protocol::Synthesized(SynthesizedProtocolType::new(
db,
ProtocolInterface::SelfReference,
)))
}
_ => self,
}
}
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool {
self.0.interface(db).contains_todo(db)

View file

@ -1,6 +1,6 @@
use std::{collections::BTreeMap, ops::Deref};
use itertools::Itertools;
use itertools::{Either, Itertools};
use ruff_python_ast::name::Name;
@ -56,24 +56,41 @@ impl<'db> Deref for ProtocolClassLiteral<'db> {
}
}
/// The interface of a protocol: the members of that protocol, and the types of those members.
#[salsa::interned(debug)]
pub(super) struct ProtocolInterface<'db> {
pub(super) struct ProtocolInterfaceMembers<'db> {
#[returns(ref)]
_members: BTreeMap<Name, ProtocolMemberData<'db>>,
inner: BTreeMap<Name, ProtocolMemberData<'db>>,
}
/// The interface of a protocol: the members of that protocol, and the types of those members.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
pub(super) enum ProtocolInterface<'db> {
Members(ProtocolInterfaceMembers<'db>),
SelfReference,
}
impl<'db> ProtocolInterface<'db> {
/// Iterate over the members of this protocol.
fn empty(db: &'db dyn Db) -> Self {
Self::Members(ProtocolInterfaceMembers::new(db, BTreeMap::default()))
}
pub(super) fn members<'a>(
&'a self,
self,
db: &'db dyn Db,
) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>> {
self._members(db).iter().map(|(name, data)| ProtocolMember {
name,
ty: data.ty,
qualifiers: data.qualifiers,
})
) -> impl ExactSizeIterator<Item = ProtocolMember<'a, 'db>>
where
'db: 'a,
{
match self {
Self::Members(members) => {
Either::Left(members.inner(db).iter().map(|(name, data)| ProtocolMember {
name,
ty: data.ty,
qualifiers: data.qualifiers,
}))
}
Self::SelfReference => Either::Right(std::iter::empty()),
}
}
pub(super) fn member_by_name<'a>(
@ -81,25 +98,34 @@ impl<'db> ProtocolInterface<'db> {
db: &'db dyn Db,
name: &'a str,
) -> Option<ProtocolMember<'a, 'db>> {
self._members(db).get(name).map(|data| ProtocolMember {
name,
ty: data.ty,
qualifiers: data.qualifiers,
})
match self {
Self::Members(members) => members.inner(db).get(name).map(|data| ProtocolMember {
name,
ty: data.ty,
qualifiers: data.qualifiers,
}),
Self::SelfReference => None,
}
}
/// Return `true` if all members of this protocol are fully static.
pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool {
cached_is_fully_static(db, self)
self.members(db).all(|member| member.ty.is_fully_static(db))
}
/// Return `true` if if all members on `self` are also members of `other`.
///
/// TODO: this method should consider the types of the members as well as their names.
pub(super) fn is_sub_interface_of(self, db: &'db dyn Db, other: Self) -> bool {
self._members(db)
.keys()
.all(|member_name| other._members(db).contains_key(member_name))
match (self, other) {
(Self::Members(self_members), Self::Members(other_members)) => self_members
.inner(db)
.keys()
.all(|member_name| other_members.inner(db).contains_key(member_name)),
_ => {
unreachable!("Enclosing protocols should never be a self-reference marker")
}
}
}
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
@ -108,13 +134,17 @@ impl<'db> ProtocolInterface<'db> {
}
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
Self::new(
db,
self._members(db)
.iter()
.map(|(name, data)| (name.clone(), data.normalized(db)))
.collect::<BTreeMap<_, _>>(),
)
match self {
Self::Members(members) => Self::Members(ProtocolInterfaceMembers::new(
db,
members
.inner(db)
.iter()
.map(|(name, data)| (name.clone(), data.normalized(db)))
.collect::<BTreeMap<_, _>>(),
)),
Self::SelfReference => Self::SelfReference,
}
}
}
@ -245,13 +275,14 @@ fn cached_protocol_interface<'db>(
})
.filter(|(name, _, _)| !excluded_from_proto_members(name))
.map(|(name, ty, qualifiers)| {
let ty = ty.replace_self_reference(db, class);
let member = ProtocolMemberData { ty, qualifiers };
(name.clone(), member)
}),
);
}
ProtocolInterface::new(db, members)
ProtocolInterface::Members(ProtocolInterfaceMembers::new(db, members))
}
#[allow(clippy::trivially_copy_pass_by_ref)]
@ -268,30 +299,5 @@ fn proto_interface_cycle_initial<'db>(
db: &'db dyn Db,
_class: ClassLiteral<'db>,
) -> ProtocolInterface<'db> {
ProtocolInterface::new(db, BTreeMap::default())
}
#[salsa::tracked(cycle_fn=is_fully_static_cycle_recover, cycle_initial=is_fully_static_cycle_initial)]
fn cached_is_fully_static<'db>(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> bool {
interface
.members(db)
.all(|member| member.ty.is_fully_static(db))
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_fully_static_cycle_recover(
_db: &dyn Db,
_value: &bool,
_count: u32,
_interface: ProtocolInterface<'_>,
) -> salsa::CycleRecoveryAction<bool> {
salsa::CycleRecoveryAction::Iterate
}
fn is_fully_static_cycle_initial<'db>(
_db: &'db dyn Db,
_interface: ProtocolInterface<'db>,
) -> bool {
// Assume that the protocol is fully static until we find members that indicate otherwise.
true
ProtocolInterface::empty(db)
}

View file

@ -18,7 +18,7 @@ use smallvec::{smallvec, SmallVec};
use super::{definition_expression_type, DynamicType, Type};
use crate::semantic_index::definition::Definition;
use crate::types::generics::{GenericContext, Specialization, TypeMapping};
use crate::types::{todo_type, TypeVarInstance};
use crate::types::{todo_type, ClassLiteral, TypeVarInstance};
use crate::{Db, FxOrderSet};
use ruff_python_ast::{self as ast, name::Name};
@ -876,6 +876,28 @@ impl<'db> Signature<'db> {
true
}
/// See [`Type::replace_self_reference`].
pub(crate) fn replace_self_reference(
mut self,
db: &'db dyn Db,
class: ClassLiteral<'db>,
) -> Self {
// TODO: also replace self references in generic context
self.parameters = self
.parameters
.iter()
.cloned()
.map(|param| param.replace_self_reference(db, class))
.collect();
if let Some(ty) = self.return_ty.as_mut() {
*ty = ty.replace_self_reference(db, class);
}
self
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]
@ -1388,6 +1410,14 @@ impl<'db> Parameter<'db> {
ParameterKind::Variadic { .. } | ParameterKind::KeywordVariadic { .. } => None,
}
}
/// See [`Type::replace_self_reference`].
fn replace_self_reference(mut self, db: &'db (dyn Db), class: ClassLiteral<'db>) -> Self {
if let Some(ty) = self.annotated_type.as_mut() {
*ty = ty.replace_self_reference(db, class);
}
self
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]