[ty] Reachability and narrowing for enum methods (#21130)

## Summary

Adds proper type narrowing and reachability analysis for matching on
non-inferable type variables bound to enums. For example:

```py
from enum import Enum

class Answer(Enum):
    NO = 0
    YES = 1

    def is_yes(self) -> bool:  # no error here!
        match self:
            case Answer.YES:
                return True
            case Answer.NO:
                return False
```

closes https://github.com/astral-sh/ty/issues/1404

## Test Plan

Added regression tests
This commit is contained in:
David Peter 2025-10-30 15:38:57 +01:00 committed by GitHub
parent 1b0ee4677e
commit e55bc943e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 121 additions and 5 deletions

View file

@ -379,3 +379,22 @@ def as_pattern_non_exhaustive(subject: int | str):
# this diagnostic is correct: the inferred type of `subject` is `str`
assert_never(subject) # error: [type-assertion-failure]
```
## Exhaustiveness checking for methods of enums
```py
from enum import Enum
class Answer(Enum):
YES = "yes"
NO = "no"
def is_yes(self) -> bool:
reveal_type(self) # revealed: Self@is_yes
match self:
case Answer.YES:
return True
case Answer.NO:
return False
```

View file

@ -252,3 +252,51 @@ match x:
reveal_type(x) # revealed: object
```
## Narrowing on `Self` in `match` statements
When performing narrowing on `self` inside methods on enums, we take into account that `Self` might
refer to a subtype of the enum class, like `Literal[Answer.YES]`. This is why we do not simplify
`Self & ~Literal[Answer.YES]` to `Literal[Answer.NO, Answer.MAYBE]`. Otherwise, we wouldn't be able
to return `self` in the `assert_yes` method below:
```py
from enum import Enum
from typing_extensions import Self, assert_never
class Answer(Enum):
NO = 0
YES = 1
MAYBE = 2
def is_yes(self) -> bool:
reveal_type(self) # revealed: Self@is_yes
match self:
case Answer.YES:
reveal_type(self) # revealed: Self@is_yes
return True
case Answer.NO | Answer.MAYBE:
reveal_type(self) # revealed: Self@is_yes & ~Literal[Answer.YES]
return False
case _:
assert_never(self) # no error
def assert_yes(self) -> Self:
reveal_type(self) # revealed: Self@assert_yes
match self:
case Answer.YES:
reveal_type(self) # revealed: Self@assert_yes
return self
case _:
reveal_type(self) # revealed: Self@assert_yes & ~Literal[Answer.YES]
raise ValueError("Answer is not YES")
Answer.YES.is_yes()
try:
reveal_type(Answer.MAYBE.assert_yes()) # revealed: Literal[Answer.MAYBE]
except ValueError:
pass
```

View file

@ -802,10 +802,27 @@ impl ReachabilityConstraints {
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
let subject_ty = infer_expression_type(db, predicate.subject(db), TypeContext::default());
let narrowed_subject_ty = IntersectionBuilder::new(db)
let narrowed_subject = IntersectionBuilder::new(db)
.add_positive(subject_ty)
.add_negative(type_excluded_by_previous_patterns(db, predicate))
.add_negative(type_excluded_by_previous_patterns(db, predicate));
let narrowed_subject_ty = narrowed_subject.clone().build();
// Consider a case where we match on a subject type of `Self` with an upper bound of `Answer`,
// where `Answer` is a {YES, NO} enum. After a previous pattern matching on `NO`, the narrowed
// subject type is `Self & ~Literal[NO]`. This type is *not* equivalent to `Literal[YES]`,
// because `Self` could also specialize to `Literal[NO]` or `Never`, making the intersection
// empty. However, if the current pattern matches on `YES`, the *next* narrowed subject type
// will be `Self & ~Literal[NO] & ~Literal[YES]`, which *is* always equivalent to `Never`. This
// means that subsequent patterns can never match. And we know that if we reach this point,
// the current pattern will have to match. We return `AlwaysTrue` here, since the call to
// `analyze_single_pattern_predicate_kind` below would return `Ambiguous` in this case.
let next_narrowed_subject_ty = narrowed_subject
.add_negative(pattern_kind_to_type(db, predicate.kind(db)))
.build();
if !narrowed_subject_ty.is_never() && next_narrowed_subject_ty.is_never() {
return Truthiness::AlwaysTrue;
}
let truthiness = Self::analyze_single_pattern_predicate_kind(
db,

View file

@ -44,6 +44,7 @@ use crate::types::{
TypeVarBoundOrConstraints, UnionType,
};
use crate::{Db, FxOrderSet};
use rustc_hash::FxHashSet;
use smallvec::SmallVec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -422,9 +423,9 @@ impl<'db> UnionBuilder<'db> {
.iter()
.filter_map(UnionElement::to_type_element)
.filter_map(Type::as_enum_literal)
.map(|literal| literal.name(self.db).clone())
.chain(std::iter::once(enum_member_to_add.name(self.db).clone()))
.collect::<FxOrderSet<_>>();
.map(|literal| literal.name(self.db))
.chain(std::iter::once(enum_member_to_add.name(self.db)))
.collect::<FxHashSet<_>>();
let all_members_are_in_union = metadata
.members
@ -780,6 +781,37 @@ impl<'db> IntersectionBuilder<'db> {
seen_aliases,
)
}
Type::EnumLiteral(enum_literal) => {
let enum_class = enum_literal.enum_class(self.db);
let metadata =
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
let enum_members_in_negative_part = self
.intersections
.iter()
.flat_map(|intersection| &intersection.negative)
.filter_map(|ty| ty.as_enum_literal())
.filter(|lit| lit.enum_class(self.db) == enum_class)
.map(|lit| lit.name(self.db))
.chain(std::iter::once(enum_literal.name(self.db)))
.collect::<FxHashSet<_>>();
let all_members_are_in_negative_part = metadata
.members
.keys()
.all(|name| enum_members_in_negative_part.contains(name));
if all_members_are_in_negative_part {
for inner in &mut self.intersections {
inner.add_negative(self.db, enum_literal.enum_class_instance(self.db));
}
} else {
for inner in &mut self.intersections {
inner.add_negative(self.db, ty);
}
}
self
}
_ => {
for inner in &mut self.intersections {
inner.add_negative(self.db, ty);