mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 21:35:58 +00:00
Narrowing for class patterns in match statements (#15223)
We now support class patterns in a match statement, adding a narrowing constraint that within the body of that match arm, we can assume that the subject is an instance of that class. --------- Co-authored-by: Carl Meyer <carl@astral.sh> Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
parent
f2a86fcfda
commit
b2a0d68d70
6 changed files with 151 additions and 43 deletions
|
@ -16,3 +16,48 @@ def _(flag: bool):
|
||||||
|
|
||||||
reveal_type(y) # revealed: Literal[0] | None
|
reveal_type(y) # revealed: Literal[0] | None
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Class patterns
|
||||||
|
|
||||||
|
```py
|
||||||
|
def get_object() -> object: ...
|
||||||
|
|
||||||
|
class A: ...
|
||||||
|
class B: ...
|
||||||
|
|
||||||
|
x = get_object()
|
||||||
|
|
||||||
|
reveal_type(x) # revealed: object
|
||||||
|
|
||||||
|
match x:
|
||||||
|
case A():
|
||||||
|
reveal_type(x) # revealed: A
|
||||||
|
case B():
|
||||||
|
# TODO could be `B & ~A`
|
||||||
|
reveal_type(x) # revealed: B
|
||||||
|
|
||||||
|
reveal_type(x) # revealed: object
|
||||||
|
```
|
||||||
|
|
||||||
|
## Class pattern with guard
|
||||||
|
|
||||||
|
```py
|
||||||
|
def get_object() -> object: ...
|
||||||
|
|
||||||
|
class A:
|
||||||
|
def y() -> int: ...
|
||||||
|
|
||||||
|
class B: ...
|
||||||
|
|
||||||
|
x = get_object()
|
||||||
|
|
||||||
|
reveal_type(x) # revealed: object
|
||||||
|
|
||||||
|
match x:
|
||||||
|
case A() if reveal_type(x): # revealed: A
|
||||||
|
pass
|
||||||
|
case B() if reveal_type(x): # revealed: B
|
||||||
|
pass
|
||||||
|
|
||||||
|
reveal_type(x) # revealed: object
|
||||||
|
```
|
||||||
|
|
|
@ -404,6 +404,17 @@ impl<'db> SemanticIndexBuilder<'db> {
|
||||||
pattern: &ast::Pattern,
|
pattern: &ast::Pattern,
|
||||||
guard: Option<&ast::Expr>,
|
guard: Option<&ast::Expr>,
|
||||||
) -> Constraint<'db> {
|
) -> Constraint<'db> {
|
||||||
|
// This is called for the top-level pattern of each match arm. We need to create a
|
||||||
|
// standalone expression for each arm of a match statement, since they can introduce
|
||||||
|
// constraints on the match subject. (Or more accurately, for the match arm's pattern,
|
||||||
|
// since its the pattern that introduces any constraints, not the body.) Ideally, that
|
||||||
|
// standalone expression would wrap the match arm's pattern as a whole. But a standalone
|
||||||
|
// expression can currently only wrap an ast::Expr, which patterns are not. So, we need to
|
||||||
|
// choose an Expr that can “stand in” for the pattern, which we can wrap in a standalone
|
||||||
|
// expression.
|
||||||
|
//
|
||||||
|
// See the comment in TypeInferenceBuilder::infer_match_pattern for more details.
|
||||||
|
|
||||||
let guard = guard.map(|guard| self.add_standalone_expression(guard));
|
let guard = guard.map(|guard| self.add_standalone_expression(guard));
|
||||||
|
|
||||||
let kind = match pattern {
|
let kind = match pattern {
|
||||||
|
@ -414,6 +425,10 @@ impl<'db> SemanticIndexBuilder<'db> {
|
||||||
ast::Pattern::MatchSingleton(singleton) => {
|
ast::Pattern::MatchSingleton(singleton) => {
|
||||||
PatternConstraintKind::Singleton(singleton.value, guard)
|
PatternConstraintKind::Singleton(singleton.value, guard)
|
||||||
}
|
}
|
||||||
|
ast::Pattern::MatchClass(pattern) => {
|
||||||
|
let cls = self.add_standalone_expression(&pattern.cls);
|
||||||
|
PatternConstraintKind::Class(cls, guard)
|
||||||
|
}
|
||||||
_ => PatternConstraintKind::Unsupported,
|
_ => PatternConstraintKind::Unsupported,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1089,37 +1104,35 @@ where
|
||||||
cases,
|
cases,
|
||||||
range: _,
|
range: _,
|
||||||
}) => {
|
}) => {
|
||||||
|
debug_assert_eq!(self.current_match_case, None);
|
||||||
|
|
||||||
let subject_expr = self.add_standalone_expression(subject);
|
let subject_expr = self.add_standalone_expression(subject);
|
||||||
self.visit_expr(subject);
|
self.visit_expr(subject);
|
||||||
|
if cases.is_empty() {
|
||||||
let after_subject = self.flow_snapshot();
|
|
||||||
let Some((first, remaining)) = cases.split_first() else {
|
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
let first_constraint_id = self.add_pattern_constraint(
|
let after_subject = self.flow_snapshot();
|
||||||
subject_expr,
|
let mut vis_constraints = vec![];
|
||||||
&first.pattern,
|
|
||||||
first.guard.as_deref(),
|
|
||||||
);
|
|
||||||
|
|
||||||
self.visit_match_case(first);
|
|
||||||
|
|
||||||
let first_vis_constraint_id =
|
|
||||||
self.record_visibility_constraint(first_constraint_id);
|
|
||||||
let mut vis_constraints = vec![first_vis_constraint_id];
|
|
||||||
|
|
||||||
let mut post_case_snapshots = vec![];
|
let mut post_case_snapshots = vec![];
|
||||||
for case in remaining {
|
for (i, case) in cases.iter().enumerate() {
|
||||||
|
if i != 0 {
|
||||||
post_case_snapshots.push(self.flow_snapshot());
|
post_case_snapshots.push(self.flow_snapshot());
|
||||||
self.flow_restore(after_subject.clone());
|
self.flow_restore(after_subject.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.current_match_case = Some(CurrentMatchCase::new(&case.pattern));
|
||||||
|
self.visit_pattern(&case.pattern);
|
||||||
|
self.current_match_case = None;
|
||||||
let constraint_id = self.add_pattern_constraint(
|
let constraint_id = self.add_pattern_constraint(
|
||||||
subject_expr,
|
subject_expr,
|
||||||
&case.pattern,
|
&case.pattern,
|
||||||
case.guard.as_deref(),
|
case.guard.as_deref(),
|
||||||
);
|
);
|
||||||
self.visit_match_case(case);
|
if let Some(expr) = &case.guard {
|
||||||
|
self.visit_expr(expr);
|
||||||
|
}
|
||||||
|
self.visit_body(&case.body);
|
||||||
for id in &vis_constraints {
|
for id in &vis_constraints {
|
||||||
self.record_negated_visibility_constraint(*id);
|
self.record_negated_visibility_constraint(*id);
|
||||||
}
|
}
|
||||||
|
@ -1538,18 +1551,6 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_match_case(&mut self, match_case: &'ast ast::MatchCase) {
|
|
||||||
debug_assert!(self.current_match_case.is_none());
|
|
||||||
self.current_match_case = Some(CurrentMatchCase::new(&match_case.pattern));
|
|
||||||
self.visit_pattern(&match_case.pattern);
|
|
||||||
self.current_match_case = None;
|
|
||||||
|
|
||||||
if let Some(expr) = &match_case.guard {
|
|
||||||
self.visit_expr(expr);
|
|
||||||
}
|
|
||||||
self.visit_body(&match_case.body);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
||||||
if let ast::Pattern::MatchStar(ast::PatternMatchStar {
|
if let ast::Pattern::MatchStar(ast::PatternMatchStar {
|
||||||
name: Some(name),
|
name: Some(name),
|
||||||
|
@ -1636,6 +1637,7 @@ impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
struct CurrentMatchCase<'a> {
|
struct CurrentMatchCase<'a> {
|
||||||
/// The pattern that's part of the current match case.
|
/// The pattern that's part of the current match case.
|
||||||
pattern: &'a ast::Pattern,
|
pattern: &'a ast::Pattern,
|
||||||
|
|
|
@ -22,6 +22,7 @@ pub(crate) enum ConstraintNode<'db> {
|
||||||
pub(crate) enum PatternConstraintKind<'db> {
|
pub(crate) enum PatternConstraintKind<'db> {
|
||||||
Singleton(Singleton, Option<Expression<'db>>),
|
Singleton(Singleton, Option<Expression<'db>>),
|
||||||
Value(Expression<'db>, Option<Expression<'db>>),
|
Value(Expression<'db>, Option<Expression<'db>>),
|
||||||
|
Class(Expression<'db>, Option<Expression<'db>>),
|
||||||
Unsupported,
|
Unsupported,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1780,26 +1780,62 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_match_pattern(&mut self, pattern: &ast::Pattern) {
|
fn infer_match_pattern(&mut self, pattern: &ast::Pattern) {
|
||||||
|
// We need to create a standalone expression for each arm of a match statement, since they
|
||||||
|
// can introduce constraints on the match subject. (Or more accurately, for the match arm's
|
||||||
|
// pattern, since its the pattern that introduces any constraints, not the body.) Ideally,
|
||||||
|
// that standalone expression would wrap the match arm's pattern as a whole. But a
|
||||||
|
// standalone expression can currently only wrap an ast::Expr, which patterns are not. So,
|
||||||
|
// we need to choose an Expr that can “stand in” for the pattern, which we can wrap in a
|
||||||
|
// standalone expression.
|
||||||
|
//
|
||||||
|
// That said, when inferring the type of a standalone expression, we don't have access to
|
||||||
|
// its parent or sibling nodes. That means, for instance, that in a class pattern, where
|
||||||
|
// we are currently using the class name as the standalone expression, we do not have
|
||||||
|
// access to the class pattern's arguments in the standalone expression inference scope.
|
||||||
|
// At the moment, we aren't trying to do anything with those arguments when creating a
|
||||||
|
// narrowing constraint for the pattern. But in the future, if we do, we will have to
|
||||||
|
// either wrap those arguments in their own standalone expressions, or update Expression to
|
||||||
|
// be able to wrap other AST node types besides just ast::Expr.
|
||||||
|
//
|
||||||
|
// This function is only called for the top-level pattern of a match arm, and is
|
||||||
|
// responsible for inferring the standalone expression for each supported pattern type. It
|
||||||
|
// then hands off to `infer_nested_match_pattern` for any subexpressions and subpatterns,
|
||||||
|
// where we do NOT have any additional standalone expressions to infer through.
|
||||||
|
//
|
||||||
// TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against
|
// TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against
|
||||||
// the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510
|
// the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510
|
||||||
match pattern {
|
match pattern {
|
||||||
ast::Pattern::MatchValue(match_value) => {
|
ast::Pattern::MatchValue(match_value) => {
|
||||||
self.infer_standalone_expression(&match_value.value);
|
self.infer_standalone_expression(&match_value.value);
|
||||||
}
|
}
|
||||||
|
ast::Pattern::MatchClass(match_class) => {
|
||||||
|
let ast::PatternMatchClass {
|
||||||
|
range: _,
|
||||||
|
cls,
|
||||||
|
arguments,
|
||||||
|
} = match_class;
|
||||||
|
for pattern in &arguments.patterns {
|
||||||
|
self.infer_nested_match_pattern(pattern);
|
||||||
|
}
|
||||||
|
for keyword in &arguments.keywords {
|
||||||
|
self.infer_nested_match_pattern(&keyword.pattern);
|
||||||
|
}
|
||||||
|
self.infer_standalone_expression(cls);
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
self.infer_match_pattern_impl(pattern);
|
self.infer_nested_match_pattern(pattern);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_match_pattern_impl(&mut self, pattern: &ast::Pattern) {
|
fn infer_nested_match_pattern(&mut self, pattern: &ast::Pattern) {
|
||||||
match pattern {
|
match pattern {
|
||||||
ast::Pattern::MatchValue(match_value) => {
|
ast::Pattern::MatchValue(match_value) => {
|
||||||
self.infer_expression(&match_value.value);
|
self.infer_expression(&match_value.value);
|
||||||
}
|
}
|
||||||
ast::Pattern::MatchSequence(match_sequence) => {
|
ast::Pattern::MatchSequence(match_sequence) => {
|
||||||
for pattern in &match_sequence.patterns {
|
for pattern in &match_sequence.patterns {
|
||||||
self.infer_match_pattern_impl(pattern);
|
self.infer_nested_match_pattern(pattern);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::Pattern::MatchMapping(match_mapping) => {
|
ast::Pattern::MatchMapping(match_mapping) => {
|
||||||
|
@ -1813,7 +1849,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
self.infer_expression(key);
|
self.infer_expression(key);
|
||||||
}
|
}
|
||||||
for pattern in patterns {
|
for pattern in patterns {
|
||||||
self.infer_match_pattern_impl(pattern);
|
self.infer_nested_match_pattern(pattern);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::Pattern::MatchClass(match_class) => {
|
ast::Pattern::MatchClass(match_class) => {
|
||||||
|
@ -1823,21 +1859,21 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
arguments,
|
arguments,
|
||||||
} = match_class;
|
} = match_class;
|
||||||
for pattern in &arguments.patterns {
|
for pattern in &arguments.patterns {
|
||||||
self.infer_match_pattern_impl(pattern);
|
self.infer_nested_match_pattern(pattern);
|
||||||
}
|
}
|
||||||
for keyword in &arguments.keywords {
|
for keyword in &arguments.keywords {
|
||||||
self.infer_match_pattern_impl(&keyword.pattern);
|
self.infer_nested_match_pattern(&keyword.pattern);
|
||||||
}
|
}
|
||||||
self.infer_expression(cls);
|
self.infer_expression(cls);
|
||||||
}
|
}
|
||||||
ast::Pattern::MatchAs(match_as) => {
|
ast::Pattern::MatchAs(match_as) => {
|
||||||
if let Some(pattern) = &match_as.pattern {
|
if let Some(pattern) = &match_as.pattern {
|
||||||
self.infer_match_pattern_impl(pattern);
|
self.infer_nested_match_pattern(pattern);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::Pattern::MatchOr(match_or) => {
|
ast::Pattern::MatchOr(match_or) => {
|
||||||
for pattern in &match_or.patterns {
|
for pattern in &match_or.patterns {
|
||||||
self.infer_match_pattern_impl(pattern);
|
self.infer_nested_match_pattern(pattern);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::Pattern::MatchStar(_) | ast::Pattern::MatchSingleton(_) => {}
|
ast::Pattern::MatchStar(_) | ast::Pattern::MatchSingleton(_) => {}
|
||||||
|
|
|
@ -233,6 +233,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
||||||
PatternConstraintKind::Singleton(singleton, _guard) => {
|
PatternConstraintKind::Singleton(singleton, _guard) => {
|
||||||
self.evaluate_match_pattern_singleton(*subject, *singleton)
|
self.evaluate_match_pattern_singleton(*subject, *singleton)
|
||||||
}
|
}
|
||||||
|
PatternConstraintKind::Class(cls, _guard) => {
|
||||||
|
self.evaluate_match_pattern_class(*subject, *cls)
|
||||||
|
}
|
||||||
// TODO: support more pattern kinds
|
// TODO: support more pattern kinds
|
||||||
PatternConstraintKind::Value(..) | PatternConstraintKind::Unsupported => None,
|
PatternConstraintKind::Value(..) | PatternConstraintKind::Unsupported => None,
|
||||||
}
|
}
|
||||||
|
@ -486,6 +489,27 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn evaluate_match_pattern_class(
|
||||||
|
&mut self,
|
||||||
|
subject: Expression<'db>,
|
||||||
|
cls: Expression<'db>,
|
||||||
|
) -> Option<NarrowingConstraints<'db>> {
|
||||||
|
if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() {
|
||||||
|
// SAFETY: we should always have a symbol for every Name node.
|
||||||
|
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||||
|
let scope = self.scope();
|
||||||
|
let inference = infer_expression_types(self.db, cls);
|
||||||
|
let ty = inference
|
||||||
|
.expression_ty(cls.node_ref(self.db).scoped_expression_id(self.db, scope))
|
||||||
|
.to_instance(self.db);
|
||||||
|
let mut constraints = NarrowingConstraints::default();
|
||||||
|
constraints.insert(symbol, ty);
|
||||||
|
Some(constraints)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn evaluate_bool_op(
|
fn evaluate_bool_op(
|
||||||
&mut self,
|
&mut self,
|
||||||
expr_bool_op: &ExprBoolOp,
|
expr_bool_op: &ExprBoolOp,
|
||||||
|
|
|
@ -329,9 +329,9 @@ impl<'db> VisibilityConstraints<'db> {
|
||||||
Truthiness::Ambiguous
|
Truthiness::Ambiguous
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
PatternConstraintKind::Singleton(..) | PatternConstraintKind::Unsupported => {
|
PatternConstraintKind::Singleton(..)
|
||||||
Truthiness::Ambiguous
|
| PatternConstraintKind::Class(..)
|
||||||
}
|
| PatternConstraintKind::Unsupported => Truthiness::Ambiguous,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue