diff --git a/crates/ty_python_semantic/resources/mdtest/conditional/match.md b/crates/ty_python_semantic/resources/mdtest/conditional/match.md index 1b92e7e8a7..52513d6698 100644 --- a/crates/ty_python_semantic/resources/mdtest/conditional/match.md +++ b/crates/ty_python_semantic/resources/mdtest/conditional/match.md @@ -201,8 +201,7 @@ def _(target: Literal[True, False]): case None: y = 4 - # TODO: with exhaustiveness checking, this should be Literal[2, 3] - reveal_type(y) # revealed: Literal[1, 2, 3] + reveal_type(y) # revealed: Literal[2, 3] def _(target: bool): y = 1 @@ -215,8 +214,7 @@ def _(target: bool): case None: y = 4 - # TODO: with exhaustiveness checking, this should be Literal[2, 3] - reveal_type(y) # revealed: Literal[1, 2, 3] + reveal_type(y) # revealed: Literal[2, 3] def _(target: None): y = 1 @@ -242,8 +240,7 @@ def _(target: None | Literal[True]): case None: y = 4 - # TODO: with exhaustiveness checking, this should be Literal[2, 4] - reveal_type(y) # revealed: Literal[1, 2, 4] + reveal_type(y) # revealed: Literal[2, 4] # bool is an int subclass def _(target: int): @@ -292,7 +289,7 @@ def _(answer: Answer): reveal_type(answer) # revealed: Literal[Answer.NO] y = 2 - reveal_type(y) # revealed: Literal[0, 1, 2] + reveal_type(y) # revealed: Literal[1, 2] ``` ## Or match @@ -311,8 +308,7 @@ def _(target: Literal["foo", "baz"]): case "baz": y = 3 - # TODO: with exhaustiveness, this should be Literal[2, 3] - reveal_type(y) # revealed: Literal[1, 2, 3] + reveal_type(y) # revealed: Literal[2, 3] def _(target: None): y = 1 diff --git a/crates/ty_python_semantic/resources/mdtest/directives/assert_never.md b/crates/ty_python_semantic/resources/mdtest/directives/assert_never.md index abb5117564..f9e4ced7d5 100644 --- a/crates/ty_python_semantic/resources/mdtest/directives/assert_never.md +++ b/crates/ty_python_semantic/resources/mdtest/directives/assert_never.md @@ -119,8 +119,6 @@ def match_singletons_success(obj: Literal[1, "a"] | None): case None: pass case _ as obj: - # TODO: Ideally, we would not emit an error here - # error: [type-assertion-failure] "Argument does not have asserted type `Never`" assert_never(obj) def match_singletons_error(obj: Literal[1, "a"] | None): diff --git a/crates/ty_python_semantic/resources/mdtest/enums.md b/crates/ty_python_semantic/resources/mdtest/enums.md index 6981f5c06a..d1265225cf 100644 --- a/crates/ty_python_semantic/resources/mdtest/enums.md +++ b/crates/ty_python_semantic/resources/mdtest/enums.md @@ -720,8 +720,6 @@ def color_name(color: Color) -> str: case _: assert_never(color) -# TODO: this should not be an error, see https://github.com/astral-sh/ty/issues/99#issuecomment-2983054488 -# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `str`" def color_name_without_assertion(color: Color) -> str: match color: case Color.RED: diff --git a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md index 9ecc26be0e..713cd05b2d 100644 --- a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md +++ b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md @@ -50,13 +50,10 @@ def match_exhaustive(x: Literal[0, 1, "a"]): case "a": pass case _: - # TODO: this should not be an error - no_diagnostic_here # error: [unresolved-reference] + no_diagnostic_here assert_never(x) -# TODO: there should be no error here -# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`" def match_exhaustive_no_assertion(x: Literal[0, 1, "a"]) -> int: match x: case 0: @@ -130,13 +127,21 @@ def match_exhaustive(x: Color): case Color.BLUE: pass case _: - # TODO: this should not be an error - no_diagnostic_here # error: [unresolved-reference] + no_diagnostic_here + + assert_never(x) + +def match_exhaustive_2(x: Color): + match x: + case Color.RED: + pass + case Color.GREEN | Color.BLUE: + pass + case _: + no_diagnostic_here assert_never(x) -# TODO: there should be no error here -# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`" def match_exhaustive_no_assertion(x: Color) -> int: match x: case Color.RED: @@ -208,13 +213,10 @@ def match_exhaustive(x: A | B | C): case C(): pass case _: - # TODO: this should not be an error - no_diagnostic_here # error: [unresolved-reference] + no_diagnostic_here assert_never(x) -# TODO: there should be no error here -# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`" def match_exhaustive_no_assertion(x: A | B | C) -> int: match x: case A(): diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index cac8287d49..b6079d42ec 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -734,7 +734,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { subject: Expression<'db>, pattern: &ast::Pattern, guard: Option<&ast::Expr>, - ) -> PredicateOrLiteral<'db> { + previous_pattern: Option>, + ) -> (PredicateOrLiteral<'db>, PatternPredicate<'db>) { // This is called for the top-level pattern of each match arm. We need to create a // standalone expression for each arm of a match statement, since they can introduce // constraints on the match subject. (Or more accurately, for the match arm's pattern, @@ -756,13 +757,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { subject, kind, guard, + previous_pattern.map(Box::new), ); let predicate = PredicateOrLiteral::Predicate(Predicate { node: PredicateNode::Pattern(pattern_predicate), is_positive: true, }); self.record_narrowing_constraint(predicate); - predicate + (predicate, pattern_predicate) } /// Record an expression that needs to be a Salsa ingredient, because we need to infer its type @@ -1747,7 +1749,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()); let mut post_case_snapshots = vec![]; - let mut match_predicate; + let mut previous_pattern: Option> = None; for (i, case) in cases.iter().enumerate() { self.current_match_case = Some(CurrentMatchCase::new(&case.pattern)); @@ -1757,11 +1759,14 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // here because the effects of visiting a pattern is binding // symbols, and this doesn't occur unless the pattern // actually matches - match_predicate = self.add_pattern_narrowing_constraint( - subject_expr, - &case.pattern, - case.guard.as_deref(), - ); + let (match_predicate, match_pattern_predicate) = self + .add_pattern_narrowing_constraint( + subject_expr, + &case.pattern, + case.guard.as_deref(), + previous_pattern, + ); + previous_pattern = Some(match_pattern_predicate); let reachability_constraint = self.record_reachability_constraint(match_predicate); diff --git a/crates/ty_python_semantic/src/semantic_index/predicate.rs b/crates/ty_python_semantic/src/semantic_index/predicate.rs index acd232ed05..35d738c6c4 100644 --- a/crates/ty_python_semantic/src/semantic_index/predicate.rs +++ b/crates/ty_python_semantic/src/semantic_index/predicate.rs @@ -150,6 +150,9 @@ pub(crate) struct PatternPredicate<'db> { pub(crate) kind: PatternPredicateKind<'db>, pub(crate) guard: Option>, + + /// A reference to the pattern of the previous match case + pub(crate) previous_predicate: Option>>, } // The Salsa heap is tracked separately. diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index caa0b7c247..03f8ac0ace 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -202,13 +202,14 @@ use crate::Db; use crate::dunder_all::dunder_all_names; use crate::place::{RequiresExplicitReExport, imported_symbol}; use crate::rank::RankBitBox; -use crate::semantic_index::expression::Expression; use crate::semantic_index::place_table; use crate::semantic_index::predicate::{ CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId, }; -use crate::types::{Truthiness, Type, infer_expression_type}; +use crate::types::{ + IntersectionBuilder, Truthiness, Type, UnionBuilder, UnionType, infer_expression_type, +}; /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula /// is just like a boolean formula, but with `Ambiguous` as a third potential result. See the @@ -311,6 +312,55 @@ const AMBIGUOUS: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId const ALWAYS_FALSE: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId::ALWAYS_FALSE; const SMALLEST_TERMINAL: ScopedReachabilityConstraintId = ALWAYS_FALSE; +fn singleton_to_type(db: &dyn Db, singleton: ruff_python_ast::Singleton) -> Type<'_> { + let ty = match singleton { + ruff_python_ast::Singleton::None => Type::none(db), + ruff_python_ast::Singleton::True => Type::BooleanLiteral(true), + ruff_python_ast::Singleton::False => Type::BooleanLiteral(false), + }; + debug_assert!(ty.is_singleton(db)); + ty +} + +/// Turn a `match` pattern kind into a type that represents the set of all values that would definitely +/// match that pattern. +fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> { + match kind { + PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton), + PatternPredicateKind::Value(value) => infer_expression_type(db, *value), + PatternPredicateKind::Class(class_expr, kind) => { + if kind.is_irrefutable() { + infer_expression_type(db, *class_expr) + .to_instance(db) + .unwrap_or(Type::Never) + } else { + Type::Never + } + } + PatternPredicateKind::Or(predicates) => { + UnionType::from_elements(db, predicates.iter().map(|p| pattern_kind_to_type(db, p))) + } + PatternPredicateKind::Unsupported => Type::Never, + } +} + +/// Go through the list of previous match cases, and accumulate a union of all types that were already +/// matched by these patterns. +fn type_excluded_by_previous_patterns<'db>( + db: &'db dyn Db, + mut predicate: PatternPredicate<'db>, +) -> Type<'db> { + let mut builder = UnionBuilder::new(db); + while let Some(previous) = predicate.previous_predicate(db) { + predicate = *previous; + + if predicate.guard(db).is_none() { + builder = builder.add(pattern_kind_to_type(db, predicate.kind(db))); + } + } + builder.build() +} + /// A collection of reachability constraints for a given scope. #[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)] pub(crate) struct ReachabilityConstraints { @@ -637,11 +687,10 @@ impl ReachabilityConstraints { fn analyze_single_pattern_predicate_kind<'db>( db: &'db dyn Db, predicate_kind: &PatternPredicateKind<'db>, - subject: Expression<'db>, + subject_ty: Type<'db>, ) -> Truthiness { match predicate_kind { PatternPredicateKind::Value(value) => { - let subject_ty = infer_expression_type(db, subject); let value_ty = infer_expression_type(db, *value); if subject_ty.is_single_valued(db) { @@ -651,15 +700,7 @@ impl ReachabilityConstraints { } } PatternPredicateKind::Singleton(singleton) => { - let subject_ty = infer_expression_type(db, subject); - - let singleton_ty = match singleton { - ruff_python_ast::Singleton::None => Type::none(db), - ruff_python_ast::Singleton::True => Type::BooleanLiteral(true), - ruff_python_ast::Singleton::False => Type::BooleanLiteral(false), - }; - - debug_assert!(singleton_ty.is_singleton(db)); + let singleton_ty = singleton_to_type(db, *singleton); if subject_ty.is_equivalent_to(db, singleton_ty) { Truthiness::AlwaysTrue @@ -671,10 +712,21 @@ impl ReachabilityConstraints { } PatternPredicateKind::Or(predicates) => { use std::ops::ControlFlow; + + let mut excluded_types = vec![]; let (ControlFlow::Break(truthiness) | ControlFlow::Continue(truthiness)) = predicates .iter() - .map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject)) + .map(|p| { + let narrowed_subject_ty = IntersectionBuilder::new(db) + .add_positive(subject_ty) + .add_negative(UnionType::from_elements(db, excluded_types.iter())) + .build(); + + excluded_types.push(pattern_kind_to_type(db, p)); + + Self::analyze_single_pattern_predicate_kind(db, p, narrowed_subject_ty) + }) // this is just a "max", but with a slight optimization: `AlwaysTrue` is the "greatest" possible element, so we short-circuit if we get there .try_fold(Truthiness::AlwaysFalse, |acc, next| match (acc, next) { (Truthiness::AlwaysTrue, _) | (_, Truthiness::AlwaysTrue) => { @@ -690,7 +742,6 @@ impl ReachabilityConstraints { truthiness } PatternPredicateKind::Class(class_expr, kind) => { - let subject_ty = infer_expression_type(db, subject); let class_ty = infer_expression_type(db, *class_expr).to_instance(db); class_ty.map_or(Truthiness::Ambiguous, |class_ty| { @@ -715,10 +766,17 @@ impl ReachabilityConstraints { } fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness { + let subject_ty = infer_expression_type(db, predicate.subject(db)); + + let narrowed_subject_ty = IntersectionBuilder::new(db) + .add_positive(subject_ty) + .add_negative(type_excluded_by_previous_patterns(db, predicate)) + .build(); + let truthiness = Self::analyze_single_pattern_predicate_kind( db, predicate.kind(db), - predicate.subject(db), + narrowed_subject_ty, ); if truthiness == Truthiness::AlwaysTrue && predicate.guard(db).is_some() {