mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-04 10:48:32 +00:00
[red-knot] Implement Type::Tuple
Comparisons (#13712)
## Summary This PR implements comparisons for (tuple, tuple). It will close #13688 and complete an item in #13618 once merged. ## Test Plan Basic tests are included for (tuple, tuple) comparisons. --------- Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
parent
8f5b2aac9a
commit
2ffc3fad47
2 changed files with 338 additions and 0 deletions
|
@ -0,0 +1,205 @@
|
|||
# Comparison - Tuples
|
||||
|
||||
## Heterogeneous
|
||||
|
||||
For tuples like `tuple[int, str, Literal[1]]`
|
||||
|
||||
### Value Comparisons
|
||||
|
||||
"Value Comparisons" refers to the operators: `==`, `!=`, `<`, `<=`, `>`, `>=`
|
||||
|
||||
#### Results without Ambiguity
|
||||
|
||||
Cases where the result can be definitively inferred as a `BooleanLiteral`.
|
||||
|
||||
```py
|
||||
a = (1, "test", (3, 13), True)
|
||||
b = (1, "test", (3, 14), False)
|
||||
|
||||
reveal_type(a == a) # revealed: Literal[True]
|
||||
reveal_type(a != a) # revealed: Literal[False]
|
||||
reveal_type(a < a) # revealed: Literal[False]
|
||||
reveal_type(a <= a) # revealed: Literal[True]
|
||||
reveal_type(a > a) # revealed: Literal[False]
|
||||
reveal_type(a >= a) # revealed: Literal[True]
|
||||
|
||||
reveal_type(a == b) # revealed: Literal[False]
|
||||
reveal_type(a != b) # revealed: Literal[True]
|
||||
reveal_type(a < b) # revealed: Literal[True]
|
||||
reveal_type(a <= b) # revealed: Literal[True]
|
||||
reveal_type(a > b) # revealed: Literal[False]
|
||||
reveal_type(a >= b) # revealed: Literal[False]
|
||||
```
|
||||
|
||||
Even when tuples have different lengths, comparisons should be handled appropriately.
|
||||
|
||||
```py path=different_length.py
|
||||
a = (1, 2, 3)
|
||||
b = (1, 2, 3, 4)
|
||||
|
||||
reveal_type(a == b) # revealed: Literal[False]
|
||||
reveal_type(a != b) # revealed: Literal[True]
|
||||
reveal_type(a < b) # revealed: Literal[True]
|
||||
reveal_type(a <= b) # revealed: Literal[True]
|
||||
reveal_type(a > b) # revealed: Literal[False]
|
||||
reveal_type(a >= b) # revealed: Literal[False]
|
||||
|
||||
c = ("a", "b", "c", "d")
|
||||
d = ("a", "b", "c")
|
||||
|
||||
reveal_type(c == d) # revealed: Literal[False]
|
||||
reveal_type(c != d) # revealed: Literal[True]
|
||||
reveal_type(c < d) # revealed: Literal[False]
|
||||
reveal_type(c <= d) # revealed: Literal[False]
|
||||
reveal_type(c > d) # revealed: Literal[True]
|
||||
reveal_type(c >= d) # revealed: Literal[True]
|
||||
```
|
||||
|
||||
#### Results with Ambiguity
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool: ...
|
||||
def int_instance() -> int: ...
|
||||
|
||||
a = (bool_instance(),)
|
||||
b = (int_instance(),)
|
||||
|
||||
# TODO: All @Todo should be `bool`
|
||||
reveal_type(a == a) # revealed: @Todo
|
||||
reveal_type(a != a) # revealed: @Todo
|
||||
reveal_type(a < a) # revealed: @Todo
|
||||
reveal_type(a <= a) # revealed: @Todo
|
||||
reveal_type(a > a) # revealed: @Todo
|
||||
reveal_type(a >= a) # revealed: @Todo
|
||||
|
||||
reveal_type(a == b) # revealed: @Todo
|
||||
reveal_type(a != b) # revealed: @Todo
|
||||
reveal_type(a < b) # revealed: @Todo
|
||||
reveal_type(a <= b) # revealed: @Todo
|
||||
reveal_type(a > b) # revealed: @Todo
|
||||
reveal_type(a >= b) # revealed: @Todo
|
||||
```
|
||||
|
||||
#### Comparison Unsupported
|
||||
|
||||
If two tuples contain types that do not support comparison, the result may be `Unknown`.
|
||||
However, `==` and `!=` are exceptions and can still provide definite results.
|
||||
|
||||
```py
|
||||
a = (1, 2)
|
||||
b = (1, "hello")
|
||||
|
||||
# TODO: should be Literal[False]
|
||||
reveal_type(a == b) # revealed: @Todo
|
||||
|
||||
# TODO: should be Literal[True]
|
||||
reveal_type(a != b) # revealed: @Todo
|
||||
|
||||
# TODO: should be Unknown and add more informative diagnostics
|
||||
reveal_type(a < b) # revealed: @Todo
|
||||
reveal_type(a <= b) # revealed: @Todo
|
||||
reveal_type(a > b) # revealed: @Todo
|
||||
reveal_type(a >= b) # revealed: @Todo
|
||||
```
|
||||
|
||||
However, if the lexicographic comparison completes without reaching a point where str and int are compared,
|
||||
Python will still produce a result based on the prior elements.
|
||||
|
||||
```py path=short_circuit.py
|
||||
a = (1, 2)
|
||||
b = (999999, "hello")
|
||||
|
||||
reveal_type(a == b) # revealed: Literal[False]
|
||||
reveal_type(a != b) # revealed: Literal[True]
|
||||
reveal_type(a < b) # revealed: Literal[True]
|
||||
reveal_type(a <= b) # revealed: Literal[True]
|
||||
reveal_type(a > b) # revealed: Literal[False]
|
||||
reveal_type(a >= b) # revealed: Literal[False]
|
||||
```
|
||||
|
||||
#### Matryoshka Tuples
|
||||
|
||||
```py
|
||||
a = (1, True, "Hello")
|
||||
b = (a, a, a)
|
||||
c = (b, b, b)
|
||||
|
||||
reveal_type(c == c) # revealed: Literal[True]
|
||||
reveal_type(c != c) # revealed: Literal[False]
|
||||
reveal_type(c < c) # revealed: Literal[False]
|
||||
reveal_type(c <= c) # revealed: Literal[True]
|
||||
reveal_type(c > c) # revealed: Literal[False]
|
||||
reveal_type(c >= c) # revealed: Literal[True]
|
||||
```
|
||||
|
||||
#### Non Boolean Rich Comparisons
|
||||
|
||||
```py
|
||||
class A():
|
||||
def __eq__(self, o) -> str: ...
|
||||
def __ne__(self, o) -> int: ...
|
||||
def __lt__(self, o) -> float: ...
|
||||
def __le__(self, o) -> object: ...
|
||||
def __gt__(self, o) -> tuple: ...
|
||||
def __ge__(self, o) -> list: ...
|
||||
|
||||
a = (A(), A())
|
||||
|
||||
# TODO: All @Todo should be bool
|
||||
reveal_type(a == a) # revealed: @Todo
|
||||
reveal_type(a != a) # revealed: @Todo
|
||||
reveal_type(a < a) # revealed: @Todo
|
||||
reveal_type(a <= a) # revealed: @Todo
|
||||
reveal_type(a > a) # revealed: @Todo
|
||||
reveal_type(a >= a) # revealed: @Todo
|
||||
```
|
||||
|
||||
### Membership Test Comparisons
|
||||
|
||||
"Membership Test Comparisons" refers to the operators `in` and `not in`.
|
||||
|
||||
```py
|
||||
def int_instance() -> int: ...
|
||||
|
||||
a = (1, 2)
|
||||
b = ((3, 4), (1, 2))
|
||||
c = ((1, 2, 3), (4, 5, 6))
|
||||
d = ((int_instance(), int_instance()), (int_instance(), int_instance()))
|
||||
|
||||
reveal_type(a in b) # revealed: Literal[True]
|
||||
reveal_type(a not in b) # revealed: Literal[False]
|
||||
|
||||
reveal_type(a in c) # revealed: Literal[False]
|
||||
reveal_type(a not in c) # revealed: Literal[True]
|
||||
|
||||
# TODO: All @Todo should be bool
|
||||
reveal_type(a in d) # revealed: @Todo
|
||||
reveal_type(a not in d) # revealed: @Todo
|
||||
```
|
||||
|
||||
### Identity Comparisons
|
||||
|
||||
"Identity Comparisons" refers to `is` and `is not`.
|
||||
|
||||
```py
|
||||
a = (1, 2)
|
||||
b = ("a", "b")
|
||||
c = (1, 2, 3)
|
||||
|
||||
reveal_type(a is (1, 2)) # revealed: bool
|
||||
reveal_type(a is not (1, 2)) # revealed: bool
|
||||
|
||||
# TODO: Update to Literal[False] once str == int comparison is implemented
|
||||
reveal_type(a is b) # revealed: @Todo
|
||||
# TODO: Update to Literal[True] once str == int comparison is implemented
|
||||
reveal_type(a is not b) # revealed: @Todo
|
||||
|
||||
reveal_type(a is c) # revealed: Literal[False]
|
||||
reveal_type(a is not c) # revealed: Literal[True]
|
||||
```
|
||||
|
||||
## Homogeneous
|
||||
|
||||
For tuples like `tuple[int, ...]`, `tuple[Any, ...]`
|
||||
|
||||
// TODO
|
|
@ -2831,7 +2831,68 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
(_, Type::BytesLiteral(_)) => {
|
||||
self.infer_binary_type_comparison(left, op, KnownClass::Bytes.to_instance(self.db))
|
||||
}
|
||||
(Type::Tuple(lhs), Type::Tuple(rhs)) => {
|
||||
// Note: This only works on heterogeneous tuple types.
|
||||
let lhs_elements = lhs.elements(self.db).as_ref();
|
||||
let rhs_elements = rhs.elements(self.db).as_ref();
|
||||
|
||||
let mut lexicographic_type_comparison =
|
||||
|op| self.infer_lexicographic_type_comparison(lhs_elements, op, rhs_elements);
|
||||
|
||||
match op {
|
||||
ast::CmpOp::Eq => lexicographic_type_comparison(RichCompareOperator::Eq),
|
||||
ast::CmpOp::NotEq => lexicographic_type_comparison(RichCompareOperator::Ne),
|
||||
ast::CmpOp::Lt => lexicographic_type_comparison(RichCompareOperator::Lt),
|
||||
ast::CmpOp::LtE => lexicographic_type_comparison(RichCompareOperator::Le),
|
||||
ast::CmpOp::Gt => lexicographic_type_comparison(RichCompareOperator::Gt),
|
||||
ast::CmpOp::GtE => lexicographic_type_comparison(RichCompareOperator::Ge),
|
||||
ast::CmpOp::In | ast::CmpOp::NotIn => {
|
||||
let mut eq_count = 0usize;
|
||||
let mut not_eq_count = 0usize;
|
||||
|
||||
for ty in rhs_elements {
|
||||
let eq_result = self.infer_binary_type_comparison(
|
||||
Type::Tuple(lhs),
|
||||
ast::CmpOp::Eq,
|
||||
*ty,
|
||||
).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");
|
||||
|
||||
match eq_result {
|
||||
Type::Todo => return Some(Type::Todo),
|
||||
ty => match ty.bool(self.db) {
|
||||
Truthiness::AlwaysTrue => eq_count += 1,
|
||||
Truthiness::AlwaysFalse => not_eq_count += 1,
|
||||
Truthiness::Ambiguous => (),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if eq_count >= 1 {
|
||||
Some(Type::BooleanLiteral(op.is_in()))
|
||||
} else if not_eq_count == rhs_elements.len() {
|
||||
Some(Type::BooleanLiteral(op.is_not_in()))
|
||||
} else {
|
||||
Some(KnownClass::Bool.to_instance(self.db))
|
||||
}
|
||||
}
|
||||
ast::CmpOp::Is | ast::CmpOp::IsNot => {
|
||||
// - `[ast::CmpOp::Is]`: returns `false` if the elements are definitely unequal, otherwise `bool`
|
||||
// - `[ast::CmpOp::IsNot]`: returns `true` if the elements are definitely unequal, otherwise `bool`
|
||||
let eq_result = lexicographic_type_comparison(RichCompareOperator::Eq)
|
||||
.expect(
|
||||
"infer_binary_type_comparison should never return None for `CmpOp::Eq`",
|
||||
);
|
||||
|
||||
Some(match eq_result {
|
||||
Type::Todo => Type::Todo,
|
||||
ty => match ty.bool(self.db) {
|
||||
Truthiness::AlwaysFalse => Type::BooleanLiteral(op.is_is_not()),
|
||||
_ => KnownClass::Bool.to_instance(self.db),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
// Lookup the rich comparison `__dunder__` methods on instances
|
||||
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
|
||||
ast::CmpOp::Lt => {
|
||||
|
@ -2845,6 +2906,55 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Performs lexicographic comparison between two slices of types.
|
||||
///
|
||||
/// For lexicographic comparison, elements from both slices are compared pairwise using
|
||||
/// `infer_binary_type_comparison`. If a conclusive result cannot be determined as a `BoolLiteral`,
|
||||
/// it returns `bool`. Returns `None` if the comparison is not supported.
|
||||
fn infer_lexicographic_type_comparison(
|
||||
&mut self,
|
||||
left: &[Type<'db>],
|
||||
op: RichCompareOperator,
|
||||
right: &[Type<'db>],
|
||||
) -> Option<Type<'db>> {
|
||||
// Compare paired elements from left and right slices
|
||||
for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) {
|
||||
let eq_result = self
|
||||
.infer_binary_type_comparison(l_ty, ast::CmpOp::Eq, r_ty)
|
||||
.expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");
|
||||
|
||||
match eq_result {
|
||||
// If propagation is required, return the result as is
|
||||
Type::Todo => return Some(Type::Todo),
|
||||
ty => match ty.bool(self.db) {
|
||||
// Types are equal, continue to the next pair
|
||||
Truthiness::AlwaysTrue => continue,
|
||||
// Types are not equal, perform the specified comparison and return the result
|
||||
Truthiness::AlwaysFalse => {
|
||||
return self.infer_binary_type_comparison(l_ty, op.into(), r_ty)
|
||||
}
|
||||
// If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral.
|
||||
// In this case, we simply return a bool instance.
|
||||
Truthiness::Ambiguous => return Some(KnownClass::Bool.to_instance(self.db)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// At this point, the lengths of the two slices may be different, but the prefix of
|
||||
// left and right slices is entirely identical.
|
||||
// We return a comparison of the slice lengths based on the operator.
|
||||
let (left_len, right_len) = (left.len(), right.len());
|
||||
|
||||
Some(Type::BooleanLiteral(match op {
|
||||
RichCompareOperator::Eq => left_len == right_len,
|
||||
RichCompareOperator::Ne => left_len != right_len,
|
||||
RichCompareOperator::Lt => left_len < right_len,
|
||||
RichCompareOperator::Le => left_len <= right_len,
|
||||
RichCompareOperator::Gt => left_len > right_len,
|
||||
RichCompareOperator::Ge => left_len >= right_len,
|
||||
}))
|
||||
}
|
||||
|
||||
fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> {
|
||||
let ast::ExprSubscript {
|
||||
range: _,
|
||||
|
@ -3286,6 +3396,29 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum RichCompareOperator {
|
||||
Eq,
|
||||
Ne,
|
||||
Gt,
|
||||
Ge,
|
||||
Lt,
|
||||
Le,
|
||||
}
|
||||
|
||||
impl From<RichCompareOperator> for ast::CmpOp {
|
||||
fn from(value: RichCompareOperator) -> Self {
|
||||
match value {
|
||||
RichCompareOperator::Eq => ast::CmpOp::Eq,
|
||||
RichCompareOperator::Ne => ast::CmpOp::NotEq,
|
||||
RichCompareOperator::Lt => ast::CmpOp::Lt,
|
||||
RichCompareOperator::Le => ast::CmpOp::LtE,
|
||||
RichCompareOperator::Gt => ast::CmpOp::Gt,
|
||||
RichCompareOperator::Ge => ast::CmpOp::GtE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_import_from_module(level: u32, module: Option<&str>) -> String {
|
||||
format!(
|
||||
"{}{}",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue