diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md new file mode 100644 index 0000000000..21595db3f5 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md @@ -0,0 +1,165 @@ +# Comparison: Membership Test + +In Python, the term "membership test operators" refers to the operators +`in` and `not in`. To customize their behavior, classes can implement one of +the special methods `__contains__`, `__iter__`, or `__getitem__`. + +For references, see: + +- +- +- + +## Implements `__contains__` + +Classes can support membership tests by implementing the `__contains__` method: + +```py +class A: + def __contains__(self, item: str) -> bool: + return True + +reveal_type("hello" in A()) # revealed: bool +reveal_type("hello" not in A()) # revealed: bool +# TODO: should emit diagnostic, need to check arg type, will fail +reveal_type(42 in A()) # revealed: bool +reveal_type(42 not in A()) # revealed: bool +``` + +## Implements `__iter__` + +Classes that don't implement `__contains__`, but do implement `__iter__`, also +support containment checks; the needle will be sought in their iterated items: + +```py +class StringIterator: + def __next__(self) -> str: + return "foo" + +class A: + def __iter__(self) -> StringIterator: + return StringIterator() + +reveal_type("hello" in A()) # revealed: bool +reveal_type("hello" not in A()) # revealed: bool +reveal_type(42 in A()) # revealed: bool +reveal_type(42 not in A()) # revealed: bool +``` + +## Implements `__getitems__` + +The final fallback is to implement `__getitem__` for integer keys. Python will +call `__getitem__` with `0`, `1`, `2`... until either the needle is found +(leading the membership test to evaluate to `True`) or `__getitem__` raises +`IndexError` (the raised exception is swallowed, but results in the membership +test evaluating to `False`). + +```py +class A: + def __getitem__(self, key: int) -> str: + return "foo" + +reveal_type("hello" in A()) # revealed: bool +reveal_type("hello" not in A()) # revealed: bool +reveal_type(42 in A()) # revealed: bool +reveal_type(42 not in A()) # revealed: bool +``` + +## Wrong Return Type + +Python coerces the results of containment checks to `bool`, even if `__contains__` +returns a non-bool: + +```py +class A: + def __contains__(self, item: str) -> str: + return "foo" + +reveal_type("hello" in A()) # revealed: bool +reveal_type("hello" not in A()) # revealed: bool +``` + +## Literal Result for `in` and `not in` + +`__contains__` with a literal return type may result in a `BooleanLiteral` +outcome. + +```py +from typing import Literal + +class AlwaysTrue: + def __contains__(self, item: int) -> Literal[1]: + return 1 + +class AlwaysFalse: + def __contains__(self, item: int) -> Literal[""]: + return "" + +# TODO: it should be Literal[True] and Literal[False] +reveal_type(42 in AlwaysTrue()) # revealed: @Todo +reveal_type(42 not in AlwaysTrue()) # revealed: @Todo + +# TODO: it should be Literal[False] and Literal[True] +reveal_type(42 in AlwaysFalse()) # revealed: @Todo +reveal_type(42 not in AlwaysFalse()) # revealed: @Todo +``` + +## No Fallback for `__contains__` + +If `__contains__` is implemented, checking membership of a type it doesn't +accept is an error; it doesn't result in a fallback to `__iter__` or +`__getitem__`: + +```py +class CheckContains: ... +class CheckIter: ... +class CheckGetItem: ... + +class CheckIterIterator: + def __next__(self) -> CheckIter: + return CheckIter() + +class A: + def __contains__(self, item: CheckContains) -> bool: + return True + + def __iter__(self) -> CheckIterIterator: + return CheckIterIterator() + + def __getitem__(self, key: int) -> CheckGetItem: + return CheckGetItem() + +reveal_type(CheckContains() in A()) # revealed: bool + +# TODO: should emit diagnostic, need to check arg type, +# should not fall back to __iter__ or __getitem__ +reveal_type(CheckIter() in A()) # revealed: bool +reveal_type(CheckGetItem() in A()) # revealed: bool + +class B: + def __iter__(self) -> CheckIterIterator: + return CheckIterIterator() + + def __getitem__(self, key: int) -> CheckGetItem: + return CheckGetItem() + +reveal_type(CheckIter() in B()) # revealed: bool +# Always use `__iter__`, regardless of iterated type; there's no NotImplemented +# in this case, so there's no fallback to `__getitem__` +reveal_type(CheckGetItem() in B()) # revealed: bool +``` + +## Invalid Old-Style Iteration + +If `__getitem__` is implemented but does not accept integer arguments, then +the membership test is not supported and should trigger a diagnostic. + +```py +class A: + def __getitem__(self, key: str) -> str: + return "foo" + +# TODO should emit a diagnostic +reveal_type(42 in A()) # revealed: bool +reveal_type("hello" in A()) # revealed: bool +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md new file mode 100644 index 0000000000..db47ac9ae9 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md @@ -0,0 +1,336 @@ +# Comparison: Rich Comparison + +Rich comparison operations (`==`, `!=`, `<`, `<=`, `>`, `>=`) in Python are +implemented through double-underscore methods that allow customization of +comparison behavior. + +For references, see: + +- +- + +## Rich Comparison Dunder Implementations For Same Class + +Classes can support rich comparison by implementing dunder methods like +`__eq__`, `__ne__`, etc. The most common case involves implementing these +methods for the same type: + +```py +from __future__ import annotations + +class A: + def __eq__(self, other: A) -> int: + return 42 + + def __ne__(self, other: A) -> float: + return 42.0 + + def __lt__(self, other: A) -> str: + return "42" + + def __le__(self, other: A) -> bytes: + return b"42" + + def __gt__(self, other: A) -> list: + return [42] + + def __ge__(self, other: A) -> set: + return {42} + +reveal_type(A() == A()) # revealed: int +reveal_type(A() != A()) # revealed: float +reveal_type(A() < A()) # revealed: str +reveal_type(A() <= A()) # revealed: bytes +reveal_type(A() > A()) # revealed: list +reveal_type(A() >= A()) # revealed: set +``` + +## Rich Comparison Dunder Implementations for Other Class + +In some cases, classes may implement rich comparison dunder methods for +comparisons with a different type: + +```py +from __future__ import annotations + +class A: + def __eq__(self, other: B) -> int: + return 42 + + def __ne__(self, other: B) -> float: + return 42.0 + + def __lt__(self, other: B) -> str: + return "42" + + def __le__(self, other: B) -> bytes: + return b"42" + + def __gt__(self, other: B) -> list: + return [42] + + def __ge__(self, other: B) -> set: + return {42} + +class B: ... + +reveal_type(A() == B()) # revealed: int +reveal_type(A() != B()) # revealed: float +reveal_type(A() < B()) # revealed: str +reveal_type(A() <= B()) # revealed: bytes +reveal_type(A() > B()) # revealed: list +reveal_type(A() >= B()) # revealed: set +``` + +## Reflected Comparisons + +Fallback to the right-hand side’s comparison methods occurs when the left-hand +side does not define them. Note: class `B` has its own `__eq__` and `__ne__` +methods to override those of `object`, but these methods will be ignored here +because they require a mismatched operand type. + +```py +from __future__ import annotations + +class A: + def __eq__(self, other: B) -> int: + return 42 + + def __ne__(self, other: B) -> float: + return 42.0 + + def __lt__(self, other: B) -> str: + return "42" + + def __le__(self, other: B) -> bytes: + return b"42" + + def __gt__(self, other: B) -> list: + return [42] + + def __ge__(self, other: B) -> set: + return {42} + +class B: + # To override builtins.object.__eq__ and builtins.object.__ne__ + # TODO these should emit an invalid override diagnostic + def __eq__(self, other: str) -> B: + return B() + + def __ne__(self, other: str) -> B: + return B() + +# TODO: should be `int` and `float`. +# Need to check arg type and fall back to `rhs.__eq__` and `rhs.__ne__`. +# +# Because `object.__eq__` and `object.__ne__` accept `object` in typeshed, +# this can only happen with an invalid override of these methods, +# but we still support it. +reveal_type(B() == A()) # revealed: B +reveal_type(B() != A()) # revealed: B + +reveal_type(B() < A()) # revealed: list +reveal_type(B() <= A()) # revealed: set + +reveal_type(B() > A()) # revealed: str +reveal_type(B() >= A()) # revealed: bytes + +class C: + def __gt__(self, other: C) -> int: + return 42 + + def __ge__(self, other: C) -> float: + return 42.0 + +reveal_type(C() < C()) # revealed: int +reveal_type(C() <= C()) # revealed: float +``` + +## Reflected Comparisons with Subclasses + +When subclasses override comparison methods, these overridden methods take +precedence over those in the parent class. Class `B` inherits from `A` and +redefines comparison methods to return types other than `A`. + +```py +from __future__ import annotations + +class A: + def __eq__(self, other: A) -> A: + return A() + + def __ne__(self, other: A) -> A: + return A() + + def __lt__(self, other: A) -> A: + return A() + + def __le__(self, other: A) -> A: + return A() + + def __gt__(self, other: A) -> A: + return A() + + def __ge__(self, other: A) -> A: + return A() + +class B(A): + def __eq__(self, other: A) -> int: + return 42 + + def __ne__(self, other: A) -> float: + return 42.0 + + def __lt__(self, other: A) -> str: + return "42" + + def __le__(self, other: A) -> bytes: + return b"42" + + def __gt__(self, other: A) -> list: + return [42] + + def __ge__(self, other: A) -> set: + return {42} + +reveal_type(A() == B()) # revealed: int +reveal_type(A() != B()) # revealed: float + +reveal_type(A() < B()) # revealed: list +reveal_type(A() <= B()) # revealed: set + +reveal_type(A() > B()) # revealed: str +reveal_type(A() >= B()) # revealed: bytes +``` + +## Reflected Comparisons with Subclass But Falls Back to LHS + +In the case of a subclass, the right-hand side has priority. However, if the +overridden dunder method has an mismatched type to operand, the comparison will +fall back to the left-hand side. + +```py +from __future__ import annotations + +class A: + def __lt__(self, other: A) -> A: + return A() + + def __gt__(self, other: A) -> A: + return A() + +class B(A): + def __lt__(self, other: int) -> B: + return B() + + def __gt__(self, other: int) -> B: + return B() + +# TODO: should be `A`, need to check argument type and fall back to LHS method +reveal_type(A() < B()) # revealed: B +reveal_type(A() > B()) # revealed: B +``` + +## Operations involving instances of classes inheriting from `Any` + +`Any` and `Unknown` represent a set of possible runtime objects, wherein the +bounds of the set are unknown. Whether the left-hand operand's dunder or the +right-hand operand's reflected dunder depends on whether the right-hand operand +is an instance of a class that is a subclass of the left-hand operand's class +and overrides the reflected dunder. In the following example, because of the +unknowable nature of `Any`/`Unknown`, we must consider both possibilities: +`Any`/`Unknown` might resolve to an unknown third class that inherits from `X` +and overrides `__gt__`; but it also might not. Thus, the correct answer here for +the `reveal_type` is `int | Unknown`. + +(This test is referenced from `mdtest/binary/instances.md`) + +```py +from does_not_exist import Foo # error: [unresolved-import] + +reveal_type(Foo) # revealed: Unknown + +class X: + def __lt__(self, other: object) -> int: + return 42 + +class Y(Foo): ... + +# TODO: Should be `int | Unknown`; see above discussion. +reveal_type(X() < Y()) # revealed: int +``` + +## Equality and Inequality Fallback + +This test confirms that `==` and `!=` comparisons default to identity +comparisons (`is`, `is not`) when argument types do not match the method +signature. + +Please refer to the +[docs](https://docs.python.org/3/reference/datamodel.html#object.__eq__) + +```py +from __future__ import annotations + +class A: + # TODO both these overrides should emit invalid-override diagnostic + def __eq__(self, other: int) -> A: + return A() + + def __ne__(self, other: int) -> A: + return A() + +# TODO: it should be `bool`, need to check arg type and fall back to `is` and `is not` +reveal_type(A() == A()) # revealed: A +reveal_type(A() != A()) # revealed: A +``` + +## Object Comparisons with Typeshed + +```py +class A: ... + +reveal_type(A() == object()) # revealed: bool +reveal_type(A() != object()) # revealed: bool +reveal_type(object() == A()) # revealed: bool +reveal_type(object() != A()) # revealed: bool + +# error: [operator-unsupported] "Operator `<` is not supported for types `A` and `object`" +# revealed: Unknown +reveal_type(A() < object()) +``` + +## Numbers Comparison with typeshed + +```py +reveal_type(1 == 1.0) # revealed: bool +reveal_type(1 != 1.0) # revealed: bool +reveal_type(1 < 1.0) # revealed: bool +reveal_type(1 <= 1.0) # revealed: bool +reveal_type(1 > 1.0) # revealed: bool +reveal_type(1 >= 1.0) # revealed: bool + +reveal_type(1 == 2j) # revealed: bool +reveal_type(1 != 2j) # revealed: bool + +# TODO: should be Unknown and emit diagnostic, +# need to check arg type and should be failed +reveal_type(1 < 2j) # revealed: bool +reveal_type(1 <= 2j) # revealed: bool +reveal_type(1 > 2j) # revealed: bool +reveal_type(1 >= 2j) # revealed: bool + +def bool_instance() -> bool: + return True + +def int_instance() -> int: + return 42 + +x = bool_instance() +y = int_instance() + +reveal_type(x < y) # revealed: bool +reveal_type(y < x) # revealed: bool +reveal_type(4.2 < x) # revealed: bool +reveal_type(x < 4.2) # revealed: bool +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/integers.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/integers.md index b576ce318d..a2092af109 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/integers.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/integers.md @@ -12,7 +12,8 @@ reveal_type(1 is 1) # revealed: bool reveal_type(1 is not 1) # revealed: bool reveal_type(1 is 2) # revealed: Literal[False] reveal_type(1 is not 7) # revealed: Literal[True] -reveal_type(1 <= "" and 0 < 1) # revealed: @Todo | Literal[True] +# TODO: should be Unknown, and emit diagnostic, once we check call argument types +reveal_type(1 <= "" and 0 < 1) # revealed: bool ``` ## Integer instance @@ -22,7 +23,7 @@ reveal_type(1 <= "" and 0 < 1) # revealed: @Todo | Literal[True] def int_instance() -> int: return 42 -reveal_type(1 == int_instance()) # revealed: @Todo +reveal_type(1 == int_instance()) # revealed: bool reveal_type(9 < int_instance()) # revealed: bool reveal_type(int_instance() < int_instance()) # revealed: bool ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md index c80ee4d601..7d3d789e07 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md @@ -65,20 +65,19 @@ def int_instance() -> int: a = (bool_instance(),) b = (int_instance(),) -# TODO: All @Todo should be `bool` -reveal_type(a == a) # revealed: @Todo -reveal_type(a != a) # revealed: @Todo -reveal_type(a < a) # revealed: @Todo -reveal_type(a <= a) # revealed: @Todo -reveal_type(a > a) # revealed: @Todo -reveal_type(a >= a) # revealed: @Todo +reveal_type(a == a) # revealed: bool +reveal_type(a != a) # revealed: bool +reveal_type(a < a) # revealed: bool +reveal_type(a <= a) # revealed: bool +reveal_type(a > a) # revealed: bool +reveal_type(a >= a) # revealed: bool -reveal_type(a == b) # revealed: @Todo -reveal_type(a != b) # revealed: @Todo -reveal_type(a < b) # revealed: @Todo -reveal_type(a <= b) # revealed: @Todo -reveal_type(a > b) # revealed: @Todo -reveal_type(a >= b) # revealed: @Todo +reveal_type(a == b) # revealed: bool +reveal_type(a != b) # revealed: bool +reveal_type(a < b) # revealed: bool +reveal_type(a <= b) # revealed: bool +reveal_type(a > b) # revealed: bool +reveal_type(a >= b) # revealed: bool ``` #### Comparison Unsupported @@ -90,17 +89,17 @@ However, `==` and `!=` are exceptions and can still provide definite results. a = (1, 2) b = (1, "hello") -# TODO: should be Literal[False] -reveal_type(a == b) # revealed: @Todo +# TODO: should be Literal[False], once we implement (in)equality for mismatched literals +reveal_type(a == b) # revealed: bool -# TODO: should be Literal[True] -reveal_type(a != b) # revealed: @Todo +# TODO: should be Literal[True], once we implement (in)equality for mismatched literals +reveal_type(a != b) # revealed: bool # TODO: should be Unknown and add more informative diagnostics -reveal_type(a < b) # revealed: @Todo -reveal_type(a <= b) # revealed: @Todo -reveal_type(a > b) # revealed: @Todo -reveal_type(a >= b) # revealed: @Todo +reveal_type(a < b) # revealed: bool +reveal_type(a <= b) # revealed: bool +reveal_type(a > b) # revealed: bool +reveal_type(a >= b) # revealed: bool ``` However, if the lexicographic comparison completes without reaching a point where str and int are compared, @@ -146,13 +145,12 @@ class A: a = (A(), A()) -# TODO: All @Todo should be bool -reveal_type(a == a) # revealed: @Todo -reveal_type(a != a) # revealed: @Todo -reveal_type(a < a) # revealed: @Todo -reveal_type(a <= a) # revealed: @Todo -reveal_type(a > a) # revealed: @Todo -reveal_type(a >= a) # revealed: @Todo +reveal_type(a == a) # revealed: bool +reveal_type(a != a) # revealed: bool +reveal_type(a < a) # revealed: bool +reveal_type(a <= a) # revealed: bool +reveal_type(a > a) # revealed: bool +reveal_type(a >= a) # revealed: bool ``` ### Membership Test Comparisons @@ -174,9 +172,8 @@ reveal_type(a not in b) # revealed: Literal[False] reveal_type(a in c) # revealed: Literal[False] reveal_type(a not in c) # revealed: Literal[True] -# TODO: All @Todo should be bool -reveal_type(a in d) # revealed: @Todo -reveal_type(a not in d) # revealed: @Todo +reveal_type(a in d) # revealed: bool +reveal_type(a not in d) # revealed: bool ``` ### Identity Comparisons @@ -191,10 +188,10 @@ c = (1, 2, 3) reveal_type(a is (1, 2)) # revealed: bool reveal_type(a is not (1, 2)) # revealed: bool -# TODO: Update to Literal[False] once str == int comparison is implemented -reveal_type(a is b) # revealed: @Todo -# TODO: Update to Literal[True] once str == int comparison is implemented -reveal_type(a is not b) # revealed: @Todo +# TODO should be Literal[False] once we implement comparison of mismatched literal types +reveal_type(a is b) # revealed: bool +# TODO should be Literal[True] once we implement comparison of mismatched literal types +reveal_type(a is not b) # revealed: bool reveal_type(a is c) # revealed: Literal[False] reveal_type(a is not c) # revealed: Literal[True] diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md index f073a19dc5..c1f23b1d48 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md @@ -10,12 +10,16 @@ reveal_type(a) # revealed: bool b = 0 not in 10 # error: "Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`" reveal_type(b) # revealed: bool -c = object() < 5 # error: "Operator `<` is not supported for types `object` and `int`" -reveal_type(c) # revealed: Unknown +# TODO: should error, once operand type check is implemented +# ("Operator `<` is not supported for types `object` and `int`") +c = object() < 5 +# TODO: should be Unknown, once operand type check is implemented +reveal_type(c) # revealed: bool -# TODO should error, need to check if __lt__ signature is valid for right operand +# TODO: should error, once operand type check is implemented +# ("Operator `<` is not supported for types `int` and `object`") d = 5 < object() -# TODO: should be `Unknown` +# TODO: should be Unknown, once operand type check is implemented reveal_type(d) # revealed: bool flag = bool_instance() @@ -27,5 +31,6 @@ reveal_type(e) # revealed: bool # TODO: should error, need to check if __lt__ signature is valid for right operand # error may be "Operator `<` is not supported for types `int` and `str`, in comparing `tuple[Literal[1], Literal[2]]` with `tuple[Literal[1], Literal["hello"]]` f = (1, 2) < (1, "hello") -reveal_type(f) # revealed: @Todo +# TODO: should be Unknown, once operand type check is implemented +reveal_type(f) # revealed: bool ``` diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index f3c6b10691..e9e1f99a43 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -58,7 +58,7 @@ use crate::types::{ use crate::util::subscript::PythonSubscript; use crate::Db; -use super::{KnownClass, UnionBuilder}; +use super::{IterationOutcome, KnownClass, UnionBuilder}; /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the @@ -3101,16 +3101,26 @@ impl<'db> TypeInferenceBuilder<'db> { } // Lookup the rich comparison `__dunder__` methods on instances - (Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op { - ast::CmpOp::Lt => perform_rich_comparison( - self.db, - left_class_ty, - right_class_ty, - RichCompareOperator::Lt, - ), - // TODO: implement mapping from `ast::CmpOp` to rich comparison methods - _ => Ok(Type::Todo), - }, + (Type::Instance(left_class), Type::Instance(right_class)) => { + let rich_comparison = + |op| perform_rich_comparison(self.db, left_class, right_class, op); + let membership_test_comparison = + |op| perform_membership_test_comparison(self.db, left_class, right_class, op); + match op { + ast::CmpOp::Eq => rich_comparison(RichCompareOperator::Eq), + ast::CmpOp::NotEq => rich_comparison(RichCompareOperator::Ne), + ast::CmpOp::Lt => rich_comparison(RichCompareOperator::Lt), + ast::CmpOp::LtE => rich_comparison(RichCompareOperator::Le), + ast::CmpOp::Gt => rich_comparison(RichCompareOperator::Gt), + ast::CmpOp::GtE => rich_comparison(RichCompareOperator::Ge), + ast::CmpOp::In => membership_test_comparison(MembershipTestCompareOperator::In), + ast::CmpOp::NotIn => { + membership_test_comparison(MembershipTestCompareOperator::NotIn) + } + ast::CmpOp::Is => Ok(KnownClass::Bool.to_instance(self.db)), + ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)), + } + } // TODO: handle more types _ => match op { ast::CmpOp::Is | ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)), @@ -3623,7 +3633,8 @@ impl From for ast::CmpOp { } impl RichCompareOperator { - const fn dunder_name(self) -> &'static str { + #[must_use] + const fn dunder(self) -> &'static str { match self { RichCompareOperator::Eq => "__eq__", RichCompareOperator::Ne => "__ne__", @@ -3633,6 +3644,33 @@ impl RichCompareOperator { RichCompareOperator::Ge => "__ge__", } } + + #[must_use] + const fn reflect(self) -> Self { + match self { + RichCompareOperator::Eq => RichCompareOperator::Eq, + RichCompareOperator::Ne => RichCompareOperator::Ne, + RichCompareOperator::Lt => RichCompareOperator::Gt, + RichCompareOperator::Le => RichCompareOperator::Ge, + RichCompareOperator::Gt => RichCompareOperator::Lt, + RichCompareOperator::Ge => RichCompareOperator::Le, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MembershipTestCompareOperator { + In, + NotIn, +} + +impl From for ast::CmpOp { + fn from(value: MembershipTestCompareOperator) -> Self { + match value { + MembershipTestCompareOperator::In => ast::CmpOp::In, + MembershipTestCompareOperator::NotIn => ast::CmpOp::NotIn, + } + } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -3716,41 +3754,99 @@ impl StringPartsCollector { /// Rich comparison in Python are the operators `==`, `!=`, `<`, `<=`, `>`, and `>=`. Their /// behaviour can be edited for classes by implementing corresponding dunder methods. -/// This function performs rich comparison between two instances and returns the resulting type. +/// This function performs rich comparison between two instances and returns the resulting type. /// see `` fn perform_rich_comparison<'db>( db: &'db dyn Db, - left: ClassType<'db>, - right: ClassType<'db>, + left_class: ClassType<'db>, + right_class: ClassType<'db>, op: RichCompareOperator, ) -> Result, CompareUnsupportedError<'db>> { // The following resource has details about the rich comparison algorithm: // https://snarky.ca/unravelling-rich-comparison-operators/ // - // TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the - // l.h.s. - // TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined + // TODO: this currently gives the return type even if the arg types are invalid + // (e.g. int.__lt__ with string instance should be errored, currently bool) - let dunder = left.class_member(db, op.dunder_name()); - if !dunder.is_unbound() { - // TODO: this currently gives the return type even if the arg types are invalid - // (e.g. int.__lt__ with string instance should be None, currently bool) - return dunder - .call(db, &[Type::Instance(left), Type::Instance(right)]) - .return_ty(db) - .ok_or_else(|| CompareUnsupportedError { - op: op.into(), - left_ty: Type::Instance(left), - right_ty: Type::Instance(right), - }); + let call_dunder = + |op: RichCompareOperator, left_class: ClassType<'db>, right_class: ClassType<'db>| { + left_class + .class_member(db, op.dunder()) + .call( + db, + &[Type::Instance(left_class), Type::Instance(right_class)], + ) + .return_ty(db) + }; + + // The reflected dunder has priority if the right-hand side is a strict subclass of the left-hand side. + if left_class != right_class && right_class.is_subclass_of(db, left_class) { + call_dunder(op.reflect(), right_class, left_class) + .or_else(|| call_dunder(op, left_class, right_class)) + } else { + call_dunder(op, left_class, right_class) + .or_else(|| call_dunder(op.reflect(), right_class, left_class)) } - - // TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=) - Err(CompareUnsupportedError { - op: op.into(), - left_ty: Type::Instance(left), - right_ty: Type::Instance(right), + .or_else(|| { + // When no appropriate method returns any value other than NotImplemented, + // the `==` and `!=` operators will fall back to `is` and `is not`, respectively. + // refer to `` + if matches!(op, RichCompareOperator::Eq | RichCompareOperator::Ne) { + Some(KnownClass::Bool.to_instance(db)) + } else { + None + } }) + .ok_or_else(|| CompareUnsupportedError { + op: op.into(), + left_ty: Type::Instance(left_class), + right_ty: Type::Instance(right_class), + }) +} + +/// Performs a membership test (`in` and `not in`) between two instances and returns the resulting type, or `None` if the test is unsupported. +/// The behavior can be customized in Python by implementing `__contains__`, `__iter__`, or `__getitem__` methods. +/// See `` +/// and `` +fn perform_membership_test_comparison<'db>( + db: &'db dyn Db, + left_class: ClassType<'db>, + right_class: ClassType<'db>, + op: MembershipTestCompareOperator, +) -> Result, CompareUnsupportedError<'db>> { + let (left_instance, right_instance) = (Type::Instance(left_class), Type::Instance(right_class)); + + let contains_dunder = right_class.class_member(db, "__contains__"); + + let compare_result_opt = if contains_dunder.is_unbound() { + // iteration-based membership test + match right_instance.iterate(db) { + IterationOutcome::Iterable { .. } => Some(KnownClass::Bool.to_instance(db)), + IterationOutcome::NotIterable { .. } => None, + } + } else { + // If `__contains__` is available, it is used directly for the membership test. + contains_dunder + .call(db, &[right_instance, left_instance]) + .return_ty(db) + }; + + compare_result_opt + .map(|ty| { + if matches!(ty, Type::Todo) { + return Type::Todo; + } + + match op { + MembershipTestCompareOperator::In => ty.bool(db).into_type(db), + MembershipTestCompareOperator::NotIn => ty.bool(db).negate().into_type(db), + } + }) + .ok_or_else(|| CompareUnsupportedError { + op: op.into(), + left_ty: left_instance, + right_ty: right_instance, + }) } #[cfg(test)]