[red-knot] feat: implement integer comparison (#13571)

## Summary

Implements the comparison operator for `[Type::IntLiteral]` and
`[Type::BooleanLiteral]` (as an artifact of special handling of `True` and
`False` in python).
Sets the framework to implement more comparison for types known at
static time (e.g. `BooleanLiteral`, `StringLiteral`), allowing us to only
implement cases of the triplet `<left> Type`, `<right> Type`, `CmpOp`.
Contributes to #12701 (without checking off an item yet).

## Test Plan

- Added a test for the comparison of literals that should include most
cases of note.
- Added a test for the comparison of int instances

Please note that the cases do not cover 100% of the branches as there
are many and the current testing strategy with variables make this
fairly confusing once we have too many in one test.

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Simon 2024-10-04 19:40:59 +02:00 committed by GitHub
parent d726f09cf0
commit 888930b7d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 320 additions and 23 deletions

1
Cargo.lock generated
View file

@ -2083,6 +2083,7 @@ dependencies = [
"countme",
"hashbrown",
"insta",
"itertools 0.13.0",
"ordermap",
"red_knot_vendored",
"ruff_db",

View file

@ -24,6 +24,7 @@ bitflags = { workspace = true }
camino = { workspace = true }
compact_str = { workspace = true }
countme = { workspace = true }
itertools = { workspace = true}
ordermap = { workspace = true }
salsa = { workspace = true }
thiserror = { workspace = true }

View file

@ -26,6 +26,7 @@
//! stringified annotations. We have a fourth Salsa query for inferring the deferred types
//! associated with a particular definition. Scope-level inference infers deferred types for all
//! definitions once the rest of the types in the scope have been inferred.
use itertools::Itertools;
use std::num::NonZeroU32;
use ruff_db::files::File;
@ -328,6 +329,14 @@ impl<'db> TypeInferenceBuilder<'db> {
matches!(self.region, InferenceRegion::Deferred(_))
}
/// Get the already-inferred type of an expression node.
///
/// PANIC if no type has been inferred for this node.
fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> {
self.types
.expression_ty(expr.scoped_ast_id(self.db, self.scope))
}
/// Infers types in the given [`InferenceRegion`].
fn infer_region(&mut self) {
match self.region {
@ -984,9 +993,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// TODO(dhruvmanila): The correct type inference here is the return type of the __enter__
// method of the context manager.
let context_expr_ty = self
.types
.expression_ty(with_item.context_expr.scoped_ast_id(self.db, self.scope));
let context_expr_ty = self.expression_ty(&with_item.context_expr);
self.types
.expressions
@ -1151,9 +1158,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let expression = self.index.expression(assignment.value.as_ref());
let result = infer_expression_types(self.db, expression);
self.extend(result);
let value_ty = self
.types
.expression_ty(assignment.value.scoped_ast_id(self.db, self.scope));
let value_ty = self.expression_ty(&assignment.value);
self.add_binding(assignment.into(), definition, value_ty);
self.types
.expressions
@ -1349,9 +1354,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
self.extend(result);
let iterable_ty = self
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
let iterable_ty = self.expression_ty(iterable);
let loop_var_value_ty = if is_async {
// TODO(Alex): async iterables/iterators!
@ -2434,28 +2437,41 @@ impl<'db> TypeInferenceBuilder<'db> {
op,
values,
} = bool_op;
Self::infer_chained_boolean_types(
self.db,
*op,
values.iter().map(|value| self.infer_expression(value)),
values.len(),
)
}
/// Computes the output of a chain of (one) boolean operation, consuming as input an iterator
/// of types. The iterator is consumed even if the boolean evaluation can be short-circuited,
/// in order to ensure the invariant that all expressions are evaluated when inferring types.
fn infer_chained_boolean_types(
db: &'db dyn Db,
op: ast::BoolOp,
values: impl IntoIterator<Item = Type<'db>>,
n_values: usize,
) -> Type<'db> {
let mut done = false;
UnionType::from_elements(
self.db,
values.iter().enumerate().map(|(i, value)| {
// We need to infer the type of every expression (that's an invariant maintained by
// type inference), even if we can short-circuit boolean evaluation of some of
// those types.
let value_ty = self.infer_expression(value);
db,
values.into_iter().enumerate().map(|(i, ty)| {
if done {
Type::Never
} else {
let is_last = i == values.len() - 1;
match (value_ty.bool(self.db), is_last, op) {
(Truthiness::Ambiguous, _, _) => value_ty,
let is_last = i == n_values - 1;
match (ty.bool(db), is_last, op) {
(Truthiness::Ambiguous, _, _) => ty,
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
done = true;
value_ty
ty
}
(_, true, _) => value_ty,
(_, true, _) => ty,
}
}
}),
@ -2466,16 +2482,138 @@ impl<'db> TypeInferenceBuilder<'db> {
let ast::ExprCompare {
range: _,
left,
ops: _,
ops,
comparators,
} = compare;
self.infer_expression(left);
// TODO actually handle ops and return correct type
for right in comparators.as_ref() {
self.infer_expression(right);
}
Type::Todo
// https://docs.python.org/3/reference/expressions.html#comparisons
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to a `op1 b and b op2 c and
// ... > y opN z`, except that each expression is evaluated at most once.
//
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
// is shared with the one in `infer_binary_type_comparison`.
Self::infer_chained_boolean_types(
self.db,
ast::BoolOp::And,
std::iter::once(left.as_ref())
.chain(comparators.as_ref().iter())
.tuple_windows::<(_, _)>()
.zip(ops.iter())
.map(|((left, right), op)| {
let left_ty = self.expression_ty(left);
let right_ty = self.expression_ty(right);
self.infer_binary_type_comparison(left_ty, *op, right_ty)
.unwrap_or_else(|| {
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
self.add_diagnostic(
AnyNodeRef::ExprCompare(compare),
"operator-unsupported",
format_args!(
"Operator `{}` is not supported for types `{}` and `{}`",
op,
left_ty.display(self.db),
right_ty.display(self.db)
),
);
match op {
// `in, not in, is, is not` always return bool instances
ast::CmpOp::In
| ast::CmpOp::NotIn
| ast::CmpOp::Is
| ast::CmpOp::IsNot => {
builtins_symbol_ty(self.db, "bool").to_instance(self.db)
}
// Other operators can return arbitrary types
_ => Type::Unknown,
}
})
}),
ops.len(),
)
}
/// Infers the type of a binary comparison (e.g. 'left == right'). See
/// `infer_compare_expression` for the higher level logic dealing with multi-comparison
/// expressions.
///
/// If the operation is not supported, return None (we need upstream context to emit a
/// diagnostic).
fn infer_binary_type_comparison(
&mut self,
left: Type<'db>,
op: ast::CmpOp,
right: Type<'db>,
) -> Option<Type<'db>> {
// Note: identity (is, is not) for equal builtin types is unreliable and not part of the
// language spec.
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal
// - `[ast::CompOp::IsNot]`: return `true` if unequal, `bool` if equal
match (left, right) {
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
ast::CmpOp::Is => {
if n == m {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if n == m {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
}
}
// Undefined for (int, int)
ast::CmpOp::In | ast::CmpOp::NotIn => None,
},
(Type::IntLiteral(_), Type::Instance(_)) => {
self.infer_binary_type_comparison(Type::builtin_int_instance(self.db), op, right)
}
(Type::Instance(_), Type::IntLiteral(_)) => {
self.infer_binary_type_comparison(left, op, Type::builtin_int_instance(self.db))
}
// Booleans are coded as integers (False = 0, True = 1)
(Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison(
Type::IntLiteral(n),
op,
Type::IntLiteral(i64::from(b)),
),
(Type::BooleanLiteral(b), Type::IntLiteral(m)) => self.infer_binary_type_comparison(
Type::IntLiteral(i64::from(b)),
op,
Type::IntLiteral(m),
),
(Type::BooleanLiteral(a), Type::BooleanLiteral(b)) => self
.infer_binary_type_comparison(
Type::IntLiteral(i64::from(a)),
op,
Type::IntLiteral(i64::from(b)),
),
// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => {
perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__")
}
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods
_ => Some(Type::Todo),
},
// TODO: handle more types
_ => Some(Type::Todo),
}
}
fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> {
@ -2995,6 +3133,36 @@ impl StringPartsCollector {
}
}
/// Rich comparison in Python are the operators `==`, `!=`, `<`, `<=`, `>`, and `>=`. Their
/// behaviour can be edited for classes by implementing corresponding dunder methods.
/// This function performs rich comparison between two instances and returns the resulting type.
/// see `<https://docs.python.org/3/reference/datamodel.html#object.__lt__>`
fn perform_rich_comparison<'db>(
db: &'db dyn Db,
left: ClassType<'db>,
right: ClassType<'db>,
dunder_name: &str,
) -> Option<Type<'db>> {
// The following resource has details about the rich comparison algorithm:
// https://snarky.ca/unravelling-rich-comparison-operators/
//
// TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the
// l.h.s.
// TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined
let dunder = left.class_member(db, dunder_name);
if !dunder.is_unbound() {
// TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be None, currently bool)
return dunder
.call(db, &[Type::Instance(left), Type::Instance(right)])
.return_ty(db);
}
// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
None
}
#[cfg(test)]
mod tests {
@ -3879,6 +4047,133 @@ mod tests {
Ok(())
}
#[test]
fn comparison_integer_literals() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
a = 1 == 1 == True
b = 1 == 1 == 2 == 4
c = False < True <= 2 < 3 != 6
d = 1 < 1
e = 1 > 1
f = 1 is 1
g = 1 is not 1
h = 1 is 2
i = 1 is not 7
j = 1 <= "" and 0 < 1
"#,
)?;
assert_public_ty(&db, "src/a.py", "a", "Literal[True]");
assert_public_ty(&db, "src/a.py", "b", "Literal[False]");
assert_public_ty(&db, "src/a.py", "c", "Literal[True]");
assert_public_ty(&db, "src/a.py", "d", "Literal[False]");
assert_public_ty(&db, "src/a.py", "e", "Literal[False]");
assert_public_ty(&db, "src/a.py", "f", "bool");
assert_public_ty(&db, "src/a.py", "g", "bool");
assert_public_ty(&db, "src/a.py", "h", "Literal[False]");
assert_public_ty(&db, "src/a.py", "i", "Literal[True]");
assert_public_ty(&db, "src/a.py", "j", "@Todo | Literal[True]");
Ok(())
}
#[test]
fn comparison_integer_instance() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
def int_instance() -> int: ...
a = 1 == int_instance()
b = 9 < int_instance()
c = int_instance() < int_instance()
"#,
)?;
// TODO: implement lookup of `__eq__` on typeshed `int` stub
assert_public_ty(&db, "src/a.py", "a", "@Todo");
assert_public_ty(&db, "src/a.py", "b", "bool");
assert_public_ty(&db, "src/a.py", "c", "bool");
Ok(())
}
#[test]
fn comparison_unsupported_operators() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
a = 1 in 7
b = 0 not in 10
c = object() < 5
d = 5 < object()
"#,
)?;
assert_file_diagnostics(
&db,
"src/a.py",
&[
"Operator `in` is not supported for types `Literal[1]` and `Literal[7]`",
"Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`",
"Operator `<` is not supported for types `object` and `Literal[5]`",
],
);
assert_public_ty(&db, "src/a.py", "a", "bool");
assert_public_ty(&db, "src/a.py", "b", "bool");
assert_public_ty(&db, "src/a.py", "c", "Unknown");
// TODO: this should be `Unknown` but we don't check if __lt__ signature is valid for right
// operand type
assert_public_ty(&db, "src/a.py", "d", "bool");
Ok(())
}
#[test]
fn comparison_non_bool_returns() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
from __future__ import annotations
class A:
def __lt__(self, other) -> A: ...
class B:
def __lt__(self, other) -> B: ...
class C:
def __lt__(self, other) -> C: ...
a = A() < B() < C()
b = 0 < 1 < A() < 3
c = 10 < 0 < A() < B() < C()
"#,
)?;
// Walking through the example
// 1. A() < B() < C()
// 2. A() < B() and B() < C() - split in N comparison
// 3. A() and B() - evaluate outcome types
// 4. bool and bool - evaluate truthiness
// 5. A | B - union of "first true" types
assert_public_ty(&db, "src/a.py", "a", "A | B");
// Walking through the example
// 1. 0 < 1 < A() < 3
// 2. 0 < 1 and 1 < A() and A() < 3 - split in N comparison
// 3. True and bool and A - evaluate outcome types
// 4. True and bool and bool - evaluate truthiness
// 5. bool | A - union of "true" types
assert_public_ty(&db, "src/a.py", "b", "bool | A");
// Short-cicuit to False
assert_public_ty(&db, "src/a.py", "c", "Literal[False]");
Ok(())
}
#[test]
fn bytes_type() -> anyhow::Result<()> {
let mut db = setup_db();