[red-knot] Type inference for comparisons between arbitrary instances (#13903)

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Carl Meyer <carl@oddbird.net>
This commit is contained in:
cake-monotone 2024-10-27 03:19:56 +09:00 committed by GitHub
parent 35f007f17f
commit b6ffa51c16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 678 additions and 78 deletions

View file

@ -58,7 +58,7 @@ use crate::types::{
use crate::util::subscript::PythonSubscript;
use crate::Db;
use super::{KnownClass, UnionBuilder};
use super::{IterationOutcome, KnownClass, UnionBuilder};
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
@ -3101,16 +3101,26 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => perform_rich_comparison(
self.db,
left_class_ty,
right_class_ty,
RichCompareOperator::Lt,
),
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods
_ => Ok(Type::Todo),
},
(Type::Instance(left_class), Type::Instance(right_class)) => {
let rich_comparison =
|op| perform_rich_comparison(self.db, left_class, right_class, op);
let membership_test_comparison =
|op| perform_membership_test_comparison(self.db, left_class, right_class, op);
match op {
ast::CmpOp::Eq => rich_comparison(RichCompareOperator::Eq),
ast::CmpOp::NotEq => rich_comparison(RichCompareOperator::Ne),
ast::CmpOp::Lt => rich_comparison(RichCompareOperator::Lt),
ast::CmpOp::LtE => rich_comparison(RichCompareOperator::Le),
ast::CmpOp::Gt => rich_comparison(RichCompareOperator::Gt),
ast::CmpOp::GtE => rich_comparison(RichCompareOperator::Ge),
ast::CmpOp::In => membership_test_comparison(MembershipTestCompareOperator::In),
ast::CmpOp::NotIn => {
membership_test_comparison(MembershipTestCompareOperator::NotIn)
}
ast::CmpOp::Is => Ok(KnownClass::Bool.to_instance(self.db)),
ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)),
}
}
// TODO: handle more types
_ => match op {
ast::CmpOp::Is | ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)),
@ -3623,7 +3633,8 @@ impl From<RichCompareOperator> for ast::CmpOp {
}
impl RichCompareOperator {
const fn dunder_name(self) -> &'static str {
#[must_use]
const fn dunder(self) -> &'static str {
match self {
RichCompareOperator::Eq => "__eq__",
RichCompareOperator::Ne => "__ne__",
@ -3633,6 +3644,33 @@ impl RichCompareOperator {
RichCompareOperator::Ge => "__ge__",
}
}
#[must_use]
const fn reflect(self) -> Self {
match self {
RichCompareOperator::Eq => RichCompareOperator::Eq,
RichCompareOperator::Ne => RichCompareOperator::Ne,
RichCompareOperator::Lt => RichCompareOperator::Gt,
RichCompareOperator::Le => RichCompareOperator::Ge,
RichCompareOperator::Gt => RichCompareOperator::Lt,
RichCompareOperator::Ge => RichCompareOperator::Le,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MembershipTestCompareOperator {
In,
NotIn,
}
impl From<MembershipTestCompareOperator> for ast::CmpOp {
fn from(value: MembershipTestCompareOperator) -> Self {
match value {
MembershipTestCompareOperator::In => ast::CmpOp::In,
MembershipTestCompareOperator::NotIn => ast::CmpOp::NotIn,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -3716,41 +3754,99 @@ impl StringPartsCollector {
/// Rich comparison in Python are the operators `==`, `!=`, `<`, `<=`, `>`, and `>=`. Their
/// behaviour can be edited for classes by implementing corresponding dunder methods.
/// This function performs rich comparison between two instances and returns the resulting type.
/// This function performs rich comparison between two instances and returns the resulting type.
/// see `<https://docs.python.org/3/reference/datamodel.html#object.__lt__>`
fn perform_rich_comparison<'db>(
db: &'db dyn Db,
left: ClassType<'db>,
right: ClassType<'db>,
left_class: ClassType<'db>,
right_class: ClassType<'db>,
op: RichCompareOperator,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// The following resource has details about the rich comparison algorithm:
// https://snarky.ca/unravelling-rich-comparison-operators/
//
// TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the
// l.h.s.
// TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined
// TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be errored, currently bool)
let dunder = left.class_member(db, op.dunder_name());
if !dunder.is_unbound() {
// TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be None, currently bool)
return dunder
.call(db, &[Type::Instance(left), Type::Instance(right)])
.return_ty(db)
.ok_or_else(|| CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
});
let call_dunder =
|op: RichCompareOperator, left_class: ClassType<'db>, right_class: ClassType<'db>| {
left_class
.class_member(db, op.dunder())
.call(
db,
&[Type::Instance(left_class), Type::Instance(right_class)],
)
.return_ty(db)
};
// The reflected dunder has priority if the right-hand side is a strict subclass of the left-hand side.
if left_class != right_class && right_class.is_subclass_of(db, left_class) {
call_dunder(op.reflect(), right_class, left_class)
.or_else(|| call_dunder(op, left_class, right_class))
} else {
call_dunder(op, left_class, right_class)
.or_else(|| call_dunder(op.reflect(), right_class, left_class))
}
// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
Err(CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
.or_else(|| {
// When no appropriate method returns any value other than NotImplemented,
// the `==` and `!=` operators will fall back to `is` and `is not`, respectively.
// refer to `<https://docs.python.org/3/reference/datamodel.html#object.__eq__>`
if matches!(op, RichCompareOperator::Eq | RichCompareOperator::Ne) {
Some(KnownClass::Bool.to_instance(db))
} else {
None
}
})
.ok_or_else(|| CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left_class),
right_ty: Type::Instance(right_class),
})
}
/// Performs a membership test (`in` and `not in`) between two instances and returns the resulting type, or `None` if the test is unsupported.
/// The behavior can be customized in Python by implementing `__contains__`, `__iter__`, or `__getitem__` methods.
/// See `<https://docs.python.org/3/reference/datamodel.html#object.__contains__>`
/// and `<https://docs.python.org/3/reference/expressions.html#membership-test-details>`
fn perform_membership_test_comparison<'db>(
db: &'db dyn Db,
left_class: ClassType<'db>,
right_class: ClassType<'db>,
op: MembershipTestCompareOperator,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
let (left_instance, right_instance) = (Type::Instance(left_class), Type::Instance(right_class));
let contains_dunder = right_class.class_member(db, "__contains__");
let compare_result_opt = if contains_dunder.is_unbound() {
// iteration-based membership test
match right_instance.iterate(db) {
IterationOutcome::Iterable { .. } => Some(KnownClass::Bool.to_instance(db)),
IterationOutcome::NotIterable { .. } => None,
}
} else {
// If `__contains__` is available, it is used directly for the membership test.
contains_dunder
.call(db, &[right_instance, left_instance])
.return_ty(db)
};
compare_result_opt
.map(|ty| {
if matches!(ty, Type::Todo) {
return Type::Todo;
}
match op {
MembershipTestCompareOperator::In => ty.bool(db).into_type(db),
MembershipTestCompareOperator::NotIn => ty.bool(db).negate().into_type(db),
}
})
.ok_or_else(|| CompareUnsupportedError {
op: op.into(),
left_ty: left_instance,
right_ty: right_instance,
})
}
#[cfg(test)]