mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-04 10:49:50 +00:00
[red-knot] support narrowing on or patterns in matches (#17030)
## Summary Part of #13694 Narrow in or-patterns by taking the type union of the type constraints in each disjunct pattern. ## Test Plan Add new tests to narrow/match.md --------- Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
parent
2e56cd3737
commit
64171744dc
6 changed files with 122 additions and 33 deletions
|
@ -114,3 +114,49 @@ match x:
|
|||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
||||
## Or patterns
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
|
||||
x = get_object()
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
match x:
|
||||
case "foo" | 42 | None:
|
||||
reveal_type(x) # revealed: Literal["foo", 42] | None
|
||||
case "foo" | tuple():
|
||||
reveal_type(x) # revealed: Literal["foo"] | tuple
|
||||
case True | False:
|
||||
reveal_type(x) # revealed: bool
|
||||
case 3.14 | 2.718 | 1.414:
|
||||
reveal_type(x) # revealed: float
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
||||
## Or patterns with guard
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
|
||||
x = get_object()
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
match x:
|
||||
case "foo" | 42 | None if reveal_type(x): # revealed: Literal["foo", 42] | None
|
||||
pass
|
||||
case "foo" | tuple() if reveal_type(x): # revealed: Literal["foo"] | tuple
|
||||
pass
|
||||
case True | False if reveal_type(x): # revealed: bool
|
||||
pass
|
||||
case 3.14 | 2.718 | 1.414 if reveal_type(x): # revealed: float
|
||||
pass
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
|
|
@ -589,6 +589,31 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn predicate_kind(&mut self, pattern: &ast::Pattern) -> PatternPredicateKind<'db> {
|
||||
match pattern {
|
||||
ast::Pattern::MatchValue(pattern) => {
|
||||
let value = self.add_standalone_expression(&pattern.value);
|
||||
PatternPredicateKind::Value(value)
|
||||
}
|
||||
ast::Pattern::MatchSingleton(singleton) => {
|
||||
PatternPredicateKind::Singleton(singleton.value)
|
||||
}
|
||||
ast::Pattern::MatchClass(pattern) => {
|
||||
let cls = self.add_standalone_expression(&pattern.cls);
|
||||
PatternPredicateKind::Class(cls)
|
||||
}
|
||||
ast::Pattern::MatchOr(pattern) => {
|
||||
let predicates = pattern
|
||||
.patterns
|
||||
.iter()
|
||||
.map(|pattern| self.predicate_kind(pattern))
|
||||
.collect();
|
||||
PatternPredicateKind::Or(predicates)
|
||||
}
|
||||
_ => PatternPredicateKind::Unsupported,
|
||||
}
|
||||
}
|
||||
|
||||
fn add_pattern_narrowing_constraint(
|
||||
&mut self,
|
||||
subject: Expression<'db>,
|
||||
|
@ -606,29 +631,16 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
//
|
||||
// See the comment in TypeInferenceBuilder::infer_match_pattern for more details.
|
||||
|
||||
let kind = self.predicate_kind(pattern);
|
||||
let guard = guard.map(|guard| self.add_standalone_expression(guard));
|
||||
|
||||
let kind = match pattern {
|
||||
ast::Pattern::MatchValue(pattern) => {
|
||||
let value = self.add_standalone_expression(&pattern.value);
|
||||
PatternPredicateKind::Value(value, guard)
|
||||
}
|
||||
ast::Pattern::MatchSingleton(singleton) => {
|
||||
PatternPredicateKind::Singleton(singleton.value, guard)
|
||||
}
|
||||
ast::Pattern::MatchClass(pattern) => {
|
||||
let cls = self.add_standalone_expression(&pattern.cls);
|
||||
PatternPredicateKind::Class(cls, guard)
|
||||
}
|
||||
_ => PatternPredicateKind::Unsupported,
|
||||
};
|
||||
|
||||
let pattern_predicate = PatternPredicate::new(
|
||||
self.db,
|
||||
self.file,
|
||||
self.current_scope(),
|
||||
subject,
|
||||
kind,
|
||||
guard,
|
||||
countme::Count::default(),
|
||||
);
|
||||
let predicate = Predicate {
|
||||
|
|
|
@ -57,9 +57,10 @@ pub(crate) enum PredicateNode<'db> {
|
|||
/// Pattern kinds for which we support type narrowing and/or static visibility analysis.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, salsa::Update)]
|
||||
pub(crate) enum PatternPredicateKind<'db> {
|
||||
Singleton(Singleton, Option<Expression<'db>>),
|
||||
Value(Expression<'db>, Option<Expression<'db>>),
|
||||
Class(Expression<'db>, Option<Expression<'db>>),
|
||||
Singleton(Singleton),
|
||||
Value(Expression<'db>),
|
||||
Or(Vec<PatternPredicateKind<'db>>),
|
||||
Class(Expression<'db>),
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
|
@ -74,6 +75,8 @@ pub(crate) struct PatternPredicate<'db> {
|
|||
#[return_ref]
|
||||
pub(crate) kind: PatternPredicateKind<'db>,
|
||||
|
||||
pub(crate) guard: Option<Expression<'db>>,
|
||||
|
||||
count: countme::Count<PatternPredicate<'static>>,
|
||||
}
|
||||
|
||||
|
|
|
@ -560,7 +560,7 @@ impl VisibilityConstraints {
|
|||
ty.bool(db).negate_if(!predicate.is_positive)
|
||||
}
|
||||
PredicateNode::Pattern(inner) => match inner.kind(db) {
|
||||
PatternPredicateKind::Value(value, guard) => {
|
||||
PatternPredicateKind::Value(value) => {
|
||||
let subject_expression = inner.subject(db);
|
||||
let subject_ty = infer_expression_type(db, subject_expression);
|
||||
let value_ty = infer_expression_type(db, *value);
|
||||
|
@ -569,7 +569,7 @@ impl VisibilityConstraints {
|
|||
let truthiness =
|
||||
Truthiness::from(subject_ty.is_equivalent_to(db, value_ty));
|
||||
|
||||
if truthiness.is_always_true() && guard.is_some() {
|
||||
if truthiness.is_always_true() && inner.guard(db).is_some() {
|
||||
// Fall back to ambiguous, the guard might change the result.
|
||||
Truthiness::Ambiguous
|
||||
} else {
|
||||
|
@ -581,6 +581,7 @@ impl VisibilityConstraints {
|
|||
}
|
||||
PatternPredicateKind::Singleton(..)
|
||||
| PatternPredicateKind::Class(..)
|
||||
| PatternPredicateKind::Or(..)
|
||||
| PatternPredicateKind::Unsupported => Truthiness::Ambiguous,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -2119,6 +2119,11 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
self.infer_standalone_expression(cls);
|
||||
}
|
||||
ast::Pattern::MatchOr(match_or) => {
|
||||
for pattern in &match_or.patterns {
|
||||
self.infer_match_pattern(pattern);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
self.infer_nested_match_pattern(pattern);
|
||||
}
|
||||
|
|
|
@ -225,25 +225,31 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn evaluate_pattern_predicate_kind(
|
||||
&mut self,
|
||||
pattern_predicate_kind: &PatternPredicateKind<'db>,
|
||||
subject: Expression<'db>,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
match pattern_predicate_kind {
|
||||
PatternPredicateKind::Singleton(singleton) => {
|
||||
self.evaluate_match_pattern_singleton(subject, *singleton)
|
||||
}
|
||||
PatternPredicateKind::Class(cls) => self.evaluate_match_pattern_class(subject, *cls),
|
||||
PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr),
|
||||
PatternPredicateKind::Or(predicates) => {
|
||||
self.evaluate_match_pattern_or(subject, predicates)
|
||||
}
|
||||
PatternPredicateKind::Unsupported => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_pattern_predicate(
|
||||
&mut self,
|
||||
pattern: PatternPredicate<'db>,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let subject = pattern.subject(self.db);
|
||||
|
||||
match pattern.kind(self.db) {
|
||||
PatternPredicateKind::Singleton(singleton, _guard) => {
|
||||
self.evaluate_match_pattern_singleton(subject, *singleton)
|
||||
}
|
||||
PatternPredicateKind::Class(cls, _guard) => {
|
||||
self.evaluate_match_pattern_class(subject, *cls)
|
||||
}
|
||||
PatternPredicateKind::Value(expr, _guard) => {
|
||||
self.evaluate_match_pattern_value(subject, *expr)
|
||||
}
|
||||
// TODO: support more pattern kinds
|
||||
PatternPredicateKind::Unsupported => None,
|
||||
}
|
||||
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject)
|
||||
}
|
||||
|
||||
fn symbols(&self) -> Arc<SymbolTable> {
|
||||
|
@ -505,6 +511,22 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
|
||||
}
|
||||
|
||||
fn evaluate_match_pattern_or(
|
||||
&mut self,
|
||||
subject: Expression<'db>,
|
||||
predicates: &Vec<PatternPredicateKind<'db>>,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let db = self.db;
|
||||
|
||||
predicates
|
||||
.iter()
|
||||
.filter_map(|predicate| self.evaluate_pattern_predicate_kind(predicate, subject))
|
||||
.reduce(|mut constraints, constraints_| {
|
||||
merge_constraints_or(&mut constraints, &constraints_, db);
|
||||
constraints
|
||||
})
|
||||
}
|
||||
|
||||
fn evaluate_bool_op(
|
||||
&mut self,
|
||||
expr_bool_op: &ExprBoolOp,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue