[red-knot] support narrowing on or patterns in matches (#17030)

## Summary

Part of #13694

Narrow in or-patterns by taking the type union of the type constraints
in each disjunct pattern.

## Test Plan

Add new tests to narrow/match.md

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Eric Mark Martin 2025-03-28 10:27:09 -04:00 committed by GitHub
parent 2e56cd3737
commit 64171744dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 122 additions and 33 deletions

View file

@ -114,3 +114,49 @@ match x:
reveal_type(x) # revealed: object
```
## Or patterns
```py
def get_object() -> object:
return object()
x = get_object()
reveal_type(x) # revealed: object
match x:
case "foo" | 42 | None:
reveal_type(x) # revealed: Literal["foo", 42] | None
case "foo" | tuple():
reveal_type(x) # revealed: Literal["foo"] | tuple
case True | False:
reveal_type(x) # revealed: bool
case 3.14 | 2.718 | 1.414:
reveal_type(x) # revealed: float
reveal_type(x) # revealed: object
```
## Or patterns with guard
```py
def get_object() -> object:
return object()
x = get_object()
reveal_type(x) # revealed: object
match x:
case "foo" | 42 | None if reveal_type(x): # revealed: Literal["foo", 42] | None
pass
case "foo" | tuple() if reveal_type(x): # revealed: Literal["foo"] | tuple
pass
case True | False if reveal_type(x): # revealed: bool
pass
case 3.14 | 2.718 | 1.414 if reveal_type(x): # revealed: float
pass
reveal_type(x) # revealed: object
```

View file

@ -589,6 +589,31 @@ impl<'db> SemanticIndexBuilder<'db> {
}
}
fn predicate_kind(&mut self, pattern: &ast::Pattern) -> PatternPredicateKind<'db> {
match pattern {
ast::Pattern::MatchValue(pattern) => {
let value = self.add_standalone_expression(&pattern.value);
PatternPredicateKind::Value(value)
}
ast::Pattern::MatchSingleton(singleton) => {
PatternPredicateKind::Singleton(singleton.value)
}
ast::Pattern::MatchClass(pattern) => {
let cls = self.add_standalone_expression(&pattern.cls);
PatternPredicateKind::Class(cls)
}
ast::Pattern::MatchOr(pattern) => {
let predicates = pattern
.patterns
.iter()
.map(|pattern| self.predicate_kind(pattern))
.collect();
PatternPredicateKind::Or(predicates)
}
_ => PatternPredicateKind::Unsupported,
}
}
fn add_pattern_narrowing_constraint(
&mut self,
subject: Expression<'db>,
@ -606,29 +631,16 @@ impl<'db> SemanticIndexBuilder<'db> {
//
// See the comment in TypeInferenceBuilder::infer_match_pattern for more details.
let kind = self.predicate_kind(pattern);
let guard = guard.map(|guard| self.add_standalone_expression(guard));
let kind = match pattern {
ast::Pattern::MatchValue(pattern) => {
let value = self.add_standalone_expression(&pattern.value);
PatternPredicateKind::Value(value, guard)
}
ast::Pattern::MatchSingleton(singleton) => {
PatternPredicateKind::Singleton(singleton.value, guard)
}
ast::Pattern::MatchClass(pattern) => {
let cls = self.add_standalone_expression(&pattern.cls);
PatternPredicateKind::Class(cls, guard)
}
_ => PatternPredicateKind::Unsupported,
};
let pattern_predicate = PatternPredicate::new(
self.db,
self.file,
self.current_scope(),
subject,
kind,
guard,
countme::Count::default(),
);
let predicate = Predicate {

View file

@ -57,9 +57,10 @@ pub(crate) enum PredicateNode<'db> {
/// Pattern kinds for which we support type narrowing and/or static visibility analysis.
#[derive(Debug, Clone, Hash, PartialEq, salsa::Update)]
pub(crate) enum PatternPredicateKind<'db> {
Singleton(Singleton, Option<Expression<'db>>),
Value(Expression<'db>, Option<Expression<'db>>),
Class(Expression<'db>, Option<Expression<'db>>),
Singleton(Singleton),
Value(Expression<'db>),
Or(Vec<PatternPredicateKind<'db>>),
Class(Expression<'db>),
Unsupported,
}
@ -74,6 +75,8 @@ pub(crate) struct PatternPredicate<'db> {
#[return_ref]
pub(crate) kind: PatternPredicateKind<'db>,
pub(crate) guard: Option<Expression<'db>>,
count: countme::Count<PatternPredicate<'static>>,
}

View file

@ -560,7 +560,7 @@ impl VisibilityConstraints {
ty.bool(db).negate_if(!predicate.is_positive)
}
PredicateNode::Pattern(inner) => match inner.kind(db) {
PatternPredicateKind::Value(value, guard) => {
PatternPredicateKind::Value(value) => {
let subject_expression = inner.subject(db);
let subject_ty = infer_expression_type(db, subject_expression);
let value_ty = infer_expression_type(db, *value);
@ -569,7 +569,7 @@ impl VisibilityConstraints {
let truthiness =
Truthiness::from(subject_ty.is_equivalent_to(db, value_ty));
if truthiness.is_always_true() && guard.is_some() {
if truthiness.is_always_true() && inner.guard(db).is_some() {
// Fall back to ambiguous, the guard might change the result.
Truthiness::Ambiguous
} else {
@ -581,6 +581,7 @@ impl VisibilityConstraints {
}
PatternPredicateKind::Singleton(..)
| PatternPredicateKind::Class(..)
| PatternPredicateKind::Or(..)
| PatternPredicateKind::Unsupported => Truthiness::Ambiguous,
},
}

View file

@ -2119,6 +2119,11 @@ impl<'db> TypeInferenceBuilder<'db> {
}
self.infer_standalone_expression(cls);
}
ast::Pattern::MatchOr(match_or) => {
for pattern in &match_or.patterns {
self.infer_match_pattern(pattern);
}
}
_ => {
self.infer_nested_match_pattern(pattern);
}

View file

@ -225,25 +225,31 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
}
fn evaluate_pattern_predicate_kind(
&mut self,
pattern_predicate_kind: &PatternPredicateKind<'db>,
subject: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
match pattern_predicate_kind {
PatternPredicateKind::Singleton(singleton) => {
self.evaluate_match_pattern_singleton(subject, *singleton)
}
PatternPredicateKind::Class(cls) => self.evaluate_match_pattern_class(subject, *cls),
PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr),
PatternPredicateKind::Or(predicates) => {
self.evaluate_match_pattern_or(subject, predicates)
}
PatternPredicateKind::Unsupported => None,
}
}
fn evaluate_pattern_predicate(
&mut self,
pattern: PatternPredicate<'db>,
) -> Option<NarrowingConstraints<'db>> {
let subject = pattern.subject(self.db);
match pattern.kind(self.db) {
PatternPredicateKind::Singleton(singleton, _guard) => {
self.evaluate_match_pattern_singleton(subject, *singleton)
}
PatternPredicateKind::Class(cls, _guard) => {
self.evaluate_match_pattern_class(subject, *cls)
}
PatternPredicateKind::Value(expr, _guard) => {
self.evaluate_match_pattern_value(subject, *expr)
}
// TODO: support more pattern kinds
PatternPredicateKind::Unsupported => None,
}
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject)
}
fn symbols(&self) -> Arc<SymbolTable> {
@ -505,6 +511,22 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}
fn evaluate_match_pattern_or(
&mut self,
subject: Expression<'db>,
predicates: &Vec<PatternPredicateKind<'db>>,
) -> Option<NarrowingConstraints<'db>> {
let db = self.db;
predicates
.iter()
.filter_map(|predicate| self.evaluate_pattern_predicate_kind(predicate, subject))
.reduce(|mut constraints, constraints_| {
merge_constraints_or(&mut constraints, &constraints_, db);
constraints
})
}
fn evaluate_bool_op(
&mut self,
expr_bool_op: &ExprBoolOp,