mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-03 18:28:24 +00:00
[red-knot] more type-narrowing in match statements (#17302)
## Summary Add more narrowing analysis for match statements: * add narrowing constraints from guard expressions * add negated constraints from previous predicates and guards to subsequent cases This PR doesn't address that guards can mutate your subject, and so theoretically invalidate some of these narrowing constraints that you've previously accumulated. Some prior art on this issue [here][mutable guards]. [mutable guards]: https://www.irif.fr/~scherer/research/mutable-patterns/mutable-patterns-mlworkshop2024-abstract.pdf ## Test Plan Add some new tests, and update some existing ones --------- Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
parent
edfa03a692
commit
de8f4e62e2
3 changed files with 131 additions and 37 deletions
|
@ -39,8 +39,7 @@ match x:
|
|||
case A():
|
||||
reveal_type(x) # revealed: A
|
||||
case B():
|
||||
# TODO could be `B & ~A`
|
||||
reveal_type(x) # revealed: B
|
||||
reveal_type(x) # revealed: B & ~A
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
@ -88,7 +87,7 @@ match x:
|
|||
case 6.0:
|
||||
reveal_type(x) # revealed: float
|
||||
case 1j:
|
||||
reveal_type(x) # revealed: complex
|
||||
reveal_type(x) # revealed: complex & ~float
|
||||
case b"foo":
|
||||
reveal_type(x) # revealed: Literal[b"foo"]
|
||||
|
||||
|
@ -134,11 +133,11 @@ match x:
|
|||
case "foo" | 42 | None:
|
||||
reveal_type(x) # revealed: Literal["foo", 42] | None
|
||||
case "foo" | tuple():
|
||||
reveal_type(x) # revealed: Literal["foo"] | tuple
|
||||
reveal_type(x) # revealed: tuple
|
||||
case True | False:
|
||||
reveal_type(x) # revealed: bool
|
||||
case 3.14 | 2.718 | 1.414:
|
||||
reveal_type(x) # revealed: float
|
||||
reveal_type(x) # revealed: float & ~tuple
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
@ -165,3 +164,49 @@ match x:
|
|||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
||||
## Narrowing due to guard
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
|
||||
x = get_object()
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
match x:
|
||||
case str() | float() if type(x) is str:
|
||||
reveal_type(x) # revealed: str
|
||||
case "foo" | 42 | None if isinstance(x, int):
|
||||
reveal_type(x) # revealed: Literal[42]
|
||||
case False if x:
|
||||
reveal_type(x) # revealed: Never
|
||||
case "foo" if x := "bar":
|
||||
reveal_type(x) # revealed: Literal["bar"]
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
||||
## Guard and reveal_type in guard
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
|
||||
x = get_object()
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
match x:
|
||||
case str() | float() if type(x) is str and reveal_type(x): # revealed: str
|
||||
pass
|
||||
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42]
|
||||
pass
|
||||
case False if x and reveal_type(x): # revealed: Never
|
||||
pass
|
||||
case "foo" if (x := "bar") and reveal_type(x): # revealed: Literal["bar"]
|
||||
pass
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
|
|
@ -1572,54 +1572,76 @@ where
|
|||
return;
|
||||
}
|
||||
|
||||
let after_subject = self.flow_snapshot();
|
||||
let mut vis_constraints = vec![];
|
||||
let mut post_case_snapshots = vec![];
|
||||
for (i, case) in cases.iter().enumerate() {
|
||||
if i != 0 {
|
||||
post_case_snapshots.push(self.flow_snapshot());
|
||||
self.flow_restore(after_subject.clone());
|
||||
}
|
||||
let mut no_case_matched = self.flow_snapshot();
|
||||
|
||||
let has_catchall = cases
|
||||
.last()
|
||||
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard());
|
||||
|
||||
let mut post_case_snapshots = vec![];
|
||||
let mut match_predicate;
|
||||
|
||||
for (i, case) in cases.iter().enumerate() {
|
||||
self.current_match_case = Some(CurrentMatchCase::new(&case.pattern));
|
||||
self.visit_pattern(&case.pattern);
|
||||
self.current_match_case = None;
|
||||
let predicate = self.add_pattern_narrowing_constraint(
|
||||
// unlike in [Stmt::If], we don't reset [no_case_matched]
|
||||
// here because the effects of visiting a pattern is binding
|
||||
// symbols, and this doesn't occur unless the pattern
|
||||
// actually matches
|
||||
match_predicate = self.add_pattern_narrowing_constraint(
|
||||
subject_expr,
|
||||
&case.pattern,
|
||||
case.guard.as_deref(),
|
||||
);
|
||||
self.record_reachability_constraint(predicate);
|
||||
if let Some(expr) = &case.guard {
|
||||
self.visit_expr(expr);
|
||||
}
|
||||
let vis_constraint_id = self.record_reachability_constraint(match_predicate);
|
||||
|
||||
let match_success_guard_failure = case.guard.as_ref().map(|guard| {
|
||||
let guard_expr = self.add_standalone_expression(guard);
|
||||
self.visit_expr(guard);
|
||||
let post_guard_eval = self.flow_snapshot();
|
||||
let predicate = Predicate {
|
||||
node: PredicateNode::Expression(guard_expr),
|
||||
is_positive: true,
|
||||
};
|
||||
self.record_negated_narrowing_constraint(predicate);
|
||||
let match_success_guard_failure = self.flow_snapshot();
|
||||
self.flow_restore(post_guard_eval);
|
||||
self.record_narrowing_constraint(predicate);
|
||||
match_success_guard_failure
|
||||
});
|
||||
|
||||
self.record_visibility_constraint_id(vis_constraint_id);
|
||||
|
||||
self.visit_body(&case.body);
|
||||
for id in &vis_constraints {
|
||||
self.record_negated_visibility_constraint(*id);
|
||||
}
|
||||
let vis_constraint_id = self.record_visibility_constraint(predicate);
|
||||
vis_constraints.push(vis_constraint_id);
|
||||
}
|
||||
|
||||
// If there is no final wildcard match case, pretend there is one. This is similar to how
|
||||
// we add an implicit `else` block in if-elif chains, in case it's not present.
|
||||
if !cases
|
||||
.last()
|
||||
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard())
|
||||
{
|
||||
post_case_snapshots.push(self.flow_snapshot());
|
||||
self.flow_restore(after_subject.clone());
|
||||
|
||||
for id in &vis_constraints {
|
||||
self.record_negated_visibility_constraint(*id);
|
||||
if i != cases.len() - 1 || !has_catchall {
|
||||
// We need to restore the state after each case, but not after the last
|
||||
// one. The last one will just become the state that we merge the other
|
||||
// snapshots into.
|
||||
self.flow_restore(no_case_matched.clone());
|
||||
self.record_negated_narrowing_constraint(match_predicate);
|
||||
if let Some(match_success_guard_failure) = match_success_guard_failure {
|
||||
self.flow_merge(match_success_guard_failure);
|
||||
} else {
|
||||
assert!(case.guard.is_none());
|
||||
}
|
||||
} else {
|
||||
debug_assert!(match_success_guard_failure.is_none());
|
||||
debug_assert!(case.guard.is_none());
|
||||
}
|
||||
|
||||
self.record_negated_visibility_constraint(vis_constraint_id);
|
||||
no_case_matched = self.flow_snapshot();
|
||||
}
|
||||
|
||||
for post_clause_state in post_case_snapshots {
|
||||
self.flow_merge(post_clause_state);
|
||||
}
|
||||
|
||||
self.simplify_visibility_constraints(after_subject);
|
||||
self.simplify_visibility_constraints(no_case_matched);
|
||||
}
|
||||
ast::Stmt::Try(ast::StmtTry {
|
||||
body,
|
||||
|
|
|
@ -50,7 +50,13 @@ pub(crate) fn infer_narrowing_constraint<'db>(
|
|||
all_negative_narrowing_constraints_for_expression(db, expression)
|
||||
}
|
||||
}
|
||||
PredicateNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern),
|
||||
PredicateNode::Pattern(pattern) => {
|
||||
if predicate.is_positive {
|
||||
all_narrowing_constraints_for_pattern(db, pattern)
|
||||
} else {
|
||||
all_negative_narrowing_constraints_for_pattern(db, pattern)
|
||||
}
|
||||
}
|
||||
PredicateNode::StarImportPlaceholder(_) => return None,
|
||||
};
|
||||
if let Some(constraints) = constraints {
|
||||
|
@ -95,6 +101,15 @@ fn all_negative_narrowing_constraints_for_expression<'db>(
|
|||
NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), false).finish()
|
||||
}
|
||||
|
||||
#[allow(clippy::ref_option)]
|
||||
#[salsa::tracked(return_ref)]
|
||||
fn all_negative_narrowing_constraints_for_pattern<'db>(
|
||||
db: &'db dyn Db,
|
||||
pattern: PatternPredicate<'db>,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), false).finish()
|
||||
}
|
||||
|
||||
#[allow(clippy::ref_option)]
|
||||
fn constraints_for_expression_cycle_recover<'db>(
|
||||
_db: &'db dyn Db,
|
||||
|
@ -217,6 +232,12 @@ fn merge_constraints_or<'db>(
|
|||
}
|
||||
}
|
||||
|
||||
fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) {
|
||||
for (_symbol, ty) in constraints.iter_mut() {
|
||||
*ty = ty.negate_if(db, yes);
|
||||
}
|
||||
}
|
||||
|
||||
struct NarrowingConstraintsBuilder<'db> {
|
||||
db: &'db dyn Db,
|
||||
predicate: PredicateNode<'db>,
|
||||
|
@ -237,7 +258,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
PredicateNode::Expression(expression) => {
|
||||
self.evaluate_expression_predicate(expression, self.is_positive)
|
||||
}
|
||||
PredicateNode::Pattern(pattern) => self.evaluate_pattern_predicate(pattern),
|
||||
PredicateNode::Pattern(pattern) => {
|
||||
self.evaluate_pattern_predicate(pattern, self.is_positive)
|
||||
}
|
||||
PredicateNode::StarImportPlaceholder(_) => return None,
|
||||
};
|
||||
if let Some(mut constraints) = constraints {
|
||||
|
@ -301,10 +324,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
fn evaluate_pattern_predicate(
|
||||
&mut self,
|
||||
pattern: PatternPredicate<'db>,
|
||||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let subject = pattern.subject(self.db);
|
||||
|
||||
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject)
|
||||
.map(|mut constraints| {
|
||||
negate_if(&mut constraints, self.db, !is_positive);
|
||||
constraints
|
||||
})
|
||||
}
|
||||
|
||||
fn symbols(&self) -> Arc<SymbolTable> {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue