mirror of
https://github.com/astral-sh/ruff.git
synced 2025-11-19 12:16:43 +00:00
[ty] Reachability and narrowing for enum methods (#21130)
## Summary
Adds proper type narrowing and reachability analysis for matching on
non-inferable type variables bound to enums. For example:
```py
from enum import Enum
class Answer(Enum):
NO = 0
YES = 1
def is_yes(self) -> bool: # no error here!
match self:
case Answer.YES:
return True
case Answer.NO:
return False
```
closes https://github.com/astral-sh/ty/issues/1404
## Test Plan
Added regression tests
This commit is contained in:
parent
1b0ee4677e
commit
e55bc943e5
4 changed files with 121 additions and 5 deletions
|
|
@ -379,3 +379,22 @@ def as_pattern_non_exhaustive(subject: int | str):
|
||||||
# this diagnostic is correct: the inferred type of `subject` is `str`
|
# this diagnostic is correct: the inferred type of `subject` is `str`
|
||||||
assert_never(subject) # error: [type-assertion-failure]
|
assert_never(subject) # error: [type-assertion-failure]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Exhaustiveness checking for methods of enums
|
||||||
|
|
||||||
|
```py
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class Answer(Enum):
|
||||||
|
YES = "yes"
|
||||||
|
NO = "no"
|
||||||
|
|
||||||
|
def is_yes(self) -> bool:
|
||||||
|
reveal_type(self) # revealed: Self@is_yes
|
||||||
|
|
||||||
|
match self:
|
||||||
|
case Answer.YES:
|
||||||
|
return True
|
||||||
|
case Answer.NO:
|
||||||
|
return False
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -252,3 +252,51 @@ match x:
|
||||||
|
|
||||||
reveal_type(x) # revealed: object
|
reveal_type(x) # revealed: object
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Narrowing on `Self` in `match` statements
|
||||||
|
|
||||||
|
When performing narrowing on `self` inside methods on enums, we take into account that `Self` might
|
||||||
|
refer to a subtype of the enum class, like `Literal[Answer.YES]`. This is why we do not simplify
|
||||||
|
`Self & ~Literal[Answer.YES]` to `Literal[Answer.NO, Answer.MAYBE]`. Otherwise, we wouldn't be able
|
||||||
|
to return `self` in the `assert_yes` method below:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from enum import Enum
|
||||||
|
from typing_extensions import Self, assert_never
|
||||||
|
|
||||||
|
class Answer(Enum):
|
||||||
|
NO = 0
|
||||||
|
YES = 1
|
||||||
|
MAYBE = 2
|
||||||
|
|
||||||
|
def is_yes(self) -> bool:
|
||||||
|
reveal_type(self) # revealed: Self@is_yes
|
||||||
|
|
||||||
|
match self:
|
||||||
|
case Answer.YES:
|
||||||
|
reveal_type(self) # revealed: Self@is_yes
|
||||||
|
return True
|
||||||
|
case Answer.NO | Answer.MAYBE:
|
||||||
|
reveal_type(self) # revealed: Self@is_yes & ~Literal[Answer.YES]
|
||||||
|
return False
|
||||||
|
case _:
|
||||||
|
assert_never(self) # no error
|
||||||
|
|
||||||
|
def assert_yes(self) -> Self:
|
||||||
|
reveal_type(self) # revealed: Self@assert_yes
|
||||||
|
|
||||||
|
match self:
|
||||||
|
case Answer.YES:
|
||||||
|
reveal_type(self) # revealed: Self@assert_yes
|
||||||
|
return self
|
||||||
|
case _:
|
||||||
|
reveal_type(self) # revealed: Self@assert_yes & ~Literal[Answer.YES]
|
||||||
|
raise ValueError("Answer is not YES")
|
||||||
|
|
||||||
|
Answer.YES.is_yes()
|
||||||
|
|
||||||
|
try:
|
||||||
|
reveal_type(Answer.MAYBE.assert_yes()) # revealed: Literal[Answer.MAYBE]
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -802,10 +802,27 @@ impl ReachabilityConstraints {
|
||||||
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
|
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
|
||||||
let subject_ty = infer_expression_type(db, predicate.subject(db), TypeContext::default());
|
let subject_ty = infer_expression_type(db, predicate.subject(db), TypeContext::default());
|
||||||
|
|
||||||
let narrowed_subject_ty = IntersectionBuilder::new(db)
|
let narrowed_subject = IntersectionBuilder::new(db)
|
||||||
.add_positive(subject_ty)
|
.add_positive(subject_ty)
|
||||||
.add_negative(type_excluded_by_previous_patterns(db, predicate))
|
.add_negative(type_excluded_by_previous_patterns(db, predicate));
|
||||||
|
|
||||||
|
let narrowed_subject_ty = narrowed_subject.clone().build();
|
||||||
|
|
||||||
|
// Consider a case where we match on a subject type of `Self` with an upper bound of `Answer`,
|
||||||
|
// where `Answer` is a {YES, NO} enum. After a previous pattern matching on `NO`, the narrowed
|
||||||
|
// subject type is `Self & ~Literal[NO]`. This type is *not* equivalent to `Literal[YES]`,
|
||||||
|
// because `Self` could also specialize to `Literal[NO]` or `Never`, making the intersection
|
||||||
|
// empty. However, if the current pattern matches on `YES`, the *next* narrowed subject type
|
||||||
|
// will be `Self & ~Literal[NO] & ~Literal[YES]`, which *is* always equivalent to `Never`. This
|
||||||
|
// means that subsequent patterns can never match. And we know that if we reach this point,
|
||||||
|
// the current pattern will have to match. We return `AlwaysTrue` here, since the call to
|
||||||
|
// `analyze_single_pattern_predicate_kind` below would return `Ambiguous` in this case.
|
||||||
|
let next_narrowed_subject_ty = narrowed_subject
|
||||||
|
.add_negative(pattern_kind_to_type(db, predicate.kind(db)))
|
||||||
.build();
|
.build();
|
||||||
|
if !narrowed_subject_ty.is_never() && next_narrowed_subject_ty.is_never() {
|
||||||
|
return Truthiness::AlwaysTrue;
|
||||||
|
}
|
||||||
|
|
||||||
let truthiness = Self::analyze_single_pattern_predicate_kind(
|
let truthiness = Self::analyze_single_pattern_predicate_kind(
|
||||||
db,
|
db,
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ use crate::types::{
|
||||||
TypeVarBoundOrConstraints, UnionType,
|
TypeVarBoundOrConstraints, UnionType,
|
||||||
};
|
};
|
||||||
use crate::{Db, FxOrderSet};
|
use crate::{Db, FxOrderSet};
|
||||||
|
use rustc_hash::FxHashSet;
|
||||||
use smallvec::SmallVec;
|
use smallvec::SmallVec;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
|
@ -422,9 +423,9 @@ impl<'db> UnionBuilder<'db> {
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(UnionElement::to_type_element)
|
.filter_map(UnionElement::to_type_element)
|
||||||
.filter_map(Type::as_enum_literal)
|
.filter_map(Type::as_enum_literal)
|
||||||
.map(|literal| literal.name(self.db).clone())
|
.map(|literal| literal.name(self.db))
|
||||||
.chain(std::iter::once(enum_member_to_add.name(self.db).clone()))
|
.chain(std::iter::once(enum_member_to_add.name(self.db)))
|
||||||
.collect::<FxOrderSet<_>>();
|
.collect::<FxHashSet<_>>();
|
||||||
|
|
||||||
let all_members_are_in_union = metadata
|
let all_members_are_in_union = metadata
|
||||||
.members
|
.members
|
||||||
|
|
@ -780,6 +781,37 @@ impl<'db> IntersectionBuilder<'db> {
|
||||||
seen_aliases,
|
seen_aliases,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
Type::EnumLiteral(enum_literal) => {
|
||||||
|
let enum_class = enum_literal.enum_class(self.db);
|
||||||
|
let metadata =
|
||||||
|
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
|
||||||
|
|
||||||
|
let enum_members_in_negative_part = self
|
||||||
|
.intersections
|
||||||
|
.iter()
|
||||||
|
.flat_map(|intersection| &intersection.negative)
|
||||||
|
.filter_map(|ty| ty.as_enum_literal())
|
||||||
|
.filter(|lit| lit.enum_class(self.db) == enum_class)
|
||||||
|
.map(|lit| lit.name(self.db))
|
||||||
|
.chain(std::iter::once(enum_literal.name(self.db)))
|
||||||
|
.collect::<FxHashSet<_>>();
|
||||||
|
|
||||||
|
let all_members_are_in_negative_part = metadata
|
||||||
|
.members
|
||||||
|
.keys()
|
||||||
|
.all(|name| enum_members_in_negative_part.contains(name));
|
||||||
|
|
||||||
|
if all_members_are_in_negative_part {
|
||||||
|
for inner in &mut self.intersections {
|
||||||
|
inner.add_negative(self.db, enum_literal.enum_class_instance(self.db));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for inner in &mut self.intersections {
|
||||||
|
inner.add_negative(self.db, ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
for inner in &mut self.intersections {
|
for inner in &mut self.intersections {
|
||||||
inner.add_negative(self.db, ty);
|
inner.add_negative(self.db, ty);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue