[red-knot] Refactoring the inference logic of lexicographic comparisons (#14422)

## Summary

closes #14279

### Limitations of the Current Implementation
#### Incorrect Error Propagation

In the current implementation of lexicographic comparisons, if the
result of an Eq operation is Ambiguous, the comparison stops
immediately, returning a bool instance. While this may yield correct
inferences, it fails to capture unsupported-operation errors that might
occur in subsequent comparisons.
```py
class A: ...

(int_instance(), A()) < (int_instance(), A())  # should error
```

#### Weak Inference in Specific Cases

> Example: `(int_instance(), "foo") == (int_instance(), "bar")`
> Current result: `bool`
> Expected result: `Literal[False]`

`Eq` and `NotEq` have unique behavior in lexicographic comparisons
compared to other operators. Specifically:
- For `Eq`, if any non-equal pair exists within the tuples being
compared, we can immediately conclude that the tuples are not equal.
- For `NotEq`, if any equal pair exists, we can conclude that the tuples
are unequal.

```py
a = (str_instance(), int_instance(), "foo")

reveal_type(a == a)  # revealed: bool
reveal_type(a != a)  # revealed: bool

b = (str_instance(), int_instance(), "bar")

reveal_type(a == b)  # revealed: bool  # should be Literal[False]
reveal_type(a != b)  # revealed: bool  # should be Literal[True]
```
#### Incorrect Support for Non-Boolean Rich Comparisons

In CPython, aside from `==` and `!=`, tuple comparisons return a
non-boolean result as-is. Tuples do not convert the value into `bool`.

Note: If all pairwise `==` comparisons between elements in the tuples
return Truthy, the comparison then considers the tuples' lengths.
Regardless of the return type of the dunder methods, the final result
can still be a boolean.

```py
from __future__ import annotations

class A:
    def __eq__(self, o: object) -> str:
        return "hello"

    def __ne__(self, o: object) -> bytes:
        return b"world"

    def __lt__(self, o: A) -> float:
        return 3.14

a = (A(), A())

reveal_type(a == a)  # revealed: bool
reveal_type(a != a)  # revealed: bool
reveal_type(a < a)  # revealed: bool # should be: `float | Literal[False]`

```

### Key Changes
One of the major changes is that comparisons no longer end with a `bool`
result when a pairwise `Eq` result is `Ambiguous`. Instead, the function
attempts to infer all possible cases and unions the results. This
improvement allows for more robust type inference and better error
detection.

Additionally, as the function is now optimized for tuple comparisons,
the name has been changed from the more general
`infer_lexicographic_comparison` to `infer_tuple_rich_comparison`.

## Test Plan

mdtest included
This commit is contained in:
cake-monotone 2024-11-20 10:32:43 +09:00 committed by GitHub
parent 42c35b6f44
commit 6a4d207db7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 208 additions and 40 deletions

View file

@ -3643,16 +3643,16 @@ impl<'db> TypeInferenceBuilder<'db> {
let lhs_elements = lhs.elements(self.db);
let rhs_elements = rhs.elements(self.db);
let mut lexicographic_type_comparison =
|op| self.infer_lexicographic_type_comparison(lhs_elements, op, rhs_elements);
let mut tuple_rich_comparison =
|op| self.infer_tuple_rich_comparison(lhs_elements, op, rhs_elements);
match op {
ast::CmpOp::Eq => lexicographic_type_comparison(RichCompareOperator::Eq),
ast::CmpOp::NotEq => lexicographic_type_comparison(RichCompareOperator::Ne),
ast::CmpOp::Lt => lexicographic_type_comparison(RichCompareOperator::Lt),
ast::CmpOp::LtE => lexicographic_type_comparison(RichCompareOperator::Le),
ast::CmpOp::Gt => lexicographic_type_comparison(RichCompareOperator::Gt),
ast::CmpOp::GtE => lexicographic_type_comparison(RichCompareOperator::Ge),
ast::CmpOp::Eq => tuple_rich_comparison(RichCompareOperator::Eq),
ast::CmpOp::NotEq => tuple_rich_comparison(RichCompareOperator::Ne),
ast::CmpOp::Lt => tuple_rich_comparison(RichCompareOperator::Lt),
ast::CmpOp::LtE => tuple_rich_comparison(RichCompareOperator::Le),
ast::CmpOp::Gt => tuple_rich_comparison(RichCompareOperator::Gt),
ast::CmpOp::GtE => tuple_rich_comparison(RichCompareOperator::Ge),
ast::CmpOp::In | ast::CmpOp::NotIn => {
let mut eq_count = 0usize;
let mut not_eq_count = 0usize;
@ -3685,8 +3685,7 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::CmpOp::Is | ast::CmpOp::IsNot => {
// - `[ast::CmpOp::Is]`: returns `false` if the elements are definitely unequal, otherwise `bool`
// - `[ast::CmpOp::IsNot]`: returns `true` if the elements are definitely unequal, otherwise `bool`
let eq_result = lexicographic_type_comparison(RichCompareOperator::Eq)
.expect(
let eq_result = tuple_rich_comparison(RichCompareOperator::Eq).expect(
"infer_binary_type_comparison should never return None for `CmpOp::Eq`",
);
@ -3751,53 +3750,80 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
/// Performs lexicographic comparison between two slices of types.
/// Simulates rich comparison between tuples and returns the inferred result.
/// This performs a lexicographic comparison, returning a union of all possible return types that could result from the comparison.
///
/// For lexicographic comparison, elements from both slices are compared pairwise using
/// `infer_binary_type_comparison`. If a conclusive result cannot be determined as a `BooleanLiteral`,
/// it returns `bool`. Returns `None` if the comparison is not supported.
fn infer_lexicographic_type_comparison(
/// basically it's based on cpython's `tuple_richcompare`
/// see `<https://github.com/python/cpython/blob/9d6366b60d01305fc5e45100e0cd13e358aa397d/Objects/tupleobject.c#L637>`
fn infer_tuple_rich_comparison(
&mut self,
left: &[Type<'db>],
op: RichCompareOperator,
right: &[Type<'db>],
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// Compare paired elements from left and right slices
for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) {
let eq_result = self
let left_iter = left.iter().copied();
let right_iter = right.iter().copied();
let mut builder = UnionBuilder::new(self.db);
for (l_ty, r_ty) in left_iter.zip(right_iter) {
let pairwise_eq_result = self
.infer_binary_type_comparison(l_ty, ast::CmpOp::Eq, r_ty)
.expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");
match eq_result {
match pairwise_eq_result {
// If propagation is required, return the result as is
Type::Todo => return Ok(Type::Todo),
ty => match ty.bool(self.db) {
// Types are equal, continue to the next pair
// - AlwaysTrue : Continue to the next pair for lexicographic comparison
Truthiness::AlwaysTrue => continue,
// Types are not equal, perform the specified comparison and return the result
Truthiness::AlwaysFalse => {
return self.infer_binary_type_comparison(l_ty, op.into(), r_ty)
// - AlwaysFalse:
// Lexicographic comparisons will always terminate with this pair.
// Complete the comparison and return the result.
// - Ambiguous:
// Lexicographic comparisons might continue to the next pair (if eq_result is true),
// or terminate here (if eq_result is false).
// To account for cases where the comparison terminates here, add the pairwise comparison result to the union builder.
eq_truthiness @ (Truthiness::AlwaysFalse | Truthiness::Ambiguous) => {
let pairwise_compare_result = match op {
RichCompareOperator::Lt
| RichCompareOperator::Le
| RichCompareOperator::Gt
| RichCompareOperator::Ge => {
self.infer_binary_type_comparison(l_ty, op.into(), r_ty)?
}
// For `==` and `!=`, we already figure out the result from `pairwise_eq_result`
// NOTE: The CPython implementation does not account for non-boolean return types
// or cases where `!=` is not the negation of `==`, we also do not consider these cases.
RichCompareOperator::Eq => Type::BooleanLiteral(false),
RichCompareOperator::Ne => Type::BooleanLiteral(true),
};
builder = builder.add(pairwise_compare_result);
if eq_truthiness.is_ambiguous() {
continue;
}
return Ok(builder.build());
}
// If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral.
// In this case, we simply return a bool instance.
Truthiness::Ambiguous => return Ok(KnownClass::Bool.to_instance(self.db)),
},
}
}
// At this point, the lengths of the two slices may be different, but the prefix of
// left and right slices is entirely identical.
// We return a comparison of the slice lengths based on the operator.
// if no more items to compare, we just compare sizes
let (left_len, right_len) = (left.len(), right.len());
Ok(Type::BooleanLiteral(match op {
builder = builder.add(Type::BooleanLiteral(match op {
RichCompareOperator::Eq => left_len == right_len,
RichCompareOperator::Ne => left_len != right_len,
RichCompareOperator::Lt => left_len < right_len,
RichCompareOperator::Le => left_len <= right_len,
RichCompareOperator::Gt => left_len > right_len,
RichCompareOperator::Ge => left_len >= right_len,
}))
}));
Ok(builder.build())
}
fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> {