[ty] Infer the correct type of Enum __eq__ and __ne__ comparisions (#19666)

## Summary

Resolves https://github.com/astral-sh/ty/issues/920

## Test Plan

Update `enums.md`

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Matthew Mckee 2025-08-18 18:45:44 +01:00 committed by GitHub
parent 3314cf90ed
commit 24f6d2dc13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 110 additions and 39 deletions

View file

@ -8019,6 +8019,48 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// language spec.
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal
// - `[ast::CompOp::IsNot]`: return `true` if unequal, `bool` if equal
let db = self.db();
let try_dunder = |inference: &mut TypeInferenceBuilder<'db, '_>,
policy: MemberLookupPolicy| {
let rich_comparison = |op| inference.infer_rich_comparison(left, right, op, policy);
let membership_test_comparison = |op, range: TextRange| {
inference.infer_membership_test_comparison(left, right, op, range)
};
match op {
ast::CmpOp::Eq => rich_comparison(RichCompareOperator::Eq),
ast::CmpOp::NotEq => rich_comparison(RichCompareOperator::Ne),
ast::CmpOp::Lt => rich_comparison(RichCompareOperator::Lt),
ast::CmpOp::LtE => rich_comparison(RichCompareOperator::Le),
ast::CmpOp::Gt => rich_comparison(RichCompareOperator::Gt),
ast::CmpOp::GtE => rich_comparison(RichCompareOperator::Ge),
ast::CmpOp::In => {
membership_test_comparison(MembershipTestCompareOperator::In, range)
}
ast::CmpOp::NotIn => {
membership_test_comparison(MembershipTestCompareOperator::NotIn, range)
}
ast::CmpOp::Is => {
if left.is_disjoint_from(db, right) {
Ok(Type::BooleanLiteral(false))
} else if left.is_singleton(db) && left.is_equivalent_to(db, right) {
Ok(Type::BooleanLiteral(true))
} else {
Ok(KnownClass::Bool.to_instance(db))
}
}
ast::CmpOp::IsNot => {
if left.is_disjoint_from(db, right) {
Ok(Type::BooleanLiteral(true))
} else if left.is_singleton(db) && left.is_equivalent_to(db, right) {
Ok(Type::BooleanLiteral(false))
} else {
Ok(KnownClass::Bool.to_instance(db))
}
}
}
};
let comparison_result = match (left, right) {
(Type::Union(union), other) => {
let mut builder = UnionBuilder::new(self.db());
@ -8233,12 +8275,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
(Type::EnumLiteral(literal_1), Type::EnumLiteral(literal_2))
if op == ast::CmpOp::Eq =>
{
Some(Ok(Type::BooleanLiteral(literal_1 == literal_2)))
Some(Ok(match try_dunder(self, MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK) {
Ok(ty) => ty,
Err(_) => Type::BooleanLiteral(literal_1 == literal_2),
}))
}
(Type::EnumLiteral(literal_1), Type::EnumLiteral(literal_2))
if op == ast::CmpOp::NotEq =>
{
Some(Ok(Type::BooleanLiteral(literal_1 != literal_2)))
Some(Ok(match try_dunder(self, MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK) {
Ok(ty) => ty,
Err(_) => Type::BooleanLiteral(literal_1 != literal_2),
}))
}
(
@ -8320,39 +8368,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
// Final generalized fallback: lookup the rich comparison `__dunder__` methods
let rich_comparison = |op| self.infer_rich_comparison(left, right, op);
let membership_test_comparison =
|op, range: TextRange| self.infer_membership_test_comparison(left, right, op, range);
match op {
ast::CmpOp::Eq => rich_comparison(RichCompareOperator::Eq),
ast::CmpOp::NotEq => rich_comparison(RichCompareOperator::Ne),
ast::CmpOp::Lt => rich_comparison(RichCompareOperator::Lt),
ast::CmpOp::LtE => rich_comparison(RichCompareOperator::Le),
ast::CmpOp::Gt => rich_comparison(RichCompareOperator::Gt),
ast::CmpOp::GtE => rich_comparison(RichCompareOperator::Ge),
ast::CmpOp::In => membership_test_comparison(MembershipTestCompareOperator::In, range),
ast::CmpOp::NotIn => {
membership_test_comparison(MembershipTestCompareOperator::NotIn, range)
}
ast::CmpOp::Is => {
if left.is_disjoint_from(self.db(), right) {
Ok(Type::BooleanLiteral(false))
} else if left.is_singleton(self.db()) && left.is_equivalent_to(self.db(), right) {
Ok(Type::BooleanLiteral(true))
} else {
Ok(KnownClass::Bool.to_instance(self.db()))
}
}
ast::CmpOp::IsNot => {
if left.is_disjoint_from(self.db(), right) {
Ok(Type::BooleanLiteral(true))
} else if left.is_singleton(self.db()) && left.is_equivalent_to(self.db(), right) {
Ok(Type::BooleanLiteral(false))
} else {
Ok(KnownClass::Bool.to_instance(self.db()))
}
}
}
try_dunder(self, MemberLookupPolicy::default())
}
/// Rich comparison in Python are the operators `==`, `!=`, `<`, `<=`, `>`, and `>=`. Their
@ -8364,14 +8380,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
left: Type<'db>,
right: Type<'db>,
op: RichCompareOperator,
policy: MemberLookupPolicy,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
let db = self.db();
// The following resource has details about the rich comparison algorithm:
// https://snarky.ca/unravelling-rich-comparison-operators/
let call_dunder = |op: RichCompareOperator, left: Type<'db>, right: Type<'db>| {
left.try_call_dunder(db, op.dunder(), CallArguments::positional([right]))
.map(|outcome| outcome.return_type(db))
.ok()
left.try_call_dunder_with_policy(
db,
op.dunder(),
&mut CallArguments::positional([right]),
policy,
)
.map(|outcome| outcome.return_type(db))
.ok()
};
// The reflected dunder has priority if the right-hand side is a strict subclass of the left-hand side.
@ -8384,7 +8406,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// When no appropriate method returns any value other than NotImplemented,
// the `==` and `!=` operators will fall back to `is` and `is not`, respectively.
// refer to `<https://docs.python.org/3/reference/datamodel.html#object.__eq__>`
if matches!(op, RichCompareOperator::Eq | RichCompareOperator::Ne) {
if matches!(op, RichCompareOperator::Eq | RichCompareOperator::Ne)
// This branch implements specific behavior of the `__eq__` and `__ne__` methods
// on `object`, so it does not apply if we skip looking up attributes on `object`.
&& !policy.mro_no_object_fallback()
{
Some(KnownClass::Bool.to_instance(db))
} else {
None