[red-knot] Type inference for comparisons between arbitrary instances (#13903)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz (push) Blocked by required conditions
CI / Fuzz the parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / benchmarks (push) Blocked by required conditions

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Carl Meyer <carl@oddbird.net>
This commit is contained in:
cake-monotone 2024-10-27 03:19:56 +09:00 committed by GitHub
parent 35f007f17f
commit b6ffa51c16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 678 additions and 78 deletions

View file

@ -0,0 +1,165 @@
# Comparison: Membership Test
In Python, the term "membership test operators" refers to the operators
`in` and `not in`. To customize their behavior, classes can implement one of
the special methods `__contains__`, `__iter__`, or `__getitem__`.
For references, see:
- <https://docs.python.org/3/reference/expressions.html#membership-test-details>
- <https://docs.python.org/3/reference/datamodel.html#object.__contains__>
- <https://snarky.ca/unravelling-membership-testing/>
## Implements `__contains__`
Classes can support membership tests by implementing the `__contains__` method:
```py
class A:
def __contains__(self, item: str) -> bool:
return True
reveal_type("hello" in A()) # revealed: bool
reveal_type("hello" not in A()) # revealed: bool
# TODO: should emit diagnostic, need to check arg type, will fail
reveal_type(42 in A()) # revealed: bool
reveal_type(42 not in A()) # revealed: bool
```
## Implements `__iter__`
Classes that don't implement `__contains__`, but do implement `__iter__`, also
support containment checks; the needle will be sought in their iterated items:
```py
class StringIterator:
def __next__(self) -> str:
return "foo"
class A:
def __iter__(self) -> StringIterator:
return StringIterator()
reveal_type("hello" in A()) # revealed: bool
reveal_type("hello" not in A()) # revealed: bool
reveal_type(42 in A()) # revealed: bool
reveal_type(42 not in A()) # revealed: bool
```
## Implements `__getitems__`
The final fallback is to implement `__getitem__` for integer keys. Python will
call `__getitem__` with `0`, `1`, `2`... until either the needle is found
(leading the membership test to evaluate to `True`) or `__getitem__` raises
`IndexError` (the raised exception is swallowed, but results in the membership
test evaluating to `False`).
```py
class A:
def __getitem__(self, key: int) -> str:
return "foo"
reveal_type("hello" in A()) # revealed: bool
reveal_type("hello" not in A()) # revealed: bool
reveal_type(42 in A()) # revealed: bool
reveal_type(42 not in A()) # revealed: bool
```
## Wrong Return Type
Python coerces the results of containment checks to `bool`, even if `__contains__`
returns a non-bool:
```py
class A:
def __contains__(self, item: str) -> str:
return "foo"
reveal_type("hello" in A()) # revealed: bool
reveal_type("hello" not in A()) # revealed: bool
```
## Literal Result for `in` and `not in`
`__contains__` with a literal return type may result in a `BooleanLiteral`
outcome.
```py
from typing import Literal
class AlwaysTrue:
def __contains__(self, item: int) -> Literal[1]:
return 1
class AlwaysFalse:
def __contains__(self, item: int) -> Literal[""]:
return ""
# TODO: it should be Literal[True] and Literal[False]
reveal_type(42 in AlwaysTrue()) # revealed: @Todo
reveal_type(42 not in AlwaysTrue()) # revealed: @Todo
# TODO: it should be Literal[False] and Literal[True]
reveal_type(42 in AlwaysFalse()) # revealed: @Todo
reveal_type(42 not in AlwaysFalse()) # revealed: @Todo
```
## No Fallback for `__contains__`
If `__contains__` is implemented, checking membership of a type it doesn't
accept is an error; it doesn't result in a fallback to `__iter__` or
`__getitem__`:
```py
class CheckContains: ...
class CheckIter: ...
class CheckGetItem: ...
class CheckIterIterator:
def __next__(self) -> CheckIter:
return CheckIter()
class A:
def __contains__(self, item: CheckContains) -> bool:
return True
def __iter__(self) -> CheckIterIterator:
return CheckIterIterator()
def __getitem__(self, key: int) -> CheckGetItem:
return CheckGetItem()
reveal_type(CheckContains() in A()) # revealed: bool
# TODO: should emit diagnostic, need to check arg type,
# should not fall back to __iter__ or __getitem__
reveal_type(CheckIter() in A()) # revealed: bool
reveal_type(CheckGetItem() in A()) # revealed: bool
class B:
def __iter__(self) -> CheckIterIterator:
return CheckIterIterator()
def __getitem__(self, key: int) -> CheckGetItem:
return CheckGetItem()
reveal_type(CheckIter() in B()) # revealed: bool
# Always use `__iter__`, regardless of iterated type; there's no NotImplemented
# in this case, so there's no fallback to `__getitem__`
reveal_type(CheckGetItem() in B()) # revealed: bool
```
## Invalid Old-Style Iteration
If `__getitem__` is implemented but does not accept integer arguments, then
the membership test is not supported and should trigger a diagnostic.
```py
class A:
def __getitem__(self, key: str) -> str:
return "foo"
# TODO should emit a diagnostic
reveal_type(42 in A()) # revealed: bool
reveal_type("hello" in A()) # revealed: bool
```

View file

@ -0,0 +1,336 @@
# Comparison: Rich Comparison
Rich comparison operations (`==`, `!=`, `<`, `<=`, `>`, `>=`) in Python are
implemented through double-underscore methods that allow customization of
comparison behavior.
For references, see:
- <https://docs.python.org/3/reference/datamodel.html#object.__lt__>
- <https://snarky.ca/unravelling-rich-comparison-operators/>
## Rich Comparison Dunder Implementations For Same Class
Classes can support rich comparison by implementing dunder methods like
`__eq__`, `__ne__`, etc. The most common case involves implementing these
methods for the same type:
```py
from __future__ import annotations
class A:
def __eq__(self, other: A) -> int:
return 42
def __ne__(self, other: A) -> float:
return 42.0
def __lt__(self, other: A) -> str:
return "42"
def __le__(self, other: A) -> bytes:
return b"42"
def __gt__(self, other: A) -> list:
return [42]
def __ge__(self, other: A) -> set:
return {42}
reveal_type(A() == A()) # revealed: int
reveal_type(A() != A()) # revealed: float
reveal_type(A() < A()) # revealed: str
reveal_type(A() <= A()) # revealed: bytes
reveal_type(A() > A()) # revealed: list
reveal_type(A() >= A()) # revealed: set
```
## Rich Comparison Dunder Implementations for Other Class
In some cases, classes may implement rich comparison dunder methods for
comparisons with a different type:
```py
from __future__ import annotations
class A:
def __eq__(self, other: B) -> int:
return 42
def __ne__(self, other: B) -> float:
return 42.0
def __lt__(self, other: B) -> str:
return "42"
def __le__(self, other: B) -> bytes:
return b"42"
def __gt__(self, other: B) -> list:
return [42]
def __ge__(self, other: B) -> set:
return {42}
class B: ...
reveal_type(A() == B()) # revealed: int
reveal_type(A() != B()) # revealed: float
reveal_type(A() < B()) # revealed: str
reveal_type(A() <= B()) # revealed: bytes
reveal_type(A() > B()) # revealed: list
reveal_type(A() >= B()) # revealed: set
```
## Reflected Comparisons
Fallback to the right-hand sides comparison methods occurs when the left-hand
side does not define them. Note: class `B` has its own `__eq__` and `__ne__`
methods to override those of `object`, but these methods will be ignored here
because they require a mismatched operand type.
```py
from __future__ import annotations
class A:
def __eq__(self, other: B) -> int:
return 42
def __ne__(self, other: B) -> float:
return 42.0
def __lt__(self, other: B) -> str:
return "42"
def __le__(self, other: B) -> bytes:
return b"42"
def __gt__(self, other: B) -> list:
return [42]
def __ge__(self, other: B) -> set:
return {42}
class B:
# To override builtins.object.__eq__ and builtins.object.__ne__
# TODO these should emit an invalid override diagnostic
def __eq__(self, other: str) -> B:
return B()
def __ne__(self, other: str) -> B:
return B()
# TODO: should be `int` and `float`.
# Need to check arg type and fall back to `rhs.__eq__` and `rhs.__ne__`.
#
# Because `object.__eq__` and `object.__ne__` accept `object` in typeshed,
# this can only happen with an invalid override of these methods,
# but we still support it.
reveal_type(B() == A()) # revealed: B
reveal_type(B() != A()) # revealed: B
reveal_type(B() < A()) # revealed: list
reveal_type(B() <= A()) # revealed: set
reveal_type(B() > A()) # revealed: str
reveal_type(B() >= A()) # revealed: bytes
class C:
def __gt__(self, other: C) -> int:
return 42
def __ge__(self, other: C) -> float:
return 42.0
reveal_type(C() < C()) # revealed: int
reveal_type(C() <= C()) # revealed: float
```
## Reflected Comparisons with Subclasses
When subclasses override comparison methods, these overridden methods take
precedence over those in the parent class. Class `B` inherits from `A` and
redefines comparison methods to return types other than `A`.
```py
from __future__ import annotations
class A:
def __eq__(self, other: A) -> A:
return A()
def __ne__(self, other: A) -> A:
return A()
def __lt__(self, other: A) -> A:
return A()
def __le__(self, other: A) -> A:
return A()
def __gt__(self, other: A) -> A:
return A()
def __ge__(self, other: A) -> A:
return A()
class B(A):
def __eq__(self, other: A) -> int:
return 42
def __ne__(self, other: A) -> float:
return 42.0
def __lt__(self, other: A) -> str:
return "42"
def __le__(self, other: A) -> bytes:
return b"42"
def __gt__(self, other: A) -> list:
return [42]
def __ge__(self, other: A) -> set:
return {42}
reveal_type(A() == B()) # revealed: int
reveal_type(A() != B()) # revealed: float
reveal_type(A() < B()) # revealed: list
reveal_type(A() <= B()) # revealed: set
reveal_type(A() > B()) # revealed: str
reveal_type(A() >= B()) # revealed: bytes
```
## Reflected Comparisons with Subclass But Falls Back to LHS
In the case of a subclass, the right-hand side has priority. However, if the
overridden dunder method has an mismatched type to operand, the comparison will
fall back to the left-hand side.
```py
from __future__ import annotations
class A:
def __lt__(self, other: A) -> A:
return A()
def __gt__(self, other: A) -> A:
return A()
class B(A):
def __lt__(self, other: int) -> B:
return B()
def __gt__(self, other: int) -> B:
return B()
# TODO: should be `A`, need to check argument type and fall back to LHS method
reveal_type(A() < B()) # revealed: B
reveal_type(A() > B()) # revealed: B
```
## Operations involving instances of classes inheriting from `Any`
`Any` and `Unknown` represent a set of possible runtime objects, wherein the
bounds of the set are unknown. Whether the left-hand operand's dunder or the
right-hand operand's reflected dunder depends on whether the right-hand operand
is an instance of a class that is a subclass of the left-hand operand's class
and overrides the reflected dunder. In the following example, because of the
unknowable nature of `Any`/`Unknown`, we must consider both possibilities:
`Any`/`Unknown` might resolve to an unknown third class that inherits from `X`
and overrides `__gt__`; but it also might not. Thus, the correct answer here for
the `reveal_type` is `int | Unknown`.
(This test is referenced from `mdtest/binary/instances.md`)
```py
from does_not_exist import Foo # error: [unresolved-import]
reveal_type(Foo) # revealed: Unknown
class X:
def __lt__(self, other: object) -> int:
return 42
class Y(Foo): ...
# TODO: Should be `int | Unknown`; see above discussion.
reveal_type(X() < Y()) # revealed: int
```
## Equality and Inequality Fallback
This test confirms that `==` and `!=` comparisons default to identity
comparisons (`is`, `is not`) when argument types do not match the method
signature.
Please refer to the
[docs](https://docs.python.org/3/reference/datamodel.html#object.__eq__)
```py
from __future__ import annotations
class A:
# TODO both these overrides should emit invalid-override diagnostic
def __eq__(self, other: int) -> A:
return A()
def __ne__(self, other: int) -> A:
return A()
# TODO: it should be `bool`, need to check arg type and fall back to `is` and `is not`
reveal_type(A() == A()) # revealed: A
reveal_type(A() != A()) # revealed: A
```
## Object Comparisons with Typeshed
```py
class A: ...
reveal_type(A() == object()) # revealed: bool
reveal_type(A() != object()) # revealed: bool
reveal_type(object() == A()) # revealed: bool
reveal_type(object() != A()) # revealed: bool
# error: [operator-unsupported] "Operator `<` is not supported for types `A` and `object`"
# revealed: Unknown
reveal_type(A() < object())
```
## Numbers Comparison with typeshed
```py
reveal_type(1 == 1.0) # revealed: bool
reveal_type(1 != 1.0) # revealed: bool
reveal_type(1 < 1.0) # revealed: bool
reveal_type(1 <= 1.0) # revealed: bool
reveal_type(1 > 1.0) # revealed: bool
reveal_type(1 >= 1.0) # revealed: bool
reveal_type(1 == 2j) # revealed: bool
reveal_type(1 != 2j) # revealed: bool
# TODO: should be Unknown and emit diagnostic,
# need to check arg type and should be failed
reveal_type(1 < 2j) # revealed: bool
reveal_type(1 <= 2j) # revealed: bool
reveal_type(1 > 2j) # revealed: bool
reveal_type(1 >= 2j) # revealed: bool
def bool_instance() -> bool:
return True
def int_instance() -> int:
return 42
x = bool_instance()
y = int_instance()
reveal_type(x < y) # revealed: bool
reveal_type(y < x) # revealed: bool
reveal_type(4.2 < x) # revealed: bool
reveal_type(x < 4.2) # revealed: bool
```

View file

@ -12,7 +12,8 @@ reveal_type(1 is 1) # revealed: bool
reveal_type(1 is not 1) # revealed: bool
reveal_type(1 is 2) # revealed: Literal[False]
reveal_type(1 is not 7) # revealed: Literal[True]
reveal_type(1 <= "" and 0 < 1) # revealed: @Todo | Literal[True]
# TODO: should be Unknown, and emit diagnostic, once we check call argument types
reveal_type(1 <= "" and 0 < 1) # revealed: bool
```
## Integer instance
@ -22,7 +23,7 @@ reveal_type(1 <= "" and 0 < 1) # revealed: @Todo | Literal[True]
def int_instance() -> int:
return 42
reveal_type(1 == int_instance()) # revealed: @Todo
reveal_type(1 == int_instance()) # revealed: bool
reveal_type(9 < int_instance()) # revealed: bool
reveal_type(int_instance() < int_instance()) # revealed: bool
```

View file

@ -65,20 +65,19 @@ def int_instance() -> int:
a = (bool_instance(),)
b = (int_instance(),)
# TODO: All @Todo should be `bool`
reveal_type(a == a) # revealed: @Todo
reveal_type(a != a) # revealed: @Todo
reveal_type(a < a) # revealed: @Todo
reveal_type(a <= a) # revealed: @Todo
reveal_type(a > a) # revealed: @Todo
reveal_type(a >= a) # revealed: @Todo
reveal_type(a == a) # revealed: bool
reveal_type(a != a) # revealed: bool
reveal_type(a < a) # revealed: bool
reveal_type(a <= a) # revealed: bool
reveal_type(a > a) # revealed: bool
reveal_type(a >= a) # revealed: bool
reveal_type(a == b) # revealed: @Todo
reveal_type(a != b) # revealed: @Todo
reveal_type(a < b) # revealed: @Todo
reveal_type(a <= b) # revealed: @Todo
reveal_type(a > b) # revealed: @Todo
reveal_type(a >= b) # revealed: @Todo
reveal_type(a == b) # revealed: bool
reveal_type(a != b) # revealed: bool
reveal_type(a < b) # revealed: bool
reveal_type(a <= b) # revealed: bool
reveal_type(a > b) # revealed: bool
reveal_type(a >= b) # revealed: bool
```
#### Comparison Unsupported
@ -90,17 +89,17 @@ However, `==` and `!=` are exceptions and can still provide definite results.
a = (1, 2)
b = (1, "hello")
# TODO: should be Literal[False]
reveal_type(a == b) # revealed: @Todo
# TODO: should be Literal[False], once we implement (in)equality for mismatched literals
reveal_type(a == b) # revealed: bool
# TODO: should be Literal[True]
reveal_type(a != b) # revealed: @Todo
# TODO: should be Literal[True], once we implement (in)equality for mismatched literals
reveal_type(a != b) # revealed: bool
# TODO: should be Unknown and add more informative diagnostics
reveal_type(a < b) # revealed: @Todo
reveal_type(a <= b) # revealed: @Todo
reveal_type(a > b) # revealed: @Todo
reveal_type(a >= b) # revealed: @Todo
reveal_type(a < b) # revealed: bool
reveal_type(a <= b) # revealed: bool
reveal_type(a > b) # revealed: bool
reveal_type(a >= b) # revealed: bool
```
However, if the lexicographic comparison completes without reaching a point where str and int are compared,
@ -146,13 +145,12 @@ class A:
a = (A(), A())
# TODO: All @Todo should be bool
reveal_type(a == a) # revealed: @Todo
reveal_type(a != a) # revealed: @Todo
reveal_type(a < a) # revealed: @Todo
reveal_type(a <= a) # revealed: @Todo
reveal_type(a > a) # revealed: @Todo
reveal_type(a >= a) # revealed: @Todo
reveal_type(a == a) # revealed: bool
reveal_type(a != a) # revealed: bool
reveal_type(a < a) # revealed: bool
reveal_type(a <= a) # revealed: bool
reveal_type(a > a) # revealed: bool
reveal_type(a >= a) # revealed: bool
```
### Membership Test Comparisons
@ -174,9 +172,8 @@ reveal_type(a not in b) # revealed: Literal[False]
reveal_type(a in c) # revealed: Literal[False]
reveal_type(a not in c) # revealed: Literal[True]
# TODO: All @Todo should be bool
reveal_type(a in d) # revealed: @Todo
reveal_type(a not in d) # revealed: @Todo
reveal_type(a in d) # revealed: bool
reveal_type(a not in d) # revealed: bool
```
### Identity Comparisons
@ -191,10 +188,10 @@ c = (1, 2, 3)
reveal_type(a is (1, 2)) # revealed: bool
reveal_type(a is not (1, 2)) # revealed: bool
# TODO: Update to Literal[False] once str == int comparison is implemented
reveal_type(a is b) # revealed: @Todo
# TODO: Update to Literal[True] once str == int comparison is implemented
reveal_type(a is not b) # revealed: @Todo
# TODO should be Literal[False] once we implement comparison of mismatched literal types
reveal_type(a is b) # revealed: bool
# TODO should be Literal[True] once we implement comparison of mismatched literal types
reveal_type(a is not b) # revealed: bool
reveal_type(a is c) # revealed: Literal[False]
reveal_type(a is not c) # revealed: Literal[True]

View file

@ -10,12 +10,16 @@ reveal_type(a) # revealed: bool
b = 0 not in 10 # error: "Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`"
reveal_type(b) # revealed: bool
c = object() < 5 # error: "Operator `<` is not supported for types `object` and `int`"
reveal_type(c) # revealed: Unknown
# TODO: should error, once operand type check is implemented
# ("Operator `<` is not supported for types `object` and `int`")
c = object() < 5
# TODO: should be Unknown, once operand type check is implemented
reveal_type(c) # revealed: bool
# TODO should error, need to check if __lt__ signature is valid for right operand
# TODO: should error, once operand type check is implemented
# ("Operator `<` is not supported for types `int` and `object`")
d = 5 < object()
# TODO: should be `Unknown`
# TODO: should be Unknown, once operand type check is implemented
reveal_type(d) # revealed: bool
flag = bool_instance()
@ -27,5 +31,6 @@ reveal_type(e) # revealed: bool
# TODO: should error, need to check if __lt__ signature is valid for right operand
# error may be "Operator `<` is not supported for types `int` and `str`, in comparing `tuple[Literal[1], Literal[2]]` with `tuple[Literal[1], Literal["hello"]]`
f = (1, 2) < (1, "hello")
reveal_type(f) # revealed: @Todo
# TODO: should be Unknown, once operand type check is implemented
reveal_type(f) # revealed: bool
```

View file

@ -58,7 +58,7 @@ use crate::types::{
use crate::util::subscript::PythonSubscript;
use crate::Db;
use super::{KnownClass, UnionBuilder};
use super::{IterationOutcome, KnownClass, UnionBuilder};
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
@ -3101,16 +3101,26 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => perform_rich_comparison(
self.db,
left_class_ty,
right_class_ty,
RichCompareOperator::Lt,
),
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods
_ => Ok(Type::Todo),
},
(Type::Instance(left_class), Type::Instance(right_class)) => {
let rich_comparison =
|op| perform_rich_comparison(self.db, left_class, right_class, op);
let membership_test_comparison =
|op| perform_membership_test_comparison(self.db, left_class, right_class, op);
match op {
ast::CmpOp::Eq => rich_comparison(RichCompareOperator::Eq),
ast::CmpOp::NotEq => rich_comparison(RichCompareOperator::Ne),
ast::CmpOp::Lt => rich_comparison(RichCompareOperator::Lt),
ast::CmpOp::LtE => rich_comparison(RichCompareOperator::Le),
ast::CmpOp::Gt => rich_comparison(RichCompareOperator::Gt),
ast::CmpOp::GtE => rich_comparison(RichCompareOperator::Ge),
ast::CmpOp::In => membership_test_comparison(MembershipTestCompareOperator::In),
ast::CmpOp::NotIn => {
membership_test_comparison(MembershipTestCompareOperator::NotIn)
}
ast::CmpOp::Is => Ok(KnownClass::Bool.to_instance(self.db)),
ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)),
}
}
// TODO: handle more types
_ => match op {
ast::CmpOp::Is | ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)),
@ -3623,7 +3633,8 @@ impl From<RichCompareOperator> for ast::CmpOp {
}
impl RichCompareOperator {
const fn dunder_name(self) -> &'static str {
#[must_use]
const fn dunder(self) -> &'static str {
match self {
RichCompareOperator::Eq => "__eq__",
RichCompareOperator::Ne => "__ne__",
@ -3633,6 +3644,33 @@ impl RichCompareOperator {
RichCompareOperator::Ge => "__ge__",
}
}
#[must_use]
const fn reflect(self) -> Self {
match self {
RichCompareOperator::Eq => RichCompareOperator::Eq,
RichCompareOperator::Ne => RichCompareOperator::Ne,
RichCompareOperator::Lt => RichCompareOperator::Gt,
RichCompareOperator::Le => RichCompareOperator::Ge,
RichCompareOperator::Gt => RichCompareOperator::Lt,
RichCompareOperator::Ge => RichCompareOperator::Le,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MembershipTestCompareOperator {
In,
NotIn,
}
impl From<MembershipTestCompareOperator> for ast::CmpOp {
fn from(value: MembershipTestCompareOperator) -> Self {
match value {
MembershipTestCompareOperator::In => ast::CmpOp::In,
MembershipTestCompareOperator::NotIn => ast::CmpOp::NotIn,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -3716,41 +3754,99 @@ impl StringPartsCollector {
/// Rich comparison in Python are the operators `==`, `!=`, `<`, `<=`, `>`, and `>=`. Their
/// behaviour can be edited for classes by implementing corresponding dunder methods.
/// This function performs rich comparison between two instances and returns the resulting type.
/// This function performs rich comparison between two instances and returns the resulting type.
/// see `<https://docs.python.org/3/reference/datamodel.html#object.__lt__>`
fn perform_rich_comparison<'db>(
db: &'db dyn Db,
left: ClassType<'db>,
right: ClassType<'db>,
left_class: ClassType<'db>,
right_class: ClassType<'db>,
op: RichCompareOperator,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// The following resource has details about the rich comparison algorithm:
// https://snarky.ca/unravelling-rich-comparison-operators/
//
// TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the
// l.h.s.
// TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined
// TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be errored, currently bool)
let dunder = left.class_member(db, op.dunder_name());
if !dunder.is_unbound() {
// TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be None, currently bool)
return dunder
.call(db, &[Type::Instance(left), Type::Instance(right)])
.return_ty(db)
.ok_or_else(|| CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
});
let call_dunder =
|op: RichCompareOperator, left_class: ClassType<'db>, right_class: ClassType<'db>| {
left_class
.class_member(db, op.dunder())
.call(
db,
&[Type::Instance(left_class), Type::Instance(right_class)],
)
.return_ty(db)
};
// The reflected dunder has priority if the right-hand side is a strict subclass of the left-hand side.
if left_class != right_class && right_class.is_subclass_of(db, left_class) {
call_dunder(op.reflect(), right_class, left_class)
.or_else(|| call_dunder(op, left_class, right_class))
} else {
call_dunder(op, left_class, right_class)
.or_else(|| call_dunder(op.reflect(), right_class, left_class))
}
// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
Err(CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
.or_else(|| {
// When no appropriate method returns any value other than NotImplemented,
// the `==` and `!=` operators will fall back to `is` and `is not`, respectively.
// refer to `<https://docs.python.org/3/reference/datamodel.html#object.__eq__>`
if matches!(op, RichCompareOperator::Eq | RichCompareOperator::Ne) {
Some(KnownClass::Bool.to_instance(db))
} else {
None
}
})
.ok_or_else(|| CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left_class),
right_ty: Type::Instance(right_class),
})
}
/// Performs a membership test (`in` and `not in`) between two instances and returns the resulting type, or `None` if the test is unsupported.
/// The behavior can be customized in Python by implementing `__contains__`, `__iter__`, or `__getitem__` methods.
/// See `<https://docs.python.org/3/reference/datamodel.html#object.__contains__>`
/// and `<https://docs.python.org/3/reference/expressions.html#membership-test-details>`
fn perform_membership_test_comparison<'db>(
db: &'db dyn Db,
left_class: ClassType<'db>,
right_class: ClassType<'db>,
op: MembershipTestCompareOperator,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
let (left_instance, right_instance) = (Type::Instance(left_class), Type::Instance(right_class));
let contains_dunder = right_class.class_member(db, "__contains__");
let compare_result_opt = if contains_dunder.is_unbound() {
// iteration-based membership test
match right_instance.iterate(db) {
IterationOutcome::Iterable { .. } => Some(KnownClass::Bool.to_instance(db)),
IterationOutcome::NotIterable { .. } => None,
}
} else {
// If `__contains__` is available, it is used directly for the membership test.
contains_dunder
.call(db, &[right_instance, left_instance])
.return_ty(db)
};
compare_result_opt
.map(|ty| {
if matches!(ty, Type::Todo) {
return Type::Todo;
}
match op {
MembershipTestCompareOperator::In => ty.bool(db).into_type(db),
MembershipTestCompareOperator::NotIn => ty.bool(db).negate().into_type(db),
}
})
.ok_or_else(|| CompareUnsupportedError {
op: op.into(),
left_ty: left_instance,
right_ty: right_instance,
})
}
#[cfg(test)]