[red-knot] Enhancing Diagnostics for Compare Expression Inference (#13819)

## Summary

- Refactored comparison type inference functions in `infer.rs`: Changed
the return type from `Option` to `Result` to lay the groundwork for
providing more detailed diagnostics.
- Updated diagnostic messages.

This is a small step toward improving diagnostics in the future.

Please refer to #13787

## Test Plan

mdtest included!

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
cake-monotone 2024-10-20 03:17:01 +09:00 committed by GitHub
parent 55bccf6680
commit fb66f715f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 120 additions and 64 deletions

View file

@ -7,11 +7,21 @@ reveal_type(a) # revealed: bool
b = 0 not in 10 # error: "Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`" b = 0 not in 10 # error: "Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`"
reveal_type(b) # revealed: bool reveal_type(b) # revealed: bool
c = object() < 5 # error: "Operator `<` is not supported for types `object` and `Literal[5]`" c = object() < 5 # error: "Operator `<` is not supported for types `object` and `int`"
reveal_type(c) # revealed: Unknown reveal_type(c) # revealed: Unknown
# TODO should error, need to check if __lt__ signature is valid for right operand # TODO should error, need to check if __lt__ signature is valid for right operand
d = 5 < object() d = 5 < object()
# TODO: should be `Unknown` # TODO: should be `Unknown`
reveal_type(d) # revealed: bool reveal_type(d) # revealed: bool
int_literal_or_str_literal = 1 if flag else "foo"
# error: "Operator `in` is not supported for types `Literal[42]` and `Literal[1]`, in comparing `Literal[42]` with `Literal[1] | Literal["foo"]`"
e = 42 in int_literal_or_str_literal
reveal_type(e) # revealed: bool
# TODO: should error, need to check if __lt__ signature is valid for right operand
# error may be "Operator `<` is not supported for types `int` and `str`, in comparing `tuple[Literal[1], Literal[2]]` with `tuple[Literal[1], Literal["hello"]]`
f = (1, 2) < (1, "hello")
reveal_type(f) # revealed: @Todo
``` ```

View file

@ -2776,18 +2776,28 @@ impl<'db> TypeInferenceBuilder<'db> {
let right_ty = self.expression_ty(right); let right_ty = self.expression_ty(right);
self.infer_binary_type_comparison(left_ty, *op, right_ty) self.infer_binary_type_comparison(left_ty, *op, right_ty)
.unwrap_or_else(|| { .unwrap_or_else(|error| {
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome) // Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
self.add_diagnostic( self.add_diagnostic(
AnyNodeRef::ExprCompare(compare), AnyNodeRef::ExprCompare(compare),
"operator-unsupported", "operator-unsupported",
format_args!( format_args!(
"Operator `{}` is not supported for types `{}` and `{}`", "Operator `{}` is not supported for types `{}` and `{}`{}",
op, error.op,
error.left_ty.display(self.db),
error.right_ty.display(self.db),
if (left_ty, right_ty) == (error.left_ty, error.right_ty) {
String::new()
} else {
format!(
", in comparing `{}` with `{}`",
left_ty.display(self.db), left_ty.display(self.db),
right_ty.display(self.db) right_ty.display(self.db)
)
}
), ),
); );
match op { match op {
// `in, not in, is, is not` always return bool instances // `in, not in, is, is not` always return bool instances
ast::CmpOp::In ast::CmpOp::In
@ -2814,7 +2824,7 @@ impl<'db> TypeInferenceBuilder<'db> {
left: Type<'db>, left: Type<'db>,
op: ast::CmpOp, op: ast::CmpOp,
right: Type<'db>, right: Type<'db>,
) -> Option<Type<'db>> { ) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// Note: identity (is, is not) for equal builtin types is unreliable and not part of the // Note: identity (is, is not) for equal builtin types is unreliable and not part of the
// language spec. // language spec.
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal // - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal
@ -2825,39 +2835,43 @@ impl<'db> TypeInferenceBuilder<'db> {
for element in union.elements(self.db) { for element in union.elements(self.db) {
builder = builder.add(self.infer_binary_type_comparison(*element, op, other)?); builder = builder.add(self.infer_binary_type_comparison(*element, op, other)?);
} }
Some(builder.build()) Ok(builder.build())
} }
(other, Type::Union(union)) => { (other, Type::Union(union)) => {
let mut builder = UnionBuilder::new(self.db); let mut builder = UnionBuilder::new(self.db);
for element in union.elements(self.db) { for element in union.elements(self.db) {
builder = builder.add(self.infer_binary_type_comparison(other, op, *element)?); builder = builder.add(self.infer_binary_type_comparison(other, op, *element)?);
} }
Some(builder.build()) Ok(builder.build())
} }
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op { (Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)), ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)), ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)), ast::CmpOp::Lt => Ok(Type::BooleanLiteral(n < m)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)), ast::CmpOp::LtE => Ok(Type::BooleanLiteral(n <= m)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)), ast::CmpOp::Gt => Ok(Type::BooleanLiteral(n > m)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)), ast::CmpOp::GtE => Ok(Type::BooleanLiteral(n >= m)),
ast::CmpOp::Is => { ast::CmpOp::Is => {
if n == m { if n == m {
Some(KnownClass::Bool.to_instance(self.db)) Ok(KnownClass::Bool.to_instance(self.db))
} else { } else {
Some(Type::BooleanLiteral(false)) Ok(Type::BooleanLiteral(false))
} }
} }
ast::CmpOp::IsNot => { ast::CmpOp::IsNot => {
if n == m { if n == m {
Some(KnownClass::Bool.to_instance(self.db)) Ok(KnownClass::Bool.to_instance(self.db))
} else { } else {
Some(Type::BooleanLiteral(true)) Ok(Type::BooleanLiteral(true))
} }
} }
// Undefined for (int, int) // Undefined for (int, int)
ast::CmpOp::In | ast::CmpOp::NotIn => None, ast::CmpOp::In | ast::CmpOp::NotIn => Err(CompareUnsupportedError {
op,
left_ty: left,
right_ty: right,
}),
}, },
(Type::IntLiteral(_), Type::Instance(_)) => { (Type::IntLiteral(_), Type::Instance(_)) => {
self.infer_binary_type_comparison(KnownClass::Int.to_instance(self.db), op, right) self.infer_binary_type_comparison(KnownClass::Int.to_instance(self.db), op, right)
@ -2888,26 +2902,26 @@ impl<'db> TypeInferenceBuilder<'db> {
let s1 = salsa_s1.value(self.db); let s1 = salsa_s1.value(self.db);
let s2 = salsa_s2.value(self.db); let s2 = salsa_s2.value(self.db);
match op { match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(s1 == s2)), ast::CmpOp::Eq => Ok(Type::BooleanLiteral(s1 == s2)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(s1 != s2)), ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(s1 != s2)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(s1 < s2)), ast::CmpOp::Lt => Ok(Type::BooleanLiteral(s1 < s2)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(s1 <= s2)), ast::CmpOp::LtE => Ok(Type::BooleanLiteral(s1 <= s2)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(s1 > s2)), ast::CmpOp::Gt => Ok(Type::BooleanLiteral(s1 > s2)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(s1 >= s2)), ast::CmpOp::GtE => Ok(Type::BooleanLiteral(s1 >= s2)),
ast::CmpOp::In => Some(Type::BooleanLiteral(s2.contains(s1.as_ref()))), ast::CmpOp::In => Ok(Type::BooleanLiteral(s2.contains(s1.as_ref()))),
ast::CmpOp::NotIn => Some(Type::BooleanLiteral(!s2.contains(s1.as_ref()))), ast::CmpOp::NotIn => Ok(Type::BooleanLiteral(!s2.contains(s1.as_ref()))),
ast::CmpOp::Is => { ast::CmpOp::Is => {
if s1 == s2 { if s1 == s2 {
Some(KnownClass::Bool.to_instance(self.db)) Ok(KnownClass::Bool.to_instance(self.db))
} else { } else {
Some(Type::BooleanLiteral(false)) Ok(Type::BooleanLiteral(false))
} }
} }
ast::CmpOp::IsNot => { ast::CmpOp::IsNot => {
if s1 == s2 { if s1 == s2 {
Some(KnownClass::Bool.to_instance(self.db)) Ok(KnownClass::Bool.to_instance(self.db))
} else { } else {
Some(Type::BooleanLiteral(true)) Ok(Type::BooleanLiteral(true))
} }
} }
} }
@ -2930,30 +2944,30 @@ impl<'db> TypeInferenceBuilder<'db> {
let b1 = &**salsa_b1.value(self.db); let b1 = &**salsa_b1.value(self.db);
let b2 = &**salsa_b2.value(self.db); let b2 = &**salsa_b2.value(self.db);
match op { match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(b1 == b2)), ast::CmpOp::Eq => Ok(Type::BooleanLiteral(b1 == b2)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(b1 != b2)), ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(b1 != b2)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(b1 < b2)), ast::CmpOp::Lt => Ok(Type::BooleanLiteral(b1 < b2)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(b1 <= b2)), ast::CmpOp::LtE => Ok(Type::BooleanLiteral(b1 <= b2)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(b1 > b2)), ast::CmpOp::Gt => Ok(Type::BooleanLiteral(b1 > b2)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(b1 >= b2)), ast::CmpOp::GtE => Ok(Type::BooleanLiteral(b1 >= b2)),
ast::CmpOp::In => { ast::CmpOp::In => {
Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some())) Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some()))
} }
ast::CmpOp::NotIn => { ast::CmpOp::NotIn => {
Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none())) Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none()))
} }
ast::CmpOp::Is => { ast::CmpOp::Is => {
if b1 == b2 { if b1 == b2 {
Some(KnownClass::Bool.to_instance(self.db)) Ok(KnownClass::Bool.to_instance(self.db))
} else { } else {
Some(Type::BooleanLiteral(false)) Ok(Type::BooleanLiteral(false))
} }
} }
ast::CmpOp::IsNot => { ast::CmpOp::IsNot => {
if b1 == b2 { if b1 == b2 {
Some(KnownClass::Bool.to_instance(self.db)) Ok(KnownClass::Bool.to_instance(self.db))
} else { } else {
Some(Type::BooleanLiteral(true)) Ok(Type::BooleanLiteral(true))
} }
} }
} }
@ -2991,7 +3005,7 @@ impl<'db> TypeInferenceBuilder<'db> {
).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); ).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");
match eq_result { match eq_result {
Type::Todo => return Some(Type::Todo), Type::Todo => return Ok(Type::Todo),
ty => match ty.bool(self.db) { ty => match ty.bool(self.db) {
Truthiness::AlwaysTrue => eq_count += 1, Truthiness::AlwaysTrue => eq_count += 1,
Truthiness::AlwaysFalse => not_eq_count += 1, Truthiness::AlwaysFalse => not_eq_count += 1,
@ -3001,11 +3015,11 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
if eq_count >= 1 { if eq_count >= 1 {
Some(Type::BooleanLiteral(op.is_in())) Ok(Type::BooleanLiteral(op.is_in()))
} else if not_eq_count == rhs_elements.len() { } else if not_eq_count == rhs_elements.len() {
Some(Type::BooleanLiteral(op.is_not_in())) Ok(Type::BooleanLiteral(op.is_not_in()))
} else { } else {
Some(KnownClass::Bool.to_instance(self.db)) Ok(KnownClass::Bool.to_instance(self.db))
} }
} }
ast::CmpOp::Is | ast::CmpOp::IsNot => { ast::CmpOp::Is | ast::CmpOp::IsNot => {
@ -3016,7 +3030,7 @@ impl<'db> TypeInferenceBuilder<'db> {
"infer_binary_type_comparison should never return None for `CmpOp::Eq`", "infer_binary_type_comparison should never return None for `CmpOp::Eq`",
); );
Some(match eq_result { Ok(match eq_result {
Type::Todo => Type::Todo, Type::Todo => Type::Todo,
ty => match ty.bool(self.db) { ty => match ty.bool(self.db) {
Truthiness::AlwaysFalse => Type::BooleanLiteral(op.is_is_not()), Truthiness::AlwaysFalse => Type::BooleanLiteral(op.is_is_not()),
@ -3029,16 +3043,19 @@ impl<'db> TypeInferenceBuilder<'db> {
// Lookup the rich comparison `__dunder__` methods on instances // Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op { (Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => { ast::CmpOp::Lt => perform_rich_comparison(
perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__") self.db,
} left_class_ty,
right_class_ty,
RichCompareOperator::Lt,
),
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods // TODO: implement mapping from `ast::CmpOp` to rich comparison methods
_ => Some(Type::Todo), _ => Ok(Type::Todo),
}, },
// TODO: handle more types // TODO: handle more types
_ => match op { _ => match op {
ast::CmpOp::Is | ast::CmpOp::IsNot => Some(KnownClass::Bool.to_instance(self.db)), ast::CmpOp::Is | ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)),
_ => Some(Type::Todo), _ => Ok(Type::Todo),
}, },
} }
} }
@ -3053,7 +3070,7 @@ impl<'db> TypeInferenceBuilder<'db> {
left: &[Type<'db>], left: &[Type<'db>],
op: RichCompareOperator, op: RichCompareOperator,
right: &[Type<'db>], right: &[Type<'db>],
) -> Option<Type<'db>> { ) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// Compare paired elements from left and right slices // Compare paired elements from left and right slices
for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) { for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) {
let eq_result = self let eq_result = self
@ -3062,7 +3079,7 @@ impl<'db> TypeInferenceBuilder<'db> {
match eq_result { match eq_result {
// If propagation is required, return the result as is // If propagation is required, return the result as is
Type::Todo => return Some(Type::Todo), Type::Todo => return Ok(Type::Todo),
ty => match ty.bool(self.db) { ty => match ty.bool(self.db) {
// Types are equal, continue to the next pair // Types are equal, continue to the next pair
Truthiness::AlwaysTrue => continue, Truthiness::AlwaysTrue => continue,
@ -3072,7 +3089,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
// If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral. // If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral.
// In this case, we simply return a bool instance. // In this case, we simply return a bool instance.
Truthiness::Ambiguous => return Some(KnownClass::Bool.to_instance(self.db)), Truthiness::Ambiguous => return Ok(KnownClass::Bool.to_instance(self.db)),
}, },
} }
} }
@ -3082,7 +3099,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// We return a comparison of the slice lengths based on the operator. // We return a comparison of the slice lengths based on the operator.
let (left_len, right_len) = (left.len(), right.len()); let (left_len, right_len) = (left.len(), right.len());
Some(Type::BooleanLiteral(match op { Ok(Type::BooleanLiteral(match op {
RichCompareOperator::Eq => left_len == right_len, RichCompareOperator::Eq => left_len == right_len,
RichCompareOperator::Ne => left_len != right_len, RichCompareOperator::Ne => left_len != right_len,
RichCompareOperator::Lt => left_len < right_len, RichCompareOperator::Lt => left_len < right_len,
@ -3556,6 +3573,26 @@ impl From<RichCompareOperator> for ast::CmpOp {
} }
} }
impl RichCompareOperator {
const fn dunder_name(self) -> &'static str {
match self {
RichCompareOperator::Eq => "__eq__",
RichCompareOperator::Ne => "__ne__",
RichCompareOperator::Lt => "__lt__",
RichCompareOperator::Le => "__le__",
RichCompareOperator::Gt => "__gt__",
RichCompareOperator::Ge => "__ge__",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct CompareUnsupportedError<'db> {
op: ast::CmpOp,
left_ty: Type<'db>,
right_ty: Type<'db>,
}
fn format_import_from_module(level: u32, module: Option<&str>) -> String { fn format_import_from_module(level: u32, module: Option<&str>) -> String {
format!( format!(
"{}{}", "{}{}",
@ -3636,8 +3673,8 @@ fn perform_rich_comparison<'db>(
db: &'db dyn Db, db: &'db dyn Db,
left: ClassType<'db>, left: ClassType<'db>,
right: ClassType<'db>, right: ClassType<'db>,
dunder_name: &str, op: RichCompareOperator,
) -> Option<Type<'db>> { ) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// The following resource has details about the rich comparison algorithm: // The following resource has details about the rich comparison algorithm:
// https://snarky.ca/unravelling-rich-comparison-operators/ // https://snarky.ca/unravelling-rich-comparison-operators/
// //
@ -3645,17 +3682,26 @@ fn perform_rich_comparison<'db>(
// l.h.s. // l.h.s.
// TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined // TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined
let dunder = left.class_member(db, dunder_name); let dunder = left.class_member(db, op.dunder_name());
if !dunder.is_unbound() { if !dunder.is_unbound() {
// TODO: this currently gives the return type even if the arg types are invalid // TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be None, currently bool) // (e.g. int.__lt__ with string instance should be None, currently bool)
return dunder return dunder
.call(db, &[Type::Instance(left), Type::Instance(right)]) .call(db, &[Type::Instance(left), Type::Instance(right)])
.return_ty(db); .return_ty(db)
.ok_or_else(|| CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
});
} }
// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=) // TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
None Err(CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
})
} }
#[cfg(test)] #[cfg(test)]