mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-01 17:32:25 +00:00
[red-knot] Enhancing Diagnostics for Compare Expression Inference (#13819)
## Summary - Refactored comparison type inference functions in `infer.rs`: Changed the return type from `Option` to `Result` to lay the groundwork for providing more detailed diagnostics. - Updated diagnostic messages. This is a small step toward improving diagnostics in the future. Please refer to #13787 ## Test Plan mdtest included! --------- Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
parent
55bccf6680
commit
fb66f715f3
2 changed files with 120 additions and 64 deletions
|
@ -7,11 +7,21 @@ reveal_type(a) # revealed: bool
|
|||
b = 0 not in 10 # error: "Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`"
|
||||
reveal_type(b) # revealed: bool
|
||||
|
||||
c = object() < 5 # error: "Operator `<` is not supported for types `object` and `Literal[5]`"
|
||||
c = object() < 5 # error: "Operator `<` is not supported for types `object` and `int`"
|
||||
reveal_type(c) # revealed: Unknown
|
||||
|
||||
# TODO should error, need to check if __lt__ signature is valid for right operand
|
||||
d = 5 < object()
|
||||
# TODO: should be `Unknown`
|
||||
reveal_type(d) # revealed: bool
|
||||
|
||||
int_literal_or_str_literal = 1 if flag else "foo"
|
||||
# error: "Operator `in` is not supported for types `Literal[42]` and `Literal[1]`, in comparing `Literal[42]` with `Literal[1] | Literal["foo"]`"
|
||||
e = 42 in int_literal_or_str_literal
|
||||
reveal_type(e) # revealed: bool
|
||||
|
||||
# TODO: should error, need to check if __lt__ signature is valid for right operand
|
||||
# error may be "Operator `<` is not supported for types `int` and `str`, in comparing `tuple[Literal[1], Literal[2]]` with `tuple[Literal[1], Literal["hello"]]`
|
||||
f = (1, 2) < (1, "hello")
|
||||
reveal_type(f) # revealed: @Todo
|
||||
```
|
||||
|
|
|
@ -2776,18 +2776,28 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
let right_ty = self.expression_ty(right);
|
||||
|
||||
self.infer_binary_type_comparison(left_ty, *op, right_ty)
|
||||
.unwrap_or_else(|| {
|
||||
.unwrap_or_else(|error| {
|
||||
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
|
||||
self.add_diagnostic(
|
||||
AnyNodeRef::ExprCompare(compare),
|
||||
"operator-unsupported",
|
||||
format_args!(
|
||||
"Operator `{}` is not supported for types `{}` and `{}`",
|
||||
op,
|
||||
left_ty.display(self.db),
|
||||
right_ty.display(self.db)
|
||||
"Operator `{}` is not supported for types `{}` and `{}`{}",
|
||||
error.op,
|
||||
error.left_ty.display(self.db),
|
||||
error.right_ty.display(self.db),
|
||||
if (left_ty, right_ty) == (error.left_ty, error.right_ty) {
|
||||
String::new()
|
||||
} else {
|
||||
format!(
|
||||
", in comparing `{}` with `{}`",
|
||||
left_ty.display(self.db),
|
||||
right_ty.display(self.db)
|
||||
)
|
||||
}
|
||||
),
|
||||
);
|
||||
|
||||
match op {
|
||||
// `in, not in, is, is not` always return bool instances
|
||||
ast::CmpOp::In
|
||||
|
@ -2814,7 +2824,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
left: Type<'db>,
|
||||
op: ast::CmpOp,
|
||||
right: Type<'db>,
|
||||
) -> Option<Type<'db>> {
|
||||
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
|
||||
// Note: identity (is, is not) for equal builtin types is unreliable and not part of the
|
||||
// language spec.
|
||||
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal
|
||||
|
@ -2825,39 +2835,43 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
for element in union.elements(self.db) {
|
||||
builder = builder.add(self.infer_binary_type_comparison(*element, op, other)?);
|
||||
}
|
||||
Some(builder.build())
|
||||
Ok(builder.build())
|
||||
}
|
||||
(other, Type::Union(union)) => {
|
||||
let mut builder = UnionBuilder::new(self.db);
|
||||
for element in union.elements(self.db) {
|
||||
builder = builder.add(self.infer_binary_type_comparison(other, op, *element)?);
|
||||
}
|
||||
Some(builder.build())
|
||||
Ok(builder.build())
|
||||
}
|
||||
|
||||
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
|
||||
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)),
|
||||
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)),
|
||||
ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)),
|
||||
ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)),
|
||||
ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)),
|
||||
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
|
||||
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
|
||||
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),
|
||||
ast::CmpOp::Lt => Ok(Type::BooleanLiteral(n < m)),
|
||||
ast::CmpOp::LtE => Ok(Type::BooleanLiteral(n <= m)),
|
||||
ast::CmpOp::Gt => Ok(Type::BooleanLiteral(n > m)),
|
||||
ast::CmpOp::GtE => Ok(Type::BooleanLiteral(n >= m)),
|
||||
ast::CmpOp::Is => {
|
||||
if n == m {
|
||||
Some(KnownClass::Bool.to_instance(self.db))
|
||||
Ok(KnownClass::Bool.to_instance(self.db))
|
||||
} else {
|
||||
Some(Type::BooleanLiteral(false))
|
||||
Ok(Type::BooleanLiteral(false))
|
||||
}
|
||||
}
|
||||
ast::CmpOp::IsNot => {
|
||||
if n == m {
|
||||
Some(KnownClass::Bool.to_instance(self.db))
|
||||
Ok(KnownClass::Bool.to_instance(self.db))
|
||||
} else {
|
||||
Some(Type::BooleanLiteral(true))
|
||||
Ok(Type::BooleanLiteral(true))
|
||||
}
|
||||
}
|
||||
// Undefined for (int, int)
|
||||
ast::CmpOp::In | ast::CmpOp::NotIn => None,
|
||||
ast::CmpOp::In | ast::CmpOp::NotIn => Err(CompareUnsupportedError {
|
||||
op,
|
||||
left_ty: left,
|
||||
right_ty: right,
|
||||
}),
|
||||
},
|
||||
(Type::IntLiteral(_), Type::Instance(_)) => {
|
||||
self.infer_binary_type_comparison(KnownClass::Int.to_instance(self.db), op, right)
|
||||
|
@ -2888,26 +2902,26 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
let s1 = salsa_s1.value(self.db);
|
||||
let s2 = salsa_s2.value(self.db);
|
||||
match op {
|
||||
ast::CmpOp::Eq => Some(Type::BooleanLiteral(s1 == s2)),
|
||||
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(s1 != s2)),
|
||||
ast::CmpOp::Lt => Some(Type::BooleanLiteral(s1 < s2)),
|
||||
ast::CmpOp::LtE => Some(Type::BooleanLiteral(s1 <= s2)),
|
||||
ast::CmpOp::Gt => Some(Type::BooleanLiteral(s1 > s2)),
|
||||
ast::CmpOp::GtE => Some(Type::BooleanLiteral(s1 >= s2)),
|
||||
ast::CmpOp::In => Some(Type::BooleanLiteral(s2.contains(s1.as_ref()))),
|
||||
ast::CmpOp::NotIn => Some(Type::BooleanLiteral(!s2.contains(s1.as_ref()))),
|
||||
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(s1 == s2)),
|
||||
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(s1 != s2)),
|
||||
ast::CmpOp::Lt => Ok(Type::BooleanLiteral(s1 < s2)),
|
||||
ast::CmpOp::LtE => Ok(Type::BooleanLiteral(s1 <= s2)),
|
||||
ast::CmpOp::Gt => Ok(Type::BooleanLiteral(s1 > s2)),
|
||||
ast::CmpOp::GtE => Ok(Type::BooleanLiteral(s1 >= s2)),
|
||||
ast::CmpOp::In => Ok(Type::BooleanLiteral(s2.contains(s1.as_ref()))),
|
||||
ast::CmpOp::NotIn => Ok(Type::BooleanLiteral(!s2.contains(s1.as_ref()))),
|
||||
ast::CmpOp::Is => {
|
||||
if s1 == s2 {
|
||||
Some(KnownClass::Bool.to_instance(self.db))
|
||||
Ok(KnownClass::Bool.to_instance(self.db))
|
||||
} else {
|
||||
Some(Type::BooleanLiteral(false))
|
||||
Ok(Type::BooleanLiteral(false))
|
||||
}
|
||||
}
|
||||
ast::CmpOp::IsNot => {
|
||||
if s1 == s2 {
|
||||
Some(KnownClass::Bool.to_instance(self.db))
|
||||
Ok(KnownClass::Bool.to_instance(self.db))
|
||||
} else {
|
||||
Some(Type::BooleanLiteral(true))
|
||||
Ok(Type::BooleanLiteral(true))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2930,30 +2944,30 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
let b1 = &**salsa_b1.value(self.db);
|
||||
let b2 = &**salsa_b2.value(self.db);
|
||||
match op {
|
||||
ast::CmpOp::Eq => Some(Type::BooleanLiteral(b1 == b2)),
|
||||
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(b1 != b2)),
|
||||
ast::CmpOp::Lt => Some(Type::BooleanLiteral(b1 < b2)),
|
||||
ast::CmpOp::LtE => Some(Type::BooleanLiteral(b1 <= b2)),
|
||||
ast::CmpOp::Gt => Some(Type::BooleanLiteral(b1 > b2)),
|
||||
ast::CmpOp::GtE => Some(Type::BooleanLiteral(b1 >= b2)),
|
||||
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(b1 == b2)),
|
||||
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(b1 != b2)),
|
||||
ast::CmpOp::Lt => Ok(Type::BooleanLiteral(b1 < b2)),
|
||||
ast::CmpOp::LtE => Ok(Type::BooleanLiteral(b1 <= b2)),
|
||||
ast::CmpOp::Gt => Ok(Type::BooleanLiteral(b1 > b2)),
|
||||
ast::CmpOp::GtE => Ok(Type::BooleanLiteral(b1 >= b2)),
|
||||
ast::CmpOp::In => {
|
||||
Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some()))
|
||||
Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some()))
|
||||
}
|
||||
ast::CmpOp::NotIn => {
|
||||
Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none()))
|
||||
Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none()))
|
||||
}
|
||||
ast::CmpOp::Is => {
|
||||
if b1 == b2 {
|
||||
Some(KnownClass::Bool.to_instance(self.db))
|
||||
Ok(KnownClass::Bool.to_instance(self.db))
|
||||
} else {
|
||||
Some(Type::BooleanLiteral(false))
|
||||
Ok(Type::BooleanLiteral(false))
|
||||
}
|
||||
}
|
||||
ast::CmpOp::IsNot => {
|
||||
if b1 == b2 {
|
||||
Some(KnownClass::Bool.to_instance(self.db))
|
||||
Ok(KnownClass::Bool.to_instance(self.db))
|
||||
} else {
|
||||
Some(Type::BooleanLiteral(true))
|
||||
Ok(Type::BooleanLiteral(true))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2991,7 +3005,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");
|
||||
|
||||
match eq_result {
|
||||
Type::Todo => return Some(Type::Todo),
|
||||
Type::Todo => return Ok(Type::Todo),
|
||||
ty => match ty.bool(self.db) {
|
||||
Truthiness::AlwaysTrue => eq_count += 1,
|
||||
Truthiness::AlwaysFalse => not_eq_count += 1,
|
||||
|
@ -3001,11 +3015,11 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
|
||||
if eq_count >= 1 {
|
||||
Some(Type::BooleanLiteral(op.is_in()))
|
||||
Ok(Type::BooleanLiteral(op.is_in()))
|
||||
} else if not_eq_count == rhs_elements.len() {
|
||||
Some(Type::BooleanLiteral(op.is_not_in()))
|
||||
Ok(Type::BooleanLiteral(op.is_not_in()))
|
||||
} else {
|
||||
Some(KnownClass::Bool.to_instance(self.db))
|
||||
Ok(KnownClass::Bool.to_instance(self.db))
|
||||
}
|
||||
}
|
||||
ast::CmpOp::Is | ast::CmpOp::IsNot => {
|
||||
|
@ -3016,7 +3030,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
"infer_binary_type_comparison should never return None for `CmpOp::Eq`",
|
||||
);
|
||||
|
||||
Some(match eq_result {
|
||||
Ok(match eq_result {
|
||||
Type::Todo => Type::Todo,
|
||||
ty => match ty.bool(self.db) {
|
||||
Truthiness::AlwaysFalse => Type::BooleanLiteral(op.is_is_not()),
|
||||
|
@ -3029,16 +3043,19 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
|
||||
// Lookup the rich comparison `__dunder__` methods on instances
|
||||
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
|
||||
ast::CmpOp::Lt => {
|
||||
perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__")
|
||||
}
|
||||
ast::CmpOp::Lt => perform_rich_comparison(
|
||||
self.db,
|
||||
left_class_ty,
|
||||
right_class_ty,
|
||||
RichCompareOperator::Lt,
|
||||
),
|
||||
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods
|
||||
_ => Some(Type::Todo),
|
||||
_ => Ok(Type::Todo),
|
||||
},
|
||||
// TODO: handle more types
|
||||
_ => match op {
|
||||
ast::CmpOp::Is | ast::CmpOp::IsNot => Some(KnownClass::Bool.to_instance(self.db)),
|
||||
_ => Some(Type::Todo),
|
||||
ast::CmpOp::Is | ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)),
|
||||
_ => Ok(Type::Todo),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -3053,7 +3070,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
left: &[Type<'db>],
|
||||
op: RichCompareOperator,
|
||||
right: &[Type<'db>],
|
||||
) -> Option<Type<'db>> {
|
||||
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
|
||||
// Compare paired elements from left and right slices
|
||||
for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) {
|
||||
let eq_result = self
|
||||
|
@ -3062,7 +3079,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
|
||||
match eq_result {
|
||||
// If propagation is required, return the result as is
|
||||
Type::Todo => return Some(Type::Todo),
|
||||
Type::Todo => return Ok(Type::Todo),
|
||||
ty => match ty.bool(self.db) {
|
||||
// Types are equal, continue to the next pair
|
||||
Truthiness::AlwaysTrue => continue,
|
||||
|
@ -3072,7 +3089,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
// If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral.
|
||||
// In this case, we simply return a bool instance.
|
||||
Truthiness::Ambiguous => return Some(KnownClass::Bool.to_instance(self.db)),
|
||||
Truthiness::Ambiguous => return Ok(KnownClass::Bool.to_instance(self.db)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -3082,7 +3099,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
// We return a comparison of the slice lengths based on the operator.
|
||||
let (left_len, right_len) = (left.len(), right.len());
|
||||
|
||||
Some(Type::BooleanLiteral(match op {
|
||||
Ok(Type::BooleanLiteral(match op {
|
||||
RichCompareOperator::Eq => left_len == right_len,
|
||||
RichCompareOperator::Ne => left_len != right_len,
|
||||
RichCompareOperator::Lt => left_len < right_len,
|
||||
|
@ -3556,6 +3573,26 @@ impl From<RichCompareOperator> for ast::CmpOp {
|
|||
}
|
||||
}
|
||||
|
||||
impl RichCompareOperator {
|
||||
const fn dunder_name(self) -> &'static str {
|
||||
match self {
|
||||
RichCompareOperator::Eq => "__eq__",
|
||||
RichCompareOperator::Ne => "__ne__",
|
||||
RichCompareOperator::Lt => "__lt__",
|
||||
RichCompareOperator::Le => "__le__",
|
||||
RichCompareOperator::Gt => "__gt__",
|
||||
RichCompareOperator::Ge => "__ge__",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct CompareUnsupportedError<'db> {
|
||||
op: ast::CmpOp,
|
||||
left_ty: Type<'db>,
|
||||
right_ty: Type<'db>,
|
||||
}
|
||||
|
||||
fn format_import_from_module(level: u32, module: Option<&str>) -> String {
|
||||
format!(
|
||||
"{}{}",
|
||||
|
@ -3636,8 +3673,8 @@ fn perform_rich_comparison<'db>(
|
|||
db: &'db dyn Db,
|
||||
left: ClassType<'db>,
|
||||
right: ClassType<'db>,
|
||||
dunder_name: &str,
|
||||
) -> Option<Type<'db>> {
|
||||
op: RichCompareOperator,
|
||||
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
|
||||
// The following resource has details about the rich comparison algorithm:
|
||||
// https://snarky.ca/unravelling-rich-comparison-operators/
|
||||
//
|
||||
|
@ -3645,17 +3682,26 @@ fn perform_rich_comparison<'db>(
|
|||
// l.h.s.
|
||||
// TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined
|
||||
|
||||
let dunder = left.class_member(db, dunder_name);
|
||||
let dunder = left.class_member(db, op.dunder_name());
|
||||
if !dunder.is_unbound() {
|
||||
// TODO: this currently gives the return type even if the arg types are invalid
|
||||
// (e.g. int.__lt__ with string instance should be None, currently bool)
|
||||
return dunder
|
||||
.call(db, &[Type::Instance(left), Type::Instance(right)])
|
||||
.return_ty(db);
|
||||
.return_ty(db)
|
||||
.ok_or_else(|| CompareUnsupportedError {
|
||||
op: op.into(),
|
||||
left_ty: Type::Instance(left),
|
||||
right_ty: Type::Instance(right),
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
|
||||
None
|
||||
Err(CompareUnsupportedError {
|
||||
op: op.into(),
|
||||
left_ty: Type::Instance(left),
|
||||
right_ty: Type::Instance(right),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue