[red-knot] more type-narrowing in match statements (#17302)

## Summary

Add more narrowing analysis for match statements:
* add narrowing constraints from guard expressions
* add negated constraints from previous predicates and guards to
subsequent cases

This PR doesn't address that guards can mutate your subject, and so
theoretically invalidate some of these narrowing constraints that you've
previously accumulated. Some prior art on this issue [here][mutable
guards].

[mutable guards]:
https://www.irif.fr/~scherer/research/mutable-patterns/mutable-patterns-mlworkshop2024-abstract.pdf

## Test Plan

Add some new tests, and update some existing ones


---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Eric Mark Martin 2025-04-17 21:18:34 -04:00 committed by GitHub
parent edfa03a692
commit de8f4e62e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 131 additions and 37 deletions

View file

@ -39,8 +39,7 @@ match x:
case A():
reveal_type(x) # revealed: A
case B():
# TODO could be `B & ~A`
reveal_type(x) # revealed: B
reveal_type(x) # revealed: B & ~A
reveal_type(x) # revealed: object
```
@ -88,7 +87,7 @@ match x:
case 6.0:
reveal_type(x) # revealed: float
case 1j:
reveal_type(x) # revealed: complex
reveal_type(x) # revealed: complex & ~float
case b"foo":
reveal_type(x) # revealed: Literal[b"foo"]
@ -134,11 +133,11 @@ match x:
case "foo" | 42 | None:
reveal_type(x) # revealed: Literal["foo", 42] | None
case "foo" | tuple():
reveal_type(x) # revealed: Literal["foo"] | tuple
reveal_type(x) # revealed: tuple
case True | False:
reveal_type(x) # revealed: bool
case 3.14 | 2.718 | 1.414:
reveal_type(x) # revealed: float
reveal_type(x) # revealed: float & ~tuple
reveal_type(x) # revealed: object
```
@ -165,3 +164,49 @@ match x:
reveal_type(x) # revealed: object
```
## Narrowing due to guard
```py
def get_object() -> object:
return object()
x = get_object()
reveal_type(x) # revealed: object
match x:
case str() | float() if type(x) is str:
reveal_type(x) # revealed: str
case "foo" | 42 | None if isinstance(x, int):
reveal_type(x) # revealed: Literal[42]
case False if x:
reveal_type(x) # revealed: Never
case "foo" if x := "bar":
reveal_type(x) # revealed: Literal["bar"]
reveal_type(x) # revealed: object
```
## Guard and reveal_type in guard
```py
def get_object() -> object:
return object()
x = get_object()
reveal_type(x) # revealed: object
match x:
case str() | float() if type(x) is str and reveal_type(x): # revealed: str
pass
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42]
pass
case False if x and reveal_type(x): # revealed: Never
pass
case "foo" if (x := "bar") and reveal_type(x): # revealed: Literal["bar"]
pass
reveal_type(x) # revealed: object
```

View file

@ -1572,54 +1572,76 @@ where
return;
}
let after_subject = self.flow_snapshot();
let mut vis_constraints = vec![];
let mut post_case_snapshots = vec![];
for (i, case) in cases.iter().enumerate() {
if i != 0 {
post_case_snapshots.push(self.flow_snapshot());
self.flow_restore(after_subject.clone());
}
let mut no_case_matched = self.flow_snapshot();
let has_catchall = cases
.last()
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard());
let mut post_case_snapshots = vec![];
let mut match_predicate;
for (i, case) in cases.iter().enumerate() {
self.current_match_case = Some(CurrentMatchCase::new(&case.pattern));
self.visit_pattern(&case.pattern);
self.current_match_case = None;
let predicate = self.add_pattern_narrowing_constraint(
// unlike in [Stmt::If], we don't reset [no_case_matched]
// here because the effects of visiting a pattern is binding
// symbols, and this doesn't occur unless the pattern
// actually matches
match_predicate = self.add_pattern_narrowing_constraint(
subject_expr,
&case.pattern,
case.guard.as_deref(),
);
self.record_reachability_constraint(predicate);
if let Some(expr) = &case.guard {
self.visit_expr(expr);
}
let vis_constraint_id = self.record_reachability_constraint(match_predicate);
let match_success_guard_failure = case.guard.as_ref().map(|guard| {
let guard_expr = self.add_standalone_expression(guard);
self.visit_expr(guard);
let post_guard_eval = self.flow_snapshot();
let predicate = Predicate {
node: PredicateNode::Expression(guard_expr),
is_positive: true,
};
self.record_negated_narrowing_constraint(predicate);
let match_success_guard_failure = self.flow_snapshot();
self.flow_restore(post_guard_eval);
self.record_narrowing_constraint(predicate);
match_success_guard_failure
});
self.record_visibility_constraint_id(vis_constraint_id);
self.visit_body(&case.body);
for id in &vis_constraints {
self.record_negated_visibility_constraint(*id);
}
let vis_constraint_id = self.record_visibility_constraint(predicate);
vis_constraints.push(vis_constraint_id);
}
// If there is no final wildcard match case, pretend there is one. This is similar to how
// we add an implicit `else` block in if-elif chains, in case it's not present.
if !cases
.last()
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard())
{
post_case_snapshots.push(self.flow_snapshot());
self.flow_restore(after_subject.clone());
for id in &vis_constraints {
self.record_negated_visibility_constraint(*id);
if i != cases.len() - 1 || !has_catchall {
// We need to restore the state after each case, but not after the last
// one. The last one will just become the state that we merge the other
// snapshots into.
self.flow_restore(no_case_matched.clone());
self.record_negated_narrowing_constraint(match_predicate);
if let Some(match_success_guard_failure) = match_success_guard_failure {
self.flow_merge(match_success_guard_failure);
} else {
assert!(case.guard.is_none());
}
} else {
debug_assert!(match_success_guard_failure.is_none());
debug_assert!(case.guard.is_none());
}
self.record_negated_visibility_constraint(vis_constraint_id);
no_case_matched = self.flow_snapshot();
}
for post_clause_state in post_case_snapshots {
self.flow_merge(post_clause_state);
}
self.simplify_visibility_constraints(after_subject);
self.simplify_visibility_constraints(no_case_matched);
}
ast::Stmt::Try(ast::StmtTry {
body,

View file

@ -50,7 +50,13 @@ pub(crate) fn infer_narrowing_constraint<'db>(
all_negative_narrowing_constraints_for_expression(db, expression)
}
}
PredicateNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern),
PredicateNode::Pattern(pattern) => {
if predicate.is_positive {
all_narrowing_constraints_for_pattern(db, pattern)
} else {
all_negative_narrowing_constraints_for_pattern(db, pattern)
}
}
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(constraints) = constraints {
@ -95,6 +101,15 @@ fn all_negative_narrowing_constraints_for_expression<'db>(
NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), false).finish()
}
#[allow(clippy::ref_option)]
#[salsa::tracked(return_ref)]
fn all_negative_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
pattern: PatternPredicate<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), false).finish()
}
#[allow(clippy::ref_option)]
fn constraints_for_expression_cycle_recover<'db>(
_db: &'db dyn Db,
@ -217,6 +232,12 @@ fn merge_constraints_or<'db>(
}
}
fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) {
for (_symbol, ty) in constraints.iter_mut() {
*ty = ty.negate_if(db, yes);
}
}
struct NarrowingConstraintsBuilder<'db> {
db: &'db dyn Db,
predicate: PredicateNode<'db>,
@ -237,7 +258,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
PredicateNode::Expression(expression) => {
self.evaluate_expression_predicate(expression, self.is_positive)
}
PredicateNode::Pattern(pattern) => self.evaluate_pattern_predicate(pattern),
PredicateNode::Pattern(pattern) => {
self.evaluate_pattern_predicate(pattern, self.is_positive)
}
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(mut constraints) = constraints {
@ -301,10 +324,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
fn evaluate_pattern_predicate(
&mut self,
pattern: PatternPredicate<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = pattern.subject(self.db);
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject)
.map(|mut constraints| {
negate_if(&mut constraints, self.db, !is_positive);
constraints
})
}
fn symbols(&self) -> Arc<SymbolTable> {