mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-01 09:22:19 +00:00
[red-knot] feat: implement integer comparison (#13571)
## Summary Implements the comparison operator for `[Type::IntLiteral]` and `[Type::BooleanLiteral]` (as an artifact of special handling of `True` and `False` in python). Sets the framework to implement more comparison for types known at static time (e.g. `BooleanLiteral`, `StringLiteral`), allowing us to only implement cases of the triplet `<left> Type`, `<right> Type`, `CmpOp`. Contributes to #12701 (without checking off an item yet). ## Test Plan - Added a test for the comparison of literals that should include most cases of note. - Added a test for the comparison of int instances Please note that the cases do not cover 100% of the branches as there are many and the current testing strategy with variables make this fairly confusing once we have too many in one test. --------- Co-authored-by: Carl Meyer <carl@astral.sh> Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
parent
d726f09cf0
commit
888930b7d3
3 changed files with 320 additions and 23 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2083,6 +2083,7 @@ dependencies = [
|
|||
"countme",
|
||||
"hashbrown",
|
||||
"insta",
|
||||
"itertools 0.13.0",
|
||||
"ordermap",
|
||||
"red_knot_vendored",
|
||||
"ruff_db",
|
||||
|
|
|
@ -24,6 +24,7 @@ bitflags = { workspace = true }
|
|||
camino = { workspace = true }
|
||||
compact_str = { workspace = true }
|
||||
countme = { workspace = true }
|
||||
itertools = { workspace = true}
|
||||
ordermap = { workspace = true }
|
||||
salsa = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
//! stringified annotations. We have a fourth Salsa query for inferring the deferred types
|
||||
//! associated with a particular definition. Scope-level inference infers deferred types for all
|
||||
//! definitions once the rest of the types in the scope have been inferred.
|
||||
use itertools::Itertools;
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
use ruff_db::files::File;
|
||||
|
@ -328,6 +329,14 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
matches!(self.region, InferenceRegion::Deferred(_))
|
||||
}
|
||||
|
||||
/// Get the already-inferred type of an expression node.
|
||||
///
|
||||
/// PANIC if no type has been inferred for this node.
|
||||
fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> {
|
||||
self.types
|
||||
.expression_ty(expr.scoped_ast_id(self.db, self.scope))
|
||||
}
|
||||
|
||||
/// Infers types in the given [`InferenceRegion`].
|
||||
fn infer_region(&mut self) {
|
||||
match self.region {
|
||||
|
@ -984,9 +993,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
|
||||
// TODO(dhruvmanila): The correct type inference here is the return type of the __enter__
|
||||
// method of the context manager.
|
||||
let context_expr_ty = self
|
||||
.types
|
||||
.expression_ty(with_item.context_expr.scoped_ast_id(self.db, self.scope));
|
||||
let context_expr_ty = self.expression_ty(&with_item.context_expr);
|
||||
|
||||
self.types
|
||||
.expressions
|
||||
|
@ -1151,9 +1158,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
let expression = self.index.expression(assignment.value.as_ref());
|
||||
let result = infer_expression_types(self.db, expression);
|
||||
self.extend(result);
|
||||
let value_ty = self
|
||||
.types
|
||||
.expression_ty(assignment.value.scoped_ast_id(self.db, self.scope));
|
||||
let value_ty = self.expression_ty(&assignment.value);
|
||||
self.add_binding(assignment.into(), definition, value_ty);
|
||||
self.types
|
||||
.expressions
|
||||
|
@ -1349,9 +1354,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
let expression = self.index.expression(iterable);
|
||||
let result = infer_expression_types(self.db, expression);
|
||||
self.extend(result);
|
||||
let iterable_ty = self
|
||||
.types
|
||||
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
|
||||
let iterable_ty = self.expression_ty(iterable);
|
||||
|
||||
let loop_var_value_ty = if is_async {
|
||||
// TODO(Alex): async iterables/iterators!
|
||||
|
@ -2434,28 +2437,41 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
op,
|
||||
values,
|
||||
} = bool_op;
|
||||
Self::infer_chained_boolean_types(
|
||||
self.db,
|
||||
*op,
|
||||
values.iter().map(|value| self.infer_expression(value)),
|
||||
values.len(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Computes the output of a chain of (one) boolean operation, consuming as input an iterator
|
||||
/// of types. The iterator is consumed even if the boolean evaluation can be short-circuited,
|
||||
/// in order to ensure the invariant that all expressions are evaluated when inferring types.
|
||||
fn infer_chained_boolean_types(
|
||||
db: &'db dyn Db,
|
||||
op: ast::BoolOp,
|
||||
values: impl IntoIterator<Item = Type<'db>>,
|
||||
n_values: usize,
|
||||
) -> Type<'db> {
|
||||
let mut done = false;
|
||||
UnionType::from_elements(
|
||||
self.db,
|
||||
values.iter().enumerate().map(|(i, value)| {
|
||||
// We need to infer the type of every expression (that's an invariant maintained by
|
||||
// type inference), even if we can short-circuit boolean evaluation of some of
|
||||
// those types.
|
||||
let value_ty = self.infer_expression(value);
|
||||
db,
|
||||
values.into_iter().enumerate().map(|(i, ty)| {
|
||||
if done {
|
||||
Type::Never
|
||||
} else {
|
||||
let is_last = i == values.len() - 1;
|
||||
match (value_ty.bool(self.db), is_last, op) {
|
||||
(Truthiness::Ambiguous, _, _) => value_ty,
|
||||
let is_last = i == n_values - 1;
|
||||
match (ty.bool(db), is_last, op) {
|
||||
(Truthiness::Ambiguous, _, _) => ty,
|
||||
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
|
||||
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
|
||||
(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
|
||||
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
|
||||
done = true;
|
||||
value_ty
|
||||
ty
|
||||
}
|
||||
(_, true, _) => value_ty,
|
||||
(_, true, _) => ty,
|
||||
}
|
||||
}
|
||||
}),
|
||||
|
@ -2466,16 +2482,138 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
let ast::ExprCompare {
|
||||
range: _,
|
||||
left,
|
||||
ops: _,
|
||||
ops,
|
||||
comparators,
|
||||
} = compare;
|
||||
|
||||
self.infer_expression(left);
|
||||
// TODO actually handle ops and return correct type
|
||||
for right in comparators.as_ref() {
|
||||
self.infer_expression(right);
|
||||
}
|
||||
Type::Todo
|
||||
|
||||
// https://docs.python.org/3/reference/expressions.html#comparisons
|
||||
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
|
||||
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to a `op1 b and b op2 c and
|
||||
// ... > y opN z`, except that each expression is evaluated at most once.
|
||||
//
|
||||
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
|
||||
// is shared with the one in `infer_binary_type_comparison`.
|
||||
Self::infer_chained_boolean_types(
|
||||
self.db,
|
||||
ast::BoolOp::And,
|
||||
std::iter::once(left.as_ref())
|
||||
.chain(comparators.as_ref().iter())
|
||||
.tuple_windows::<(_, _)>()
|
||||
.zip(ops.iter())
|
||||
.map(|((left, right), op)| {
|
||||
let left_ty = self.expression_ty(left);
|
||||
let right_ty = self.expression_ty(right);
|
||||
|
||||
self.infer_binary_type_comparison(left_ty, *op, right_ty)
|
||||
.unwrap_or_else(|| {
|
||||
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
|
||||
self.add_diagnostic(
|
||||
AnyNodeRef::ExprCompare(compare),
|
||||
"operator-unsupported",
|
||||
format_args!(
|
||||
"Operator `{}` is not supported for types `{}` and `{}`",
|
||||
op,
|
||||
left_ty.display(self.db),
|
||||
right_ty.display(self.db)
|
||||
),
|
||||
);
|
||||
match op {
|
||||
// `in, not in, is, is not` always return bool instances
|
||||
ast::CmpOp::In
|
||||
| ast::CmpOp::NotIn
|
||||
| ast::CmpOp::Is
|
||||
| ast::CmpOp::IsNot => {
|
||||
builtins_symbol_ty(self.db, "bool").to_instance(self.db)
|
||||
}
|
||||
// Other operators can return arbitrary types
|
||||
_ => Type::Unknown,
|
||||
}
|
||||
})
|
||||
}),
|
||||
ops.len(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Infers the type of a binary comparison (e.g. 'left == right'). See
|
||||
/// `infer_compare_expression` for the higher level logic dealing with multi-comparison
|
||||
/// expressions.
|
||||
///
|
||||
/// If the operation is not supported, return None (we need upstream context to emit a
|
||||
/// diagnostic).
|
||||
fn infer_binary_type_comparison(
|
||||
&mut self,
|
||||
left: Type<'db>,
|
||||
op: ast::CmpOp,
|
||||
right: Type<'db>,
|
||||
) -> Option<Type<'db>> {
|
||||
// Note: identity (is, is not) for equal builtin types is unreliable and not part of the
|
||||
// language spec.
|
||||
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal
|
||||
// - `[ast::CompOp::IsNot]`: return `true` if unequal, `bool` if equal
|
||||
match (left, right) {
|
||||
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
|
||||
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)),
|
||||
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)),
|
||||
ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)),
|
||||
ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)),
|
||||
ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)),
|
||||
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
|
||||
ast::CmpOp::Is => {
|
||||
if n == m {
|
||||
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
|
||||
} else {
|
||||
Some(Type::BooleanLiteral(false))
|
||||
}
|
||||
}
|
||||
ast::CmpOp::IsNot => {
|
||||
if n == m {
|
||||
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
|
||||
} else {
|
||||
Some(Type::BooleanLiteral(true))
|
||||
}
|
||||
}
|
||||
// Undefined for (int, int)
|
||||
ast::CmpOp::In | ast::CmpOp::NotIn => None,
|
||||
},
|
||||
(Type::IntLiteral(_), Type::Instance(_)) => {
|
||||
self.infer_binary_type_comparison(Type::builtin_int_instance(self.db), op, right)
|
||||
}
|
||||
(Type::Instance(_), Type::IntLiteral(_)) => {
|
||||
self.infer_binary_type_comparison(left, op, Type::builtin_int_instance(self.db))
|
||||
}
|
||||
// Booleans are coded as integers (False = 0, True = 1)
|
||||
(Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison(
|
||||
Type::IntLiteral(n),
|
||||
op,
|
||||
Type::IntLiteral(i64::from(b)),
|
||||
),
|
||||
(Type::BooleanLiteral(b), Type::IntLiteral(m)) => self.infer_binary_type_comparison(
|
||||
Type::IntLiteral(i64::from(b)),
|
||||
op,
|
||||
Type::IntLiteral(m),
|
||||
),
|
||||
(Type::BooleanLiteral(a), Type::BooleanLiteral(b)) => self
|
||||
.infer_binary_type_comparison(
|
||||
Type::IntLiteral(i64::from(a)),
|
||||
op,
|
||||
Type::IntLiteral(i64::from(b)),
|
||||
),
|
||||
// Lookup the rich comparison `__dunder__` methods on instances
|
||||
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
|
||||
ast::CmpOp::Lt => {
|
||||
perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__")
|
||||
}
|
||||
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods
|
||||
_ => Some(Type::Todo),
|
||||
},
|
||||
// TODO: handle more types
|
||||
_ => Some(Type::Todo),
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> {
|
||||
|
@ -2995,6 +3133,36 @@ impl StringPartsCollector {
|
|||
}
|
||||
}
|
||||
|
||||
/// Rich comparison in Python are the operators `==`, `!=`, `<`, `<=`, `>`, and `>=`. Their
|
||||
/// behaviour can be edited for classes by implementing corresponding dunder methods.
|
||||
/// This function performs rich comparison between two instances and returns the resulting type.
|
||||
/// see `<https://docs.python.org/3/reference/datamodel.html#object.__lt__>`
|
||||
fn perform_rich_comparison<'db>(
|
||||
db: &'db dyn Db,
|
||||
left: ClassType<'db>,
|
||||
right: ClassType<'db>,
|
||||
dunder_name: &str,
|
||||
) -> Option<Type<'db>> {
|
||||
// The following resource has details about the rich comparison algorithm:
|
||||
// https://snarky.ca/unravelling-rich-comparison-operators/
|
||||
//
|
||||
// TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the
|
||||
// l.h.s.
|
||||
// TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined
|
||||
|
||||
let dunder = left.class_member(db, dunder_name);
|
||||
if !dunder.is_unbound() {
|
||||
// TODO: this currently gives the return type even if the arg types are invalid
|
||||
// (e.g. int.__lt__ with string instance should be None, currently bool)
|
||||
return dunder
|
||||
.call(db, &[Type::Instance(left), Type::Instance(right)])
|
||||
.return_ty(db);
|
||||
}
|
||||
|
||||
// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
@ -3879,6 +4047,133 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comparison_integer_literals() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
r#"
|
||||
a = 1 == 1 == True
|
||||
b = 1 == 1 == 2 == 4
|
||||
c = False < True <= 2 < 3 != 6
|
||||
d = 1 < 1
|
||||
e = 1 > 1
|
||||
f = 1 is 1
|
||||
g = 1 is not 1
|
||||
h = 1 is 2
|
||||
i = 1 is not 7
|
||||
j = 1 <= "" and 0 < 1
|
||||
"#,
|
||||
)?;
|
||||
|
||||
assert_public_ty(&db, "src/a.py", "a", "Literal[True]");
|
||||
assert_public_ty(&db, "src/a.py", "b", "Literal[False]");
|
||||
assert_public_ty(&db, "src/a.py", "c", "Literal[True]");
|
||||
assert_public_ty(&db, "src/a.py", "d", "Literal[False]");
|
||||
assert_public_ty(&db, "src/a.py", "e", "Literal[False]");
|
||||
assert_public_ty(&db, "src/a.py", "f", "bool");
|
||||
assert_public_ty(&db, "src/a.py", "g", "bool");
|
||||
assert_public_ty(&db, "src/a.py", "h", "Literal[False]");
|
||||
assert_public_ty(&db, "src/a.py", "i", "Literal[True]");
|
||||
assert_public_ty(&db, "src/a.py", "j", "@Todo | Literal[True]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comparison_integer_instance() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
r#"
|
||||
def int_instance() -> int: ...
|
||||
a = 1 == int_instance()
|
||||
b = 9 < int_instance()
|
||||
c = int_instance() < int_instance()
|
||||
"#,
|
||||
)?;
|
||||
|
||||
// TODO: implement lookup of `__eq__` on typeshed `int` stub
|
||||
assert_public_ty(&db, "src/a.py", "a", "@Todo");
|
||||
assert_public_ty(&db, "src/a.py", "b", "bool");
|
||||
assert_public_ty(&db, "src/a.py", "c", "bool");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comparison_unsupported_operators() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
r#"
|
||||
a = 1 in 7
|
||||
b = 0 not in 10
|
||||
c = object() < 5
|
||||
d = 5 < object()
|
||||
"#,
|
||||
)?;
|
||||
|
||||
assert_file_diagnostics(
|
||||
&db,
|
||||
"src/a.py",
|
||||
&[
|
||||
"Operator `in` is not supported for types `Literal[1]` and `Literal[7]`",
|
||||
"Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`",
|
||||
"Operator `<` is not supported for types `object` and `Literal[5]`",
|
||||
],
|
||||
);
|
||||
assert_public_ty(&db, "src/a.py", "a", "bool");
|
||||
assert_public_ty(&db, "src/a.py", "b", "bool");
|
||||
assert_public_ty(&db, "src/a.py", "c", "Unknown");
|
||||
// TODO: this should be `Unknown` but we don't check if __lt__ signature is valid for right
|
||||
// operand type
|
||||
assert_public_ty(&db, "src/a.py", "d", "bool");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comparison_non_bool_returns() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
db.write_dedented(
|
||||
"src/a.py",
|
||||
r#"
|
||||
from __future__ import annotations
|
||||
class A:
|
||||
def __lt__(self, other) -> A: ...
|
||||
class B:
|
||||
def __lt__(self, other) -> B: ...
|
||||
class C:
|
||||
def __lt__(self, other) -> C: ...
|
||||
|
||||
a = A() < B() < C()
|
||||
b = 0 < 1 < A() < 3
|
||||
c = 10 < 0 < A() < B() < C()
|
||||
"#,
|
||||
)?;
|
||||
|
||||
// Walking through the example
|
||||
// 1. A() < B() < C()
|
||||
// 2. A() < B() and B() < C() - split in N comparison
|
||||
// 3. A() and B() - evaluate outcome types
|
||||
// 4. bool and bool - evaluate truthiness
|
||||
// 5. A | B - union of "first true" types
|
||||
assert_public_ty(&db, "src/a.py", "a", "A | B");
|
||||
// Walking through the example
|
||||
// 1. 0 < 1 < A() < 3
|
||||
// 2. 0 < 1 and 1 < A() and A() < 3 - split in N comparison
|
||||
// 3. True and bool and A - evaluate outcome types
|
||||
// 4. True and bool and bool - evaluate truthiness
|
||||
// 5. bool | A - union of "true" types
|
||||
assert_public_ty(&db, "src/a.py", "b", "bool | A");
|
||||
// Short-cicuit to False
|
||||
assert_public_ty(&db, "src/a.py", "c", "Literal[False]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bytes_type() -> anyhow::Result<()> {
|
||||
let mut db = setup_db();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue