[ty] detect cycles in binary comparison inference (#20446)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

Catch infinite recursion in binary-compare inference.

Fixes the stack overflow in `graphql-core` in mypy-primer.

## Test Plan

Added two tests that stack-overflowed before this PR.
This commit is contained in:
Carl Meyer 2025-09-17 00:45:25 -07:00 committed by GitHub
parent 9f0b942b9e
commit 99ec4d2c69
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 28 deletions

View file

@ -392,3 +392,16 @@ class A:
# error: [unsupported-bool-conversion] # error: [unsupported-bool-conversion]
(A(),) == (A(),) (A(),) == (A(),)
``` ```
## Recursive NamedTuple
```py
from __future__ import annotations
from typing import NamedTuple
class Node(NamedTuple):
parent: Node | None
def _(n: Node):
reveal_type(n.parent is n) # revealed: bool
```

View file

@ -351,3 +351,12 @@ def f(x: A):
for item in x: for item in x:
reveal_type(item) # revealed: list[A | str | None] | str | None reveal_type(item) # revealed: list[A | str | None] | str | None
``` ```
### Tuple comparison
```py
type X = tuple[X, int]
def _(x: X):
reveal_type(x is x) # revealed: bool
```

View file

@ -45,6 +45,7 @@ use crate::semantic_index::{
use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind};
use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator}; use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator};
use crate::types::context::{InNoTypeCheck, InferContext}; use crate::types::context::{InNoTypeCheck, InferContext};
use crate::types::cyclic::CycleDetector;
use crate::types::diagnostic::{ use crate::types::diagnostic::{
CALL_NON_CALLABLE, CONFLICTING_DECLARATIONS, CONFLICTING_METACLASS, CYCLIC_CLASS_DEFINITION, CALL_NON_CALLABLE, CONFLICTING_DECLARATIONS, CONFLICTING_METACLASS, CYCLIC_CLASS_DEFINITION,
DIVISION_BY_ZERO, DUPLICATE_KW_ONLY, INCONSISTENT_MRO, INVALID_ARGUMENT_TYPE, DIVISION_BY_ZERO, DUPLICATE_KW_ONLY, INCONSISTENT_MRO, INVALID_ARGUMENT_TYPE,
@ -132,6 +133,13 @@ impl<'db> DeclaredAndInferredType<'db> {
} }
} }
/// A [`CycleDetector`] that is used in `infer_binary_type_comparison`.
type BinaryComparisonVisitor<'db> = CycleDetector<
ast::CmpOp,
(Type<'db>, ast::CmpOp, Type<'db>),
Result<Type<'db>, CompareUnsupportedError<'db>>,
>;
/// Builder to infer all types in a region. /// Builder to infer all types in a region.
/// ///
/// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling /// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling
@ -7438,7 +7446,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let range = TextRange::new(left.start(), right.end()); let range = TextRange::new(left.start(), right.end());
let ty = builder let ty = builder
.infer_binary_type_comparison(left_ty, *op, right_ty, range) .infer_binary_type_comparison(
left_ty,
*op,
right_ty,
range,
&BinaryComparisonVisitor::new(Ok(Type::BooleanLiteral(true))),
)
.unwrap_or_else(|error| { .unwrap_or_else(|error| {
if let Some(diagnostic_builder) = if let Some(diagnostic_builder) =
builder.context.report_lint(&UNSUPPORTED_OPERATOR, range) builder.context.report_lint(&UNSUPPORTED_OPERATOR, range)
@ -7484,6 +7498,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
other: Type<'db>, other: Type<'db>,
intersection_on: IntersectionOn, intersection_on: IntersectionOn,
range: TextRange, range: TextRange,
visitor: &BinaryComparisonVisitor<'db>,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> { ) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
enum State<'db> { enum State<'db> {
// We have not seen any positive elements (yet) // We have not seen any positive elements (yet)
@ -7500,8 +7515,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// intersection type, which is even more specific. // intersection type, which is even more specific.
for pos in intersection.positive(self.db()) { for pos in intersection.positive(self.db()) {
let result = match intersection_on { let result = match intersection_on {
IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other, range), IntersectionOn::Left => {
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos, range), self.infer_binary_type_comparison(*pos, op, other, range, visitor)
}
IntersectionOn::Right => {
self.infer_binary_type_comparison(other, op, *pos, range, visitor)
}
}; };
if let Ok(Type::BooleanLiteral(_)) = result { if let Ok(Type::BooleanLiteral(_)) = result {
@ -7514,10 +7533,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for neg in intersection.negative(self.db()) { for neg in intersection.negative(self.db()) {
let result = match intersection_on { let result = match intersection_on {
IntersectionOn::Left => self IntersectionOn::Left => self
.infer_binary_type_comparison(*neg, op, other, range) .infer_binary_type_comparison(*neg, op, other, range, visitor)
.ok(), .ok(),
IntersectionOn::Right => self IntersectionOn::Right => self
.infer_binary_type_comparison(other, op, *neg, range) .infer_binary_type_comparison(other, op, *neg, range, visitor)
.ok(), .ok(),
}; };
@ -7578,8 +7597,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for pos in intersection.positive(self.db()) { for pos in intersection.positive(self.db()) {
let result = match intersection_on { let result = match intersection_on {
IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other, range), IntersectionOn::Left => {
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos, range), self.infer_binary_type_comparison(*pos, op, other, range, visitor)
}
IntersectionOn::Right => {
self.infer_binary_type_comparison(other, op, *pos, range, visitor)
}
}; };
match result { match result {
@ -7614,10 +7637,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// We didn't see any positive elements, check if the operation is supported on `object`: // We didn't see any positive elements, check if the operation is supported on `object`:
match intersection_on { match intersection_on {
IntersectionOn::Left => { IntersectionOn::Left => {
self.infer_binary_type_comparison(Type::object(), op, other, range) self.infer_binary_type_comparison(Type::object(), op, other, range, visitor)
} }
IntersectionOn::Right => { IntersectionOn::Right => {
self.infer_binary_type_comparison(other, op, Type::object(), range) self.infer_binary_type_comparison(other, op, Type::object(), range, visitor)
} }
} }
} }
@ -7637,6 +7660,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op: ast::CmpOp, op: ast::CmpOp,
right: Type<'db>, right: Type<'db>,
range: TextRange, range: TextRange,
visitor: &BinaryComparisonVisitor<'db>,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> { ) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// Note: identity (is, is not) for equal builtin types is unreliable and not part of the // Note: identity (is, is not) for equal builtin types is unreliable and not part of the
// language spec. // language spec.
@ -7689,7 +7713,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut builder = UnionBuilder::new(self.db()); let mut builder = UnionBuilder::new(self.db());
for element in union.elements(self.db()) { for element in union.elements(self.db()) {
builder = builder =
builder.add(self.infer_binary_type_comparison(*element, op, other, range)?); builder.add(self.infer_binary_type_comparison(*element, op, other, range, visitor)?);
} }
Some(Ok(builder.build())) Some(Ok(builder.build()))
} }
@ -7697,7 +7721,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut builder = UnionBuilder::new(self.db()); let mut builder = UnionBuilder::new(self.db());
for element in union.elements(self.db()) { for element in union.elements(self.db()) {
builder = builder =
builder.add(self.infer_binary_type_comparison(other, op, *element, range)?); builder.add(self.infer_binary_type_comparison(other, op, *element, range, visitor)?);
} }
Some(Ok(builder.build())) Some(Ok(builder.build()))
} }
@ -7709,6 +7733,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
right, right,
IntersectionOn::Left, IntersectionOn::Left,
range, range,
visitor,
)) ))
} }
(left, Type::Intersection(intersection)) => { (left, Type::Intersection(intersection)) => {
@ -7718,22 +7743,29 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
left, left,
IntersectionOn::Right, IntersectionOn::Right,
range, range,
visitor,
)) ))
} }
(Type::TypeAlias(alias), right) => Some(self.infer_binary_type_comparison( (Type::TypeAlias(alias), right) => Some(
alias.value_type(self.db()), visitor.visit((left, op, right), || { self.infer_binary_type_comparison(
op, alias.value_type(self.db()),
right, op,
range, right,
)), range,
visitor,
)
})),
(left, Type::TypeAlias(alias)) => Some(self.infer_binary_type_comparison( (left, Type::TypeAlias(alias)) => Some(
left, visitor.visit((left, op, right), || { self.infer_binary_type_comparison(
op, left,
alias.value_type(self.db()), op,
range, alias.value_type(self.db()),
)), range,
visitor,
)
})),
(Type::IntLiteral(n), Type::IntLiteral(m)) => Some(match op { (Type::IntLiteral(n), Type::IntLiteral(m)) => Some(match op {
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)), ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
@ -7771,6 +7803,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op, op,
right, right,
range, range,
visitor,
)) ))
} }
(Type::NominalInstance(_), Type::IntLiteral(_)) => { (Type::NominalInstance(_), Type::IntLiteral(_)) => {
@ -7779,6 +7812,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op, op,
KnownClass::Int.to_instance(self.db()), KnownClass::Int.to_instance(self.db()),
range, range,
visitor,
)) ))
} }
@ -7789,6 +7823,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op, op,
Type::IntLiteral(i64::from(b)), Type::IntLiteral(i64::from(b)),
range, range,
visitor,
)) ))
} }
(Type::BooleanLiteral(b), Type::IntLiteral(m)) => { (Type::BooleanLiteral(b), Type::IntLiteral(m)) => {
@ -7797,6 +7832,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op, op,
Type::IntLiteral(m), Type::IntLiteral(m),
range, range,
visitor,
)) ))
} }
(Type::BooleanLiteral(a), Type::BooleanLiteral(b)) => { (Type::BooleanLiteral(a), Type::BooleanLiteral(b)) => {
@ -7805,6 +7841,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op, op,
Type::IntLiteral(i64::from(b)), Type::IntLiteral(i64::from(b)),
range, range,
visitor,
)) ))
} }
@ -7842,12 +7879,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op, op,
right, right,
range, range,
visitor,
)), )),
(_, Type::StringLiteral(_)) => Some(self.infer_binary_type_comparison( (_, Type::StringLiteral(_)) => Some(self.infer_binary_type_comparison(
left, left,
op, op,
KnownClass::Str.to_instance(self.db()), KnownClass::Str.to_instance(self.db()),
range, range,
visitor,
)), )),
(Type::LiteralString, _) => Some(self.infer_binary_type_comparison( (Type::LiteralString, _) => Some(self.infer_binary_type_comparison(
@ -7855,12 +7894,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op, op,
right, right,
range, range,
visitor,
)), )),
(_, Type::LiteralString) => Some(self.infer_binary_type_comparison( (_, Type::LiteralString) => Some(self.infer_binary_type_comparison(
left, left,
op, op,
KnownClass::Str.to_instance(self.db()), KnownClass::Str.to_instance(self.db()),
range, range,
visitor,
)), )),
(Type::BytesLiteral(salsa_b1), Type::BytesLiteral(salsa_b2)) => { (Type::BytesLiteral(salsa_b1), Type::BytesLiteral(salsa_b2)) => {
@ -7901,12 +7942,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op, op,
right, right,
range, range,
visitor,
)), )),
(_, Type::BytesLiteral(_)) => Some(self.infer_binary_type_comparison( (_, Type::BytesLiteral(_)) => Some(self.infer_binary_type_comparison(
left, left,
op, op,
KnownClass::Bytes.to_instance(self.db()), KnownClass::Bytes.to_instance(self.db()),
range, range,
visitor,
)), )),
(Type::EnumLiteral(literal_1), Type::EnumLiteral(literal_2)) (Type::EnumLiteral(literal_1), Type::EnumLiteral(literal_2))
@ -7933,7 +7976,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.and_then(|lhs_tuple| Some((lhs_tuple, nominal2.tuple_spec(self.db())?))) .and_then(|lhs_tuple| Some((lhs_tuple, nominal2.tuple_spec(self.db())?)))
.map(|(lhs_tuple, rhs_tuple)| { .map(|(lhs_tuple, rhs_tuple)| {
let mut tuple_rich_comparison = let mut tuple_rich_comparison =
|op| self.infer_tuple_rich_comparison(&lhs_tuple, op, &rhs_tuple, range); |rich_op| visitor.visit((left, op, right), || {
self.infer_tuple_rich_comparison(&lhs_tuple, rich_op, &rhs_tuple, range, visitor)
});
match op { match op {
ast::CmpOp::Eq => tuple_rich_comparison(RichCompareOperator::Eq), ast::CmpOp::Eq => tuple_rich_comparison(RichCompareOperator::Eq),
@ -7952,6 +7997,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::CmpOp::Eq, ast::CmpOp::Eq,
ty, ty,
range, range,
visitor
).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); ).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");
match eq_result { match eq_result {
@ -8125,6 +8171,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op: RichCompareOperator, op: RichCompareOperator,
right: &TupleSpec<'db>, right: &TupleSpec<'db>,
range: TextRange, range: TextRange,
visitor: &BinaryComparisonVisitor<'db>,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> { ) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// If either tuple is variable length, we can make no assumptions about the relative // If either tuple is variable length, we can make no assumptions about the relative
// lengths of the tuples, and therefore neither about how they compare lexicographically. // lengths of the tuples, and therefore neither about how they compare lexicographically.
@ -8141,7 +8188,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for (l_ty, r_ty) in left_iter.zip(right_iter) { for (l_ty, r_ty) in left_iter.zip(right_iter) {
let pairwise_eq_result = self let pairwise_eq_result = self
.infer_binary_type_comparison(l_ty, ast::CmpOp::Eq, r_ty, range) .infer_binary_type_comparison(l_ty, ast::CmpOp::Eq, r_ty, range, visitor)
.expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); .expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");
match pairwise_eq_result match pairwise_eq_result
@ -8166,9 +8213,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
RichCompareOperator::Lt RichCompareOperator::Lt
| RichCompareOperator::Le | RichCompareOperator::Le
| RichCompareOperator::Gt | RichCompareOperator::Gt
| RichCompareOperator::Ge => { | RichCompareOperator::Ge => self.infer_binary_type_comparison(
self.infer_binary_type_comparison(l_ty, op.into(), r_ty, range)? l_ty,
} op.into(),
r_ty,
range,
visitor,
)?,
// For `==` and `!=`, we already figure out the result from `pairwise_eq_result` // For `==` and `!=`, we already figure out the result from `pairwise_eq_result`
// NOTE: The CPython implementation does not account for non-boolean return types // NOTE: The CPython implementation does not account for non-boolean return types
// or cases where `!=` is not the negation of `==`, we also do not consider these cases. // or cases where `!=` is not the negation of `==`, we also do not consider these cases.