From 2a00eca66b17de90498445996712149a8fc0e3bb Mon Sep 17 00:00:00 2001 From: David Peter Date: Wed, 23 Jul 2025 22:45:45 +0200 Subject: [PATCH] [ty] Exhaustiveness checking & reachability for `match` statements (#19508) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Implements proper reachability analysis and — in effect — exhaustiveness checking for `match` statements. This allows us to check the following code without any errors (leads to *"can implicitly return `None`"* on `main`): ```py from enum import Enum, auto class Color(Enum): RED = auto() GREEN = auto() BLUE = auto() def hex(color: Color) -> str: match color: case Color.RED: return "#ff0000" case Color.GREEN: return "#00ff00" case Color.BLUE: return "#0000ff" ``` Note that code like this already worked fine if there was a `assert_never(color)` statement in a catch-all case, because we would then consider that `assert_never` call terminal. But now this also works without the wildcard case. Adding a member to the enum would still lead to an error here, if that case would not be handled in `hex`. What needed to happen to support this is a new way of evaluating match pattern constraints. Previously, we would simply compare the type of the subject expression against the patterns. For the last case here, the subject type would still be `Color` and the value type would be `Literal[Color.BLUE]`, so we would infer an ambiguous truthiness. Now, before we compare the subject type against the pattern, we first generate a union type that corresponds to the set of all values that would have *definitely been matched* by previous patterns. Then, we build a "narrowed" subject type by computing `subject_type & ~already_matched_type`, and compare *that* against the pattern type. For the example here, `already_matched_type = Literal[Color.RED] | Literal[Color.GREEN]`, and so we have a narrowed subject type of `Color & ~(Literal[Color.RED] | Literal[Color.GREEN]) = Literal[Color.BLUE]`, which allows us to infer a reachability of `AlwaysTrue`.
A note on negated reachability constraints It might seem that we now perform duplicate work, because we also record *negated* reachability constraints. But that is still important for cases like the following (and possibly also for more realistic scenarios): ```py from typing import Literal def _(x: int | str): match x: case None: pass # never reachable case _: y = 1 y ```
closes https://github.com/astral-sh/ty/issues/99 ## Test Plan * I verified that this solves all examples from the linked ticket (the first example needs a PEP 695 type alias, because we don't support legacy type aliases yet) * Verified that the ecosystem changes are all because of removed false positives * Updated tests --- .../resources/mdtest/conditional/match.md | 14 ++- .../mdtest/directives/assert_never.md | 2 - .../resources/mdtest/enums.md | 2 - .../mdtest/exhaustiveness_checking.md | 26 +++--- .../src/semantic_index/builder.rs | 21 +++-- .../src/semantic_index/predicate.rs | 3 + .../reachability_constraints.rs | 90 +++++++++++++++---- 7 files changed, 109 insertions(+), 49 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/conditional/match.md b/crates/ty_python_semantic/resources/mdtest/conditional/match.md index 1b92e7e8a7..52513d6698 100644 --- a/crates/ty_python_semantic/resources/mdtest/conditional/match.md +++ b/crates/ty_python_semantic/resources/mdtest/conditional/match.md @@ -201,8 +201,7 @@ def _(target: Literal[True, False]): case None: y = 4 - # TODO: with exhaustiveness checking, this should be Literal[2, 3] - reveal_type(y) # revealed: Literal[1, 2, 3] + reveal_type(y) # revealed: Literal[2, 3] def _(target: bool): y = 1 @@ -215,8 +214,7 @@ def _(target: bool): case None: y = 4 - # TODO: with exhaustiveness checking, this should be Literal[2, 3] - reveal_type(y) # revealed: Literal[1, 2, 3] + reveal_type(y) # revealed: Literal[2, 3] def _(target: None): y = 1 @@ -242,8 +240,7 @@ def _(target: None | Literal[True]): case None: y = 4 - # TODO: with exhaustiveness checking, this should be Literal[2, 4] - reveal_type(y) # revealed: Literal[1, 2, 4] + reveal_type(y) # revealed: Literal[2, 4] # bool is an int subclass def _(target: int): @@ -292,7 +289,7 @@ def _(answer: Answer): reveal_type(answer) # revealed: Literal[Answer.NO] y = 2 - reveal_type(y) # revealed: Literal[0, 1, 2] + reveal_type(y) # revealed: Literal[1, 2] ``` ## Or match @@ -311,8 +308,7 @@ def _(target: Literal["foo", "baz"]): case "baz": y = 3 - # TODO: with exhaustiveness, this should be Literal[2, 3] - reveal_type(y) # revealed: Literal[1, 2, 3] + reveal_type(y) # revealed: Literal[2, 3] def _(target: None): y = 1 diff --git a/crates/ty_python_semantic/resources/mdtest/directives/assert_never.md b/crates/ty_python_semantic/resources/mdtest/directives/assert_never.md index abb5117564..f9e4ced7d5 100644 --- a/crates/ty_python_semantic/resources/mdtest/directives/assert_never.md +++ b/crates/ty_python_semantic/resources/mdtest/directives/assert_never.md @@ -119,8 +119,6 @@ def match_singletons_success(obj: Literal[1, "a"] | None): case None: pass case _ as obj: - # TODO: Ideally, we would not emit an error here - # error: [type-assertion-failure] "Argument does not have asserted type `Never`" assert_never(obj) def match_singletons_error(obj: Literal[1, "a"] | None): diff --git a/crates/ty_python_semantic/resources/mdtest/enums.md b/crates/ty_python_semantic/resources/mdtest/enums.md index 6981f5c06a..d1265225cf 100644 --- a/crates/ty_python_semantic/resources/mdtest/enums.md +++ b/crates/ty_python_semantic/resources/mdtest/enums.md @@ -720,8 +720,6 @@ def color_name(color: Color) -> str: case _: assert_never(color) -# TODO: this should not be an error, see https://github.com/astral-sh/ty/issues/99#issuecomment-2983054488 -# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `str`" def color_name_without_assertion(color: Color) -> str: match color: case Color.RED: diff --git a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md index 9ecc26be0e..713cd05b2d 100644 --- a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md +++ b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md @@ -50,13 +50,10 @@ def match_exhaustive(x: Literal[0, 1, "a"]): case "a": pass case _: - # TODO: this should not be an error - no_diagnostic_here # error: [unresolved-reference] + no_diagnostic_here assert_never(x) -# TODO: there should be no error here -# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`" def match_exhaustive_no_assertion(x: Literal[0, 1, "a"]) -> int: match x: case 0: @@ -130,13 +127,21 @@ def match_exhaustive(x: Color): case Color.BLUE: pass case _: - # TODO: this should not be an error - no_diagnostic_here # error: [unresolved-reference] + no_diagnostic_here + + assert_never(x) + +def match_exhaustive_2(x: Color): + match x: + case Color.RED: + pass + case Color.GREEN | Color.BLUE: + pass + case _: + no_diagnostic_here assert_never(x) -# TODO: there should be no error here -# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`" def match_exhaustive_no_assertion(x: Color) -> int: match x: case Color.RED: @@ -208,13 +213,10 @@ def match_exhaustive(x: A | B | C): case C(): pass case _: - # TODO: this should not be an error - no_diagnostic_here # error: [unresolved-reference] + no_diagnostic_here assert_never(x) -# TODO: there should be no error here -# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`" def match_exhaustive_no_assertion(x: A | B | C) -> int: match x: case A(): diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index cac8287d49..b6079d42ec 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -734,7 +734,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { subject: Expression<'db>, pattern: &ast::Pattern, guard: Option<&ast::Expr>, - ) -> PredicateOrLiteral<'db> { + previous_pattern: Option>, + ) -> (PredicateOrLiteral<'db>, PatternPredicate<'db>) { // This is called for the top-level pattern of each match arm. We need to create a // standalone expression for each arm of a match statement, since they can introduce // constraints on the match subject. (Or more accurately, for the match arm's pattern, @@ -756,13 +757,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { subject, kind, guard, + previous_pattern.map(Box::new), ); let predicate = PredicateOrLiteral::Predicate(Predicate { node: PredicateNode::Pattern(pattern_predicate), is_positive: true, }); self.record_narrowing_constraint(predicate); - predicate + (predicate, pattern_predicate) } /// Record an expression that needs to be a Salsa ingredient, because we need to infer its type @@ -1747,7 +1749,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()); let mut post_case_snapshots = vec![]; - let mut match_predicate; + let mut previous_pattern: Option> = None; for (i, case) in cases.iter().enumerate() { self.current_match_case = Some(CurrentMatchCase::new(&case.pattern)); @@ -1757,11 +1759,14 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // here because the effects of visiting a pattern is binding // symbols, and this doesn't occur unless the pattern // actually matches - match_predicate = self.add_pattern_narrowing_constraint( - subject_expr, - &case.pattern, - case.guard.as_deref(), - ); + let (match_predicate, match_pattern_predicate) = self + .add_pattern_narrowing_constraint( + subject_expr, + &case.pattern, + case.guard.as_deref(), + previous_pattern, + ); + previous_pattern = Some(match_pattern_predicate); let reachability_constraint = self.record_reachability_constraint(match_predicate); diff --git a/crates/ty_python_semantic/src/semantic_index/predicate.rs b/crates/ty_python_semantic/src/semantic_index/predicate.rs index acd232ed05..35d738c6c4 100644 --- a/crates/ty_python_semantic/src/semantic_index/predicate.rs +++ b/crates/ty_python_semantic/src/semantic_index/predicate.rs @@ -150,6 +150,9 @@ pub(crate) struct PatternPredicate<'db> { pub(crate) kind: PatternPredicateKind<'db>, pub(crate) guard: Option>, + + /// A reference to the pattern of the previous match case + pub(crate) previous_predicate: Option>>, } // The Salsa heap is tracked separately. diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index caa0b7c247..03f8ac0ace 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -202,13 +202,14 @@ use crate::Db; use crate::dunder_all::dunder_all_names; use crate::place::{RequiresExplicitReExport, imported_symbol}; use crate::rank::RankBitBox; -use crate::semantic_index::expression::Expression; use crate::semantic_index::place_table; use crate::semantic_index::predicate::{ CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId, }; -use crate::types::{Truthiness, Type, infer_expression_type}; +use crate::types::{ + IntersectionBuilder, Truthiness, Type, UnionBuilder, UnionType, infer_expression_type, +}; /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula /// is just like a boolean formula, but with `Ambiguous` as a third potential result. See the @@ -311,6 +312,55 @@ const AMBIGUOUS: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId const ALWAYS_FALSE: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId::ALWAYS_FALSE; const SMALLEST_TERMINAL: ScopedReachabilityConstraintId = ALWAYS_FALSE; +fn singleton_to_type(db: &dyn Db, singleton: ruff_python_ast::Singleton) -> Type<'_> { + let ty = match singleton { + ruff_python_ast::Singleton::None => Type::none(db), + ruff_python_ast::Singleton::True => Type::BooleanLiteral(true), + ruff_python_ast::Singleton::False => Type::BooleanLiteral(false), + }; + debug_assert!(ty.is_singleton(db)); + ty +} + +/// Turn a `match` pattern kind into a type that represents the set of all values that would definitely +/// match that pattern. +fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> { + match kind { + PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton), + PatternPredicateKind::Value(value) => infer_expression_type(db, *value), + PatternPredicateKind::Class(class_expr, kind) => { + if kind.is_irrefutable() { + infer_expression_type(db, *class_expr) + .to_instance(db) + .unwrap_or(Type::Never) + } else { + Type::Never + } + } + PatternPredicateKind::Or(predicates) => { + UnionType::from_elements(db, predicates.iter().map(|p| pattern_kind_to_type(db, p))) + } + PatternPredicateKind::Unsupported => Type::Never, + } +} + +/// Go through the list of previous match cases, and accumulate a union of all types that were already +/// matched by these patterns. +fn type_excluded_by_previous_patterns<'db>( + db: &'db dyn Db, + mut predicate: PatternPredicate<'db>, +) -> Type<'db> { + let mut builder = UnionBuilder::new(db); + while let Some(previous) = predicate.previous_predicate(db) { + predicate = *previous; + + if predicate.guard(db).is_none() { + builder = builder.add(pattern_kind_to_type(db, predicate.kind(db))); + } + } + builder.build() +} + /// A collection of reachability constraints for a given scope. #[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)] pub(crate) struct ReachabilityConstraints { @@ -637,11 +687,10 @@ impl ReachabilityConstraints { fn analyze_single_pattern_predicate_kind<'db>( db: &'db dyn Db, predicate_kind: &PatternPredicateKind<'db>, - subject: Expression<'db>, + subject_ty: Type<'db>, ) -> Truthiness { match predicate_kind { PatternPredicateKind::Value(value) => { - let subject_ty = infer_expression_type(db, subject); let value_ty = infer_expression_type(db, *value); if subject_ty.is_single_valued(db) { @@ -651,15 +700,7 @@ impl ReachabilityConstraints { } } PatternPredicateKind::Singleton(singleton) => { - let subject_ty = infer_expression_type(db, subject); - - let singleton_ty = match singleton { - ruff_python_ast::Singleton::None => Type::none(db), - ruff_python_ast::Singleton::True => Type::BooleanLiteral(true), - ruff_python_ast::Singleton::False => Type::BooleanLiteral(false), - }; - - debug_assert!(singleton_ty.is_singleton(db)); + let singleton_ty = singleton_to_type(db, *singleton); if subject_ty.is_equivalent_to(db, singleton_ty) { Truthiness::AlwaysTrue @@ -671,10 +712,21 @@ impl ReachabilityConstraints { } PatternPredicateKind::Or(predicates) => { use std::ops::ControlFlow; + + let mut excluded_types = vec![]; let (ControlFlow::Break(truthiness) | ControlFlow::Continue(truthiness)) = predicates .iter() - .map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject)) + .map(|p| { + let narrowed_subject_ty = IntersectionBuilder::new(db) + .add_positive(subject_ty) + .add_negative(UnionType::from_elements(db, excluded_types.iter())) + .build(); + + excluded_types.push(pattern_kind_to_type(db, p)); + + Self::analyze_single_pattern_predicate_kind(db, p, narrowed_subject_ty) + }) // this is just a "max", but with a slight optimization: `AlwaysTrue` is the "greatest" possible element, so we short-circuit if we get there .try_fold(Truthiness::AlwaysFalse, |acc, next| match (acc, next) { (Truthiness::AlwaysTrue, _) | (_, Truthiness::AlwaysTrue) => { @@ -690,7 +742,6 @@ impl ReachabilityConstraints { truthiness } PatternPredicateKind::Class(class_expr, kind) => { - let subject_ty = infer_expression_type(db, subject); let class_ty = infer_expression_type(db, *class_expr).to_instance(db); class_ty.map_or(Truthiness::Ambiguous, |class_ty| { @@ -715,10 +766,17 @@ impl ReachabilityConstraints { } fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness { + let subject_ty = infer_expression_type(db, predicate.subject(db)); + + let narrowed_subject_ty = IntersectionBuilder::new(db) + .add_positive(subject_ty) + .add_negative(type_excluded_by_previous_patterns(db, predicate)) + .build(); + let truthiness = Self::analyze_single_pattern_predicate_kind( db, predicate.kind(db), - predicate.subject(db), + narrowed_subject_ty, ); if truthiness == Truthiness::AlwaysTrue && predicate.guard(db).is_some() {