diff --git a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md index 800ae7c7bb..7218359750 100644 --- a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md +++ b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md @@ -379,3 +379,22 @@ def as_pattern_non_exhaustive(subject: int | str): # this diagnostic is correct: the inferred type of `subject` is `str` assert_never(subject) # error: [type-assertion-failure] ``` + +## Exhaustiveness checking for methods of enums + +```py +from enum import Enum + +class Answer(Enum): + YES = "yes" + NO = "no" + + def is_yes(self) -> bool: + reveal_type(self) # revealed: Self@is_yes + + match self: + case Answer.YES: + return True + case Answer.NO: + return False +``` diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/match.md b/crates/ty_python_semantic/resources/mdtest/narrow/match.md index 55772eab24..ee51d50af2 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/match.md @@ -252,3 +252,51 @@ match x: reveal_type(x) # revealed: object ``` + +## Narrowing on `Self` in `match` statements + +When performing narrowing on `self` inside methods on enums, we take into account that `Self` might +refer to a subtype of the enum class, like `Literal[Answer.YES]`. This is why we do not simplify +`Self & ~Literal[Answer.YES]` to `Literal[Answer.NO, Answer.MAYBE]`. Otherwise, we wouldn't be able +to return `self` in the `assert_yes` method below: + +```py +from enum import Enum +from typing_extensions import Self, assert_never + +class Answer(Enum): + NO = 0 + YES = 1 + MAYBE = 2 + + def is_yes(self) -> bool: + reveal_type(self) # revealed: Self@is_yes + + match self: + case Answer.YES: + reveal_type(self) # revealed: Self@is_yes + return True + case Answer.NO | Answer.MAYBE: + reveal_type(self) # revealed: Self@is_yes & ~Literal[Answer.YES] + return False + case _: + assert_never(self) # no error + + def assert_yes(self) -> Self: + reveal_type(self) # revealed: Self@assert_yes + + match self: + case Answer.YES: + reveal_type(self) # revealed: Self@assert_yes + return self + case _: + reveal_type(self) # revealed: Self@assert_yes & ~Literal[Answer.YES] + raise ValueError("Answer is not YES") + +Answer.YES.is_yes() + +try: + reveal_type(Answer.MAYBE.assert_yes()) # revealed: Literal[Answer.MAYBE] +except ValueError: + pass +``` diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index 3d09733324..af3ef642e3 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -802,10 +802,27 @@ impl ReachabilityConstraints { fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness { let subject_ty = infer_expression_type(db, predicate.subject(db), TypeContext::default()); - let narrowed_subject_ty = IntersectionBuilder::new(db) + let narrowed_subject = IntersectionBuilder::new(db) .add_positive(subject_ty) - .add_negative(type_excluded_by_previous_patterns(db, predicate)) + .add_negative(type_excluded_by_previous_patterns(db, predicate)); + + let narrowed_subject_ty = narrowed_subject.clone().build(); + + // Consider a case where we match on a subject type of `Self` with an upper bound of `Answer`, + // where `Answer` is a {YES, NO} enum. After a previous pattern matching on `NO`, the narrowed + // subject type is `Self & ~Literal[NO]`. This type is *not* equivalent to `Literal[YES]`, + // because `Self` could also specialize to `Literal[NO]` or `Never`, making the intersection + // empty. However, if the current pattern matches on `YES`, the *next* narrowed subject type + // will be `Self & ~Literal[NO] & ~Literal[YES]`, which *is* always equivalent to `Never`. This + // means that subsequent patterns can never match. And we know that if we reach this point, + // the current pattern will have to match. We return `AlwaysTrue` here, since the call to + // `analyze_single_pattern_predicate_kind` below would return `Ambiguous` in this case. + let next_narrowed_subject_ty = narrowed_subject + .add_negative(pattern_kind_to_type(db, predicate.kind(db))) .build(); + if !narrowed_subject_ty.is_never() && next_narrowed_subject_ty.is_never() { + return Truthiness::AlwaysTrue; + } let truthiness = Self::analyze_single_pattern_predicate_kind( db, diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 7f46a80239..6b555b6fdb 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -44,6 +44,7 @@ use crate::types::{ TypeVarBoundOrConstraints, UnionType, }; use crate::{Db, FxOrderSet}; +use rustc_hash::FxHashSet; use smallvec::SmallVec; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -422,9 +423,9 @@ impl<'db> UnionBuilder<'db> { .iter() .filter_map(UnionElement::to_type_element) .filter_map(Type::as_enum_literal) - .map(|literal| literal.name(self.db).clone()) - .chain(std::iter::once(enum_member_to_add.name(self.db).clone())) - .collect::>(); + .map(|literal| literal.name(self.db)) + .chain(std::iter::once(enum_member_to_add.name(self.db))) + .collect::>(); let all_members_are_in_union = metadata .members @@ -780,6 +781,37 @@ impl<'db> IntersectionBuilder<'db> { seen_aliases, ) } + Type::EnumLiteral(enum_literal) => { + let enum_class = enum_literal.enum_class(self.db); + let metadata = + enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum"); + + let enum_members_in_negative_part = self + .intersections + .iter() + .flat_map(|intersection| &intersection.negative) + .filter_map(|ty| ty.as_enum_literal()) + .filter(|lit| lit.enum_class(self.db) == enum_class) + .map(|lit| lit.name(self.db)) + .chain(std::iter::once(enum_literal.name(self.db))) + .collect::>(); + + let all_members_are_in_negative_part = metadata + .members + .keys() + .all(|name| enum_members_in_negative_part.contains(name)); + + if all_members_are_in_negative_part { + for inner in &mut self.intersections { + inner.add_negative(self.db, enum_literal.enum_class_instance(self.db)); + } + } else { + for inner in &mut self.intersections { + inner.add_negative(self.db, ty); + } + } + self + } _ => { for inner in &mut self.intersections { inner.add_negative(self.db, ty);