mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 05:15:12 +00:00
Narrowing for class patterns in match statements (#15223)
We now support class patterns in a match statement, adding a narrowing constraint that within the body of that match arm, we can assume that the subject is an instance of that class. --------- Co-authored-by: Carl Meyer <carl@astral.sh> Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
parent
f2a86fcfda
commit
b2a0d68d70
6 changed files with 151 additions and 43 deletions
|
@ -404,6 +404,17 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
pattern: &ast::Pattern,
|
||||
guard: Option<&ast::Expr>,
|
||||
) -> Constraint<'db> {
|
||||
// This is called for the top-level pattern of each match arm. We need to create a
|
||||
// standalone expression for each arm of a match statement, since they can introduce
|
||||
// constraints on the match subject. (Or more accurately, for the match arm's pattern,
|
||||
// since its the pattern that introduces any constraints, not the body.) Ideally, that
|
||||
// standalone expression would wrap the match arm's pattern as a whole. But a standalone
|
||||
// expression can currently only wrap an ast::Expr, which patterns are not. So, we need to
|
||||
// choose an Expr that can “stand in” for the pattern, which we can wrap in a standalone
|
||||
// expression.
|
||||
//
|
||||
// See the comment in TypeInferenceBuilder::infer_match_pattern for more details.
|
||||
|
||||
let guard = guard.map(|guard| self.add_standalone_expression(guard));
|
||||
|
||||
let kind = match pattern {
|
||||
|
@ -414,6 +425,10 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
ast::Pattern::MatchSingleton(singleton) => {
|
||||
PatternConstraintKind::Singleton(singleton.value, guard)
|
||||
}
|
||||
ast::Pattern::MatchClass(pattern) => {
|
||||
let cls = self.add_standalone_expression(&pattern.cls);
|
||||
PatternConstraintKind::Class(cls, guard)
|
||||
}
|
||||
_ => PatternConstraintKind::Unsupported,
|
||||
};
|
||||
|
||||
|
@ -1089,37 +1104,35 @@ where
|
|||
cases,
|
||||
range: _,
|
||||
}) => {
|
||||
debug_assert_eq!(self.current_match_case, None);
|
||||
|
||||
let subject_expr = self.add_standalone_expression(subject);
|
||||
self.visit_expr(subject);
|
||||
|
||||
let after_subject = self.flow_snapshot();
|
||||
let Some((first, remaining)) = cases.split_first() else {
|
||||
if cases.is_empty() {
|
||||
return;
|
||||
};
|
||||
|
||||
let first_constraint_id = self.add_pattern_constraint(
|
||||
subject_expr,
|
||||
&first.pattern,
|
||||
first.guard.as_deref(),
|
||||
);
|
||||
|
||||
self.visit_match_case(first);
|
||||
|
||||
let first_vis_constraint_id =
|
||||
self.record_visibility_constraint(first_constraint_id);
|
||||
let mut vis_constraints = vec![first_vis_constraint_id];
|
||||
|
||||
let after_subject = self.flow_snapshot();
|
||||
let mut vis_constraints = vec![];
|
||||
let mut post_case_snapshots = vec![];
|
||||
for case in remaining {
|
||||
post_case_snapshots.push(self.flow_snapshot());
|
||||
self.flow_restore(after_subject.clone());
|
||||
for (i, case) in cases.iter().enumerate() {
|
||||
if i != 0 {
|
||||
post_case_snapshots.push(self.flow_snapshot());
|
||||
self.flow_restore(after_subject.clone());
|
||||
}
|
||||
|
||||
self.current_match_case = Some(CurrentMatchCase::new(&case.pattern));
|
||||
self.visit_pattern(&case.pattern);
|
||||
self.current_match_case = None;
|
||||
let constraint_id = self.add_pattern_constraint(
|
||||
subject_expr,
|
||||
&case.pattern,
|
||||
case.guard.as_deref(),
|
||||
);
|
||||
self.visit_match_case(case);
|
||||
|
||||
if let Some(expr) = &case.guard {
|
||||
self.visit_expr(expr);
|
||||
}
|
||||
self.visit_body(&case.body);
|
||||
for id in &vis_constraints {
|
||||
self.record_negated_visibility_constraint(*id);
|
||||
}
|
||||
|
@ -1538,18 +1551,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn visit_match_case(&mut self, match_case: &'ast ast::MatchCase) {
|
||||
debug_assert!(self.current_match_case.is_none());
|
||||
self.current_match_case = Some(CurrentMatchCase::new(&match_case.pattern));
|
||||
self.visit_pattern(&match_case.pattern);
|
||||
self.current_match_case = None;
|
||||
|
||||
if let Some(expr) = &match_case.guard {
|
||||
self.visit_expr(expr);
|
||||
}
|
||||
self.visit_body(&match_case.body);
|
||||
}
|
||||
|
||||
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
||||
if let ast::Pattern::MatchStar(ast::PatternMatchStar {
|
||||
name: Some(name),
|
||||
|
@ -1636,6 +1637,7 @@ impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct CurrentMatchCase<'a> {
|
||||
/// The pattern that's part of the current match case.
|
||||
pattern: &'a ast::Pattern,
|
||||
|
|
|
@ -22,6 +22,7 @@ pub(crate) enum ConstraintNode<'db> {
|
|||
pub(crate) enum PatternConstraintKind<'db> {
|
||||
Singleton(Singleton, Option<Expression<'db>>),
|
||||
Value(Expression<'db>, Option<Expression<'db>>),
|
||||
Class(Expression<'db>, Option<Expression<'db>>),
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
|
|
|
@ -1780,26 +1780,62 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
|
||||
fn infer_match_pattern(&mut self, pattern: &ast::Pattern) {
|
||||
// We need to create a standalone expression for each arm of a match statement, since they
|
||||
// can introduce constraints on the match subject. (Or more accurately, for the match arm's
|
||||
// pattern, since its the pattern that introduces any constraints, not the body.) Ideally,
|
||||
// that standalone expression would wrap the match arm's pattern as a whole. But a
|
||||
// standalone expression can currently only wrap an ast::Expr, which patterns are not. So,
|
||||
// we need to choose an Expr that can “stand in” for the pattern, which we can wrap in a
|
||||
// standalone expression.
|
||||
//
|
||||
// That said, when inferring the type of a standalone expression, we don't have access to
|
||||
// its parent or sibling nodes. That means, for instance, that in a class pattern, where
|
||||
// we are currently using the class name as the standalone expression, we do not have
|
||||
// access to the class pattern's arguments in the standalone expression inference scope.
|
||||
// At the moment, we aren't trying to do anything with those arguments when creating a
|
||||
// narrowing constraint for the pattern. But in the future, if we do, we will have to
|
||||
// either wrap those arguments in their own standalone expressions, or update Expression to
|
||||
// be able to wrap other AST node types besides just ast::Expr.
|
||||
//
|
||||
// This function is only called for the top-level pattern of a match arm, and is
|
||||
// responsible for inferring the standalone expression for each supported pattern type. It
|
||||
// then hands off to `infer_nested_match_pattern` for any subexpressions and subpatterns,
|
||||
// where we do NOT have any additional standalone expressions to infer through.
|
||||
//
|
||||
// TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against
|
||||
// the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510
|
||||
match pattern {
|
||||
ast::Pattern::MatchValue(match_value) => {
|
||||
self.infer_standalone_expression(&match_value.value);
|
||||
}
|
||||
ast::Pattern::MatchClass(match_class) => {
|
||||
let ast::PatternMatchClass {
|
||||
range: _,
|
||||
cls,
|
||||
arguments,
|
||||
} = match_class;
|
||||
for pattern in &arguments.patterns {
|
||||
self.infer_nested_match_pattern(pattern);
|
||||
}
|
||||
for keyword in &arguments.keywords {
|
||||
self.infer_nested_match_pattern(&keyword.pattern);
|
||||
}
|
||||
self.infer_standalone_expression(cls);
|
||||
}
|
||||
_ => {
|
||||
self.infer_match_pattern_impl(pattern);
|
||||
self.infer_nested_match_pattern(pattern);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_match_pattern_impl(&mut self, pattern: &ast::Pattern) {
|
||||
fn infer_nested_match_pattern(&mut self, pattern: &ast::Pattern) {
|
||||
match pattern {
|
||||
ast::Pattern::MatchValue(match_value) => {
|
||||
self.infer_expression(&match_value.value);
|
||||
}
|
||||
ast::Pattern::MatchSequence(match_sequence) => {
|
||||
for pattern in &match_sequence.patterns {
|
||||
self.infer_match_pattern_impl(pattern);
|
||||
self.infer_nested_match_pattern(pattern);
|
||||
}
|
||||
}
|
||||
ast::Pattern::MatchMapping(match_mapping) => {
|
||||
|
@ -1813,7 +1849,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
self.infer_expression(key);
|
||||
}
|
||||
for pattern in patterns {
|
||||
self.infer_match_pattern_impl(pattern);
|
||||
self.infer_nested_match_pattern(pattern);
|
||||
}
|
||||
}
|
||||
ast::Pattern::MatchClass(match_class) => {
|
||||
|
@ -1823,21 +1859,21 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
arguments,
|
||||
} = match_class;
|
||||
for pattern in &arguments.patterns {
|
||||
self.infer_match_pattern_impl(pattern);
|
||||
self.infer_nested_match_pattern(pattern);
|
||||
}
|
||||
for keyword in &arguments.keywords {
|
||||
self.infer_match_pattern_impl(&keyword.pattern);
|
||||
self.infer_nested_match_pattern(&keyword.pattern);
|
||||
}
|
||||
self.infer_expression(cls);
|
||||
}
|
||||
ast::Pattern::MatchAs(match_as) => {
|
||||
if let Some(pattern) = &match_as.pattern {
|
||||
self.infer_match_pattern_impl(pattern);
|
||||
self.infer_nested_match_pattern(pattern);
|
||||
}
|
||||
}
|
||||
ast::Pattern::MatchOr(match_or) => {
|
||||
for pattern in &match_or.patterns {
|
||||
self.infer_match_pattern_impl(pattern);
|
||||
self.infer_nested_match_pattern(pattern);
|
||||
}
|
||||
}
|
||||
ast::Pattern::MatchStar(_) | ast::Pattern::MatchSingleton(_) => {}
|
||||
|
|
|
@ -233,6 +233,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
PatternConstraintKind::Singleton(singleton, _guard) => {
|
||||
self.evaluate_match_pattern_singleton(*subject, *singleton)
|
||||
}
|
||||
PatternConstraintKind::Class(cls, _guard) => {
|
||||
self.evaluate_match_pattern_class(*subject, *cls)
|
||||
}
|
||||
// TODO: support more pattern kinds
|
||||
PatternConstraintKind::Value(..) | PatternConstraintKind::Unsupported => None,
|
||||
}
|
||||
|
@ -486,6 +489,27 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn evaluate_match_pattern_class(
|
||||
&mut self,
|
||||
subject: Expression<'db>,
|
||||
cls: Expression<'db>,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() {
|
||||
// SAFETY: we should always have a symbol for every Name node.
|
||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||
let scope = self.scope();
|
||||
let inference = infer_expression_types(self.db, cls);
|
||||
let ty = inference
|
||||
.expression_ty(cls.node_ref(self.db).scoped_expression_id(self.db, scope))
|
||||
.to_instance(self.db);
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
constraints.insert(symbol, ty);
|
||||
Some(constraints)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_bool_op(
|
||||
&mut self,
|
||||
expr_bool_op: &ExprBoolOp,
|
||||
|
|
|
@ -329,9 +329,9 @@ impl<'db> VisibilityConstraints<'db> {
|
|||
Truthiness::Ambiguous
|
||||
}
|
||||
}
|
||||
PatternConstraintKind::Singleton(..) | PatternConstraintKind::Unsupported => {
|
||||
Truthiness::Ambiguous
|
||||
}
|
||||
PatternConstraintKind::Singleton(..)
|
||||
| PatternConstraintKind::Class(..)
|
||||
| PatternConstraintKind::Unsupported => Truthiness::Ambiguous,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue