[ty] Fall back to Divergent for deeply nested specializations (#20988)

## Summary

Fall back to `C[Divergent]` if we are trying to specialize `C[T]` with a
type that itself already contains deeply nested specialized generic
classes. This is a way to prevent infinite recursion for cases like
`self.x = [self.x]` where type inference for the implicit instance
attribute would not converge.

closes https://github.com/astral-sh/ty/issues/1383
closes https://github.com/astral-sh/ty/issues/837

## Test Plan

Regression tests.
This commit is contained in:
David Peter 2025-10-22 14:29:10 +02:00 committed by GitHub
parent 2c9433796a
commit 58a68f1bbd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 317 additions and 26 deletions

View file

@ -2457,6 +2457,48 @@ class Counter:
reveal_type(Counter().count) # revealed: Unknown | int reveal_type(Counter().count) # revealed: Unknown | int
``` ```
We also handle infinitely nested generics:
```py
class NestedLists:
def __init__(self: "NestedLists"):
self.x = 1
def f(self: "NestedLists"):
self.x = [self.x]
reveal_type(NestedLists().x) # revealed: Unknown | Literal[1] | list[Divergent]
class NestedMixed:
def f(self: "NestedMixed"):
self.x = [self.x]
def g(self: "NestedMixed"):
self.x = {self.x}
def h(self: "NestedMixed"):
self.x = {"a": self.x}
reveal_type(NestedMixed().x) # revealed: Unknown | list[Divergent] | set[Divergent] | dict[Unknown | str, Divergent]
```
And cases where the types originate from annotations:
```py
from typing import TypeVar
T = TypeVar("T")
def make_list(value: T) -> list[T]:
return [value]
class NestedLists2:
def f(self: "NestedLists2"):
self.x = make_list(self.x)
reveal_type(NestedLists2().x) # revealed: Unknown | list[Divergent]
```
### Builtin types attributes ### Builtin types attributes
This test can probably be removed eventually, but we currently include it because we do not yet This test can probably be removed eventually, but we currently include it because we do not yet
@ -2551,13 +2593,54 @@ reveal_type(Answer.__members__) # revealed: MappingProxyType[str, Unknown]
## Divergent inferred implicit instance attribute types ## Divergent inferred implicit instance attribute types
```py ```py
# TODO: This test currently panics, see https://github.com/astral-sh/ty/issues/837 class C:
def f(self, other: "C"):
self.x = (other.x, 1)
# class C: reveal_type(C().x) # revealed: Unknown | tuple[Divergent, Literal[1]]
# def f(self, other: "C"): ```
# self.x = (other.x, 1)
# This also works if the tuple is not constructed directly:
# reveal_type(C().x) # revealed: Unknown | tuple[Divergent, Literal[1]]
```py
from typing import TypeVar, Literal
T = TypeVar("T")
def make_tuple(x: T) -> tuple[T, Literal[1]]:
return (x, 1)
class D:
def f(self, other: "D"):
self.x = make_tuple(other.x)
reveal_type(D().x) # revealed: Unknown | tuple[Divergent, Literal[1]]
```
The tuple type may also expand exponentially "in breadth":
```py
def duplicate(x: T) -> tuple[T, T]:
return (x, x)
class E:
def f(self: "E"):
self.x = duplicate(self.x)
reveal_type(E().x) # revealed: Unknown | tuple[Divergent, Divergent]
```
And it also works for homogeneous tuples:
```py
def make_homogeneous_tuple(x: T) -> tuple[T, ...]:
return (x, x)
class E:
def f(self, other: "E"):
self.x = make_homogeneous_tuple(other.x)
reveal_type(E().x) # revealed: Unknown | tuple[Divergent, ...]
``` ```
## Attributes of standard library modules that aren't yet defined ## Attributes of standard library modules that aren't yet defined

View file

@ -0,0 +1,17 @@
# PEP 613 type aliases
We do not support PEP 613 type aliases yet. For now, just make sure that we don't panic:
```py
from typing import TypeAlias
RecursiveTuple: TypeAlias = tuple[int | "RecursiveTuple", str]
def _(rec: RecursiveTuple):
reveal_type(rec) # revealed: tuple[Divergent, str]
RecursiveHomogeneousTuple: TypeAlias = tuple[int | "RecursiveHomogeneousTuple", ...]
def _(rec: RecursiveHomogeneousTuple):
reveal_type(rec) # revealed: tuple[Divergent, ...]
```

View file

@ -69,7 +69,7 @@ use crate::types::tuple::{TupleSpec, TupleSpecBuilder};
pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type}; pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type};
pub use crate::types::variance::TypeVarVariance; pub use crate::types::variance::TypeVarVariance;
use crate::types::variance::VarianceInferable; use crate::types::variance::VarianceInferable;
use crate::types::visitor::any_over_type; use crate::types::visitor::{any_over_type, exceeds_max_specialization_depth};
use crate::unpack::EvaluationMode; use crate::unpack::EvaluationMode;
use crate::{Db, FxOrderSet, Module, Program}; use crate::{Db, FxOrderSet, Module, Program};
pub(crate) use class::{ClassLiteral, ClassType, GenericAlias, KnownClass}; pub(crate) use class::{ClassLiteral, ClassType, GenericAlias, KnownClass};
@ -827,10 +827,14 @@ impl<'db> Type<'db> {
Self::Dynamic(DynamicType::Unknown) Self::Dynamic(DynamicType::Unknown)
} }
pub(crate) fn divergent(scope: ScopeId<'db>) -> Self { pub(crate) fn divergent(scope: Option<ScopeId<'db>>) -> Self {
Self::Dynamic(DynamicType::Divergent(DivergentType { scope })) Self::Dynamic(DynamicType::Divergent(DivergentType { scope }))
} }
pub(crate) const fn is_divergent(&self) -> bool {
matches!(self, Type::Dynamic(DynamicType::Divergent(_)))
}
pub const fn is_unknown(&self) -> bool { pub const fn is_unknown(&self) -> bool {
matches!(self, Type::Dynamic(DynamicType::Unknown)) matches!(self, Type::Dynamic(DynamicType::Unknown))
} }
@ -6652,7 +6656,7 @@ impl<'db> Type<'db> {
match self { match self {
Type::TypeVar(bound_typevar) => match type_mapping { Type::TypeVar(bound_typevar) => match type_mapping {
TypeMapping::Specialization(specialization) => { TypeMapping::Specialization(specialization) => {
specialization.get(db, bound_typevar).unwrap_or(self) specialization.get(db, bound_typevar).unwrap_or(self).fallback_to_divergent(db)
} }
TypeMapping::PartialSpecialization(partial) => { TypeMapping::PartialSpecialization(partial) => {
partial.get(db, bound_typevar).unwrap_or(self) partial.get(db, bound_typevar).unwrap_or(self)
@ -7214,6 +7218,16 @@ impl<'db> Type<'db> {
pub(super) fn has_divergent_type(self, db: &'db dyn Db, div: Type<'db>) -> bool { pub(super) fn has_divergent_type(self, db: &'db dyn Db, div: Type<'db>) -> bool {
any_over_type(db, self, &|ty| ty == div, false) any_over_type(db, self, &|ty| ty == div, false)
} }
/// If the specialization depth of `self` exceeds the maximum limit allowed,
/// return `Divergent`. Otherwise, return `self`.
pub(super) fn fallback_to_divergent(self, db: &'db dyn Db) -> Type<'db> {
if exceeds_max_specialization_depth(db, self) {
Type::divergent(None)
} else {
self
}
}
} }
impl<'db> From<&Type<'db>> for Type<'db> { impl<'db> From<&Type<'db>> for Type<'db> {
@ -7659,7 +7673,7 @@ impl<'db> KnownInstanceType<'db> {
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)]
pub struct DivergentType<'db> { pub struct DivergentType<'db> {
/// The scope where this divergence was detected. /// The scope where this divergence was detected.
scope: ScopeId<'db>, scope: Option<ScopeId<'db>>,
} }
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)]
@ -11772,7 +11786,7 @@ pub(crate) mod tests {
let file_scope_id = FileScopeId::global(); let file_scope_id = FileScopeId::global();
let scope = file_scope_id.to_scope_id(&db, file); let scope = file_scope_id.to_scope_id(&db, file);
let div = Type::Dynamic(DynamicType::Divergent(DivergentType { scope })); let div = Type::Dynamic(DynamicType::Divergent(DivergentType { scope: Some(scope) }));
// The `Divergent` type must not be eliminated in union with other dynamic types, // The `Divergent` type must not be eliminated in union with other dynamic types,
// as this would prevent detection of divergent type inference using `Divergent`. // as this would prevent detection of divergent type inference using `Divergent`.

View file

@ -37,7 +37,8 @@ use crate::types::{
IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType,
MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType,
TypeContext, TypeMapping, TypeRelation, TypedDictParams, UnionBuilder, VarianceInferable, TypeContext, TypeMapping, TypeRelation, TypedDictParams, UnionBuilder, VarianceInferable,
declaration_type, determine_upper_bound, infer_definition_types, declaration_type, determine_upper_bound, exceeds_max_specialization_depth,
infer_definition_types,
}; };
use crate::{ use crate::{
Db, FxIndexMap, FxIndexSet, FxOrderSet, Program, Db, FxIndexMap, FxIndexSet, FxOrderSet, Program,
@ -1612,7 +1613,18 @@ impl<'db> ClassLiteral<'db> {
match self.generic_context(db) { match self.generic_context(db) {
None => ClassType::NonGeneric(self), None => ClassType::NonGeneric(self),
Some(generic_context) => { Some(generic_context) => {
let specialization = f(generic_context); let mut specialization = f(generic_context);
for (idx, ty) in specialization.types(db).iter().enumerate() {
if exceeds_max_specialization_depth(db, *ty) {
specialization = specialization.with_replaced_type(
db,
idx,
Type::divergent(Some(self.body_scope(db))),
);
}
}
ClassType::Generic(GenericAlias::new(db, self, specialization)) ClassType::Generic(GenericAlias::new(db, self, specialization))
} }
} }

View file

@ -1264,6 +1264,25 @@ impl<'db> Specialization<'db> {
// A tuple's specialization will include all of its element types, so we don't need to also // A tuple's specialization will include all of its element types, so we don't need to also
// look in `self.tuple`. // look in `self.tuple`.
} }
/// Returns a copy of this specialization with the type at a given index replaced.
pub(crate) fn with_replaced_type(
self,
db: &'db dyn Db,
index: usize,
new_type: Type<'db>,
) -> Self {
let mut new_types: Box<[_]> = self.types(db).to_vec().into_boxed_slice();
new_types[index] = new_type;
Self::new(
db,
self.generic_context(db),
new_types,
self.materialization_kind(db),
self.tuple_inner(db),
)
}
} }
/// A mapping between type variables and types. /// A mapping between type variables and types.

View file

@ -567,7 +567,7 @@ impl<'db> CycleRecovery<'db> {
fn fallback_type(self) -> Type<'db> { fn fallback_type(self) -> Type<'db> {
match self { match self {
Self::Initial => Type::Never, Self::Initial => Type::Never,
Self::Divergent(scope) => Type::divergent(scope), Self::Divergent(scope) => Type::divergent(Some(scope)),
} }
} }
} }

View file

@ -5968,16 +5968,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut annotated_elt_tys = annotated_tuple.as_ref().map(Tuple::all_elements); let mut annotated_elt_tys = annotated_tuple.as_ref().map(Tuple::all_elements);
let db = self.db(); let db = self.db();
let divergent = Type::divergent(self.scope());
let element_types = elts.iter().map(|element| { let element_types = elts.iter().map(|element| {
let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied(); let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied();
let element_type = self.infer_expression(element, TypeContext::new(annotated_elt_ty)); self.infer_expression(element, TypeContext::new(annotated_elt_ty))
if element_type.has_divergent_type(self.db(), divergent) {
divergent
} else {
element_type
}
}); });
Type::heterogeneous_tuple(db, element_types) Type::heterogeneous_tuple(db, element_types)

View file

@ -22,7 +22,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
/// Infer the type of a type expression. /// Infer the type of a type expression.
pub(super) fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> { pub(super) fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
let mut ty = self.infer_type_expression_no_store(expression); let mut ty = self.infer_type_expression_no_store(expression);
let divergent = Type::divergent(self.scope()); let divergent = Type::divergent(Some(self.scope()));
if ty.has_divergent_type(self.db(), divergent) { if ty.has_divergent_type(self.db(), divergent) {
ty = divergent; ty = divergent;
} }
@ -588,7 +588,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
// TODO: emit a diagnostic // TODO: emit a diagnostic
} }
} else { } else {
element_types.push(element_ty); element_types.push(element_ty.fallback_to_divergent(self.db()));
} }
} }

View file

@ -72,7 +72,10 @@ impl<'db> Type<'db> {
{ {
Type::tuple(TupleType::heterogeneous( Type::tuple(TupleType::heterogeneous(
db, db,
elements.into_iter().map(Into::into), elements
.into_iter()
.map(Into::into)
.map(|element| element.fallback_to_divergent(db)),
)) ))
} }

View file

@ -1,3 +1,5 @@
use rustc_hash::FxHashMap;
use crate::{ use crate::{
Db, FxIndexSet, Db, FxIndexSet,
types::{ types::{
@ -16,7 +18,10 @@ use crate::{
walk_typed_dict_type, walk_typeis_type, walk_union, walk_typed_dict_type, walk_typeis_type, walk_union,
}, },
}; };
use std::cell::{Cell, RefCell}; use std::{
cell::{Cell, RefCell},
collections::hash_map::Entry,
};
/// A visitor trait that recurses into nested types. /// A visitor trait that recurses into nested types.
/// ///
@ -295,3 +300,148 @@ pub(super) fn any_over_type<'db>(
visitor.visit_type(db, ty); visitor.visit_type(db, ty);
visitor.found_matching_type.get() visitor.found_matching_type.get()
} }
/// Returns the maximum number of layers of generic specializations for a given type.
///
/// For example, `int` has a depth of `0`, `list[int]` has a depth of `1`, and `list[set[int]]`
/// has a depth of `2`. A set-theoretic type like `list[int] | list[list[int]]` has a maximum
/// depth of `2`.
fn specialization_depth(db: &dyn Db, ty: Type<'_>) -> usize {
#[derive(Debug, Default)]
struct SpecializationDepthVisitor<'db> {
seen_types: RefCell<FxHashMap<NonAtomicType<'db>, Option<usize>>>,
max_depth: Cell<usize>,
}
impl<'db> TypeVisitor<'db> for SpecializationDepthVisitor<'db> {
fn should_visit_lazy_type_attributes(&self) -> bool {
false
}
fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) {
match TypeKind::from(ty) {
TypeKind::Atomic => {
if ty.is_divergent() {
self.max_depth.set(usize::MAX);
}
}
TypeKind::NonAtomic(non_atomic_type) => {
match self.seen_types.borrow_mut().entry(non_atomic_type) {
Entry::Occupied(cached_depth) => {
self.max_depth
.update(|current| current.max(cached_depth.get().unwrap_or(0)));
return;
}
Entry::Vacant(entry) => {
entry.insert(None);
}
}
let self_depth: usize =
matches!(non_atomic_type, NonAtomicType::GenericAlias(_)).into();
let previous_max_depth = self.max_depth.replace(0);
walk_non_atomic_type(db, non_atomic_type, self);
self.max_depth.update(|max_child_depth| {
previous_max_depth.max(max_child_depth.saturating_add(self_depth))
});
self.seen_types
.borrow_mut()
.insert(non_atomic_type, Some(self.max_depth.get()));
}
}
}
}
let visitor = SpecializationDepthVisitor::default();
visitor.visit_type(db, ty);
visitor.max_depth.get()
}
pub(super) fn exceeds_max_specialization_depth(db: &dyn Db, ty: Type<'_>) -> bool {
// To prevent infinite recursion during type inference for infinite types, we fall back to
// `C[Divergent]` once a certain amount of levels of specialization have occurred. For
// example:
//
// ```py
// x = 1
// while random_bool():
// x = [x]
//
// reveal_type(x) # Unknown | Literal[1] | list[Divergent]
// ```
const MAX_SPECIALIZATION_DEPTH: usize = 10;
specialization_depth(db, ty) > MAX_SPECIALIZATION_DEPTH
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{db::tests::setup_db, types::KnownClass};
#[test]
fn test_generics_layering_depth() {
let db = setup_db();
let int = || KnownClass::Int.to_instance(&db);
let list = |element| KnownClass::List.to_specialized_instance(&db, [element]);
let dict = |key, value| KnownClass::Dict.to_specialized_instance(&db, [key, value]);
let set = |element| KnownClass::Set.to_specialized_instance(&db, [element]);
let str = || KnownClass::Str.to_instance(&db);
let bytes = || KnownClass::Bytes.to_instance(&db);
let list_of_int = list(int());
assert_eq!(specialization_depth(&db, list_of_int), 1);
let list_of_list_of_int = list(list_of_int);
assert_eq!(specialization_depth(&db, list_of_list_of_int), 2);
let list_of_list_of_list_of_int = list(list_of_list_of_int);
assert_eq!(specialization_depth(&db, list_of_list_of_list_of_int), 3);
assert_eq!(specialization_depth(&db, set(dict(str(), list_of_int))), 3);
assert_eq!(
specialization_depth(
&db,
UnionType::from_elements(&db, [list_of_list_of_list_of_int, list_of_list_of_int])
),
3
);
assert_eq!(
specialization_depth(
&db,
UnionType::from_elements(&db, [list_of_list_of_int, list_of_list_of_list_of_int])
),
3
);
assert_eq!(
specialization_depth(
&db,
Type::heterogeneous_tuple(&db, [Type::heterogeneous_tuple(&db, [int()])])
),
2
);
assert_eq!(
specialization_depth(&db, Type::heterogeneous_tuple(&db, [list_of_int, str()])),
2
);
assert_eq!(
specialization_depth(
&db,
list(UnionType::from_elements(
&db,
[list(int()), list(str()), list(bytes())]
))
),
2
);
}
}