From 9768abe5abc80d5d3e936ac009946bbf77fc8980 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Sun, 21 Dec 2025 21:02:36 -0500 Subject: [PATCH] [ty] Add support for narrowing on tuple match cases --- .../resources/mdtest/narrow/match.md | 104 ++++++++++ .../src/semantic_index/builder.rs | 8 + .../src/semantic_index/predicate.rs | 1 + .../reachability_constraints.rs | 55 +++++- crates/ty_python_semantic/src/types.rs | 3 +- .../ty_python_semantic/src/types/instance.rs | 2 +- crates/ty_python_semantic/src/types/narrow.rs | 177 ++++++++++++++++++ 7 files changed, 346 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/match.md b/crates/ty_python_semantic/resources/mdtest/narrow/match.md index f0c107851b..2cd18f72ed 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/match.md @@ -375,3 +375,107 @@ try: except ValueError: pass ``` + +## Sequence patterns + +Sequence patterns narrow tuple element types based on the patterns matched against each element. + +```py +def _(subj: tuple[int | str, int | str]): + match subj: + case (x, str()): + reveal_type(subj) # revealed: tuple[int | str, str] + case (int(), y): + reveal_type(subj) # revealed: tuple[int, int | str] + +def _(subj: tuple[int | str, int | str]): + match subj: + case (int(), str()): + reveal_type(subj) # revealed: tuple[int, str] + +def _(subj: tuple[int | str | None, int | str | None]): + match subj: + case (None, _): + reveal_type(subj) # revealed: tuple[None, int | str | None] + case (_, None): + reveal_type(subj) # revealed: tuple[int | str | None, None] +``` + +## Sequence patterns with nested tuples + +```py +def _(subj: tuple[tuple[int | str, int], int | str]): + match subj: + case ((str(), _), _): + # The inner tuple is narrowed by intersecting with the pattern's constraint + reveal_type(subj) # revealed: tuple[tuple[int | str, int] & tuple[str, object], int | str] +``` + +## Sequence patterns with or patterns + +```py +def _(subj: tuple[int | str | bytes, int | str]): + match subj: + case (int() | str(), _): + reveal_type(subj) # revealed: tuple[int | str, int | str] +``` + +## Sequence patterns with wildcards + +Wildcards (`_`) and name patterns don't narrow the element type. + +```py +def _(subj: tuple[int | str, int | str]): + match subj: + case (_, _): + reveal_type(subj) # revealed: tuple[int | str, int | str] + +def _(subj: tuple[int | str, int | str]): + match subj: + case (x, y): + reveal_type(subj) # revealed: tuple[int | str, int | str] +``` + +## Sequence pattern negative narrowing + +Negative narrowing for sequence patterns is not currently supported. When a sequence pattern doesn't +match, subsequent cases see the original type. + +```py +def _(subj: tuple[int | str, int | str]): + match subj: + case (int(), int()): + reveal_type(subj) # revealed: tuple[int, int] + case _: + reveal_type(subj) # revealed: tuple[int | str, int | str] +``` + +## Sequence pattern exhaustiveness + +When a sequence pattern exhaustively matches all possible tuple values, subsequent cases should be +unreachable (`Never`). + +```py +def _(subj: tuple[int, str]): + match subj: + case (int(), str()): + reveal_type(subj) # revealed: tuple[int, str] + case _: + reveal_type(subj) # revealed: Never +``` + +## Sequence patterns with homogeneous tuples + +Sequence patterns on homogeneous tuples narrow to a fixed-length tuple with the specified length. + +```py +def _(subj: tuple[int | str, ...]): + match subj: + case (x, str()): + reveal_type(subj) # revealed: tuple[int | str, str] + +def _(subj: tuple[int | str, ...]): + match subj: + case (int(), int(), y): + reveal_type(subj) # revealed: tuple[int, int, int | str] +``` diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 67745d08aa..5f6eb40444 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -948,6 +948,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { .map(|p| Box::new(self.predicate_kind(p))), pattern.name.as_ref().map(|name| name.id.clone()), ), + ast::Pattern::MatchSequence(pattern) => { + let predicates = pattern + .patterns + .iter() + .map(|pattern| self.predicate_kind(pattern)) + .collect(); + PatternPredicateKind::Sequence(predicates) + } _ => PatternPredicateKind::Unsupported, } } diff --git a/crates/ty_python_semantic/src/semantic_index/predicate.rs b/crates/ty_python_semantic/src/semantic_index/predicate.rs index abefcc34b4..d1e239e829 100644 --- a/crates/ty_python_semantic/src/semantic_index/predicate.rs +++ b/crates/ty_python_semantic/src/semantic_index/predicate.rs @@ -137,6 +137,7 @@ pub(crate) enum PatternPredicateKind<'db> { Or(Vec>), Class(Expression<'db>, ClassPatternKind), As(Option>>, Option), + Sequence(Vec>), Unsupported, } diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index 6f723f4679..194e9707cd 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -208,8 +208,8 @@ use crate::semantic_index::predicate::{ Predicates, ScopedPredicateId, }; use crate::types::{ - CallableTypes, IntersectionBuilder, Truthiness, Type, TypeContext, UnionBuilder, UnionType, - infer_expression_type, static_expression_truthiness, + CallableTypes, IntersectionBuilder, Truthiness, TupleSpec, Type, TypeContext, UnionBuilder, + UnionType, infer_expression_type, static_expression_truthiness, }; /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula @@ -348,6 +348,13 @@ fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) .as_deref() .map(|p| pattern_kind_to_type(db, p)) .unwrap_or_else(Type::object), + PatternPredicateKind::Sequence(patterns) => { + let elements: Vec<_> = patterns + .iter() + .map(|p| pattern_kind_to_type(db, p)) + .collect(); + Type::heterogeneous_tuple(db, elements) + } PatternPredicateKind::Unsupported => Type::Never, } } @@ -852,6 +859,50 @@ impl ReachabilityConstraints { .as_deref() .map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject_ty)) .unwrap_or(Truthiness::AlwaysTrue), + PatternPredicateKind::Sequence(patterns) => { + // Check if the subject is a tuple with matching length. + let tuple_spec = match subject_ty { + Type::NominalInstance(instance) => instance.tuple_spec(db), + _ => None, + }; + + let Some(tuple_spec) = tuple_spec else { + // Subject is not a tuple type; can't determine if it matches. + return Truthiness::Ambiguous; + }; + + match tuple_spec.as_ref() { + TupleSpec::Fixed(fixed) => { + if fixed.len() != patterns.len() { + // Length mismatch; pattern definitely can't match. + return Truthiness::AlwaysFalse; + } + + // Check each element pattern against its corresponding element type. + let mut result = Truthiness::AlwaysTrue; + + for (element_ty, pattern) in fixed.elements().zip(patterns.iter()) { + let element_result = Self::analyze_single_pattern_predicate_kind( + db, + pattern, + *element_ty, + ); + + match element_result { + Truthiness::AlwaysFalse => return Truthiness::AlwaysFalse, + Truthiness::Ambiguous => result = Truthiness::Ambiguous, + Truthiness::AlwaysTrue => {} + } + } + + result + } + TupleSpec::Variable(_) => { + // Variable-length tuples could match patterns of various lengths. + Truthiness::Ambiguous + } + } + } PatternPredicateKind::Unsupported => Truthiness::Ambiguous, } } diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index da46f40709..edeb9248a1 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -34,6 +34,7 @@ pub(crate) use self::infer::{ pub use self::signatures::ParameterKind; pub(crate) use self::signatures::{CallableSignature, Signature}; pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType}; +pub(crate) use self::tuple::TupleSpec; pub use crate::diagnostic::add_inferred_python_version_hint_to_diagnostic; use crate::place::{ Definedness, Place, PlaceAndQualifiers, TypeOrigin, imported_symbol, known_module_symbol, @@ -68,7 +69,7 @@ pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::newtype::NewType; pub(crate) use crate::types::signatures::{Parameter, Parameters}; use crate::types::signatures::{ParameterForm, walk_signature}; -use crate::types::tuple::{Tuple, TupleSpec, TupleSpecBuilder}; +use crate::types::tuple::{Tuple, TupleSpecBuilder}; pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type}; pub use crate::types::variance::TypeVarVariance; use crate::types::variance::VarianceInferable; diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 9e674065b9..5725df417a 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -260,7 +260,7 @@ impl<'db> NominalInstanceType<'db> { /// /// I.e., for the type `tuple[int, str]`, this will return the tuple spec `[int, str]`. /// For a subclass of `tuple[int, str]`, it will return the same tuple spec. - pub(super) fn tuple_spec(&self, db: &'db dyn Db) -> Option>> { + pub(crate) fn tuple_spec(&self, db: &'db dyn Db) -> Option>> { match self.0 { NominalInstanceInner::ExactTuple(tuple) => Some(Cow::Borrowed(tuple.tuple(db))), NominalInstanceInner::NonTuple(class) => { diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index cc9c0ca0f6..8dba6deb27 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -26,6 +26,7 @@ use rustc_hash::FxHashMap; use std::collections::hash_map::Entry; use super::UnionType; +use super::tuple::TupleSpec; /// Return the type constraint that `test` (if true) would place on `symbol`, if any. /// @@ -421,6 +422,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { PatternPredicateKind::As(pattern, _) => pattern .as_deref() .and_then(|p| self.evaluate_pattern_predicate_kind(p, subject, is_positive)), + PatternPredicateKind::Sequence(element_patterns) => { + self.evaluate_match_pattern_sequence(subject, element_patterns, is_positive) + } PatternPredicateKind::Unsupported => None, } } @@ -1153,6 +1157,179 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { }) } + /// Evaluate a sequence pattern like `case (x, y, z):` or `case [a, b]:`. + /// + /// For each element pattern, we narrow the corresponding element of the tuple subject. + fn evaluate_match_pattern_sequence( + &mut self, + subject: Expression<'db>, + element_patterns: &[PatternPredicateKind<'db>], + is_positive: bool, + ) -> Option> { + // Get the subject expression's place. + let place_expr = place_expr(subject.node_ref(self.db, self.module))?; + let place = self.expect_place(&place_expr); + + // Get the subject's type. + let subject_ty = + infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module); + + // Get the tuple spec, if it's a tuple type. + let tuple_spec = match subject_ty { + Type::NominalInstance(instance) => instance.tuple_spec(self.db)?, + _ => return None, + }; + + // Check if any element pattern provides narrowing constraints. + let has_any_constraint = element_patterns + .iter() + .any(|pattern| self.pattern_to_type_constraint(pattern).is_some()); + + // If no element pattern provides constraints (e.g., all wildcards), don't narrow. + if !has_any_constraint { + return None; + } + + // Negative narrowing for sequences is not supported. It would produce types like + // `tuple[int | str, int | str] & ~tuple[int, int]` which the type system can't + // simplify, making them impractical for actual use. + if !is_positive { + return None; + } + + // Positive narrowing: narrow each element based on its pattern. + let narrowed_elements: Vec> = match tuple_spec.as_ref() { + TupleSpec::Fixed(fixed) => { + // Require exact length match for fixed-length tuples. + if fixed.len() != element_patterns.len() { + return None; + } + + let elements = fixed.elements().collect::>(); + + // Narrow each element based on its pattern. + elements + .iter() + .zip(element_patterns.iter()) + .map(|(element_ty, pattern)| { + if let Some(constraint_ty) = self.pattern_to_type_constraint(pattern) { + // Positive case: intersect element type with pattern constraint. + return IntersectionBuilder::new(self.db) + .add_positive(**element_ty) + .add_positive(constraint_ty) + .build(); + } + // No constraint from this pattern (e.g., wildcard). + **element_ty + }) + .collect() + } + TupleSpec::Variable(variable) => { + // For variable-length tuples like `tuple[int | str, ...]`, a pattern like + // `(x, str())` narrows to a fixed-length tuple with the pattern's length. + // + // The tuple structure is: prefix + variable* + suffix. + let pattern_len = element_patterns.len(); + let prefix_len = variable.prefix.len(); + let suffix_len = variable.suffix.len(); + + // Pattern must have at least as many elements as prefix + suffix. + if pattern_len < prefix_len + suffix_len { + return None; + } + + // Build element types for a fixed-length tuple matching the pattern. + element_patterns + .iter() + .enumerate() + .map(|(i, pattern)| { + // Determine which part of the tuple this element comes from. + let element_ty = if i < prefix_len { + variable.prefix[i] + } else if i >= pattern_len - suffix_len { + variable.suffix[i - (pattern_len - suffix_len)] + } else { + variable.variable + }; + + // Apply pattern constraint if present. + if let Some(constraint_ty) = self.pattern_to_type_constraint(pattern) { + return IntersectionBuilder::new(self.db) + .add_positive(element_ty) + .add_positive(constraint_ty) + .build(); + } + element_ty + }) + .collect() + } + }; + + // Build the narrowed tuple type. + let narrowed_tuple = Type::heterogeneous_tuple(self.db, narrowed_elements); + + Some(NarrowingConstraints::from_iter([(place, narrowed_tuple)])) + } + + /// Convert a pattern kind to the type it constrains to. + /// + /// Returns `None` for patterns that don't constrain the type (like wildcards or name patterns). + fn pattern_to_type_constraint(&self, pattern: &PatternPredicateKind<'db>) -> Option> { + match pattern { + PatternPredicateKind::Singleton(singleton) => Some(match singleton { + ast::Singleton::None => Type::none(self.db), + ast::Singleton::True => Type::BooleanLiteral(true), + ast::Singleton::False => Type::BooleanLiteral(false), + }), + PatternPredicateKind::Class(cls, _) => { + let class_ty = infer_same_file_expression_type( + self.db, + *cls, + TypeContext::default(), + self.module, + ); + match class_ty { + Type::ClassLiteral(class) => { + Some(Type::instance(self.db, class.top_materialization(self.db))) + } + dynamic @ Type::Dynamic(_) => Some(dynamic), + Type::SpecialForm(SpecialFormType::Any) => Some(Type::any()), + _ => None, + } + } + PatternPredicateKind::Value(expr) => Some(infer_same_file_expression_type( + self.db, + *expr, + TypeContext::default(), + self.module, + )), + PatternPredicateKind::Or(patterns) => { + // Union of all pattern constraints. + let elements: Vec<_> = patterns + .iter() + .filter_map(|p| self.pattern_to_type_constraint(p)) + .collect(); + if elements.is_empty() { + None + } else { + Some(UnionType::from_elements(self.db, elements)) + } + } + PatternPredicateKind::As(inner, _) => inner + .as_deref() + .and_then(|p| self.pattern_to_type_constraint(p)), + PatternPredicateKind::Sequence(patterns) => { + // For nested sequences, create a tuple type. + let elements: Vec<_> = patterns + .iter() + .map(|p| self.pattern_to_type_constraint(p).unwrap_or(Type::object())) + .collect(); + Some(Type::heterogeneous_tuple(self.db, elements)) + } + PatternPredicateKind::Unsupported => None, + } + } + fn evaluate_bool_op( &mut self, expr_bool_op: &ExprBoolOp,