[ty] Add support for narrowing on tuple match cases

This commit is contained in:
Charlie Marsh 2025-12-21 21:02:36 -05:00
parent fee4e2d72a
commit 9768abe5ab
7 changed files with 346 additions and 4 deletions

View file

@ -375,3 +375,107 @@ try:
except ValueError:
pass
```
## Sequence patterns
Sequence patterns narrow tuple element types based on the patterns matched against each element.
```py
def _(subj: tuple[int | str, int | str]):
match subj:
case (x, str()):
reveal_type(subj) # revealed: tuple[int | str, str]
case (int(), y):
reveal_type(subj) # revealed: tuple[int, int | str]
def _(subj: tuple[int | str, int | str]):
match subj:
case (int(), str()):
reveal_type(subj) # revealed: tuple[int, str]
def _(subj: tuple[int | str | None, int | str | None]):
match subj:
case (None, _):
reveal_type(subj) # revealed: tuple[None, int | str | None]
case (_, None):
reveal_type(subj) # revealed: tuple[int | str | None, None]
```
## Sequence patterns with nested tuples
```py
def _(subj: tuple[tuple[int | str, int], int | str]):
match subj:
case ((str(), _), _):
# The inner tuple is narrowed by intersecting with the pattern's constraint
reveal_type(subj) # revealed: tuple[tuple[int | str, int] & tuple[str, object], int | str]
```
## Sequence patterns with or patterns
```py
def _(subj: tuple[int | str | bytes, int | str]):
match subj:
case (int() | str(), _):
reveal_type(subj) # revealed: tuple[int | str, int | str]
```
## Sequence patterns with wildcards
Wildcards (`_`) and name patterns don't narrow the element type.
```py
def _(subj: tuple[int | str, int | str]):
match subj:
case (_, _):
reveal_type(subj) # revealed: tuple[int | str, int | str]
def _(subj: tuple[int | str, int | str]):
match subj:
case (x, y):
reveal_type(subj) # revealed: tuple[int | str, int | str]
```
## Sequence pattern negative narrowing
Negative narrowing for sequence patterns is not currently supported. When a sequence pattern doesn't
match, subsequent cases see the original type.
```py
def _(subj: tuple[int | str, int | str]):
match subj:
case (int(), int()):
reveal_type(subj) # revealed: tuple[int, int]
case _:
reveal_type(subj) # revealed: tuple[int | str, int | str]
```
## Sequence pattern exhaustiveness
When a sequence pattern exhaustively matches all possible tuple values, subsequent cases should be
unreachable (`Never`).
```py
def _(subj: tuple[int, str]):
match subj:
case (int(), str()):
reveal_type(subj) # revealed: tuple[int, str]
case _:
reveal_type(subj) # revealed: Never
```
## Sequence patterns with homogeneous tuples
Sequence patterns on homogeneous tuples narrow to a fixed-length tuple with the specified length.
```py
def _(subj: tuple[int | str, ...]):
match subj:
case (x, str()):
reveal_type(subj) # revealed: tuple[int | str, str]
def _(subj: tuple[int | str, ...]):
match subj:
case (int(), int(), y):
reveal_type(subj) # revealed: tuple[int, int, int | str]
```

View file

@ -948,6 +948,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
.map(|p| Box::new(self.predicate_kind(p))),
pattern.name.as_ref().map(|name| name.id.clone()),
),
ast::Pattern::MatchSequence(pattern) => {
let predicates = pattern
.patterns
.iter()
.map(|pattern| self.predicate_kind(pattern))
.collect();
PatternPredicateKind::Sequence(predicates)
}
_ => PatternPredicateKind::Unsupported,
}
}

View file

@ -137,6 +137,7 @@ pub(crate) enum PatternPredicateKind<'db> {
Or(Vec<PatternPredicateKind<'db>>),
Class(Expression<'db>, ClassPatternKind),
As(Option<Box<PatternPredicateKind<'db>>>, Option<Name>),
Sequence(Vec<PatternPredicateKind<'db>>),
Unsupported,
}

View file

@ -208,8 +208,8 @@ use crate::semantic_index::predicate::{
Predicates, ScopedPredicateId,
};
use crate::types::{
CallableTypes, IntersectionBuilder, Truthiness, Type, TypeContext, UnionBuilder, UnionType,
infer_expression_type, static_expression_truthiness,
CallableTypes, IntersectionBuilder, Truthiness, TupleSpec, Type, TypeContext, UnionBuilder,
UnionType, infer_expression_type, static_expression_truthiness,
};
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
@ -348,6 +348,13 @@ fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>)
.as_deref()
.map(|p| pattern_kind_to_type(db, p))
.unwrap_or_else(Type::object),
PatternPredicateKind::Sequence(patterns) => {
let elements: Vec<_> = patterns
.iter()
.map(|p| pattern_kind_to_type(db, p))
.collect();
Type::heterogeneous_tuple(db, elements)
}
PatternPredicateKind::Unsupported => Type::Never,
}
}
@ -852,6 +859,50 @@ impl ReachabilityConstraints {
.as_deref()
.map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject_ty))
.unwrap_or(Truthiness::AlwaysTrue),
PatternPredicateKind::Sequence(patterns) => {
// Check if the subject is a tuple with matching length.
let tuple_spec = match subject_ty {
Type::NominalInstance(instance) => instance.tuple_spec(db),
_ => None,
};
let Some(tuple_spec) = tuple_spec else {
// Subject is not a tuple type; can't determine if it matches.
return Truthiness::Ambiguous;
};
match tuple_spec.as_ref() {
TupleSpec::Fixed(fixed) => {
if fixed.len() != patterns.len() {
// Length mismatch; pattern definitely can't match.
return Truthiness::AlwaysFalse;
}
// Check each element pattern against its corresponding element type.
let mut result = Truthiness::AlwaysTrue;
for (element_ty, pattern) in fixed.elements().zip(patterns.iter()) {
let element_result = Self::analyze_single_pattern_predicate_kind(
db,
pattern,
*element_ty,
);
match element_result {
Truthiness::AlwaysFalse => return Truthiness::AlwaysFalse,
Truthiness::Ambiguous => result = Truthiness::Ambiguous,
Truthiness::AlwaysTrue => {}
}
}
result
}
TupleSpec::Variable(_) => {
// Variable-length tuples could match patterns of various lengths.
Truthiness::Ambiguous
}
}
}
PatternPredicateKind::Unsupported => Truthiness::Ambiguous,
}
}

View file

@ -34,6 +34,7 @@ pub(crate) use self::infer::{
pub use self::signatures::ParameterKind;
pub(crate) use self::signatures::{CallableSignature, Signature};
pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType};
pub(crate) use self::tuple::TupleSpec;
pub use crate::diagnostic::add_inferred_python_version_hint_to_diagnostic;
use crate::place::{
Definedness, Place, PlaceAndQualifiers, TypeOrigin, imported_symbol, known_module_symbol,
@ -68,7 +69,7 @@ pub(crate) use crate::types::narrow::infer_narrowing_constraint;
use crate::types::newtype::NewType;
pub(crate) use crate::types::signatures::{Parameter, Parameters};
use crate::types::signatures::{ParameterForm, walk_signature};
use crate::types::tuple::{Tuple, TupleSpec, TupleSpecBuilder};
use crate::types::tuple::{Tuple, TupleSpecBuilder};
pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type};
pub use crate::types::variance::TypeVarVariance;
use crate::types::variance::VarianceInferable;

View file

@ -260,7 +260,7 @@ impl<'db> NominalInstanceType<'db> {
///
/// I.e., for the type `tuple[int, str]`, this will return the tuple spec `[int, str]`.
/// For a subclass of `tuple[int, str]`, it will return the same tuple spec.
pub(super) fn tuple_spec(&self, db: &'db dyn Db) -> Option<Cow<'db, TupleSpec<'db>>> {
pub(crate) fn tuple_spec(&self, db: &'db dyn Db) -> Option<Cow<'db, TupleSpec<'db>>> {
match self.0 {
NominalInstanceInner::ExactTuple(tuple) => Some(Cow::Borrowed(tuple.tuple(db))),
NominalInstanceInner::NonTuple(class) => {

View file

@ -26,6 +26,7 @@ use rustc_hash::FxHashMap;
use std::collections::hash_map::Entry;
use super::UnionType;
use super::tuple::TupleSpec;
/// Return the type constraint that `test` (if true) would place on `symbol`, if any.
///
@ -421,6 +422,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
PatternPredicateKind::As(pattern, _) => pattern
.as_deref()
.and_then(|p| self.evaluate_pattern_predicate_kind(p, subject, is_positive)),
PatternPredicateKind::Sequence(element_patterns) => {
self.evaluate_match_pattern_sequence(subject, element_patterns, is_positive)
}
PatternPredicateKind::Unsupported => None,
}
}
@ -1153,6 +1157,179 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
})
}
/// Evaluate a sequence pattern like `case (x, y, z):` or `case [a, b]:`.
///
/// For each element pattern, we narrow the corresponding element of the tuple subject.
fn evaluate_match_pattern_sequence(
&mut self,
subject: Expression<'db>,
element_patterns: &[PatternPredicateKind<'db>],
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
// Get the subject expression's place.
let place_expr = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&place_expr);
// Get the subject's type.
let subject_ty =
infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module);
// Get the tuple spec, if it's a tuple type.
let tuple_spec = match subject_ty {
Type::NominalInstance(instance) => instance.tuple_spec(self.db)?,
_ => return None,
};
// Check if any element pattern provides narrowing constraints.
let has_any_constraint = element_patterns
.iter()
.any(|pattern| self.pattern_to_type_constraint(pattern).is_some());
// If no element pattern provides constraints (e.g., all wildcards), don't narrow.
if !has_any_constraint {
return None;
}
// Negative narrowing for sequences is not supported. It would produce types like
// `tuple[int | str, int | str] & ~tuple[int, int]` which the type system can't
// simplify, making them impractical for actual use.
if !is_positive {
return None;
}
// Positive narrowing: narrow each element based on its pattern.
let narrowed_elements: Vec<Type<'db>> = match tuple_spec.as_ref() {
TupleSpec::Fixed(fixed) => {
// Require exact length match for fixed-length tuples.
if fixed.len() != element_patterns.len() {
return None;
}
let elements = fixed.elements().collect::<Vec<_>>();
// Narrow each element based on its pattern.
elements
.iter()
.zip(element_patterns.iter())
.map(|(element_ty, pattern)| {
if let Some(constraint_ty) = self.pattern_to_type_constraint(pattern) {
// Positive case: intersect element type with pattern constraint.
return IntersectionBuilder::new(self.db)
.add_positive(**element_ty)
.add_positive(constraint_ty)
.build();
}
// No constraint from this pattern (e.g., wildcard).
**element_ty
})
.collect()
}
TupleSpec::Variable(variable) => {
// For variable-length tuples like `tuple[int | str, ...]`, a pattern like
// `(x, str())` narrows to a fixed-length tuple with the pattern's length.
//
// The tuple structure is: prefix + variable* + suffix.
let pattern_len = element_patterns.len();
let prefix_len = variable.prefix.len();
let suffix_len = variable.suffix.len();
// Pattern must have at least as many elements as prefix + suffix.
if pattern_len < prefix_len + suffix_len {
return None;
}
// Build element types for a fixed-length tuple matching the pattern.
element_patterns
.iter()
.enumerate()
.map(|(i, pattern)| {
// Determine which part of the tuple this element comes from.
let element_ty = if i < prefix_len {
variable.prefix[i]
} else if i >= pattern_len - suffix_len {
variable.suffix[i - (pattern_len - suffix_len)]
} else {
variable.variable
};
// Apply pattern constraint if present.
if let Some(constraint_ty) = self.pattern_to_type_constraint(pattern) {
return IntersectionBuilder::new(self.db)
.add_positive(element_ty)
.add_positive(constraint_ty)
.build();
}
element_ty
})
.collect()
}
};
// Build the narrowed tuple type.
let narrowed_tuple = Type::heterogeneous_tuple(self.db, narrowed_elements);
Some(NarrowingConstraints::from_iter([(place, narrowed_tuple)]))
}
/// Convert a pattern kind to the type it constrains to.
///
/// Returns `None` for patterns that don't constrain the type (like wildcards or name patterns).
fn pattern_to_type_constraint(&self, pattern: &PatternPredicateKind<'db>) -> Option<Type<'db>> {
match pattern {
PatternPredicateKind::Singleton(singleton) => Some(match singleton {
ast::Singleton::None => Type::none(self.db),
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
}),
PatternPredicateKind::Class(cls, _) => {
let class_ty = infer_same_file_expression_type(
self.db,
*cls,
TypeContext::default(),
self.module,
);
match class_ty {
Type::ClassLiteral(class) => {
Some(Type::instance(self.db, class.top_materialization(self.db)))
}
dynamic @ Type::Dynamic(_) => Some(dynamic),
Type::SpecialForm(SpecialFormType::Any) => Some(Type::any()),
_ => None,
}
}
PatternPredicateKind::Value(expr) => Some(infer_same_file_expression_type(
self.db,
*expr,
TypeContext::default(),
self.module,
)),
PatternPredicateKind::Or(patterns) => {
// Union of all pattern constraints.
let elements: Vec<_> = patterns
.iter()
.filter_map(|p| self.pattern_to_type_constraint(p))
.collect();
if elements.is_empty() {
None
} else {
Some(UnionType::from_elements(self.db, elements))
}
}
PatternPredicateKind::As(inner, _) => inner
.as_deref()
.and_then(|p| self.pattern_to_type_constraint(p)),
PatternPredicateKind::Sequence(patterns) => {
// For nested sequences, create a tuple type.
let elements: Vec<_> = patterns
.iter()
.map(|p| self.pattern_to_type_constraint(p).unwrap_or(Type::object()))
.collect();
Some(Type::heterogeneous_tuple(self.db, elements))
}
PatternPredicateKind::Unsupported => None,
}
}
fn evaluate_bool_op(
&mut self,
expr_bool_op: &ExprBoolOp,