Narrowing for class patterns in match statements (#15223)

We now support class patterns in a match statement, adding a narrowing
constraint that within the body of that match arm, we can assume that
the subject is an instance of that class.

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
Douglas Creager 2025-01-07 15:58:12 -05:00 committed by GitHub
parent f2a86fcfda
commit b2a0d68d70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 151 additions and 43 deletions

View file

@ -16,3 +16,48 @@ def _(flag: bool):
reveal_type(y) # revealed: Literal[0] | None reveal_type(y) # revealed: Literal[0] | None
``` ```
## Class patterns
```py
def get_object() -> object: ...
class A: ...
class B: ...
x = get_object()
reveal_type(x) # revealed: object
match x:
case A():
reveal_type(x) # revealed: A
case B():
# TODO could be `B & ~A`
reveal_type(x) # revealed: B
reveal_type(x) # revealed: object
```
## Class pattern with guard
```py
def get_object() -> object: ...
class A:
def y() -> int: ...
class B: ...
x = get_object()
reveal_type(x) # revealed: object
match x:
case A() if reveal_type(x): # revealed: A
pass
case B() if reveal_type(x): # revealed: B
pass
reveal_type(x) # revealed: object
```

View file

@ -404,6 +404,17 @@ impl<'db> SemanticIndexBuilder<'db> {
pattern: &ast::Pattern, pattern: &ast::Pattern,
guard: Option<&ast::Expr>, guard: Option<&ast::Expr>,
) -> Constraint<'db> { ) -> Constraint<'db> {
// This is called for the top-level pattern of each match arm. We need to create a
// standalone expression for each arm of a match statement, since they can introduce
// constraints on the match subject. (Or more accurately, for the match arm's pattern,
// since its the pattern that introduces any constraints, not the body.) Ideally, that
// standalone expression would wrap the match arm's pattern as a whole. But a standalone
// expression can currently only wrap an ast::Expr, which patterns are not. So, we need to
// choose an Expr that can “stand in” for the pattern, which we can wrap in a standalone
// expression.
//
// See the comment in TypeInferenceBuilder::infer_match_pattern for more details.
let guard = guard.map(|guard| self.add_standalone_expression(guard)); let guard = guard.map(|guard| self.add_standalone_expression(guard));
let kind = match pattern { let kind = match pattern {
@ -414,6 +425,10 @@ impl<'db> SemanticIndexBuilder<'db> {
ast::Pattern::MatchSingleton(singleton) => { ast::Pattern::MatchSingleton(singleton) => {
PatternConstraintKind::Singleton(singleton.value, guard) PatternConstraintKind::Singleton(singleton.value, guard)
} }
ast::Pattern::MatchClass(pattern) => {
let cls = self.add_standalone_expression(&pattern.cls);
PatternConstraintKind::Class(cls, guard)
}
_ => PatternConstraintKind::Unsupported, _ => PatternConstraintKind::Unsupported,
}; };
@ -1089,37 +1104,35 @@ where
cases, cases,
range: _, range: _,
}) => { }) => {
debug_assert_eq!(self.current_match_case, None);
let subject_expr = self.add_standalone_expression(subject); let subject_expr = self.add_standalone_expression(subject);
self.visit_expr(subject); self.visit_expr(subject);
if cases.is_empty() {
let after_subject = self.flow_snapshot();
let Some((first, remaining)) = cases.split_first() else {
return; return;
}; };
let first_constraint_id = self.add_pattern_constraint( let after_subject = self.flow_snapshot();
subject_expr, let mut vis_constraints = vec![];
&first.pattern,
first.guard.as_deref(),
);
self.visit_match_case(first);
let first_vis_constraint_id =
self.record_visibility_constraint(first_constraint_id);
let mut vis_constraints = vec![first_vis_constraint_id];
let mut post_case_snapshots = vec![]; let mut post_case_snapshots = vec![];
for case in remaining { for (i, case) in cases.iter().enumerate() {
post_case_snapshots.push(self.flow_snapshot()); if i != 0 {
self.flow_restore(after_subject.clone()); post_case_snapshots.push(self.flow_snapshot());
self.flow_restore(after_subject.clone());
}
self.current_match_case = Some(CurrentMatchCase::new(&case.pattern));
self.visit_pattern(&case.pattern);
self.current_match_case = None;
let constraint_id = self.add_pattern_constraint( let constraint_id = self.add_pattern_constraint(
subject_expr, subject_expr,
&case.pattern, &case.pattern,
case.guard.as_deref(), case.guard.as_deref(),
); );
self.visit_match_case(case); if let Some(expr) = &case.guard {
self.visit_expr(expr);
}
self.visit_body(&case.body);
for id in &vis_constraints { for id in &vis_constraints {
self.record_negated_visibility_constraint(*id); self.record_negated_visibility_constraint(*id);
} }
@ -1538,18 +1551,6 @@ where
} }
} }
fn visit_match_case(&mut self, match_case: &'ast ast::MatchCase) {
debug_assert!(self.current_match_case.is_none());
self.current_match_case = Some(CurrentMatchCase::new(&match_case.pattern));
self.visit_pattern(&match_case.pattern);
self.current_match_case = None;
if let Some(expr) = &match_case.guard {
self.visit_expr(expr);
}
self.visit_body(&match_case.body);
}
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) { fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
if let ast::Pattern::MatchStar(ast::PatternMatchStar { if let ast::Pattern::MatchStar(ast::PatternMatchStar {
name: Some(name), name: Some(name),
@ -1636,6 +1637,7 @@ impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
} }
} }
#[derive(Debug, PartialEq)]
struct CurrentMatchCase<'a> { struct CurrentMatchCase<'a> {
/// The pattern that's part of the current match case. /// The pattern that's part of the current match case.
pattern: &'a ast::Pattern, pattern: &'a ast::Pattern,

View file

@ -22,6 +22,7 @@ pub(crate) enum ConstraintNode<'db> {
pub(crate) enum PatternConstraintKind<'db> { pub(crate) enum PatternConstraintKind<'db> {
Singleton(Singleton, Option<Expression<'db>>), Singleton(Singleton, Option<Expression<'db>>),
Value(Expression<'db>, Option<Expression<'db>>), Value(Expression<'db>, Option<Expression<'db>>),
Class(Expression<'db>, Option<Expression<'db>>),
Unsupported, Unsupported,
} }

View file

@ -1780,26 +1780,62 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
fn infer_match_pattern(&mut self, pattern: &ast::Pattern) { fn infer_match_pattern(&mut self, pattern: &ast::Pattern) {
// We need to create a standalone expression for each arm of a match statement, since they
// can introduce constraints on the match subject. (Or more accurately, for the match arm's
// pattern, since its the pattern that introduces any constraints, not the body.) Ideally,
// that standalone expression would wrap the match arm's pattern as a whole. But a
// standalone expression can currently only wrap an ast::Expr, which patterns are not. So,
// we need to choose an Expr that can “stand in” for the pattern, which we can wrap in a
// standalone expression.
//
// That said, when inferring the type of a standalone expression, we don't have access to
// its parent or sibling nodes. That means, for instance, that in a class pattern, where
// we are currently using the class name as the standalone expression, we do not have
// access to the class pattern's arguments in the standalone expression inference scope.
// At the moment, we aren't trying to do anything with those arguments when creating a
// narrowing constraint for the pattern. But in the future, if we do, we will have to
// either wrap those arguments in their own standalone expressions, or update Expression to
// be able to wrap other AST node types besides just ast::Expr.
//
// This function is only called for the top-level pattern of a match arm, and is
// responsible for inferring the standalone expression for each supported pattern type. It
// then hands off to `infer_nested_match_pattern` for any subexpressions and subpatterns,
// where we do NOT have any additional standalone expressions to infer through.
//
// TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against // TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against
// the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510 // the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510
match pattern { match pattern {
ast::Pattern::MatchValue(match_value) => { ast::Pattern::MatchValue(match_value) => {
self.infer_standalone_expression(&match_value.value); self.infer_standalone_expression(&match_value.value);
} }
ast::Pattern::MatchClass(match_class) => {
let ast::PatternMatchClass {
range: _,
cls,
arguments,
} = match_class;
for pattern in &arguments.patterns {
self.infer_nested_match_pattern(pattern);
}
for keyword in &arguments.keywords {
self.infer_nested_match_pattern(&keyword.pattern);
}
self.infer_standalone_expression(cls);
}
_ => { _ => {
self.infer_match_pattern_impl(pattern); self.infer_nested_match_pattern(pattern);
} }
} }
} }
fn infer_match_pattern_impl(&mut self, pattern: &ast::Pattern) { fn infer_nested_match_pattern(&mut self, pattern: &ast::Pattern) {
match pattern { match pattern {
ast::Pattern::MatchValue(match_value) => { ast::Pattern::MatchValue(match_value) => {
self.infer_expression(&match_value.value); self.infer_expression(&match_value.value);
} }
ast::Pattern::MatchSequence(match_sequence) => { ast::Pattern::MatchSequence(match_sequence) => {
for pattern in &match_sequence.patterns { for pattern in &match_sequence.patterns {
self.infer_match_pattern_impl(pattern); self.infer_nested_match_pattern(pattern);
} }
} }
ast::Pattern::MatchMapping(match_mapping) => { ast::Pattern::MatchMapping(match_mapping) => {
@ -1813,7 +1849,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_expression(key); self.infer_expression(key);
} }
for pattern in patterns { for pattern in patterns {
self.infer_match_pattern_impl(pattern); self.infer_nested_match_pattern(pattern);
} }
} }
ast::Pattern::MatchClass(match_class) => { ast::Pattern::MatchClass(match_class) => {
@ -1823,21 +1859,21 @@ impl<'db> TypeInferenceBuilder<'db> {
arguments, arguments,
} = match_class; } = match_class;
for pattern in &arguments.patterns { for pattern in &arguments.patterns {
self.infer_match_pattern_impl(pattern); self.infer_nested_match_pattern(pattern);
} }
for keyword in &arguments.keywords { for keyword in &arguments.keywords {
self.infer_match_pattern_impl(&keyword.pattern); self.infer_nested_match_pattern(&keyword.pattern);
} }
self.infer_expression(cls); self.infer_expression(cls);
} }
ast::Pattern::MatchAs(match_as) => { ast::Pattern::MatchAs(match_as) => {
if let Some(pattern) = &match_as.pattern { if let Some(pattern) = &match_as.pattern {
self.infer_match_pattern_impl(pattern); self.infer_nested_match_pattern(pattern);
} }
} }
ast::Pattern::MatchOr(match_or) => { ast::Pattern::MatchOr(match_or) => {
for pattern in &match_or.patterns { for pattern in &match_or.patterns {
self.infer_match_pattern_impl(pattern); self.infer_nested_match_pattern(pattern);
} }
} }
ast::Pattern::MatchStar(_) | ast::Pattern::MatchSingleton(_) => {} ast::Pattern::MatchStar(_) | ast::Pattern::MatchSingleton(_) => {}

View file

@ -233,6 +233,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
PatternConstraintKind::Singleton(singleton, _guard) => { PatternConstraintKind::Singleton(singleton, _guard) => {
self.evaluate_match_pattern_singleton(*subject, *singleton) self.evaluate_match_pattern_singleton(*subject, *singleton)
} }
PatternConstraintKind::Class(cls, _guard) => {
self.evaluate_match_pattern_class(*subject, *cls)
}
// TODO: support more pattern kinds // TODO: support more pattern kinds
PatternConstraintKind::Value(..) | PatternConstraintKind::Unsupported => None, PatternConstraintKind::Value(..) | PatternConstraintKind::Unsupported => None,
} }
@ -486,6 +489,27 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
} }
} }
fn evaluate_match_pattern_class(
&mut self,
subject: Expression<'db>,
cls: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() {
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let scope = self.scope();
let inference = infer_expression_types(self.db, cls);
let ty = inference
.expression_ty(cls.node_ref(self.db).scoped_expression_id(self.db, scope))
.to_instance(self.db);
let mut constraints = NarrowingConstraints::default();
constraints.insert(symbol, ty);
Some(constraints)
} else {
None
}
}
fn evaluate_bool_op( fn evaluate_bool_op(
&mut self, &mut self,
expr_bool_op: &ExprBoolOp, expr_bool_op: &ExprBoolOp,

View file

@ -329,9 +329,9 @@ impl<'db> VisibilityConstraints<'db> {
Truthiness::Ambiguous Truthiness::Ambiguous
} }
} }
PatternConstraintKind::Singleton(..) | PatternConstraintKind::Unsupported => { PatternConstraintKind::Singleton(..)
Truthiness::Ambiguous | PatternConstraintKind::Class(..)
} | PatternConstraintKind::Unsupported => Truthiness::Ambiguous,
}, },
} }
} }